forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 5
CI: 06/17/25 upstream sync #474
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
PiperOrigin-RevId: 766730509
The conversion uses the `cvt.rn.satfinite.e4m3x2.f32` intrinsics, which means that the saturation behaviour is different from XLA's default. This does ask the question of which numerical behaviour we expect Mosaic GPU to uphold---but we probably don't want to propagate NaNs in this case anyway. PiperOrigin-RevId: 766737083
PiperOrigin-RevId: 766747093
PiperOrigin-RevId: 766748530
…xecutable`. - Introduce `xla::PyExecutable` so we have a public constructor for returning an `xla::nb_class_ptr` from `PyClient::Compile`. There might be other acceptable ways of accomplishing this, but we have a `PyLoadedExecutable` object, so going for consistency. - Migrate uses of `ifrt::Executable` to `ifrt::ExecutableRef` (an alias for `std::shared_ptr<ifrt::Executable>`). There might be undesirable consequences for doing this (i.e., a reason why this wasn't migrated before). PiperOrigin-RevId: 766757937
For now it only works with the TFRT TPU runtime, because other PjRt plugins don't implement the necessary APIs. The per-shard indices of the source and destination shardings must be the same, and all shards must require cross-host transfers (support for a mixture of cross-host and host-local transfers is forthcoming). Transfers take place via the xla::ifrt::PjRtClient::CopyArrays API, which copies the buffers from a set of arrays to a new device list. The distributed KV store from the coordination service is used to store metadata for cross-host transfers. The receiving process populates the store with a descriptor, and the sending process reads it and completes the send. PiperOrigin-RevId: 766765989
We were not testing the logic for non-32-bit-wide dtypes, and as a result missed that one of the `bitcast`s was converting between two types with different bitwidths. PiperOrigin-RevId: 766775563
…ry on L4. New values (though small) show very good performance on ampere. PiperOrigin-RevId: 766797437
…ainder left. Fixes jax-ml#29195 PiperOrigin-RevId: 766801968
…tch performance. PiperOrigin-RevId: 766811714
… cores. Internally, TPU interpret mode uses a new io_callback which spawns multiple threads to simulate multiple Megacore cores. Also updates some comments / code / variable names to better distinguish between internal indices used in interpret mode vs. indices into the Pallas grid. PiperOrigin-RevId: 766851983
Missed this in jax-ml@6c18aa8 PiperOrigin-RevId: 766873399
PiperOrigin-RevId: 766874092
PiperOrigin-RevId: 766895212
PiperOrigin-RevId: 766941330
…put's sharding is concrete i.e. does not contain an AbstractMesh PiperOrigin-RevId: 766962130
PiperOrigin-RevId: 766962750
… for GSPMD. The final module that will be created by JAX export will contain a bit of Shardy and GSPMD ops. What we then do during compilation is detect whether there is a mix of these ops. If there is, we override the build option and instead use GSPMD for propagation (we have well tested code to export Shardy->GSPMD, but not vice versa). PiperOrigin-RevId: 767064075
`TCGEN05_ROW` is to `TCGEN05` what `WGMMA_ROW` is to `WGMMA`. PiperOrigin-RevId: 767068597
The lowering b/w Shardy and GSPMD is slightly different with the custom calls, so I needed to choose different test data based on whether or not Shardy was enabled. PiperOrigin-RevId: 767074094
…antics. PiperOrigin-RevId: 767125470
… without sharding rule. PiperOrigin-RevId: 767131346
…:PyClient::CompileAndLoad`. - Remove redundant `xla::PyClient` `compile` bindings. - Remove host_callback arguments to `compile`. PiperOrigin-RevId: 767135320
…mpile`. Currently, we just forward any calls to `compiler.backend_compile_and_load`, which returns an `xla::PyLoadedExecutable` whereas we'd like `compiler.backend_compile` to return an unloaded `xla::PyExecutable`. PiperOrigin-RevId: 767142396
PiperOrigin-RevId: 767160894
…used_attention_stablehlo.py`. PiperOrigin-RevId: 767166345
PiperOrigin-RevId: 771244288
… standard snake case PiperOrigin-RevId: 771296764
PiperOrigin-RevId: 771434581
Imported from GitHub PR jax-ml#28102 * add cudnn support for paged attention described in https://arxiv.org/pdf/2309.06180. * add new arguments `page_table_k` and `page_table_v`. * create a new interface `paged_attention` for paged attention. Copybara import of the project: -- 003d8b1 by cjkkkk <[email protected]>: add paged attn Merging this change closes jax-ml#28102 COPYBARA_INTEGRATE_REVIEW=jax-ml#28102 from Cjkkkk:page_attention 003d8b1 PiperOrigin-RevId: 771980902
…o that it return an instance of AbstractTMEMRef and not AbstractMemoryRef PiperOrigin-RevId: 772026526
PiperOrigin-RevId: 772026590
PiperOrigin-RevId: 772049204
…n binding PiperOrigin-RevId: 772055206
PiperOrigin-RevId: 772057312
PiperOrigin-RevId: 772091145
PiperOrigin-RevId: 772092105
PiperOrigin-RevId: 772096777
Current SciPy releases don't support Python 3.14, and building from source will resolve compatibility issues introduced by the new Python version: https://github.com/jax-ml/jax/actions/runs/15678323189/job/44163705784. PiperOrigin-RevId: 772113604
JAX v0.4.38 (released Dec 17, 2024) no longer lowered to any legacy CPU custom calls. Following our export compatibility guide (https://docs.jax.dev/en/latest/export/export.html#ensuring-forward-and-backward-compatibility), the remaining legacy custom calls can be removed on June 15, 2025, 180 days after the 0.4.38 release. PiperOrigin-RevId: 772126428
PiperOrigin-RevId: 772134059
PiperOrigin-RevId: 772146614
PiperOrigin-RevId: 772187342
`colocated_python` decorator wraps a function, and the returned function has a special method `specialize` that lets the user provide explicit information of the output spec or execution devices. This `specialize` method is in principle not a part of `Callable` protocol, so access to it would not be valid typing. This change relaxes the return type of `colocated_python` to `Any` so that `specialize` method access does not cause typing check failure. PiperOrigin-RevId: 772206576
Recursively calculates the roofline result for the primitives from the custom function. PiperOrigin-RevId: 772252912
PiperOrigin-RevId: 772266335
PiperOrigin-RevId: 772268506
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Daily sync with upstream