@@ -85,6 +85,34 @@ def compute_loss_and_updates(
85
85
metrics_variables ,
86
86
)
87
87
88
+ def update_metrics_variables (
89
+ self ,
90
+ metrics_variables ,
91
+ unscaled_loss ,
92
+ x ,
93
+ y ,
94
+ y_pred ,
95
+ sample_weight
96
+ ):
97
+ with backend .StatelessScope (
98
+ state_mapping = [
99
+ (ref_v , v )
100
+ for ref_v , v in zip (self .metrics_variables , metrics_variables )
101
+ ]
102
+ ) as scope :
103
+ self ._loss_tracker .update_state (
104
+ unscaled_loss , sample_weight = tree .flatten (x )[0 ].shape [0 ]
105
+ )
106
+ logs = self .compute_metrics (x , y , y_pred , sample_weight )
107
+
108
+ new_metrics_variables = []
109
+ for ref_v in self .metrics_variables :
110
+ new_v = scope .get_current_value (ref_v )
111
+ if new_v is None :
112
+ new_v = ref_v .value
113
+ new_metrics_variables .append (new_v )
114
+ return logs , new_metrics_variables
115
+
88
116
def train_step (self , state , data ):
89
117
(
90
118
trainable_variables ,
@@ -117,24 +145,9 @@ def train_step(self, state, data):
117
145
optimizer_variables , grads , trainable_variables
118
146
)
119
147
120
- with backend .StatelessScope (
121
- state_mapping = [
122
- (ref_v , v )
123
- for ref_v , v in zip (self .metrics_variables , metrics_variables )
124
- ]
125
- ) as scope :
126
- self ._loss_tracker .update_state (
127
- unscaled_loss , sample_weight = tree .flatten (x )[0 ].shape [0 ]
128
- )
129
- logs = self .compute_metrics (x , y , y_pred , sample_weight )
130
-
131
- new_metrics_variables = []
132
- for ref_v in self .metrics_variables :
133
- new_v = scope .get_current_value (ref_v )
134
- if new_v is None :
135
- new_v = ref_v .value
136
- new_metrics_variables .append (new_v )
137
- metrics_variables = new_metrics_variables
148
+ logs , metrics_variables = self .update_metrics_variables (
149
+ metrics_variables , unscaled_loss , x , y , y_pred , sample_weight
150
+ )
138
151
139
152
state = self ._enforce_jax_state_sharding (
140
153
trainable_variables ,
@@ -164,24 +177,9 @@ def test_step(self, state, data):
164
177
aux
165
178
)
166
179
167
- with backend .StatelessScope (
168
- state_mapping = [
169
- (ref_v , v )
170
- for ref_v , v in zip (self .metrics_variables , metrics_variables )
171
- ]
172
- ) as scope :
173
- self ._loss_tracker .update_state (
174
- unscaled_loss , sample_weight = tree .flatten (x )[0 ].shape [0 ]
175
- )
176
- logs = self .compute_metrics (x , y , y_pred , sample_weight )
177
-
178
- new_metrics_variables = []
179
- for ref_v in self .metrics_variables :
180
- new_v = scope .get_current_value (ref_v )
181
- if new_v is None :
182
- new_v = ref_v .value
183
- new_metrics_variables .append (new_v )
184
- metrics_variables = new_metrics_variables
180
+ logs , metrics_variables = self .update_metrics_variables (
181
+ metrics_variables , unscaled_loss , x , y , y_pred , sample_weight
182
+ )
185
183
186
184
(
187
185
trainable_variables ,
0 commit comments