diff --git a/src/art/unsloth/train.py b/src/art/unsloth/train.py index 7954ac1d..bc468766 100644 --- a/src/art/unsloth/train.py +++ b/src/art/unsloth/train.py @@ -155,10 +155,23 @@ def compute_loss( prob_ratio = torch.clamp( prob_ratio, max=max_negative_advantage_importance_sampling_weight ) - policy_loss = -torch.min( - prob_ratio * advantages, - torch.clip(prob_ratio, 1 - epsilon, 1 + epsilon_high) * advantages, - ) + if _config.get("gmpo", True): + advantage_signs = -torch.sign(advantages) + signed_logprob_diff = logprob_diff * advantage_signs + signed_logprob_diff_clamp = torch.clamp( + signed_logprob_diff, -epsilon, epsilon_high + ) + signed_logprob_diff_max = torch.max( + signed_logprob_diff, signed_logprob_diff_clamp + ) + logprobs_diff_max = advantage_signs * signed_logprob_diff_max + prob_ratio = torch.exp(logprobs_diff_max) + policy_loss = -advantages * prob_ratio + else: + policy_loss = -torch.min( + prob_ratio * advantages, + torch.clip(prob_ratio, 1 - epsilon, 1 + epsilon_high) * advantages, + ) if upper_bound := _config.get("truncated_importance_sampling", None): policy_loss *= torch.clamp(prob_ratio, max=upper_bound) if ref_logprobs is not None: