@@ -79,6 +79,7 @@ def __init__(
7979 self ._total_rejected_logps = torch .tensor (0.0 , device = device )
8080 self ._total_chosen_rewards = torch .tensor (0.0 , device = device )
8181 self ._total_rejected_rewards = torch .tensor (0.0 , device = device )
82+ self ._total_rewards_accuracy = torch .tensor (0.0 , device = device )
8283 self ._total_aux_loss = torch .tensor (0.0 , device = device ) if args .load_balancing_loss else None
8384
8485 if args .packing :
@@ -161,13 +162,15 @@ def train_batch(self, batch: dict[str, Any], dry_run: bool = False) -> None:
161162 self ._total_rejected_logps .zero_ ()
162163 self ._total_chosen_rewards .zero_ ()
163164 self ._total_rejected_rewards .zero_ ()
165+ self ._total_rewards_accuracy .zero_ ()
164166 if self ._total_aux_loss is not None :
165167 self ._total_aux_loss .zero_ ()
166168 total_loss = self ._total_loss
167169 total_chosen_logps = self ._total_chosen_logps
168170 total_rejected_logps = self ._total_rejected_logps
169171 total_chosen_rewards = self ._total_chosen_rewards
170172 total_rejected_rewards = self ._total_rejected_rewards
173+ total_rewards_accuracy = self ._total_rewards_accuracy
171174 total_aux_loss = self ._total_aux_loss
172175
173176 for micro_batch_idx , micro_batch in enumerate (micro_batches ):
@@ -209,6 +212,9 @@ def train_batch(self, batch: dict[str, Any], dry_run: bool = False) -> None:
209212 if self .args .loss_type .computes_reward_metrics :
210213 total_chosen_rewards += chosen_rewards .mean ().detach () / num_micro_batches
211214 total_rejected_rewards += rejected_rewards .mean ().detach () / num_micro_batches
215+ total_rewards_accuracy += (
216+ chosen_rewards > rejected_rewards
217+ ).float ().mean ().detach () / num_micro_batches
212218 if total_aux_loss is not None and aux_loss is not None :
213219 total_aux_loss += aux_loss .detach () / num_micro_batches
214220
@@ -222,14 +228,13 @@ def train_batch(self, batch: dict[str, Any], dry_run: bool = False) -> None:
222228 self .record_metric ("train/logps_rejected" , total_rejected_logps , ReduceType .mean )
223229
224230 if self .args .loss_type .computes_reward_metrics :
225- accuracy = (total_chosen_rewards > total_rejected_rewards ).float ()
226231 margin = total_chosen_rewards - total_rejected_rewards
227232 self .record_metric ("train/rewards_chosen" , total_chosen_rewards , ReduceType .mean )
228233 self .record_metric ("train/rewards_rejected" , total_rejected_rewards , ReduceType .mean )
229234 self .record_metric (
230235 "train/rewards_average" , (total_chosen_rewards + total_rejected_rewards ) / 2 , ReduceType .mean
231236 )
232- self .record_metric ("train/rewards_accuracy" , accuracy , ReduceType .mean )
237+ self .record_metric ("train/rewards_accuracy" , total_rewards_accuracy , ReduceType .mean )
233238 self .record_metric ("train/rewards_margin" , margin , ReduceType .mean )
234239
235240 chosen_lengths = (batch ["chosen_labels" ] != - 100 ).sum ()
0 commit comments