@@ -85,6 +85,28 @@ def compute_loss_and_updates(
85
85
metrics_variables ,
86
86
)
87
87
88
+ def update_metrics_variables (
89
+ self , metrics_variables , unscaled_loss , x , y , y_pred , sample_weight
90
+ ):
91
+ with backend .StatelessScope (
92
+ state_mapping = [
93
+ (ref_v , v )
94
+ for ref_v , v in zip (self .metrics_variables , metrics_variables )
95
+ ]
96
+ ) as scope :
97
+ self ._loss_tracker .update_state (
98
+ unscaled_loss , sample_weight = tree .flatten (x )[0 ].shape [0 ]
99
+ )
100
+ logs = self .compute_metrics (x , y , y_pred , sample_weight )
101
+
102
+ new_metrics_variables = []
103
+ for ref_v in self .metrics_variables :
104
+ new_v = scope .get_current_value (ref_v )
105
+ if new_v is None :
106
+ new_v = ref_v .value
107
+ new_metrics_variables .append (new_v )
108
+ return logs , new_metrics_variables
109
+
88
110
def train_step (self , state , data ):
89
111
(
90
112
trainable_variables ,
@@ -117,24 +139,9 @@ def train_step(self, state, data):
117
139
optimizer_variables , grads , trainable_variables
118
140
)
119
141
120
- with backend .StatelessScope (
121
- state_mapping = [
122
- (ref_v , v )
123
- for ref_v , v in zip (self .metrics_variables , metrics_variables )
124
- ]
125
- ) as scope :
126
- self ._loss_tracker .update_state (
127
- unscaled_loss , sample_weight = tree .flatten (x )[0 ].shape [0 ]
128
- )
129
- logs = self .compute_metrics (x , y , y_pred , sample_weight )
130
-
131
- new_metrics_variables = []
132
- for ref_v in self .metrics_variables :
133
- new_v = scope .get_current_value (ref_v )
134
- if new_v is None :
135
- new_v = ref_v .value
136
- new_metrics_variables .append (new_v )
137
- metrics_variables = new_metrics_variables
142
+ logs , metrics_variables = self .update_metrics_variables (
143
+ metrics_variables , unscaled_loss , x , y , y_pred , sample_weight
144
+ )
138
145
139
146
state = self ._enforce_jax_state_sharding (
140
147
trainable_variables ,
@@ -164,24 +171,9 @@ def test_step(self, state, data):
164
171
aux
165
172
)
166
173
167
- with backend .StatelessScope (
168
- state_mapping = [
169
- (ref_v , v )
170
- for ref_v , v in zip (self .metrics_variables , metrics_variables )
171
- ]
172
- ) as scope :
173
- self ._loss_tracker .update_state (
174
- unscaled_loss , sample_weight = tree .flatten (x )[0 ].shape [0 ]
175
- )
176
- logs = self .compute_metrics (x , y , y_pred , sample_weight )
177
-
178
- new_metrics_variables = []
179
- for ref_v in self .metrics_variables :
180
- new_v = scope .get_current_value (ref_v )
181
- if new_v is None :
182
- new_v = ref_v .value
183
- new_metrics_variables .append (new_v )
184
- metrics_variables = new_metrics_variables
174
+ logs , metrics_variables = self .update_metrics_variables (
175
+ metrics_variables , unscaled_loss , x , y , y_pred , sample_weight
176
+ )
185
177
186
178
(
187
179
trainable_variables ,
0 commit comments