diff --git a/keras/src/losses/loss.py b/keras/src/losses/loss.py index beeb016f5063..5177180d1117 100644 --- a/keras/src/losses/loss.py +++ b/keras/src/losses/loss.py @@ -14,7 +14,8 @@ class Loss(KerasSaveable): Args: reduction: Type of reduction to apply to the loss. In almost all cases this should be `"sum_over_batch_size"`. - Supported options are `"sum"`, `"sum_over_batch_size"` or `None`. + Supported options are `"sum"`, `"sum_over_batch_size"`, `"mean"` + or `None`. name: Optional name for the loss instance. dtype: The dtype of the loss's computations. Defaults to `None`, which means using `keras.backend.floatx()`. `keras.backend.floatx()` is a @@ -92,7 +93,7 @@ def _obj_type(self): def standardize_reduction(reduction): - allowed = {"sum_over_batch_size", "sum", None, "none"} + allowed = {"sum_over_batch_size", "sum", None, "none", "mean"} if reduction not in allowed: raise ValueError( "Invalid value for argument `reduction`. " @@ -132,7 +133,7 @@ def reduce_values(values, reduction="sum_over_batch_size"): ): return values loss = ops.sum(values) - if reduction == "sum_over_batch_size": + if reduction in ("mean", "sum_over_batch_size"): loss /= ops.cast( ops.prod(ops.convert_to_tensor(ops.shape(values), dtype="int32")), loss.dtype, @@ -177,7 +178,7 @@ def apply_mask(sample_weight, mask, dtype, reduction): """Applies any mask on predictions to sample weights.""" if mask is not None: mask = ops.cast(mask, dtype=dtype) - if reduction == "sum_over_batch_size": + if reduction in ("mean", "sum_over_batch_size"): # Valid entries have weight `total/valid`, while invalid ones # have 0. When summed over batch, they will be reduced to: # diff --git a/keras/src/losses/loss_test.py b/keras/src/losses/loss_test.py index 3f13bc96725b..003dd1e7b4a4 100644 --- a/keras/src/losses/loss_test.py +++ b/keras/src/losses/loss_test.py @@ -69,7 +69,7 @@ def test_reduction(self): self.assertEqual(backend.standardize_dtype(loss.dtype), "float32") self.assertAllClose(np.sum((y_true - y_pred) ** 2), loss) - # sum_over_batch_size + # sum_over_batch_size or mean loss_fn = ExampleLoss(reduction="sum_over_batch_size") loss = loss_fn(y_true, y_pred) self.assertEqual(backend.standardize_dtype(loss.dtype), "float32")