Skip to content

Commit 29d204e

Browse files
author
Zoe Kendall
committed
extract metrics update logic into a helper method
this change will allow users to customize what happens in the step function while being able to use existing metrics update logic without needing to duplicate it
1 parent 5df8fb9 commit 29d204e

File tree

1 file changed

+28
-36
lines changed

1 file changed

+28
-36
lines changed

keras/src/backend/jax/trainer.py

Lines changed: 28 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,28 @@ def compute_loss_and_updates(
8585
metrics_variables,
8686
)
8787

88+
def _update_metrics_variables(
89+
self, metrics_variables, unscaled_loss, x, y, y_pred, sample_weight
90+
):
91+
with backend.StatelessScope(
92+
state_mapping=[
93+
(ref_v, v)
94+
for ref_v, v in zip(self.metrics_variables, metrics_variables)
95+
]
96+
) as scope:
97+
self._loss_tracker.update_state(
98+
unscaled_loss, sample_weight=tree.flatten(x)[0].shape[0]
99+
)
100+
logs = self.compute_metrics(x, y, y_pred, sample_weight)
101+
102+
new_metrics_variables = []
103+
for ref_v in self.metrics_variables:
104+
new_v = scope.get_current_value(ref_v)
105+
if new_v is None:
106+
new_v = ref_v.value
107+
new_metrics_variables.append(new_v)
108+
return logs, new_metrics_variables
109+
88110
def train_step(self, state, data):
89111
(
90112
trainable_variables,
@@ -117,24 +139,9 @@ def train_step(self, state, data):
117139
optimizer_variables, grads, trainable_variables
118140
)
119141

120-
with backend.StatelessScope(
121-
state_mapping=[
122-
(ref_v, v)
123-
for ref_v, v in zip(self.metrics_variables, metrics_variables)
124-
]
125-
) as scope:
126-
self._loss_tracker.update_state(
127-
unscaled_loss, sample_weight=tree.flatten(x)[0].shape[0]
128-
)
129-
logs = self.compute_metrics(x, y, y_pred, sample_weight)
130-
131-
new_metrics_variables = []
132-
for ref_v in self.metrics_variables:
133-
new_v = scope.get_current_value(ref_v)
134-
if new_v is None:
135-
new_v = ref_v.value
136-
new_metrics_variables.append(new_v)
137-
metrics_variables = new_metrics_variables
142+
logs, metrics_variables = self._update_metrics_variables(
143+
metrics_variables, unscaled_loss, x, y, y_pred, sample_weight
144+
)
138145

139146
state = self._enforce_jax_state_sharding(
140147
trainable_variables,
@@ -164,24 +171,9 @@ def test_step(self, state, data):
164171
aux
165172
)
166173

167-
with backend.StatelessScope(
168-
state_mapping=[
169-
(ref_v, v)
170-
for ref_v, v in zip(self.metrics_variables, metrics_variables)
171-
]
172-
) as scope:
173-
self._loss_tracker.update_state(
174-
unscaled_loss, sample_weight=tree.flatten(x)[0].shape[0]
175-
)
176-
logs = self.compute_metrics(x, y, y_pred, sample_weight)
177-
178-
new_metrics_variables = []
179-
for ref_v in self.metrics_variables:
180-
new_v = scope.get_current_value(ref_v)
181-
if new_v is None:
182-
new_v = ref_v.value
183-
new_metrics_variables.append(new_v)
184-
metrics_variables = new_metrics_variables
174+
logs, metrics_variables = self._update_metrics_variables(
175+
metrics_variables, unscaled_loss, x, y, y_pred, sample_weight
176+
)
185177

186178
(
187179
trainable_variables,

0 commit comments

Comments
 (0)