Skip to content

Commit e01505b

Browse files
Fix rewards_accuracy to use per-sample comparison instead of scalar
Previously compared mean chosen vs mean rejected rewards (always 0 or 1). Now computes per-sample accuracy and averages, matching dpo_tune_cache.py. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 72c7715 commit e01505b

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

open_instruct/olmo_core_train_modules.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def __init__(
7979
self._total_rejected_logps = torch.tensor(0.0, device=device)
8080
self._total_chosen_rewards = torch.tensor(0.0, device=device)
8181
self._total_rejected_rewards = torch.tensor(0.0, device=device)
82+
self._total_rewards_accuracy = torch.tensor(0.0, device=device)
8283
self._total_aux_loss = torch.tensor(0.0, device=device) if args.load_balancing_loss else None
8384

8485
if args.packing:
@@ -161,13 +162,15 @@ def train_batch(self, batch: dict[str, Any], dry_run: bool = False) -> None:
161162
self._total_rejected_logps.zero_()
162163
self._total_chosen_rewards.zero_()
163164
self._total_rejected_rewards.zero_()
165+
self._total_rewards_accuracy.zero_()
164166
if self._total_aux_loss is not None:
165167
self._total_aux_loss.zero_()
166168
total_loss = self._total_loss
167169
total_chosen_logps = self._total_chosen_logps
168170
total_rejected_logps = self._total_rejected_logps
169171
total_chosen_rewards = self._total_chosen_rewards
170172
total_rejected_rewards = self._total_rejected_rewards
173+
total_rewards_accuracy = self._total_rewards_accuracy
171174
total_aux_loss = self._total_aux_loss
172175

173176
for micro_batch_idx, micro_batch in enumerate(micro_batches):
@@ -209,6 +212,9 @@ def train_batch(self, batch: dict[str, Any], dry_run: bool = False) -> None:
209212
if self.args.loss_type.computes_reward_metrics:
210213
total_chosen_rewards += chosen_rewards.mean().detach() / num_micro_batches
211214
total_rejected_rewards += rejected_rewards.mean().detach() / num_micro_batches
215+
total_rewards_accuracy += (
216+
chosen_rewards > rejected_rewards
217+
).float().mean().detach() / num_micro_batches
212218
if total_aux_loss is not None and aux_loss is not None:
213219
total_aux_loss += aux_loss.detach() / num_micro_batches
214220

@@ -222,14 +228,13 @@ def train_batch(self, batch: dict[str, Any], dry_run: bool = False) -> None:
222228
self.record_metric("train/logps_rejected", total_rejected_logps, ReduceType.mean)
223229

224230
if self.args.loss_type.computes_reward_metrics:
225-
accuracy = (total_chosen_rewards > total_rejected_rewards).float()
226231
margin = total_chosen_rewards - total_rejected_rewards
227232
self.record_metric("train/rewards_chosen", total_chosen_rewards, ReduceType.mean)
228233
self.record_metric("train/rewards_rejected", total_rejected_rewards, ReduceType.mean)
229234
self.record_metric(
230235
"train/rewards_average", (total_chosen_rewards + total_rejected_rewards) / 2, ReduceType.mean
231236
)
232-
self.record_metric("train/rewards_accuracy", accuracy, ReduceType.mean)
237+
self.record_metric("train/rewards_accuracy", total_rewards_accuracy, ReduceType.mean)
233238
self.record_metric("train/rewards_margin", margin, ReduceType.mean)
234239

235240
chosen_lengths = (batch["chosen_labels"] != -100).sum()

0 commit comments

Comments
 (0)