Skip to content

CI: 06/06/25 upstream sync #460

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1,907 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
1907 commits
Select commit Hold shift + click to select a range
aa63a15
[tree_util] raise more informative error when pytree equality check f…
jakevdp May 23, 2025
d4ab826
[Mosaic GPU] Add support for copy_gmem_to_smem in Warp semantics.
justinjfu May 23, 2025
f429162
Merge pull request #28978 from matthiasdiener:patch-1
Google-ML-Automation May 23, 2025
c4a90c1
[Mosaic GPU] Add barrier transformation support to tcgen05_mma.
justinjfu May 23, 2025
f5a9d46
Move jax/_src/custom_dce.py to its own BUILD rule
May 23, 2025
9153ab7
Transfer library: poison outstanding buffer fetches upon connection f…
pschuh May 23, 2025
57d07e1
This is a change to patch some internal Google builds while we comple…
ZacCranko May 23, 2025
ae2f943
TSAN CI, make jax buid/test step fail if missing deps wheels
vfdev-5 May 23, 2025
704c3c6
Merge pull request #28979 from jakevdp:pytree-err
Google-ML-Automation May 23, 2025
292dea6
Simplify attention VJP definition
rdyro May 23, 2025
966bcb9
[ragged-paged-attn] Implement static kv cache quantization. (The scal…
Google-ML-Automation May 23, 2025
2b9d7c8
Move jax/_src/tree.py to its own build rule
May 23, 2025
4eb3220
Update index.rst
dlwh May 23, 2025
d0195f2
Move jax/_src/sourcemap to its own build rule
May 24, 2025
0833cc2
Use block_until_ready to fix races in TPU interpret mode tests
jburnim May 24, 2025
29f9905
Merge pull request #28992 from dlwh:marin-link
Google-ML-Automation May 24, 2025
f28565d
Update XlaCallModule so tests are compatible with DCE.
mrguenther May 24, 2025
8b54a6d
Update XLA dependency to use revision
Google-ML-Automation May 24, 2025
f6b8cb6
Enter into the right mesh context during shmap DCE
yashk2810 May 25, 2025
37068b6
Update XLA dependency to use revision
Google-ML-Automation May 25, 2025
c22bba2
Clarify that upper bound takes precedence in jnp.clip where bounds ar…
May 25, 2025
f9c7a14
[Mosaic GPU] Add missing allocator config and skips in one of our dis…
apaszke May 26, 2025
fae05bd
[Pallas:MGPU] Support remote async copies and use them in the collect…
apaszke May 26, 2025
4bfd163
Update XLA dependency to use revision
Google-ML-Automation May 26, 2025
444e952
Fix a test which blocks the openxla change.
mooskagh May 26, 2025
c1e8f25
[Mosaic GPU] Use PTX ISA version = min(ptxas, LLVM)
andportnoy May 7, 2025
f35d708
[pallas] The `cf` dialect is now always available
superbobry May 27, 2025
2cbec58
Merge pull request #28595 from andportnoy:mosaic-gpu-ptx-isa-from-ptx…
Google-ML-Automation May 27, 2025
3aa4e36
Update XLA dependency to use revision
Google-ML-Automation May 27, 2025
f68aab1
[Mosaic GPU] Work around MLIR recognizing strided<[1]> as identity la…
apaszke May 27, 2025
b44b963
[Pallas:MGPU] Make sure that lowering errors mention the offending line
apaszke May 27, 2025
9a7f9f1
[Pallas:MGPU] Add a missing warpgroup barrier before warp core_map
apaszke May 27, 2025
4f717d3
[pallas:mosaic_gpu] `Barrier` and `ClusterBarrier` are now `kw_only=T…
superbobry May 27, 2025
c13de5c
Move jax/_src/custom_derivatives.py to its own BUILD rule
May 27, 2025
8124cb6
[Pallas] Require parallel dimensions to form a prefix of the grid in …
Google-ML-Automation May 27, 2025
71edce4
Skip //third_party/py/jax/tests/pallas:mgpu_ragged_dot_test_gpu_h100 …
belitskiy May 27, 2025
a57b4a1
#sdy remove redundant call to sdy-round-trip-export in JAX export.
bartchr808 May 27, 2025
487eeb4
[Mosaic GPU] Add tests for the Blackwell matmul kernel
apaszke May 27, 2025
f5ffd7f
[Mosaic GPU] Fix missing symbol errors in OSS collective kernels
apaszke May 27, 2025
ee727f9
[Mosaic GPU][NFC] Refactor the body of the matmul kernel
apaszke May 27, 2025
10cdbb7
Block until ready for PGLE test
Google-ML-Automation May 27, 2025
fce93d2
Fix handling of input None in custom_transpose.
dfm May 27, 2025
3c926a2
Update XLA dependency to use revision
Google-ML-Automation May 27, 2025
3b3c338
#sdy Remove redundant sdy export since it's now done as part of `Mlir…
tomnatan30 May 27, 2025
6f0b993
[Pallas:MGPU] Add an unsafe flag that disables automatic WG-barrier i…
apaszke May 27, 2025
e258708
Fix sempahore typo in JAX
apivovarov May 27, 2025
1d10a48
Merge pull request #29001 from johannahaffner:test-clip
Google-ML-Automation May 27, 2025
0f4da0c
Merge pull request #28955 from jax-ml:prevent-partial-eval-dce-effects
Google-ML-Automation May 27, 2025
c09b1bb
Update lock files for jaxlib 0.6.1
hawkinsp May 27, 2025
0caeb98
Merge pull request #29043 from hawkinsp:locks
Google-ML-Automation May 27, 2025
669f08a
Reshape ragged_all_to_all to correct shape before concatenating
ghpvnist May 27, 2025
69c4317
[pallas] Fix `broadcast_in_dim` fuser eval rule.
chr1sj0nes May 28, 2025
b07aa27
Automated Code Change
Google-ML-Automation May 28, 2025
6004c7b
[Mosaic GPU] Make the Blackwell matmul kernel persistent
apaszke May 28, 2025
27e4a74
Move jax/_src/attrs.py to its own BUILD rule
May 28, 2025
1ff8a65
[Mosaic GPU] Perform a cluster barrier before deallocating collective…
apaszke May 28, 2025
f7adde5
[Mosaic GPU] Improve the error message when PTX version inference fails
apaszke May 28, 2025
5635717
Update XLA dependency to use revision
Google-ML-Automation May 28, 2025
0b17f6c
[Mosaic GPU] Implement FragmentedArray.__getitem__ for arbitrary tile…
apaszke May 28, 2025
39f0906
[Mosaic GPU] Use a second warpgroup to store the MMA outputs
apaszke May 28, 2025
360799e
[Mosaic GPU] Reduce SMEM pressure of the GMEM store
apaszke May 28, 2025
98e6041
[Mosaic GPU] Implement a new MMA/TMEM read pipelined matmul kernel
apaszke May 28, 2025
0d0393f
Set the mesh in SPMDAxisContext to be a concrete mesh so that pallas/…
yashk2810 May 28, 2025
fd28b2f
Fix JAX PGLE test
frgossen May 28, 2025
30eecf6
[pallas:triton] Removed the `Triton` prefix from `TritonCompilerParams`
superbobry May 28, 2025
de491b9
[pallas:mosaic_gpu] Added the missing resource estimation rule for `p…
superbobry May 28, 2025
ba64c02
[better-errors] if a non-jaxtype is returned, say it's a return problem
mattjj May 27, 2025
f32ce04
Merge pull request #29030 from dfm:custom-transpose-nones
Google-ML-Automation May 28, 2025
bfc20eb
Merge pull request #28665 from mattjj:smap-systematic
Google-ML-Automation May 28, 2025
68fcf15
Skip TPU metadata server query when not using TPU.
jenriver May 28, 2025
b940738
Add visibility registration for `jax._src.sharding_impls`
May 28, 2025
da4ea8d
Merge pull request #29069 from jenriver:skip_mds_for_cpu
Google-ML-Automation May 28, 2025
1994074
Make CI job names to be shorter
nitins17 May 28, 2025
36eeceb
Update actions to adhere to best practices
MichaelHudgins May 28, 2025
ea4049f
Merge pull request #29074 from MichaelHudgins:actions-fixes
Google-ML-Automation May 28, 2025
5aa3395
[Mosaic GPU] Rework CUDA_ROOT logic a bit
apaszke May 28, 2025
c5b908c
Change all us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-b…
quoctruong May 28, 2025
2dc69da
Integrate LLVM at llvm/llvm-project@2b8bff6f66fd
Google-ML-Automation May 28, 2025
22b4f26
Merge pull request #29047 from mattjj:returning-non-jaxtype
Google-ML-Automation May 28, 2025
37a9ac2
[Pallas Fuser] Add support for basic PRNG op fusion
sharadmv May 29, 2025
a4a31ec
A more numerically stable implementation of logaddexp2
DanisNone May 29, 2025
8982881
[mosaic_gpu] Use `DIScopeForLLVMFuncOpPass` from MLIR instead of its …
superbobry May 29, 2025
bc33d0e
Update XLA dependency to use revision
Google-ML-Automation May 29, 2025
50253f1
[pallas:mosaic_gpu] `emit_pipeline` now allows specifying a carry
superbobry May 29, 2025
5f11054
Merge pull request #29081 from DanisNone:main
Google-ML-Automation May 29, 2025
770eff0
Apply extensive input to extensive output forwarding in scan.
dfm May 23, 2025
38ecd13
[CI] Move k8s actions test files out of .github directory
MichaelHudgins May 29, 2025
67c5e28
cloud_tpu_init: Remove verbose logging.
jenriver May 29, 2025
42977e5
[Pallas/Fuser] Add custom_vjp_call rule for physicalize
Google-ML-Automation May 29, 2025
d1a1346
Merge pull request #28985 from dfm:scan-fwd-ext-traceable
Google-ML-Automation May 29, 2025
f60aa11
Merge pull request #29090 from jenriver:disable_logging
Google-ML-Automation May 29, 2025
605b8c0
Expose `GSPMDSharding` via `jex` as a temporary measure.
yashk2810 May 29, 2025
63c1b8a
Merge pull request #28972 from vfdev-5:fix-tsan-314-jax-build-step
Google-ML-Automation May 29, 2025
64ef37a
[pallas:mosaic] Enabled more lowering rules for all kernel types
superbobry May 29, 2025
1e334cf
Add dtype arg collective_matmul_mgpu.py to support bfloat16
hanzlfs May 28, 2025
f823faf
Update `compile_options_proto_cc` deps to new proto dir.
zacmustin May 29, 2025
6ddbdd9
[Mosaic GPU] Fix collective argument to infer_tmem_layout
justinjfu May 29, 2025
57fe3f2
[Mosaic GPU] Check that the device order in the mesh follows logical_ids
apaszke May 29, 2025
da845de
Lock down more permissions and update default usage for some workflows
MichaelHudgins May 29, 2025
448c07d
Reverts 42977e51816b9eb42c7360abe05f56cad70e894a
Google-ML-Automation May 29, 2025
4e5725b
Merge pull request #29076 from hanzlfs:zhonglin/mosaic/collective_matmul
Google-ML-Automation May 29, 2025
3abdf56
[Pallas Fuser] Use lu transformation to physicalize fwd/bwd functions…
sharadmv May 29, 2025
976aa7a
[pallas] Added `pl.loop` -- a decorator for writing stateless loops
superbobry May 29, 2025
a808fe8
[Mosaic GPU] Add non-collective blackwell matmul example
justinjfu May 29, 2025
8176920
[Pallas] Add a base class for custom BufferedRef implementations.
justinjfu May 29, 2025
5b13729
[Mosaic GPU] Add support for inout arguments
apaszke May 29, 2025
7eec8e1
[hijax] all pre-existing Box tests passing, still using typechange env
mattjj May 17, 2025
2c838d4
rename layout to format, part 1
froystig May 29, 2025
da106b9
Merge pull request #29101 from mattjj:hijax
Google-ML-Automation May 29, 2025
663e50f
[Mosaic] Make 1D tiling agnostic to large 2nd minor flags.
WindQAQ May 29, 2025
7ff6f0d
rename `Array.layout` to `Array.format`
froystig May 30, 2025
75b2c7e
[Mosaic GPU] Move the semaphore implementation to Mosaic
apaszke May 30, 2025
70f5aa4
Update XLA dependency to use revision
Google-ML-Automation May 30, 2025
a940776
[pallas:mosaic_gpu] Unconditionally emit line info for Mosaic GPU ker…
superbobry May 30, 2025
6564a4b
Remove Mac x86 from the installation instructions.
hawkinsp May 30, 2025
2d4baf4
Merge pull request #29097 from MichaelHudgins:more-actions-fixes
Google-ML-Automation May 30, 2025
7b01f6d
[typing] adjust axis annotation for ufunc.reduce
jakevdp May 30, 2025
73c016a
Don't sort replicated and unreduced axes wrt mesh axis names as they …
yashk2810 May 30, 2025
5a066bc
Merge pull request #29113 from hawkinsp:install
Google-ML-Automation May 30, 2025
213985a
replace mentions of `Compiled.input_layouts` with `Compiled.input_for…
froystig May 30, 2025
d15253e
[Mosaic] Support interleaved packing on TPUv4-.
WindQAQ May 30, 2025
91be698
Merge pull request #29120 from jakevdp:ufunc-annotation
Google-ML-Automation May 30, 2025
c2e7d61
[pallas] Expose TPUInterpretParams in jax.experimental.pallas.tpu
jburnim May 30, 2025
6ba11c1
Pass list rather than generator to donate_argnums
jakevdp May 30, 2025
f13a560
Update XLA dependency to use revision
Google-ML-Automation May 30, 2025
8a22011
Merge pull request #29128 from jakevdp:shard-map-generator
Google-ML-Automation May 30, 2025
69bcb0d
[mutable-arrays] don't let scan AD hoist mutable operations
mattjj May 30, 2025
581cb62
[mutable-arrays] add basic tests for vmap + mutable array
mattjj May 30, 2025
753ae57
Fix a rare numerical flake in svd_test seen on TPU v6e.
hawkinsp May 30, 2025
ebf0588
Merge pull request #29130 from mattjj:mutable-array-vmap2
Google-ML-Automation May 30, 2025
22f04d9
Refactor jax/_src/api.py and associated files in preparation for movi…
May 30, 2025
67bf8f9
Add experimental array serialization for nested pytrees
rdyro May 30, 2025
26228f5
Allow setting non-string TPU runtime flags. For example:
pschuh May 30, 2025
6f0f0ad
fix incorrect TODO
May 30, 2025
6c18aa8
[Mosaic] Move i1 broadcast lowering logic to Mosaic.
WindQAQ May 30, 2025
3c04713
Automated Code Change
Google-ML-Automation May 31, 2025
ff6892b
Update XLA dependency to use revision
Google-ML-Automation May 31, 2025
5cca31f
test_binary_ufunc_reduce now also tests behavior with the initial and…
DanisNone May 30, 2025
0a1ada8
Allow specifying non-differentiable arguments by name
Google-ML-Automation Jun 1, 2025
88dbf60
Update XLA dependency to use revision
Google-ML-Automation Jun 1, 2025
107efde
Reverts 73c016a534af51614741d70d36c2c75ca59f2dcc
yashk2810 Jun 1, 2025
52e5a87
Introduce profiler_options in the documentation.
sannidhyachauhan Jun 2, 2025
1914815
Automated Code Change
Google-ML-Automation Jun 2, 2025
b782b46
Update XLA dependency to use revision
Google-ML-Automation Jun 2, 2025
27e454d
[JAX] Use `util.fun_name` to determine `WrappedFun.__name__` instead …
ZacharyGarrett Jun 2, 2025
8f5dae4
[jaxlib] Use SafeStaticInit in more places.
hawkinsp Jun 2, 2025
a964f54
Update partial eval to avoid DCEing a specific set of effects.
dfm Jun 2, 2025
73aabb4
Bump the minimum NumPy and SciPy versions.
hawkinsp Jun 2, 2025
432de62
Merge pull request #29165 from dfm:pe-effects
Google-ML-Automation Jun 2, 2025
e1b59e5
Merge pull request #29084 from MichaelHudgins:actions
Google-ML-Automation Jun 2, 2025
8625207
Raise a better error when inputs sharded on explicit mesh axes are cl…
yashk2810 Jun 2, 2025
a99ca73
Merge pull request #29127 from mattjj:scan-vjp-mutable-hoist
Google-ML-Automation Jun 2, 2025
d62d94c
Add a pretty printing rule for custom_lin_p.
dfm Jun 2, 2025
2c3018d
Merge pull request #29166 from hawkinsp:minver
Google-ML-Automation Jun 2, 2025
0edfc72
Merge pull request #29169 from dfm:pp-custom-lin
Google-ML-Automation Jun 2, 2025
8eaa9bf
[cleanup] inline uses of NumpyComplexWarning
jakevdp Jun 2, 2025
980f5dc
always compile Pallas calls, enabling `pallas_call` under `disable_jit`
froystig Jun 2, 2025
a43ccbb
Fix native tiling logic in infer_vector_layout.
Google-ML-Automation Jun 2, 2025
9a32fab
Merge pull request #29170 from jakevdp:complex-warning
Google-ML-Automation Jun 2, 2025
0347b66
Merge pull request #29168 from froystig:pallas-call-eager
Google-ML-Automation Jun 2, 2025
62ab725
Update workflow files to use new ml-build containers.
quoctruong Jun 2, 2025
f6e6118
Merge pull request #29144 from DanisNone:main
Google-ML-Automation Jun 2, 2025
e9925ee
Enable profiler_test for TPU's
cliveverghese Jun 2, 2025
3e52872
Clean up some unused GPU linear algebra kernels.
dfm Jun 2, 2025
674fb5b
Simplify `jnp.isclose`
soraros Jun 1, 2025
6f0c2a8
Clean up some unused GPU sparse kernels.
dfm Jun 2, 2025
94037a8
Maintain the dtype of the input on the output in `broadcast_one_to_all`.
yashk2810 Jun 2, 2025
9e4ff92
[pallas] Added a note on `pl.loop` to the changelog
superbobry Jun 2, 2025
81de911
[pallas:mosaic_gpu] `plgpu.nd_loop` is now a decorator similar to `pl…
superbobry Jun 2, 2025
3ede957
[Mosaic GPU] Add reduction support for TCGEN05 layout.
justinjfu Jun 2, 2025
2f32a79
Clean up unused GPU RNN kernels.
dfm Jun 2, 2025
31017c5
When the size of the remainder array is 0, don't append it to the rem…
yashk2810 Jun 2, 2025
3545339
Merge pull request #29153 from soraros:simplify-isclose
Google-ML-Automation Jun 2, 2025
4367d7c
Move jax/_src/extend/* to its own build rule
Jun 3, 2025
41fd7a7
Move jax/_src/custom_partitioning_sharding_rule.py to its own build rule
Jun 3, 2025
2193c59
Update XLA dependency to use revision
Google-ML-Automation Jun 3, 2025
d30b176
[Mosaic GPU] Add support for tiled loads and stores of `f8` data types.
bchetioui Jun 3, 2025
0a5924c
[pallas:mosaic_gpu] Dropped the `GPU` prefix from `GPUShapeDtypeStruct`
superbobry Jun 3, 2025
d0d0815
[pallas:mosaic] Removed the `TPU` prefix from `TPUCompilerParams` and…
superbobry Jun 3, 2025
87641cc
[pallas:mosaic] Dropped the `TPU` prefix from the recently added `TPU…
superbobry Jun 3, 2025
6241a2a
Propagate layouts correctly via mutable arrays
yashk2810 Jun 3, 2025
d17b292
Don't canonicalize in `__eq__` if `other` is a PartitionSpec since it…
yashk2810 Jun 3, 2025
cecf2f6
[imports] avoid top-level imports in jax.numpy sources
jakevdp Jun 3, 2025
0b89b23
Merge pull request #29186 from jakevdp:jax-numpy-imports
Google-ML-Automation Jun 3, 2025
e24f780
[Mosaic GPU] Add lowering for `2xf32 -> 2xf8e4m3fn` conversions.
bchetioui Jun 3, 2025
cda50f5
[JAX] Remove the redundant pjit BUILD target.
Jun 3, 2025
554cc01
[Mosaic GPU] Add BUILD rules for blackwell matmul kernel
justinjfu Jun 3, 2025
7dd0344
[jaxlib] Add `PyClient::Compile` method that returns an unloaded `PyE…
danielsuo Jun 3, 2025
6cd196a
Prototype of cross-host device transfers in IFRT-PJRT.
emilyfertig Jun 3, 2025
7e0913f
[Mosaic GPU] Fix `bitcast` logic in `shfl_bfly`.
bchetioui Jun 3, 2025
b7adddf
Reduce block sizes in attention to prevent running out of shared memo…
rdyro Jun 3, 2025
e20b3a4
Fix sharding-in-types + lax.map usage when batch_size usage has a rem…
yashk2810 Jun 3, 2025
1216dac
Resurrect _pjit_lower's cache because it's important for python dispa…
yashk2810 Jun 3, 2025
b6a1575
[pallas] In TPU interpret mode, run kernels in parallel over Megacore…
jburnim Jun 3, 2025
3c1df03
[Pallas] Add forward-compatible i1 broadcast.
WindQAQ Jun 3, 2025
ab84dde
Update more uses of `backend.compile` to `backend.compile_and_load`.
danielsuo Jun 3, 2025
31fde29
Fix pgle test breakage
yashk2810 Jun 4, 2025
8519fd2
[Pallas] Fix missing sub lowering rule for sparsecore.
justinjfu Jun 4, 2025
002078b
Only infer sharding from input in full_like (in eager mode) if the in…
yashk2810 Jun 4, 2025
7d93eee
Make experimental pytree_serialization visible in OSS jax build
rdyro Jun 4, 2025
c222fb6
Update XLA dependency to use revision
Google-ML-Automation Jun 4, 2025
b7833e9
#sdy Fallback to GSPMD in JAX export if the loaded module was lowered…
bartchr808 Jun 4, 2025
d34f1dd
[Pallas/Mosaic GPU] Expose the new `TCGEN05_ROW` layout.
bchetioui Jun 4, 2025
d2d6211
#sdy Have JAX export compat tests also run on Shardy.
bartchr808 Jun 4, 2025
b218617
[Mosaic GPU] Use the `mosaic_gpu.sliceSMEM` MLIR op when using WG sem…
dimitar-asenov Jun 4, 2025
6e75a04
Raise `NotImplementedError` instead of `ValueError` when using Shardy…
ZixuanJiang Jun 4, 2025
b9658ed
[jaxlib] Bind 'compile' to `xla::PyClient::Compile` rather than `xla:…
danielsuo Jun 4, 2025
bf635d8
[jax::compiler] Bind `compiler.backend_compile` to `xla::PyClient::Co…
danielsuo Jun 4, 2025
2226be4
Fix typos discovered by codespell
cclauss May 31, 2025
8d8cc2b
Reverts 6cd196a5db22b8db0ed4000e4cf67ad748bf52f3
yashk2810 Jun 4, 2025
705bcbc
Merge pull request #29132 from hawkinsp:svd
Google-ML-Automation Jun 4, 2025
e19e18d
Add not-implemented sharding rule in `third_party/py/jax/_src/cudnn/f…
ZixuanJiang Jun 4, 2025
b3db374
Make `unreduced` argument in `PartitionSpec` a `set | frozenset` inst…
yashk2810 Jun 4, 2025
8c34865
[Mosaic GPU] Add a test for TMA multicasts in pallas. This also effec…
Rifur13 Jun 4, 2025
704eb71
jnp.array: avoid call to stack
jakevdp Jun 3, 2025
032afca
Merge pull request #29146 from cclauss:codespell
Google-ML-Automation Jun 4, 2025
2acbbcc
Add a general system for keeping track of quasi-dynamic data (QDD).
dougalm May 30, 2025
1554de5
Fix documentation for the CLI `up` command in the debugger.
Google-ML-Automation Jun 4, 2025
08530bc
Link c-api raw buffer support into jaxlib.
pschuh Jun 4, 2025
1ccc387
[Pallas Fuser] Add basic reshape push rule
sharadmv Jun 5, 2025
9f7e802
[Rollback] Roll-forward with fix and test: prototype of cross-host de…
emilyfertig Jun 5, 2025
6ad1b11
Merge pull request #29245 from jax-ml:quasi-dynamic-data
Google-ML-Automation Jun 5, 2025
c149078
fix sharding-in-types + from_edtype
mattjj Jun 5, 2025
5f05406
Merge pull request #29258 from mattjj:andy-fix
Google-ML-Automation Jun 5, 2025
c5ef4e5
Update XLA dependency to use revision
Google-ML-Automation Jun 5, 2025
986c411
[Pallas/Mosaic GPU] Expose the new `TCGEN05_COL` layout.
bchetioui Jun 5, 2025
af8e2e3
doc: clarified lack of gpu support for schur and sqrtm
YousefElbrolosy May 29, 2025
b7a3250
Don't repeatedly recompute a tuple of axis names for a membership test.
hawkinsp Jun 5, 2025
59c0171
[Mosaic GPU] Move `should_have_transforms` to `inference_utils`.
dimitar-asenov Jun 5, 2025
f8ab209
[Pallas][Mosaic GPU] Use separate allocations for collective TMEM.
justinjfu Jun 5, 2025
9c16c6b
Merge pull request #29224 from jakevdp:array-impl
Google-ML-Automation Jun 5, 2025
1c4bb50
Cache get_vma because it's the same thing we do for `get_sharding` an…
yashk2810 Jun 5, 2025
0b899a9
lax_numpy: move array and asarray to their own submodule
jakevdp Jun 5, 2025
ef9d3f8
skip pytype on slow file
mattjj Jun 5, 2025
9fc670e
[cleanup] remove core.gensym, and Var.suffix
mattjj Jun 5, 2025
f51effa
Merge pull request #29087 from YousefElbrolosy:doc/sqrtm-schur-not-su…
Google-ML-Automation Jun 5, 2025
4353f34
Merge pull request #29246 from jakevdp:array-refactor
Google-ML-Automation Jun 5, 2025
18d0da9
Make sure unsupported transfers between multi-process CPU arrays and …
emilyfertig Jun 5, 2025
5c33588
Merge pull request #29273 from mattjj:no-more-gensym
Google-ML-Automation Jun 5, 2025
960f9c5
Don't recompute source_info.current() in DynamicJaxprTracer.
hawkinsp Jun 5, 2025
ff50b5f
[Pallas][Mosaic GPU] Support column slicing on TMEM.
justinjfu Jun 5, 2025
013b1b1
[Mosaic GPU] Fix `2xf32 -> 2xf8e4m3fn` conversion.
bchetioui Jun 5, 2025
fc43122
[Pallas][Mosaic GPU] Skip tcgen05 reduce test on WG semantics.
justinjfu Jun 5, 2025
c30ed91
[Pallas][Mosaic GPU] Add support for load/broadcast using TCGEN05 ROW…
justinjfu Jun 5, 2025
a5ce3ad
lax.top_k: raise error if indices will overflow
jakevdp Jun 4, 2025
f4597d3
Merge pull request #29254 from jakevdp:top-k-overflow
Google-ML-Automation Jun 6, 2025
50d93ee
Bring back tree concat optimization for np.array(...)
pschuh Jun 6, 2025
fd43650
[Mosaic] Adds both direct (where hardware can) support for int8 Trans…
Google-ML-Automation Jun 6, 2025
20d641f
[Mosaic GPU] Add support for lowering `2xbf16 -> 2xf8e4m3fn` converts.
bchetioui Jun 6, 2025
e66745d
Fix logic for checking supported cross-host device transfers, since t…
emilyfertig Jun 6, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
32 changes: 23 additions & 9 deletions .bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ build -c opt
build --output_filter=DONT_MATCH_ANYTHING

build --copt=-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir.
build --copt=-DNB_DOMAIN=jax

build --legacy_external_runfiles=false

# #############################################################################
# Platform Specific configs below. These are automatically picked up by Bazel
Expand Down Expand Up @@ -97,6 +100,7 @@ build:windows --incompatible_strict_action_env=true
# #############################################################################
build:nonccl --define=no_nccl_support=true

build --repo_env USE_PYWRAP_RULES=1
build:posix --copt=-fvisibility=hidden
build:posix --copt=-Wno-sign-compare
build:posix --cxxopt=-std=c++17
Expand Down Expand Up @@ -130,19 +134,21 @@ build:clang --copt=-Wno-gnu-offsetof-extensions
build:clang --copt=-Qunused-arguments
# Error on struct/class mismatches, since this causes link failures on Windows.
build:clang --copt=-Werror=mismatched-tags
# Required when building with clang>=19, see jax-ml/jax#27091
build:clang --copt=-Wno-error=c23-extensions

# Configs for CUDA
build:cuda --repo_env TF_NEED_CUDA=1
build:cuda --repo_env TF_NCCL_USE_STUB=1
# "sm" means we emit only cubin, which is forward compatible within a GPU generation.
# "compute" means we emit both cubin and PTX, which is larger but also forward compatible to future GPU generations.
build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,compute_90"
build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,sm_90,sm_100,compute_120"
build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain
build:cuda --@local_config_cuda//:enable_cuda

# Default hermetic CUDA and CUDNN versions.
build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2"
build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.1.1"
build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.8.0"
build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.8.0"
build:cuda --@local_config_cuda//cuda:include_cuda_libs=true

# This config is used for building targets with CUDA libraries from stubs.
Expand Down Expand Up @@ -238,6 +244,9 @@ build:ci_linux_aarch64_base --config=clang --verbose_failures=true
build:ci_linux_aarch64_base --action_env=TF_SYSROOT="/dt10"
build:ci_linux_aarch64_base --color=yes

# This appears to help avoid a timeout in CI for linalg_test.
build:ci_linux_aarch64_base --test_env=OMP_NUM_THREADS=8

build:ci_linux_aarch64 --config=ci_linux_aarch64_base
build:ci_linux_aarch64 --host_crosstool_top="@ml2014_clang_aarch64_config_aarch64//crosstool:toolchain"
build:ci_linux_aarch64 --crosstool_top="@ml2014_clang_aarch64_config_aarch64//crosstool:toolchain"
Expand All @@ -260,8 +269,8 @@ build:ci_darwin_arm64 --color=yes
# Windows x86 CI configs
build:ci_windows_amd64 --config=avx_windows
build:ci_windows_amd64 --compiler=clang-cl --config=clang --verbose_failures=true
build:ci_windows_amd64 --crosstool_top="@xla//tools/toolchains/win/20240424:toolchain"
build:ci_windows_amd64 --extra_toolchains="@xla//tools/toolchains/win/20240424:cc-toolchain-x64_windows-clang-cl"
build:ci_windows_amd64 --crosstool_top="@xla//tools/toolchains/win2022/20241118:toolchain"
build:ci_windows_amd64 --extra_toolchains="@xla//tools/toolchains/win2022/20241118:cc-toolchain-x64_windows-clang-cl"
build:ci_windows_amd64 --host_linkopt=/FORCE:MULTIPLE --linkopt=/FORCE:MULTIPLE
build:ci_windows_amd64 --color=yes

Expand Down Expand Up @@ -321,6 +330,8 @@ build:rbe_linux_x86_64 --config=ci_linux_x86_64
build:rbe_linux_x86_64_cuda --config=rbe_linux_x86_64_base
build:rbe_linux_x86_64_cuda --config=ci_linux_x86_64_cuda
build:rbe_linux_x86_64_cuda --repo_env=REMOTE_GPU_TESTING=1
# Speed up CUDA repos creation by downloading ".tar" dists from the mirror.
build:rbe_linux_x86_64_cuda --repo_env=USE_CUDA_TAR_ARCHIVE_FILES=1

# RBE configs for Windows
# Set the remote worker pool
Expand All @@ -329,9 +340,9 @@ common:rbe_windows_amd64 --remote_instance_name=projects/tensorflow-testing/inst
build:rbe_windows_amd64 --config=rbe

# Set the host, execution, and target platform
build:rbe_windows_amd64 --host_platform="@xla//tools/toolchains/win:x64_windows-clang-cl"
build:rbe_windows_amd64 --extra_execution_platforms="@xla//tools/toolchains/win:x64_windows-clang-cl"
build:rbe_windows_amd64 --platforms="@xla//tools/toolchains/win:x64_windows-clang-cl"
build:rbe_windows_amd64 --host_platform="@xla//tools/toolchains/win2022:windows_ltsc2022_clang"
build:rbe_windows_amd64 --extra_execution_platforms="@xla//tools/toolchains/win2022:windows_ltsc2022_clang"
build:rbe_windows_amd64 --platforms="@xla//tools/toolchains/win2022:windows_ltsc2022_clang"

build:rbe_windows_amd64 --shell_executable=C:\\tools\\msys64\\usr\\bin\\bash.exe
build:rbe_windows_amd64 --enable_runfiles
Expand Down Expand Up @@ -371,6 +382,9 @@ build:rbe_cross_compile_base --remote_instance_name=projects/tensorflow-testing/
build:rbe_cross_compile_linux_aarch64 --config=cross_compile_linux_aarch64
build:rbe_cross_compile_linux_aarch64 --config=rbe_cross_compile_base

# Avoids a timeout in linalg_test on ARM.
build:rbe_cross_compile_linux_aarch64 --test_env=OMP_NUM_THREADS=8

# Mac x86
build:cross_compile_darwin_x86_64 --config=cross_compile_base
build:cross_compile_darwin_x86_64 --config=nonccl
Expand Down Expand Up @@ -410,7 +424,7 @@ build:rbe_cross_compile_darwin_x86_64 --config=rbe_cross_compile_base
#############################################################################

build:debug_symbols --strip=never --per_file_copt="xla/pjrt|xla/python@-g3"
build:debug --config debug_symbols -c fastbuild
build:debug --config=debug_symbols -c fastbuild

# Load `.jax_configure.bazelrc` file written by build.py
try-import %workspace%/.jax_configure.bazelrc
Expand Down
2 changes: 1 addition & 1 deletion .github/ISSUE_TEMPLATE/bug-report.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ body:

[issue search]: https://github.com/jax-ml/jax/search?q=is%3Aissue&type=issues

[Raw report]: http://github.com/jax-ml/jax/issues/new
[Raw report]: https://github.com/jax-ml/jax/issues/new?template=none
- type: textarea
attributes:
label: Description
Expand Down
20 changes: 20 additions & 0 deletions .github/actionlint.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Configuration related to self-hosted runner.
self-hosted-runner:
labels:
- "linux-x86-n2-32" # Linux X86 runner using the 32 vcpu n2-standard-32 machine.
- "linux-x86-n2-64" # Linux X86 runner using the 64 vcpu n2-standard-64 machine.
- "linux-x86-g2-16-l4-1gpu" # Linux X86 GPU runner using g2-standard-16 machine with 1 NVIDIA L4 GPU attached.
- "linux-x86-g2-48-l4-4gpu" # Linux X86 GPU runner using g2-standard-48 machine with 4 NVIDIA L4 GPUs attached.
- "linux-x86-ct5lp-224-8tpu" # Linux X86 TPU runner using ct5lp-hightpu-8t machine with 2x4 topology.
- "linux-arm64-c4a-16" # Linux ARM64 CPU Runner using the 16 vcpu c4a-standard-16 machine.
- "linux-arm64-c4a-64" # Linux ARM64 CPU Runner using the 64 vcpu c4a-standard-64 machine.
- "windows-x86-n2-16" # Windows X86 runner using n2-standard-16 machine.
- "windows-x86-n2-64" # Windows X86 runner using n2-standard-64 machine.
- "linux-x86-a4-224-b200-1gpu" # Linux X86 GPU runner using 1 B200 GPU and 1/8 the resources of a a4-highgpu-8g machine
- "linux-x86-a3-8g-h100-8gpu" # Linux X86 GPU runner using a3-highgpu-8g machine with 8 NVIDIA H100 GPUs attached.
- "linux-x86-ct6e-180-8tpu" # Linux X86 TPU runner using ct6e-hightpu-8t machine with 2x4 topology.
- "linux-x86-ct6e-180-4tpu" # Linux X86 TPU runner using ct6e-hightpu-4t machine with 2x2 topology.
- "linux-x86-ct4p-240-4tpu" # Linux X86 TPU runner using ct4p-hightpu-4t machine with 2x2x1 topology.
- "linux-x86-n2-128" # Linux X86 runner using the 128 vcpu n2-standard-128 machine.
- "linux-x86-n2-16" # Linux X86 runner using the 16 vcpu n2-standard-16 machine.
- "linux-x86_64-cirrascale-64-8gpu-amd-mi250" # AMD runner
4 changes: 3 additions & 1 deletion .github/workflows/asan.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ on:
- main
paths:
- '**/workflows/asan.yaml'

permissions: {}
jobs:
asan:
# Don't execute in fork due to runner type
Expand Down Expand Up @@ -41,11 +41,13 @@ jobs:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
path: jax
persist-credentials: false
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
repository: python/cpython
path: cpython
ref: v3.13.0
persist-credentials: false
- name: Build CPython with ASAN enabled
env:
ASAN_OPTIONS: detect_leaks=0
Expand Down
60 changes: 60 additions & 0 deletions .github/workflows/bazel_cpu_py_import_rbe.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# CI - Bazel CPU tests with py_import (RBE)
#
# This workflow runs the Bazel CPU tests with py_import dependency. It can only be triggered by
# other workflows via `workflow_call`. It is used by the `CI - Wheel Tests (Continuous)` workflows
# to run the Bazel CPU tests.
#
# It consists of the following job:
# run-tests:
# - Executes the `run_bazel_test_cpu_py_import_rbe.sh` script, which performs the following actions:
# - Runs the Bazel CPU tests with py_import dependency.
name: CI - Bazel CPU tests with py_import (RBE)
permissions: {}
on:
workflow_call:
inputs:
runner:
description: "Which runner should the workflow run on?"
type: string
default: "linux-x86-n2-16"
python:
description: "Which python version to test?"
type: string
default: "3.12"
enable-x64:
description: "Should x64 mode be enabled?"
type: string
default: "0"
halt-for-connection:
description: 'Should this workflow run wait for a remote connection?'
type: string
default: 'no'

jobs:
run-tests:
defaults:
run:
# Explicitly set the shell to bash
shell: bash
runs-on: ${{ inputs.runner }}
container: ${{ (contains(inputs.runner, 'linux-x86') && 'us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest') ||
(contains(inputs.runner, 'linux-arm64') && 'us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build-arm64:latest') }}
env:
JAXCI_HERMETIC_PYTHON_VERSION: ${{ inputs.python }}
JAXCI_ENABLE_X64: ${{ inputs.enable-x64 }}

name: "${{ (contains(inputs.runner, 'linux-x86') && 'linux x86') ||
(contains(inputs.runner, 'linux-arm64') && 'linux arm64') }}, py ${{ inputs.python }}, x64=${{ inputs.enable-x64 }}"

steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
persist-credentials: false
# Halt for testing
- name: Wait For Connection
uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c
with:
halt-dispatch-input: ${{ inputs.halt-for-connection }}
- name: Run Bazel CPU tests with py_import (RBE)
timeout-minutes: 60
run: ./ci/run_bazel_test_cpu_py_import_rbe.sh
19 changes: 13 additions & 6 deletions .github/workflows/bazel_cpu_rbe.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ on:
branches:
- main
- 'release/**'

permissions: {}
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
# Don't cancel in-progress jobs for main/release branches.
Expand All @@ -28,8 +28,8 @@ jobs:
run_tests:
if: github.event.repository.fork == false
runs-on: ${{ matrix.runner }}
container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') ||
(contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') }}
container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest') ||
(contains(matrix.runner, 'linux-arm64') && 'us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build-arm64:latest') }}
env:
JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }}
JAXCI_ENABLE_X64: ${{ matrix.enable-x_64 }}
Expand All @@ -46,13 +46,20 @@ jobs:
enable-x_64: 1
- python: "3.13"
enable-x_64: 0
name: "Bazel CPU tests (${{ matrix.runner }}, Python ${{ matrix.python }}, x64=${{ matrix.enable-x_64 }})"
# Only test a single Python version on Arm64 as we don't run the tests.
- python: "3.10"
runner: "linux-arm64-c4a-16"
name: "Bazel CPU ${{ (contains(matrix.runner, 'linux-arm64') && 'build only' || 'tests') }} (${{ matrix.runner }}, Python ${{ matrix.python }}, x64=${{ matrix.enable-x_64 }})"
# End Presubmit Naming Check github-cpu-presubmits
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
persist-credentials: false
- name: Wait For Connection
uses: google-ml-infra/actions/ci_connection@main
uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c
with:
halt-dispatch-input: ${{ inputs.halt-for-connection }}
- name: Run Bazel CPU Tests with RBE
# Since we do not have a Linux Arm64 RBE pool, we do not run the tests on Arm64. Instead, we
# cross-compile the tests on the Linux x86 RBE pool.
- name: ${{ (contains(matrix.runner, 'linux-arm64') && 'Build' || 'Run') }} Bazel CPU Tests with RBE
run: ./ci/run_bazel_test_cpu_rbe.sh
46 changes: 30 additions & 16 deletions .github/workflows/bazel_cuda_non_rbe.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,48 +17,52 @@ on:
runner:
description: "Which runner should the workflow run on?"
type: string
required: true
default: "linux-x86-n2-16"
python:
description: "Which python version to test?"
type: string
required: true
default: "3.12"
enable-x64:
description: "Should x64 mode be enabled?"
type: string
required: true
default: "0"
jaxlib-version:
description: "Which jaxlib version to test? (head/pypi_latest)"
type: string
default: "head"
gcs_download_uri:
description: "GCS location URI from where the artifacts should be downloaded"
required: true
default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
type: string
halt-for-connection:
description: 'Should this workflow run wait for a remote connection?'
type: boolean
required: false
default: false

