The Self-Taught Reasoner

Coming soon

Yapping

Earlier in this semester, in a seminar on LLM reasoning and planning, I talked about STaR, a method that helps LLMs do reasoning through finetuning. Well, that “explanation” kinda oversimplifies a lot of things in the paper. When I first read it, I had little to no context on LLM reasoning and RL, but the background reading I had to do to just understand the choices of the authors was worth it.

The Self-Taught Reasoner

(Wei et al., 2022) showed that a sufficiently large language model trained on a variety of data has many “emergent abilities”- abilities that happened to manifest after the model was scaled, like reasoning and few-shot learning. (Wei et al., 2022) showed that when prompted with triplets of task, rationale and answer, the model will learn patterns from the in-context examples and follow up with a rationale and answer to the prompted question. While effective in-context learning may only be observable at parameter counts >= 100B, a smaller model can learn to do a similar level of reasoning by finetuning on a dataset of questions, rationales and answers, and that is what this paper does, eh, if we simplify it. The idea that sets this paper apart from others is the rationalization of answers.

Let’s go through the whole process step by step. At the start, we have an LLM $M$, dataset $\mathcal{D} = {(x_i, y_i)}_{i=1}^{D}$, and $\mathcal{P} \ll \mathcal{D}$, a small subset that has intermediate rationales $r$.

x, y = sampler(D)
M0 = M.copy()

for n in range (N):
    ri_hat, yi_hat = M_n-1(xi) # use model from previous step to generate rationale and answer

    Dn = (xi, ri, yi) if yi_hat == yi # add it to dataset if the generated answer matches the ground truth
    
    if yi_hat != yi:
        ri_hat_rat, yi_hat_rat = M_n-1(add_hint(xi)) # add hint to the question by concatenating answer
    
    Dn_rat = (xi, ri_hat_rat, yi_hat_rat) if yi_hat != yi and yi_hat_rat == yi # add sample to rationalized dataset if only the rationalized answer  is correct
    M_n.train(concat(Dn, Dn_rat))

I promise the python pseudocode may look scary but the idea is very simple

Why self-taught?

You may ask yourself, why is the model trained through bootstrapping? The model could be finetuned on the reasoning traces of a bigger model like GPT-4 or DeepSeek-R1; it surely is a much simpler approach. I have two theories

  1. Distillation of bias:

    When the student model is finetuned on the reasoning traces of the teacher model, it will also learn the biases of the teacher model, among the many features that the teacher has.

  2. New reasoning mechanisms:

    During rationalization, the search space for the next token is much lower that that during generating rationales; all because the final answer is known. As we don’t directly care about the quality of rationales / rationalizations, the model implicitly learns to prioritize rationales that lead to correct answers, thus learning a strong signal

References

2022

  1. Emergent Abilities of Large Language Models
    Jason Wei, Yi Tay, Rishi Bommasani, and 13 more authors
    Transactions on Machine Learning Research, 2022
    Survey Certification
  2. Chain-of-thought prompting elicits reasoning in large language models
    Jason Wei, Xuezhi Wang, Dale Schuurmans, and 6 more authors
    In Proceedings of the 36th International Conference on Neural Information Processing Systems, New Orleans, LA, USA, 2022