You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: chapters/en/chapter12/3a.mdx
+62-49Lines changed: 62 additions & 49 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -10,11 +10,11 @@ Let's deepen our understanding of GRPO so that we can improve our model's traini
10
10
11
11
GRPO directly evaluates the model-generated responses by comparing them within groups of generation to optimize policy model, instead of training a separate value model (Critic). This approach leads to significant reduction in computational cost!
12
12
13
-
GRPO can be applied to any verifiable task where the correctness of the response can be determined. For instance, in math reasoning, the correctness of the response can be easily verified by comparing it to the ground truth.
13
+
GRPO can be applied to any verifiable task where the correctness of the response can be determined. For instance, in math reasoning, the correctness of the response can be easily verified by comparing it to the ground truth.
14
14
15
15
Before diving into the technical details, let's visualize how GRPO works at a high level:
Now that we have a visual overview, let's break down how GRPO works step by step.
20
20
@@ -28,49 +28,56 @@ Let's walk through each step of the algorithm in detail:
28
28
29
29
The first step is to generate multiple possible answers for each question. This creates a diverse set of outputs that can be compared against each other.
30
30
31
-
For each question $q$, the model will generate $G$ outputs (group size) from the trained policy:{${o_1, o_2, o_3, \dots, o_G}\pi_{\theta_{\text{old}}}$}, $G=8$ where each $o_i$ represents one completion from the model.
31
+
For each question \\( q \\), the model will generate \\( G \\) outputs (group size) from the trained policy:{ \\( {o_1, o_2, o_3, \dots, o_G}\pi_{\theta_{\text{old}}} \\) }, \\( G=8\\) where each \\( o_i\\) represents one completion from the model.
32
32
33
-
#### Example:
33
+
#### Example
34
34
35
35
To make this concrete, let's look at a simple arithmetic problem:
Notice how some of the generated answers are correct (14) while others are wrong (16 or 10). This diversity is crucial for the next step.
41
46
42
47
### Step 2: Advantage Calculation
43
48
44
49
Once we have multiple responses, we need a way to determine which ones are better than others. This is where the advantage calculation comes in.
45
50
46
-
#### Reward Distribution:
51
+
#### Reward Distribution
47
52
48
53
First, we assign a reward score to each generated response. In this example, we'll use a reward model, but as we learnt in the previous section, we can use any reward returning function.
49
54
50
-
Assign a RM score to each of the generated responses based on the correctness $r_i$ *(e.g. 1 for correct response, 0 for wrong response)* then for each of the $r_i$ calculate the following Advantage value
55
+
Assign a RM score to each of the generated responses based on the correctness \\( r_i\\) *(e.g. 1 for correct response, 0 for wrong response)* then for each of the \\( r_i\\) calculate the following Advantage value.
51
56
52
-
#### Advantage Value Formula:
57
+
#### Advantage Value Formula
53
58
54
59
The key insight of GRPO is that we don't need absolute measures of quality - we can compare outputs within the same group. This is done using standardization:
Now that we have calculated the advantage values, let's understand what they mean:
70
77
71
-
This standardization (i.e. $A_i$ weighting) allows the model to assess each response's relative performance, guiding the optimization process to favour responses that are better than average (high reward) and discourage those that are worse. For instance if $A_i > 0$, then the $o_i$ is better response than the average level within its group; and if $A_i < 0$, then the $o_i$ then the quality of the response is less than the average (i.e. poor quality/performance).
78
+
This standardization (i.e. \\( A_i\\) weighting) allows the model to assess each response's relative performance, guiding the optimization process to favorable responses that are better than average (high reward) and discourage those that are worse. For instance if \\( A_i > 0\\), then the \\( o_i\\) is better response than the average level within its group; and if \\( A_i < 0\\), then the \\( o_i\\) then the quality of the response is less than the average (i.e. poor quality/performance).
72
79
73
-
For the example above, if $A_i = 0.94 \text{(correctoutput)}$ then during optimization steps its generation probability will be increased.
80
+
For the example above, if \\( A_i = 0.94 \text{(correctoutput)}\\) then during optimization steps its generation probability will be increased.
74
81
75
82
With our advantage values calculated, we're now ready to update the policy.
76
83
@@ -80,7 +87,7 @@ The final step is to use these advantage values to update our model so that it b
This formula might look intimidating at first, but it's built from several components that each serve an important purpose. Let's break them down one by one.
86
93
@@ -92,70 +99,76 @@ The GRPO update function combines several techniques to ensure stable and effect
Intuitively, the formula compares how much the new model's response probability differs from the old model's response probability while incorporating a preference for responses that improve the expected outcome.
98
105
99
-
#### Interpretation:
100
-
- If $\text{ratio} > 1$, the new model assigns a higher probability to response $o_i$ than the old model.
101
-
- If $\text{ratio} < 1$, the new model assigns a lower probability to $o_i$
106
+
#### Interpretation
107
+
108
+
- If \\( \text{ratio} > 1 \\), the new model assigns a higher probability to response \\( o_i \\) than the old model.
109
+
- If \\( \text{ratio} < 1 \\), the new model assigns a lower probability to \\( o_i \\)
102
110
103
111
This ratio allows us to control how much the model changes at each step, which leads us to the next component.
Limit the ratio discussed above to be within $[1 - \epsilon, 1 + \epsilon]$ to avoid/control drastic changes or crazy updates and stepping too far off from the old policy. In other words, it limit how much the probability ratio can increase to help maintaining stability by avoiding updates that push the new model too far from the old one.
119
+
Limit the ratio discussed above to be within \\( [1 - \epsilon, 1 + \epsilon]\\) to avoid/control drastic changes or crazy updates and stepping too far off from the old policy. In other words, it limit how much the probability ratio can increase to help maintaining stability by avoiding updates that push the new model too far from the old one.
120
+
121
+
#### Example (ε = 0.2)
112
122
113
-
#### Example $\space \text{suppose}(\epsilon = 0.2)$
114
123
Let's look at two different scenarios to better understand this clipping function:
115
124
116
125
-**Case 1**: if the new policy has a probability of 0.9 for a specific response and the old policy has a probabiliy of 0.5, it means this response is getting reinforeced by the new policy to have higher probability, but within a controlled limit which is the clipping to tight up its hands to not get drastic
-**Case 2**: If the new policy is not in favour of a response (lower probability e.g. 0.2), meaning if the response is not beneficial the increase might be incorrect, and the model would be penalized.
- The formula encourages the new model to favour responses that the old model underweighted **if they improve the outcome**.
122
-
- If the old model already favoured a response with a high probability, the new model can still reinforce it **but only within a controlled limit $[1 - \epsilon, 1 + \epsilon]$, $\text{(e.g., }\epsilon=0.2, \space \text{so} \space [0.8-1.2])$**.
133
+
- If the old model already favoured a response with a high probability, the new model can still reinforce it **but only within a controlled limit \\( [1 - \epsilon, 1 + \epsilon]\\), \\( \text{(e.g., }\epsilon=0.2, \space \text{so} \space [0.8-1.2]) \\)**.
-Therefore, intuitively, Byincorporatingtheprobabilityratio, theobjectivefunction ensures that updates to the policy are proportional to the advantage $A_i$ while being moderated to prevent drastic changes. T
135
+
-Therefore, intuitively, Byincorporatingtheprobabilityratio, theobjectivefunction ensures that updates to the policy are proportional to the advantage \\( A_i \\) while being moderated to prevent drastic changes. T
125
136
126
137
While the clipping function helps prevent drastic changes, we need one more safeguard to ensure our model doesn't deviate too far from its original behavior.
127
138
128
139
### 3. KL Divergence
129
140
130
141
The KL divergence term is:
131
142
132
-
$\beta D_{KL}(\pi_{\theta} || \pi_{ref})$
143
+
\\( \betaD_{KL}(\pi_{\theta} \|\| \pi_{ref}) \\)
133
144
134
-
IntheKLdivergenceterm, the$\pi_{ref}$isbasicallythepre-updatemodel's output, `per_token_logps` and $\pi_{\theta}$ is the new model'soutput, `new_per_token_logps`. Theoretically, KLdivergenceisminimizedtopreventthemodelfromdeviatingtoofarfromitsoriginalbehaviorduringoptimization. Thishelpsstrikeabalancebetweenimprovingperformancebasedontherewardsignalandmaintainingcoherence. Inthiscontext, minimizingKLdivergencereducestheriskofthemodelgeneratingnonsensicaltextor, inthecaseofmathematicalreasoning, producingextremelyincorrectanswers.
145
+
In the KL divergence term, the \\( \pi_{ref} \\) is basically the pre-update model's output, `per_token_logps` and \\( \pi_{\theta} \\) is the new model's output, `new_per_token_logps`. Theoretically, KL divergence is minimized to prevent the model from deviating too far from its original behavior during optimization. This helps strike a balance between improving performance based on the reward signal and maintaining coherence. In this context, minimizing KL divergence reduces the risk of the model generating nonsensical text or, in the case of mathematical reasoning, producing extremely incorrect answers.
135
146
136
147
#### Interpretation
148
+
137
149
- A KL divergence penalty keeps the model's outputs close to its original distribution, preventing extreme shifts.
138
150
- Instead of drifting towards completely irrational outputs, the model would refine its understanding while still allowing some exploration
139
151
140
152
#### Math Definition
153
+
141
154
For those interested in the mathematical details, let's look at the formal definition:
-Thenwhenthetargetfunction is re-weighted, the model tends to reinforce the generation of correct output, and the $\text{KLDivergence}$limitsthedeviationfromthereferencepolicy.
208
+
-Thenwhenthetargetfunction is re-weighted, the model tends to reinforce the generation of correct output, and the \\( \text{KLDivergence} \\) limits the deviation from the reference policy.
196
209
197
210
With the theoretical understanding in place, let's see how GRPO can be implemented in code.
0 commit comments