type: string
default: 'no'
permissions: {}
jobs:
run-tests:
defaults:
run:
# Explicitly set the shell to bash
shell: bash
runs-on: ${{ inputs.runner }}
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.3-cudnn9.1:latest"
container: "us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build-cuda12.8-cudnn9.8:latest"

env:
JAXCI_HERMETIC_PYTHON_VERSION: ${{ inputs.python }}
JAXCI_ENABLE_X64: ${{ inputs.enable-x64 }}
# Enable writing to the Bazel remote cache bucket.
JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE: "1"

name: "Bazel single accelerator and multi-accelerator CUDA tests (${{ inputs.runner }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})"
name: "jaxlib=${{ inputs.jaxlib-version }},
${{ (contains(inputs.runner, 'h100') && 'h100') ||
(contains(inputs.runner, 'b200') && 'b200') ||
(contains(inputs.runner, 'l4') && 'l4') }}, py ${{ inputs.python }}, x64=${{ inputs.enable-x64 }}"

steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
persist-credentials: false
- name: Set env vars for use in artifact download URL
run: |
os=$(uname -s | awk '{print tolower($0)}')
Expand All @@ -77,11 +81,21 @@ jobs:
# fails. Instead, we verify the outcome in the next step so that we can print a more
# informative error message.
continue-on-error: true
run: >-
mkdir -p $(pwd)/dist &&
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ &&
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ &&
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/
run: |
mkdir -p $(pwd)/dist
gcloud storage cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/

