Skip to content

CI: 06/02/25 upstream sync #441

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1,794 commits into from

Conversation

rocm-repo-management-api-2[bot]
Copy link

Daily sync with upstream

danielsuo and others added 30 commits May 20, 2025 08:41
Part of a larger refactor. Today, `compile` returns a loaded executable i.e., fuses the compile and load functions. Eventually, `compile` should return an unloaded executable and `load` should return a loaded exectuable; the default jit path will still return a loaded executable.

PiperOrigin-RevId: 761098001
… simpler code. Slightly more efficient because we don’t initialize the outputs now.

PiperOrigin-RevId: 761113870
This just makes the corresponding conditions a bit easier to read.

PiperOrigin-RevId: 761137840
Since the `pallas/tpu/ops/random` directory was missing an `__init__.py` file, it was inadvertently excluded from the released JAX distribution. I don't see any reason why this submodule shouldn't be included so let's fix that!

To deal with the fact that they weren't included in the distribution, we were also monkey patching these files into the wheel when testing, but that's no longer needed.

PiperOrigin-RevId: 761138525
PiperOrigin-RevId: 761182341
This bug is marked as fixed upstream.
…d preserve partition specs everywhere internally.

**This is because spec -> names canonicalization gets rid of unreduced axes present on PartitionSpecs and we want to preserve that**. We can thread 2 new parameters called `in_unreduced` and `out_unreduced` and keep `in_names`, `out_names` but that doesn't buy us anything except for more lines added and complexity :)

It's better to just use pspecs everywhere. It's a net reduction in lines of code too!

PiperOrigin-RevId: 761196531
chex now requires jax>=0.4.27, so the previous backward-compatibility is no longer necessary.

PiperOrigin-RevId: 761212887
Expected small regression as we insert kv masking logic which needs to unpack/pack.

PiperOrigin-RevId: 761217303
…els hermetically.

PiperOrigin-RevId: 761231648
PiperOrigin-RevId: 761237485
This just wastes time since the error will never be read.

PiperOrigin-RevId: 761285843
This was previously removed in jax-ml@18ff6ca, and that promptly broke our CI again. I am guessing the problem is actually too few threads, not a NumPy deadlock as I originally guessed.
PiperOrigin-RevId: 761306262
I ran the TPU race checker on it and it did report a number of races
that were uncovered by recent Mosaic compiler changes.

PiperOrigin-RevId: 761447182
PiperOrigin-RevId: 761482778
Google-ML-Automation and others added 26 commits May 30, 2025 02:44
…nels

I also changed the lowering to override --jax_include_full_tracebacks_in_locations
so that we get a single location per emitted op, since the
ensure-debug-info-scope-on-llvm-func pass in MLIR does not correctly handle
nested CallSiteLocs.

PiperOrigin-RevId: 765112273
We have not been shipping Mac x86 for some time.
…are not set and their order actually matters for all-reduce.

PiperOrigin-RevId: 765199626
…mats`

This is part of a broader renaming of "layout" to "format".

PiperOrigin-RevId: 765205967
This enables row broadcast for int8 and int4 on TPUv4.

PiperOrigin-RevId: 765252479
…ng them to their own BUILD rule

Creating smaller build rules enforces better organized dependency graphs in the JAX project, helps pytype propagate annotations correctly, and leads to improved build and iteration times.

This change stops short of actually making the main jax package build rule depend on the new api build rule, because some downstream targets need to be migrated and pytype errors need to be fixed before we can land the final change.

PiperOrigin-RevId: 765341918
Why this change?
* JAX is missing a simple data serialization functionality that is compatible with pytrees.
* Serialization of list of arrays is already supported, leading users to implement one-off data serialization solutions.

New API:
```
def save(data: PyTreeT, directory: str | PathLike[str], overwrite: bool = True,
         ts_specs: PyTreeT | None = None) -> None:
  ...

def load(directory: str | PathLike[str], shardings: PyTreeT,
         mask: PyTreeT | None = None, ts_specs: PyTreeT | None = None
         ) -> PyTreeT:
  ...

def load_pytreedef(directory: str | PathLike[str]) -> PyTreeT:
  ...
```

PiperOrigin-RevId: 765345616
jax.config.update("jax_pjrt_client_create_options", {"max_inflight_computations": 64})

PiperOrigin-RevId: 765357524
PiperOrigin-RevId: 765371666
And relax the test skip conditions. Somehow we skipped everything before. Also, this should fix jax-ml#29092.

PiperOrigin-RevId: 765380392
PiperOrigin-RevId: 765494419
(`nondiff_argnames`) in addition to by index (`nondiff_argnums`).
The implementation normalizes `nondiff_argnames` to indices in the constructor
and merges them with `nondiff_argnums`, allowing the rest of the custom derivative logic to continue using a unified list of indices.

PiperOrigin-RevId: 765730837
PiperOrigin-RevId: 765852528
@rocm-repo-management-api-2 rocm-repo-management-api-2 bot requested a review from a team as a code owner June 2, 2025 06:02
@rocm-repo-management-api-2 rocm-repo-management-api-2 bot enabled auto-merge (rebase) June 2, 2025 06:02
auto-merge was automatically disabled June 19, 2025 16:33

Pull request was closed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.