Skip to content

Commit c789805

Browse files
Unscale loss value in TF (#20610)
1 parent b1e4057 commit c789805

File tree

4 files changed

+30
-19
lines changed

4 files changed

+30
-19
lines changed

keras/src/backend/tensorflow/distribute_test.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,8 @@ def test_correctness_with_fit_and_regularizer(self):
162162
)
163163
model = models.Model(inputs, layer(inputs))
164164
model.compile(loss="mse", optimizer="sgd")
165-
model.fit(x, y, batch_size=batch_size, epochs=1)
166-
165+
history = model.fit(x, y, batch_size=batch_size, epochs=1)
166+
expected_loss = history.history["loss"]
167167
expected_weights = keras.ops.convert_to_numpy(layer.kernel)
168168

169169
# Runs with a mirrored strategy.
@@ -177,8 +177,10 @@ def test_correctness_with_fit_and_regularizer(self):
177177
)
178178
model = models.Model(inputs, layer(inputs))
179179
model.compile(loss="mse", optimizer="sgd")
180-
model.fit(x, y, batch_size=batch_size, epochs=1)
180+
history = model.fit(x, y, batch_size=batch_size, epochs=1)
181181
weights = strategy.run(lambda: layer.kernel.value).values
182+
183+
self.assertAllClose(history.history["loss"], expected_loss)
182184
for w in weights:
183185
self.assertAllClose(
184186
keras.ops.convert_to_numpy(w), expected_weights

keras/src/backend/tensorflow/trainer.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,11 @@
55
import tensorflow as tf
66
from tensorflow.python.eager import context as tf_context
77

8-
from keras.src import backend as backend_module
98
from keras.src import callbacks as callbacks_module
109
from keras.src import metrics as metrics_module
11-
from keras.src import ops as ops_module
1210
from keras.src import optimizers as optimizers_module
1311
from keras.src import tree
12+
from keras.src.losses import loss as loss_module
1413
from keras.src.trainers import trainer as base_trainer
1514
from keras.src.trainers.data_adapters import array_slicing
1615
from keras.src.trainers.data_adapters import data_adapter_utils
@@ -66,7 +65,8 @@ def train_step(self, data):
6665
training=True,
6766
)
6867
self._loss_tracker.update_state(
69-
loss, sample_weight=tf.shape(tree.flatten(x)[0])[0]
68+
loss_module.unscale_loss_for_distribution(loss),
69+
sample_weight=tf.shape(tree.flatten(x)[0])[0],
7070
)
7171
if self.optimizer is not None:
7272
loss = self.optimizer.scale_loss(loss)
@@ -93,7 +93,8 @@ def test_step(self, data):
9393
x=x, y=y, y_pred=y_pred, sample_weight=sample_weight, training=False
9494
)
9595
self._loss_tracker.update_state(
96-
loss, sample_weight=tf.shape(tree.flatten(x)[0])[0]
96+
loss_module.unscale_loss_for_distribution(loss),
97+
sample_weight=tf.shape(tree.flatten(x)[0])[0],
9798
)
9899
return self.compute_metrics(x, y, y_pred, sample_weight=sample_weight)
99100

@@ -710,17 +711,8 @@ def _maybe_symbolic_build(self, iterator=None, data_batch=None):
710711
self._symbolic_build(data_batch=data_batch)
711712

712713
def _aggregate_additional_loss(self, loss):
713-
if not backend_module.is_float_dtype(loss.dtype):
714-
loss = ops_module.cast(loss, dtype=backend_module.floatx())
715-
loss = ops_module.sum(loss)
716-
717-
# Scales the loss by the number of replicas in the strategy.
718-
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
719-
if num_replicas > 1:
720-
loss = ops_module.multiply(
721-
loss, ops_module.cast(1.0 / num_replicas, loss.dtype)
722-
)
723-
return loss
714+
loss = super()._aggregate_additional_loss(loss)
715+
return loss_module.scale_loss_for_distribution(loss)
724716

725717

726718
class TFEpochIterator(EpochIterator):

keras/src/losses/loss.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,3 +239,18 @@ def scale_loss_for_distribution(value):
239239
value, ops.cast(1.0 / num_replicas, value.dtype)
240240
)
241241
return value
242+
243+
244+
def unscale_loss_for_distribution(value):
245+
"""Unscales the given value by the number of replicas in the strategy.
246+
247+
Currently, this function is only effective when using the tensorflow backend
248+
and `tf.distribute`.
249+
"""
250+
if backend.backend() == "tensorflow":
251+
import tensorflow as tf
252+
253+
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
254+
if num_replicas > 1:
255+
value = ops.multiply(value, ops.cast(num_replicas, value.dtype))
256+
return value

keras/src/trainers/compile_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from keras.src import ops
66
from keras.src import tree
77
from keras.src.backend.common.keras_tensor import KerasTensor
8+
from keras.src.losses import loss as loss_module
89
from keras.src.utils.naming import get_object_name
910
from keras.src.utils.tracking import Tracker
1011

@@ -799,7 +800,8 @@ def resolve_path(path, object):
799800
# Record *unweighted* individual losses.
800801
if metric:
801802
metric.update_state(
802-
value, sample_weight=tree.flatten(y_p)[0].shape[0]
803+
loss_module.unscale_loss_for_distribution(value),
804+
sample_weight=tree.flatten(y_p)[0].shape[0],
803805
)
804806
if loss_weight is not None:
805807
value = ops.multiply(value, loss_weight)

0 commit comments

Comments
 (0)