Skip to content

Commit 2462886

Browse files
committed
Merge branch 'release' into release_unit2_improvements
1 parent fac4675 commit 2462886

File tree

4 files changed

+63
-50
lines changed

4 files changed

+63
-50
lines changed

chapters/en/_toctree.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
- local: chapter1/4
1515
title: How do Transformers work?
1616
- local: chapter1/5
17-
title: How 🤗 Transformers solve tasks
17+
title: Solving Tasks with Transformers
1818
- local: chapter1/6
1919
title: Transformer Architectures
2020
- local: chapter1/7

chapters/en/chapter12/3a.mdx

Lines changed: 62 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@ Let's deepen our understanding of GRPO so that we can improve our model's traini
1010

1111
GRPO directly evaluates the model-generated responses by comparing them within groups of generation to optimize policy model, instead of training a separate value model (Critic). This approach leads to significant reduction in computational cost!
1212

13-
GRPO can be applied to any verifiable task where the correctness of the response can be determined. For instance, in math reasoning, the correctness of the response can be easily verified by comparing it to the ground truth.
13+
GRPO can be applied to any verifiable task where the correctness of the response can be determined. For instance, in math reasoning, the correctness of the response can be easily verified by comparing it to the ground truth.
1414

1515
Before diving into the technical details, let's visualize how GRPO works at a high level:
1616

