Skip to content

Commit bfd72bb

Browse files
Merge pull request #883 from mlcommons/jit_switch
Small fix to prevent OOM in Imagenet VIT
2 parents 47c8d2b + cfbaf7a commit bfd72bb

File tree

1 file changed

+1
-8
lines changed

1 file changed

+1
-8
lines changed

algoperf/workloads/imagenet_resnet/imagenet_v2.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@
66
import functools
77
from typing import Dict, Iterator, Tuple
88

9-
import jax
109
import tensorflow_datasets as tfds
1110

12-
from algoperf import data_utils, jax_sharding_utils, spec
11+
from algoperf import data_utils, spec
1312
from algoperf.workloads.imagenet_resnet.imagenet_jax import input_pipeline
1413

1514

@@ -47,10 +46,4 @@ def _decode_example(example: Dict[str, float]) -> Dict[str, float]:
4746
if framework == 'pytorch':
4847
it = map(data_utils.shard, it)
4948

50-
elif framework == 'jax':
51-
f = functools.partial(
52-
jax.device_put, device=jax_sharding_utils.get_batch_dim_sharding()
53-
)
54-
it = map(f, it)
55-
5649
return it

0 commit comments

Comments
 (0)