Skip to content

Commit 6715342

Browse files
committed
fix wmt jax
1 parent 4573499 commit 6715342

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

algoperf/workloads/wmt/wmt_jax/workload.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,6 @@ def init_model_fn(
282282
initial_params = initial_variables['params']
283283
self._param_shapes = param_utils.jax_param_shapes(initial_params)
284284
self._param_types = param_utils.jax_param_types(self._param_shapes)
285-
initial_params = jax_sharding_utils.shard_along_batch_dim(initial_params)
286285
return initial_params, None
287286

288287
def is_output_params(self, param_key: spec.ParameterKey) -> bool:
@@ -340,16 +339,16 @@ def _build_input_queue(
340339
split: str,
341340
data_dir: str,
342341
global_batch_size: int,
343-
repeat_final_dataset: Optional[bool] = None,
344342
num_batches: Optional[int] = None,
343+
repeat_final_dataset: Optional[bool] = None,
345344
):
346345
it = super()._build_input_queue(
347346
data_rng,
348347
split,
349348
data_dir,
350349
global_batch_size,
351-
repeat_final_dataset,
352350
num_batches,
351+
repeat_final_dataset,
353352
)
354353
f = functools.partial(
355354
jax.device_put, device=jax_sharding_utils.get_batch_dim_sharding()

0 commit comments

Comments
 (0)