Skip to content

Commit e716ac9

Browse files
committed
fix criteo1tb
1 parent 74b95fb commit e716ac9

File tree

3 files changed

+3
-7
lines changed

3 files changed

+3
-7
lines changed

algoperf/workloads/criteo1tb/criteo1tb_jax/models.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -223,11 +223,7 @@ def scaled_init(key, shape, dtype=jnp.float_):
223223
top_mlp_input = nn.relu(top_mlp_input)
224224
if self.use_layer_norm:
225225
top_mlp_input = nn.LayerNorm()(top_mlp_input)
226-
if (
227-
dropout_rate is not None
228-
and dropout_rate > 0.0
229-
and layer_idx == num_layers_top - 2
230-
):
226+
if dropout_rate is not None and layer_idx == num_layers_top - 2:
231227
top_mlp_input = Dropout(dropout_rate, deterministic=not train)(
232228
top_mlp_input, rate=dropout_rate
233229
)

algoperf/workloads/imagenet_resnet/imagenet_jax/models.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,8 @@ def __call__(
8989
x: spec.Tensor,
9090
update_batch_norm: bool = True,
9191
use_running_average_bn: Optional[bool] = None,
92-
dropout_rate: float = 0.0,
9392
) -> spec.Tensor:
9493
conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype)
95-
del dropout_rate # unused
9694
# Preserve default behavior for backwards compatibility
9795
if use_running_average_bn is None:
9896
use_running_average_bn = not update_batch_norm

algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,9 +154,11 @@ def model_fn(
154154
rng: spec.RandomState,
155155
update_batch_norm: bool,
156156
use_running_average_bn: Optional[bool] = None,
157+
dropout_rate: Optional[float] = None,
157158
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
158159
del mode
159160
del rng
161+
del dropout_rate
160162
variables = {'params': params, **model_state}
161163
if update_batch_norm:
162164
logits, new_model_state = self._model.apply(

0 commit comments

Comments
 (0)