Skip to content

CI: 06/17/25 upstream sync #474

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2,090 commits into from

Conversation

rocm-repo-management-api-2[bot]
Copy link

Daily sync with upstream

jakevdp and others added 30 commits June 3, 2025 10:13
The conversion uses the `cvt.rn.satfinite.e4m3x2.f32` intrinsics, which means
that the saturation behaviour is different from XLA's default.

This does ask the question of which numerical behaviour we expect Mosaic GPU to
uphold---but we probably don't want to propagate NaNs in this case anyway.

PiperOrigin-RevId: 766737083
…xecutable`.

- Introduce `xla::PyExecutable` so we have a public constructor for returning an `xla::nb_class_ptr` from `PyClient::Compile`. There might be other acceptable ways of accomplishing this, but we have a `PyLoadedExecutable` object, so going for consistency.
- Migrate uses of `ifrt::Executable` to `ifrt::ExecutableRef` (an alias for `std::shared_ptr<ifrt::Executable>`). There might be undesirable consequences for doing this (i.e., a reason why this wasn't migrated before).

PiperOrigin-RevId: 766757937
For now it only works with the TFRT TPU runtime, because other PjRt plugins don't implement the necessary APIs. The per-shard indices of the source and destination shardings must be the same, and all shards must require cross-host transfers (support for a mixture of cross-host and host-local transfers is forthcoming).

Transfers take place via the xla::ifrt::PjRtClient::CopyArrays API, which copies the buffers from a set of arrays to a new device list. The distributed KV store from the coordination service is used to store metadata for cross-host transfers. The receiving process populates the store with a descriptor, and the sending process reads it and completes the send.

PiperOrigin-RevId: 766765989
We were not testing the logic for non-32-bit-wide dtypes, and as a result
missed that one of the `bitcast`s was converting between two types with
different bitwidths.

PiperOrigin-RevId: 766775563
…ry on L4.

New values (though small) show very good performance on ampere.

PiperOrigin-RevId: 766797437
…tch performance.

PiperOrigin-RevId: 766811714
… cores.

Internally, TPU interpret mode uses a new io_callback which spawns multiple threads to simulate multiple Megacore cores.

Also updates some comments / code / variable names to better distinguish between internal indices used in interpret mode vs. indices into the Pallas grid.

PiperOrigin-RevId: 766851983
Missed this in jax-ml@6c18aa8

PiperOrigin-RevId: 766873399
PiperOrigin-RevId: 766895212
…put's sharding is concrete i.e. does not contain an AbstractMesh

PiperOrigin-RevId: 766962130
… for GSPMD.

The final module that will be created by JAX export will contain a bit of Shardy and GSPMD ops. What we then do during compilation is detect whether there is a mix of these ops. If there is, we override the build option and instead use GSPMD for propagation (we have well tested code to export Shardy->GSPMD, but not vice versa).

PiperOrigin-RevId: 767064075
`TCGEN05_ROW` is to `TCGEN05` what `WGMMA_ROW` is to `WGMMA`.

PiperOrigin-RevId: 767068597
The lowering b/w Shardy and GSPMD is slightly different with the custom calls, so I needed to choose different test data based on whether or not Shardy was enabled.

PiperOrigin-RevId: 767074094
… without sharding rule.

PiperOrigin-RevId: 767131346
…:PyClient::CompileAndLoad`.

- Remove redundant `xla::PyClient` `compile` bindings.
- Remove host_callback arguments to `compile`.

PiperOrigin-RevId: 767135320
…mpile`.

Currently, we just forward any calls to `compiler.backend_compile_and_load`, which returns an `xla::PyLoadedExecutable` whereas we'd like `compiler.backend_compile` to return an unloaded `xla::PyExecutable`.

PiperOrigin-RevId: 767142396
PiperOrigin-RevId: 767149635
PiperOrigin-RevId: 767160894
…used_attention_stablehlo.py`.

PiperOrigin-RevId: 767166345
yashk2810 and others added 26 commits June 13, 2025 15:25
… standard snake case

PiperOrigin-RevId: 771296764
PiperOrigin-RevId: 771434581
Imported from GitHub PR jax-ml#28102

* add cudnn support for paged attention described in https://arxiv.org/pdf/2309.06180.
* add new arguments `page_table_k` and `page_table_v`.
* create a new interface `paged_attention` for paged attention.

Copybara import of the project:

--
003d8b1 by cjkkkk <[email protected]>:

add paged attn

Merging this change closes jax-ml#28102

COPYBARA_INTEGRATE_REVIEW=jax-ml#28102 from Cjkkkk:page_attention 003d8b1
PiperOrigin-RevId: 771980902
…o that it return an instance of AbstractTMEMRef and not AbstractMemoryRef

PiperOrigin-RevId: 772026526
PiperOrigin-RevId: 772026590
PiperOrigin-RevId: 772049204
PiperOrigin-RevId: 772096777
Current SciPy releases don't support Python 3.14, and building from source will resolve compatibility issues introduced by the new Python version: https://github.com/jax-ml/jax/actions/runs/15678323189/job/44163705784.

PiperOrigin-RevId: 772113604
JAX v0.4.38 (released Dec 17, 2024) no longer lowered to any legacy CPU custom calls. Following our export compatibility guide (https://docs.jax.dev/en/latest/export/export.html#ensuring-forward-and-backward-compatibility), the remaining legacy custom calls can be removed on June 15, 2025, 180 days after the 0.4.38 release.

PiperOrigin-RevId: 772126428
PiperOrigin-RevId: 772187342
`colocated_python` decorator wraps a function, and the returned function has a
special method `specialize` that lets the user provide explicit information of
the output spec or execution devices. This `specialize` method is in principle
not a part of `Callable` protocol, so access to it would not be valid typing.

This change relaxes the return type of `colocated_python` to `Any` so that
`specialize` method access does not cause typing check failure.

PiperOrigin-RevId: 772206576
Recursively calculates the roofline result for the primitives from the custom function.

PiperOrigin-RevId: 772252912
@rocm-repo-management-api-2 rocm-repo-management-api-2 bot requested a review from a team as a code owner June 17, 2025 06:02
@rocm-repo-management-api-2 rocm-repo-management-api-2 bot enabled auto-merge (rebase) June 17, 2025 06:02
auto-merge was automatically disabled June 19, 2025 15:58

Pull request was closed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.