Migrate JAX workloads from pmap to jit #5110
CI.yml
on: pull_request
fastmri
5m 12s
wmt_jax
11m 48s
wmt_pytorch
12m 43s
imagenet_jax
5m 11s
imagenet_pytorch
4m 0s
criteo_jax
3m 51s
criteo_pytorch
3m 48s
speech_jax
5m 24s
speech_pytorch
4m 40s
ogbg
4m 0s
pytest-params
7m 44s
pytest-baselines
4m 28s