What objective is STaR optimizing?


STaR: Bootstrapping Reasoning With Reasoning is a cool paper. It presents a method for training a reasoning model to produce better reasoning chains for problems it had trouble solving before.

In this post, I will show that the training objective of STaR is the normal ELBo in latent variable modeling, which has been studied pretty extensively in both QA and NLP in general.

Problem setup

STaR is a generative model that takes word problems xx and produces rationales rr and answers yy:

p(x,r,y)=p(yx,r)p(rx).p(x,r,y) = p(y\mid x, r)p(r\mid x).

Both models on the RHS can be parameterized by the same LLM.

Example from the GSM8K dataset:

Problem Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?

Rationale and Answer Natalia sold 48/2 = <<48/2=24>>24 clips in May. Natalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May. #### 72

Ideally, we would optimize the marginal likelihood (evidence)

logp(yx)=logrp(y,rx).\begin{equation} \log p(y\mid x) = \log \sum_r p(y,r\mid x). \end{equation}

This is intractable, since we have to marginalize (sum) over all rationales rr.

Instead, we can optimize a tractable lower bound on the marginal likelhood (evidence), called the evidence lower bound (ELBo). We introduce a posterior rationalizer, q(rx,y)q(r \mid x,y), that potentially generates rationales given the problem and actual answer. The posterior rationalizer serves as a crutch to bypass the intractable marginalization.

ELBo:=logp(yx)KL[q(rx,y)p(rx,y)]=Eq(rx,y)[logp(yx,r)]KL[q(rx,y)p(rx)].\begin{align} \text{ELBo} :=& \log p(y\mid x) - KL[q(r\mid x,y) || p(r \mid x, y)]\\ =& E_{q(r\mid x,y)}[\log p(y \mid x, r)] - KL[q(r \mid x,y) || p(r \mid x)]. \end{align}

How does this line up with the method presented in STaR?

STaR presents two methods for training the generative model. The first is just by sampling rationales from the generative model, p(rx)p(r\mid x). The second is by employing a rationalizer that gives hints after seeing the answer, q(rx,y)q(r \mid x,y). We will show that both of these can be written as variations of the ELBo.

ELBo (generative model)

Let’s first take a look at sampling rationales from the generative model. We reproduce Equation (1) from STaR here, which describes STaR without the posterior rationalizer (slight difference: we focus on a single datapoint x,yx,y^* and ignore model parameters).

J(x,y)=Ep(r,yx)[1(y=y)].\begin{equation} J(x, y^*) = E_{p(r,y \mid x)}[1(y=y^*)]. \end{equation}

Let’s try to recover this objective by setting the posterior rationalizer q(rx,y)=p(rx)q(r\mid x,y^*) = p(r\mid x) in the normal ELBo:

Eq(rx,y)[logp(yx,r)]KL[q(rx,y)p(rx)]=Ep(rx)[logp(yx,r)]KL[p(rx)p(rx)]=Ep(rx)[logp(yx,r)]=Ep(rx)[logEp(yx,r)1(y=y)]Ep(r,yx)[log1(y=y)].\begin{align} &E_{q(r\mid x,y^*)}[\log p(y^* \mid x, r)] - KL[q(r \mid x,y^*) || p(r \mid x)]\\ &= E_{p(r\mid x)}[\log p(y^* \mid x, r)] - KL[p(r \mid x) || p(r \mid x)]\\ &= E_{p(r\mid x)}[\log p(y^* \mid x, r)]\\ &= E_{p(r\mid x)}[\log E_{p(y\mid x,r)}1(y = y^*)]\\ &\ge E_{p(r,y\mid x)}[\log 1(y = y^*)]. \end{align}

There are two interesting things to note here:

  1. The rewards are log-scaled in this derivation, while they are not in STaR (equation 4 vs 9).
  2. Pulling out the expectation over yy (to get equation 9) results in an even looser lower bound on the ELBo, due to another application of Jensen’s inequality.

In general, pulling out the expectation through the log results in additional bias and applying a Monte Carlo approximation of the expectation results in additional variance. It’s possible that ignoring the rationales that do not lead to correct sampled answers counteracts this additional bias and variance in helpful ways.

ELBo (rationalizer)

When employing the rationalizer q(rx,y)q(r \mid x,y^*) that sees true answers before producing rationales, the objective should be the exact ELBo presented earlier,

Eq(rx,y)[logp(yx,r)]KL[q(rx,y)p(rx)].\begin{equation} E_{q(r\mid x,y^*)}[\log p(y^* \mid x, r)] - KL[q(r \mid x,y^*) || p(r \mid x)]. \end{equation}

Is this what STaR optimizes? The short answer is yes, with the same caveats as the previous approach. Namely, the expectation wrt yy is pulled out, resulting in only rationales where the sampled y=yy = y^* being trained on, as opposed to weighting by p(yx,r)p(y^* \mid x, r).

Let’s translate the STaR pseudocode from Algorithm 1 in their paper and point out what each step corresponds to. At each iteration of STaR,

r,yp(r,yx)[Sample from prior]r,yq(rx,y)[Sample from rationalizer]D={(r,y)y=y}{(r,y)yy,y=y}[Filter out incorrect rationales]Then train on D\begin{align} r, y &\sim p(r,y\mid x) & \text{[Sample from prior]}\\ r', y' &\sim q(r\mid x,y^*) & \text{[Sample from rationalizer]}\\ D &= \{(r,y) | y = y^*\}\cup\{(r',y')| y\ne y^*, y'=y^*\} & \text{[Filter out incorrect rationales]}\\ &\text{Then train on }D \end{align}

This corresponds roughly to setting

qstar(rx,y)y1(y=y)p(yx,r)p(rx)+1(yy)q(rx,y).\begin{equation} q_{\text{star}}(r\mid x,y^*) \propto \sum_y 1(y=y^*) p(y|x,r) p(r\mid x) + 1(y\ne y^*) q(r\mid x,y^*). \end{equation}

STaR only adds the rationalizer’s rationales if the prior’s rationales are incorrect. This helps keep the rationales easy to model for the prior. The KL term in equation 10 should also achieve this, if qq is trained through the ELBo.

What is the point of analyzing things as latent variable models?

Formal frameworks serve to guide development by making the tradeoffs between different choices clear and composable. In the case of STaR, it could help us improve the method by analyzing the bias-variance tradeoffs of different design choices. It would also give us a principled way of conditioning on previous rationales. For example, if it is difficult to find a correct rationale even after conditioning on yy^*, we could introduce a rationale editor q(rx,rold,y)q(r \mid x, r_{\text{old}}, y) that can incorporate feedback.