Skip to content

Commit 0aaad1e

Browse files
author
Zoe Kendall
committed
extract metrics update logic into a helper method
this change will allow users to customize what happens in the step function while being able to use existing metrics update logic without needing to duplicate it
1 parent 5df8fb9 commit 0aaad1e

File tree

1 file changed

+34
-36
lines changed

1 file changed

+34
-36
lines changed

keras/src/backend/jax/trainer.py

Lines changed: 34 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,34 @@ def compute_loss_and_updates(
8585
metrics_variables,
8686
)
8787

88+
def update_metrics_variables(
89+
self,
90+
metrics_variables,
91+
unscaled_loss,
92+
x,
93+
y,
94+
y_pred,
95+
sample_weight
96+
):
97+
with backend.StatelessScope(
98+
state_mapping=[
99+
(ref_v, v)
100+
for ref_v, v in zip(self.metrics_variables, metrics_variables)
101+
]
102+
) as scope:
103+
self._loss_tracker.update_state(
104+
unscaled_loss, sample_weight=tree.flatten(x)[0].shape[0]
105+
)
106+
logs = self.compute_metrics(x, y, y_pred, sample_weight)
107+
108+
new_metrics_variables = []
109+
for ref_v in self.metrics_variables:
110+
new_v = scope.get_current_value(ref_v)
111+
if new_v is None:
112+
new_v = ref_v.value
113+
new_metrics_variables.append(new_v)
114+
return logs, new_metrics_variables
115+
88116
def train_step(self, state, data):
89117
(
90118
trainable_variables,
@@ -117,24 +145,9 @@ def train_step(self, state, data):
117145
optimizer_variables, grads, trainable_variables
118146
)
119147

120-
with backend.StatelessScope(
121-
state_mapping=[
122-
(ref_v, v)
123-
for ref_v, v in zip(self.metrics_variables, metrics_variables)
124-
]
125-
) as scope:
126-
self._loss_tracker.update_state(
127-
unscaled_loss, sample_weight=tree.flatten(x)[0].shape[0]
128-
)
129-
logs = self.compute_metrics(x, y, y_pred, sample_weight)
130-
131-
new_metrics_variables = []
132-
for ref_v in self.metrics_variables:
133-
new_v = scope.get_current_value(ref_v)
134-
if new_v is None:
135-
new_v = ref_v.value
136-
new_metrics_variables.append(new_v)
137-
metrics_variables = new_metrics_variables
148+
logs, metrics_variables = self.update_metrics_variables(
149+
metrics_variables, unscaled_loss, x, y, y_pred, sample_weight
150+
)
138151

139152
state = self._enforce_jax_state_sharding(
140153
trainable_variables,
@@ -164,24 +177,9 @@ def test_step(self, state, data):
164177
aux
165178
)
166179

167-
with backend.StatelessScope(
168-
state_mapping=[
169-
(ref_v, v)
170-
for ref_v, v in zip(self.metrics_variables, metrics_variables)
171-
]
172-
) as scope:
173-
self._loss_tracker.update_state(
174-
unscaled_loss, sample_weight=tree.flatten(x)[0].shape[0]
175-
)
176-
logs = self.compute_metrics(x, y, y_pred, sample_weight)
177-
178-
new_metrics_variables = []
179-
for ref_v in self.metrics_variables:
180-
new_v = scope.get_current_value(ref_v)
181-
if new_v is None:
182-
new_v = ref_v.value
183-
new_metrics_variables.append(new_v)
184-
metrics_variables = new_metrics_variables
180+
logs, metrics_variables = self.update_metrics_variables(
181+
metrics_variables, unscaled_loss, x, y, y_pred, sample_weight
182+
)
185183

186184
(
187185
trainable_variables,

0 commit comments

Comments
 (0)