forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 5
CI: 06/02/25 upstream sync #441
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Part of a larger refactor. Today, `compile` returns a loaded executable i.e., fuses the compile and load functions. Eventually, `compile` should return an unloaded executable and `load` should return a loaded exectuable; the default jit path will still return a loaded executable. PiperOrigin-RevId: 761098001
PiperOrigin-RevId: 761101581
… simpler code. Slightly more efficient because we don’t initialize the outputs now. PiperOrigin-RevId: 761113870
This just makes the corresponding conditions a bit easier to read. PiperOrigin-RevId: 761137840
Since the `pallas/tpu/ops/random` directory was missing an `__init__.py` file, it was inadvertently excluded from the released JAX distribution. I don't see any reason why this submodule shouldn't be included so let's fix that! To deal with the fact that they weren't included in the distribution, we were also monkey patching these files into the wheel when testing, but that's no longer needed. PiperOrigin-RevId: 761138525
PiperOrigin-RevId: 761150621
PiperOrigin-RevId: 761152994
PiperOrigin-RevId: 761182341
PiperOrigin-RevId: 761191493
This bug is marked as fixed upstream.
…d preserve partition specs everywhere internally. **This is because spec -> names canonicalization gets rid of unreduced axes present on PartitionSpecs and we want to preserve that**. We can thread 2 new parameters called `in_unreduced` and `out_unreduced` and keep `in_names`, `out_names` but that doesn't buy us anything except for more lines added and complexity :) It's better to just use pspecs everywhere. It's a net reduction in lines of code too! PiperOrigin-RevId: 761196531
PiperOrigin-RevId: 761204561
chex now requires jax>=0.4.27, so the previous backward-compatibility is no longer necessary. PiperOrigin-RevId: 761212887
Expected small regression as we insert kv masking logic which needs to unpack/pack. PiperOrigin-RevId: 761217303
… the same shape PiperOrigin-RevId: 761230458
…els hermetically. PiperOrigin-RevId: 761231648
PiperOrigin-RevId: 761268857
PiperOrigin-RevId: 761283866
This just wastes time since the error will never be read. PiperOrigin-RevId: 761285843
This was previously removed in jax-ml@18ff6ca, and that promptly broke our CI again. I am guessing the problem is actually too few threads, not a NumPy deadlock as I originally guessed.
PiperOrigin-RevId: 761306262
…sh in context. PiperOrigin-RevId: 761322712
…sion PiperOrigin-RevId: 761429155
PiperOrigin-RevId: 761442340
I ran the TPU race checker on it and it did report a number of races that were uncovered by recent Mosaic compiler changes. PiperOrigin-RevId: 761447182
PiperOrigin-RevId: 761482778
…nels I also changed the lowering to override --jax_include_full_tracebacks_in_locations so that we get a single location per emitted op, since the ensure-debug-info-scope-on-llvm-func pass in MLIR does not correctly handle nested CallSiteLocs. PiperOrigin-RevId: 765112273
We have not been shipping Mac x86 for some time.
PiperOrigin-RevId: 765167817
…are not set and their order actually matters for all-reduce. PiperOrigin-RevId: 765199626
PiperOrigin-RevId: 765199706
…mats` This is part of a broader renaming of "layout" to "format". PiperOrigin-RevId: 765205967
This enables row broadcast for int8 and int4 on TPUv4. PiperOrigin-RevId: 765252479
PiperOrigin-RevId: 765259002
PiperOrigin-RevId: 765266754
PiperOrigin-RevId: 765289353
PiperOrigin-RevId: 765330882
…ng them to their own BUILD rule Creating smaller build rules enforces better organized dependency graphs in the JAX project, helps pytype propagate annotations correctly, and leads to improved build and iteration times. This change stops short of actually making the main jax package build rule depend on the new api build rule, because some downstream targets need to be migrated and pytype errors need to be fixed before we can land the final change. PiperOrigin-RevId: 765341918
Why this change? * JAX is missing a simple data serialization functionality that is compatible with pytrees. * Serialization of list of arrays is already supported, leading users to implement one-off data serialization solutions. New API: ``` def save(data: PyTreeT, directory: str | PathLike[str], overwrite: bool = True, ts_specs: PyTreeT | None = None) -> None: ... def load(directory: str | PathLike[str], shardings: PyTreeT, mask: PyTreeT | None = None, ts_specs: PyTreeT | None = None ) -> PyTreeT: ... def load_pytreedef(directory: str | PathLike[str]) -> PyTreeT: ... ``` PiperOrigin-RevId: 765345616
jax.config.update("jax_pjrt_client_create_options", {"max_inflight_computations": 64}) PiperOrigin-RevId: 765357524
PiperOrigin-RevId: 765371666
And relax the test skip conditions. Somehow we skipped everything before. Also, this should fix jax-ml#29092. PiperOrigin-RevId: 765380392
PiperOrigin-RevId: 765494419
(`nondiff_argnames`) in addition to by index (`nondiff_argnums`). The implementation normalizes `nondiff_argnames` to indices in the constructor and merges them with `nondiff_argnums`, allowing the rest of the custom derivative logic to continue using a unified list of indices. PiperOrigin-RevId: 765730837
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Daily sync with upstream