17-
![deep](./img/2.jpg)
17+
![deep](https://huggingface.co/reasoning-course/images/resolve/main/grpo/16.png)
1818

1919
Now that we have a visual overview, let's break down how GRPO works step by step.
2020

@@ -28,49 +28,56 @@ Let's walk through each step of the algorithm in detail:
2828

2929
The first step is to generate multiple possible answers for each question. This creates a diverse set of outputs that can be compared against each other.
3030

31-
For each question $q$, the model will generate $G$ outputs (group size) from the trained policy:{ ${o_1, o_2, o_3, \dots, o_G}\pi_{\theta_{\text{old}}}$ }, $G=8$ where each $o_i$ represents one completion from the model.
31+
For each question \\( q \\), the model will generate \\( G \\) outputs (group size) from the trained policy: { \\( {o_1, o_2, o_3, \dots, o_G}\pi_{\theta_{\text{old}}} \\) }, \\( G=8 \\) where each \\( o_i \\) represents one completion from the model.
3232

33-
#### Example:
33+
#### Example
3434

3535
To make this concrete, let's look at a simple arithmetic problem:
3636

37-
- **Question** $q$ : $\text{Calculate}\space2 + 2 \times 6$
38-
- **Outputs** $(G = 8)$: $\{o_1:14 \text{ (correct)}, o_2:16 \text{ (wrong)}, o_3:10 \text{ (wrong)}, \ldots, o_8:14 \text{ (correct)}\}$
37+
**Question**
38+
39+
\\( q \\) : \\( \text{Calculate}\space2 + 2 \times 6 \\)
40+
41+
**Outputs**
42+
43+
\\( (G = 8) \\): \\( \{o_1:14 \text{ (correct)}, o_2:16 \text{ (wrong)}, o_3:10 \text{ (wrong)}, \ldots, o_8:14 \text{ (correct)}\} \\)
3944

4045
Notice how some of the generated answers are correct (14) while others are wrong (16 or 10). This diversity is crucial for the next step.
4146

4247
### Step 2: Advantage Calculation
4348

4449
Once we have multiple responses, we need a way to determine which ones are better than others. This is where the advantage calculation comes in.
4550

46-
#### Reward Distribution:
51+
#### Reward Distribution
4752

4853
First, we assign a reward score to each generated response. In this example, we'll use a reward model, but as we learnt in the previous section, we can use any reward returning function.
4954

50-
Assign a RM score to each of the generated responses based on the correctness $r_i$ *(e.g. 1 for correct response, 0 for wrong response)* then for each of the $r_i$ calculate the following Advantage value
55+
Assign a RM score to each of the generated responses based on the correctness \\( r_i \\) *(e.g. 1 for correct response, 0 for wrong response)* then for each of the \\( r_i \\) calculate the following Advantage value.
5156

52-
#### Advantage Value Formula:
57+
#### Advantage Value Formula
5358

5459
The key insight of GRPO is that we don't need absolute measures of quality - we can compare outputs within the same group. This is done using standardization:
5560

5661
$$A_i = \frac{r_i - \text{mean}(\{r_1, r_2, \ldots, r_G\})}{\text{std}(\{r_1, r_2, \ldots, r_G\})}$$
5762

58-
#### Example:
63+
#### Example
5964

6065
Continuing with our arithmetic example for the same example above, imagine we have 8 responses, 4 of which is correct and the rest wrong, therefore;
61-
- Group Average: $mean(r_i) = 0.5$
62-
- Std: $std(r_i) = 0.53$
63-
- Advantage Value:
64-
- Correct response: $A_i = \frac{1 - 0.5}{0.53}= 0.94$
65-
- Wrong response: $A_i = \frac{0 - 0.5}{0.53}= -0.94$
6666

67-
#### Interpretation:
67+
| Metric | Value |
68+
|--------|-------|
69+
| Group Average | \\( mean(r_i) = 0.5 \\) |
70+
| Standard Deviation | \\( std(r_i) = 0.53 \\) |
71+
| Advantage Value (Correct response) | \\( A_i = \frac{1 - 0.5}{0.53}= 0.94 \\) |
72+
| Advantage Value (Wrong response) | \\( A_i = \frac{0 - 0.5}{0.53}= -0.94 \\) |
73+
74+
#### Interpretation
6875

6976
Now that we have calculated the advantage values, let's understand what they mean:
7077

71-
This standardization (i.e. $A_i$ weighting) allows the model to assess each response's relative performance, guiding the optimization process to favour responses that are better than average (high reward) and discourage those that are worse. For instance if $A_i > 0$, then the $o_i$ is better response than the average level within its group; and if $A_i < 0$, then the $o_i$ then the quality of the response is less than the average (i.e. poor quality/performance).
78+
This standardization (i.e. \\( A_i \\) weighting) allows the model to assess each response's relative performance, guiding the optimization process to favorable responses that are better than average (high reward) and discourage those that are worse. For instance if \\( A_i > 0 \\), then the \\( o_i \\) is better response than the average level within its group; and if \\( A_i < 0 \\), then the \\( o_i \\) then the quality of the response is less than the average (i.e. poor quality/performance).
7279

73-
For the example above, if $A_i = 0.94 \text{(correct output)}$ then during optimization steps its generation probability will be increased.
80+
For the example above, if \\( A_i = 0.94 \text{(correct output)} \\) then during optimization steps its generation probability will be increased.
7481

7582
With our advantage values calculated, we're now ready to update the policy.
7683

@@ -80,7 +87,7 @@ The final step is to use these advantage values to update our model so that it b
8087

8188
The target function for policy update is:
8289

83-
$$J_{GRPO}(\theta) = \left[\frac{1}{G} \sum_{i=1}^{G} \min \left( \frac{\pi_{\theta}(o_i|q)}{\pi_{\theta_{old}}(o_i|q)} A_i \text{clip}\left( \frac{\pi_{\theta}(o_i|q)}{\pi_{\theta_{old}}(o_i|q)}, 1 - \epsilon, 1 + \epsilon \right) A_i \right)\right]- \beta D_{KL}(\pi_{\theta} || \pi_{ref})$$
90+
$$J_{GRPO}(\theta) = \left[\frac{1}{G} \sum_{i=1}^{G} \min \left( \frac{\pi_{\theta}(o_i|q)}{\pi_{\theta_{old}}(o_i|q)} A_i \text{clip}\left( \frac{\pi_{\theta}(o_i|q)}{\pi_{\theta_{old}}(o_i|q)}, 1 - \epsilon, 1 + \epsilon \right) A_i \right)\right]- \beta D_{KL}(\pi_{\theta} \|\| \pi_{ref})$$
8491

8592
This formula might look intimidating at first, but it's built from several components that each serve an important purpose. Let's break them down one by one.
8693

@@ -92,70 +99,76 @@ The GRPO update function combines several techniques to ensure stable and effect
9299

93100
The probability ratio is defined as:
94101

95-
$\left(\frac{\pi_{\theta}(o_i|q)}{\pi_{\theta_{old}}(o_i|q)}\right)$
102+
\\( \left(\frac{\pi_{\theta}(o_i|q)}{\pi_{\theta_{old}}(o_i|q)}\right) \\)
96103

97104
Intuitively, the formula compares how much the new model's response probability differs from the old model's response probability while incorporating a preference for responses that improve the expected outcome.
98105

99-
#### Interpretation:
100-
- If $\text{ratio} > 1$, the new model assigns a higher probability to response $o_i$​ than the old model.
101-
- If $\text{ratio} < 1$, the new model assigns a lower probability to $o_i$​
106+
#### Interpretation
107+
108+
- If \\( \text{ratio} > 1 \\), the new model assigns a higher probability to response \\( o_i \\) than the old model.
109+
- If \\( \text{ratio} < 1 \\), the new model assigns a lower probability to \\( o_i \\)
102110

103111
This ratio allows us to control how much the model changes at each step, which leads us to the next component.
104112

105113
### 2. Clip Function
106114

107115
The clipping function is defined as:
108116

109-
$\text{clip}\left( \frac{\pi_{\theta}(o_i|q)}{\pi_{\theta_{old}}(o_i|q)}, 1 - \epsilon, 1 + \epsilon\right)$
117+
\\( \text{clip}\left( \frac{\pi_{\theta}(o_i|q)}{\pi_{\theta_{old}}(o_i|q)}, 1 - \epsilon, 1 + \epsilon\right) \\)
110118

111-
Limit the ratio discussed above to be within $[1 - \epsilon, 1 + \epsilon]$ to avoid/control drastic changes or crazy updates and stepping too far off from the old policy. In other words, it limit how much the probability ratio can increase to help maintaining stability by avoiding updates that push the new model too far from the old one.
119+
Limit the ratio discussed above to be within \\( [1 - \epsilon, 1 + \epsilon] \\) to avoid/control drastic changes or crazy updates and stepping too far off from the old policy. In other words, it limit how much the probability ratio can increase to help maintaining stability by avoiding updates that push the new model too far from the old one.
120+
121+
#### Example (ε = 0.2)
112122

113-
#### Example $\space \text{suppose}(\epsilon = 0.2)$
114123
Let's look at two different scenarios to better understand this clipping function:
115124

116125
- **Case 1**: if the new policy has a probability of 0.9 for a specific response and the old policy has a probabiliy of 0.5, it means this response is getting reinforeced by the new policy to have higher probability, but within a controlled limit which is the clipping to tight up its hands to not get drastic
117-
- $\text{Ratio}: \frac{\pi_{\theta}(o_i|q)}{\pi_{\theta_{old}}(o_i|q)} = \frac{0.9}{0.5} = 1.8 → \text{Clip}\space1.2$ (upper bound limit 1.2)
126+
- \\( \text{Ratio}: \frac{\pi_{\theta}(o_i|q)}{\pi_{\theta_{old}}(o_i|q)} = \frac{0.9}{0.5} = 1.8 → \text{Clip}\space1.2 \\) (upper bound limit 1.2)
118127
- **Case 2**: If the new policy is not in favour of a response (lower probability e.g. 0.2), meaning if the response is not beneficial the increase might be incorrect, and the model would be penalized.
119-
- $\text{Ratio}: \frac{\pi_{\theta}(o_i|q)}{\pi_{\theta_{old}}(o_i|q)} = \frac{0.2}{0.5} = 0.4 →\text{Clip}\space0.8$ (lower bound limit 0.8)
120-
#### Interpretation:
128+
- \\( \text{Ratio}: \frac{\pi_{\theta}(o_i|q)}{\pi_{\theta_{old}}(o_i|q)} = \frac{0.2}{0.5} = 0.4 →\text{Clip}\space0.8 \\) (lower bound limit 0.8)
129+
130+
#### Interpretation
131+
121132
- The formula encourages the new model to favour responses that the old model underweighted **if they improve the outcome**.
122-
- If the old model already favoured a response with a high probability, the new model can still reinforce it **but only within a controlled limit $[1 - \epsilon, 1 + \epsilon]$, $\text{(e.g., }\epsilon = 0.2, \space \text{so} \space [0.8-1.2])$**.
133+
- If the old model already favoured a response with a high probability, the new model can still reinforce it **but only within a controlled limit \\( [1 - \epsilon, 1 + \epsilon] \\), \\( \text{(e.g., }\epsilon = 0.2, \space \text{so} \space [0.8-1.2]) \\)**.
123134
- If the old model overestimated a response that performs poorly, the new model is **discouraged** from maintaining that high probability.
124-
- Therefore, intuitively, By incorporating the probability ratio, the objective function ensures that updates to the policy are proportional to the advantage $A_i$ while being moderated to prevent drastic changes. T
135+
- Therefore, intuitively, By incorporating the probability ratio, the objective function ensures that updates to the policy are proportional to the advantage \\( A_i \\) while being moderated to prevent drastic changes. T
125136

126137
While the clipping function helps prevent drastic changes, we need one more safeguard to ensure our model doesn't deviate too far from its original behavior.
127138

128139
### 3. KL Divergence
129140

130141
The KL divergence term is:
131142

132-
$\beta D_{KL}(\pi_{\theta} || \pi_{ref})$
143+
\\( \beta D_{KL}(\pi_{\theta} \|\| \pi_{ref}) \\)
133144

134-
In the KL divergence term, the $\pi_{ref}$ is basically the pre-update model's output, `per_token_logps` and $\pi_{\theta}$ is the new model's output, `new_per_token_logps`. Theoretically, KL divergence is minimized to prevent the model from deviating too far from its original behavior during optimization. This helps strike a balance between improving performance based on the reward signal and maintaining coherence. In this context, minimizing KL divergence reduces the risk of the model generating nonsensical text or, in the case of mathematical reasoning, producing extremely incorrect answers.
145+
In the KL divergence term, the \\( \pi_{ref} \\) is basically the pre-update model's output, `per_token_logps` and \\( \pi_{\theta} \\) is the new model's output, `new_per_token_logps`. Theoretically, KL divergence is minimized to prevent the model from deviating too far from its original behavior during optimization. This helps strike a balance between improving performance based on the reward signal and maintaining coherence. In this context, minimizing KL divergence reduces the risk of the model generating nonsensical text or, in the case of mathematical reasoning, producing extremely incorrect answers.
135146

136147
#### Interpretation
148+
137149
- A KL divergence penalty keeps the model's outputs close to its original distribution, preventing extreme shifts.
138150
- Instead of drifting towards completely irrational outputs, the model would refine its understanding while still allowing some exploration
139151

140152
#### Math Definition
153+
141154
For those interested in the mathematical details, let's look at the formal definition:
142155

143156
Recall that KL distance is defined as follows:
144-
$$D_{KL}(P || Q) = \sum_{x \in X} P(x) \log \frac{P(x)}{Q(x)}$$
157+
$$D_{KL}(P \|\| Q) = \sum_{x \in X} P(x) \log \frac{P(x)}{Q(x)}$$
145158
In RLHF, the two distributions of interest are often the distribution of the new model version, P(x), and a distribution of the reference policy, Q(x).
146159

147-
#### The Role of $\beta$ Parameter
160+
#### The Role of β Parameter
148161

149-
The coefficient $\beta$ controls how strongly we enforce the KL divergence constraint:
162+
The coefficient \\( \beta \\) controls how strongly we enforce the KL divergence constraint:
150163

151-
- **Higher $\beta$ (Stronger KL Penalty)**
164+
- **Higher β (Stronger KL Penalty)**
152165
- More constraint on policy updates. The model remains close to its reference distribution.
153166
- Can slow down adaptation: The model may struggle to explore better responses.
154-
- **Lower $\beta$ (Weaker KL Penalty)**
167+
- **Lower β (Weaker KL Penalty)**
155168
- More freedom to update policy: The model can deviate more from the reference.
156169
- Faster adaptation but risk of instability: The model might learn reward-hacking behaviors.
157170
- Over-optimization risk: If the reward model is flawed, the policy might generate nonsensical outputs.
158-
- **Original** [DeepSeekMath](https://arxiv.org/abs/2402.03300) paper set this $\beta= 0.04$
171+
- **Original** [DeepSeekMath](https://arxiv.org/abs/2402.03300) paper set this \\( \beta= 0.04 \\)
159172

160173
Now that we understand the components of GRPO, let's see how they work together in a complete example.
161174

@@ -169,30 +182,30 @@ $$\text{Q: Calculate}\space2 + 2 \times 6$$
169182

170183
### Step 1: Group Sampling
171184

172-
First, we generate multiple responses from our model:
185+
First, we generate multiple responses from our model.
173186

174-
Generate $(G = 8)$ responses, $4$ of which are correct answer ($14, \text{reward=} 1$) and $4$ incorrect $\text{(reward= 0)}$, Therefore:
187+
Generate \\( (G = 8) \\) responses, \\( 4 \\) of which are correct answer (\\( 14, \text{reward=} 1 \\)) and \\( 4 \\) incorrect \\( \text{(reward= 0)} \\), Therefore:
175188

176189
$${o_1:14(correct), o_2:10 (wrong), o_3:16 (wrong), ... o_G:14(correct)}$$
177190

178191
### Step 2: Advantage Calculation
179192

180193
Next, we calculate the advantage values to determine which responses are better than average:
181194

182-
- Group Average:
183-
$$mean(r_i) = 0.5$$
184-
- Std: $$std(r_i) = 0.53$$
185-
- Advantage Value:
186-
- Correct response: $A_i = \frac{1 - 0.5}{0.53}= 0.94$
187-
- Wrong response: $A_i = \frac{0 - 0.5}{0.53}= -0.94$
195+
| Statistic | Value |
196+
|-----------|-------|
197+
| Group Average | \\( mean(r_i) = 0.5 \\) |
198+
| Standard Deviation | \\( std(r_i) = 0.53 \\) |
199+
| Advantage Value (Correct response) | \\( A_i = \frac{1 - 0.5}{0.53}= 0.94 \\) |
200+
| Advantage Value (Wrong response) | \\( A_i = \frac{0 - 0.5}{0.53}= -0.94 \\) |
188201

189202
### Step 3: Policy Update
190203

191204
Finally, we update our model to reinforce the correct responses:
192205

193-
- Assuming the probability of old policy ($\pi_{\theta_{old}}$) for a correct output $o_1$ is $0.5$ and the new policy increases it to $0.7$ then:
206+
- Assuming the probability of old policy (\\( \pi_{\theta_{old}} \\)) for a correct output \\( o_1 \\) is \\( 0.5 \\) and the new policy increases it to \\( 0.7 \\) then:
194207
$$\text{Ratio}: \frac{0.7}{0.5} = 1.4 →\text{after Clip}\space1.2 \space (\epsilon = 0.2)$$
195-
- Then when the target function is re-weighted, the model tends to reinforce the generation of correct output, and the $\text{KL Divergence}$ limits the deviation from the reference policy.
208+
- Then when the target function is re-weighted, the model tends to reinforce the generation of correct output, and the \\( \text{KL Divergence} \\) limits the deviation from the reference policy.
196209

197210
With the theoretical understanding in place, let's see how GRPO can be implemented in code.
198211

chapters/en/chapter12/img/1.png

-57.4 KB
Binary file not shown.

chapters/en/chapter12/img/2.jpg

-43.3 KB
Binary file not shown.

0 commit comments

Comments
 (0)