File tree Expand file tree Collapse file tree 1 file changed +2
-3
lines changed
algoperf/workloads/wmt/wmt_jax Expand file tree Collapse file tree 1 file changed +2
-3
lines changed Original file line number Diff line number Diff line change @@ -282,7 +282,6 @@ def init_model_fn(
282
282
initial_params = initial_variables ['params' ]
283
283
self ._param_shapes = param_utils .jax_param_shapes (initial_params )
284
284
self ._param_types = param_utils .jax_param_types (self ._param_shapes )
285
- initial_params = jax_sharding_utils .shard_along_batch_dim (initial_params )
286
285
return initial_params , None
287
286
288
287
def is_output_params (self , param_key : spec .ParameterKey ) -> bool :
@@ -340,16 +339,16 @@ def _build_input_queue(
340
339
split : str ,
341
340
data_dir : str ,
342
341
global_batch_size : int ,
343
- repeat_final_dataset : Optional [bool ] = None ,
344
342
num_batches : Optional [int ] = None ,
343
+ repeat_final_dataset : Optional [bool ] = None ,
345
344
):
346
345
it = super ()._build_input_queue (
347
346
data_rng ,
348
347
split ,
349
348
data_dir ,
350
349
global_batch_size ,
351
- repeat_final_dataset ,
352
350
num_batches ,
351
+ repeat_final_dataset ,
353
352
)
354
353
f = functools .partial (
355
354
jax .device_put , device = jax_sharding_utils .get_batch_dim_sharding ()
You can’t perform that action at this time.
0 commit comments