if [[ ${{ inputs.jaxlib-version }} == "head" ]]; then
gcloud storage cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ &&
gcloud storage cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ &&
gcloud storage cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/
elif [[ ${{ inputs.jaxlib-version }} == "pypi_latest" ]]; then
PYTHON=python${{ inputs.python }}
$PYTHON -m pip download jaxlib jax-cuda12-pjrt jax-cuda12-plugin --dest $(pwd)/dist/
else
echo "Invalid jaxlib version: ${{ inputs.jaxlib-version }}"
exit 1
fi
- name: Skip the test run if the wheel artifacts were not downloaded successfully
if: steps.download-wheel-artifacts.outcome == 'failure'
run: |
Expand All @@ -91,7 +105,7 @@ jobs:
exit 1
# Halt for testing
- name: Wait For Connection
uses: google-ml-infra/actions/ci_connection@main
uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c
with:
halt-dispatch-input: ${{ inputs.halt-for-connection }}
- name: Run Bazel CUDA tests (Non-RBE)
Expand Down
8 changes: 5 additions & 3 deletions .github/workflows/bazel_cuda_rbe.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
# Don't cancel in-progress jobs for main/release branches.
cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }}

permissions: {}
jobs:
run_tests:
if: github.event.repository.fork == false
runs-on: ${{ matrix.runner }}
container: 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest'
container: 'us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest'
env:
JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }}
JAXCI_ENABLE_X64: ${{ matrix.enable-x_64 }}
Expand All @@ -49,8 +49,10 @@ jobs:
# End Presubmit Naming Check github-cuda-presubmits
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
persist-credentials: false
- name: Wait For Connection
uses: google-ml-infra/actions/ci_connection@main
uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c
with:
halt-dispatch-input: ${{ inputs.halt-for-connection }}
- name: Run Bazel CUDA Tests with RBE
Expand Down
Loading
Loading