-
Notifications
You must be signed in to change notification settings - Fork 1.1k
feat(mjx): Make solver differentiable and fix name collisions #2721
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
feat(mjx): Make solver differentiable and fix name collisions #2721
Conversation
- [x] I branched from main - [x] My pull request is against main - [x] My code follows the DeepMind Style Guide - [x] I ran git clang-format The details Resolves - Fixes google-deepmind#2259: [MJX] jax.lax.while_loop in solver.py prevents computation of backward gradients (google-deepmind#2259) Proposed Changes This pull request introduces two main changes: 1. It resolves the core issue by enabling reverse-mode differentiation (gradients) through the MJX solver. 2. It fixes latent name-collision bugs in the project structure that were discovered during testing. 1. Differentiable Solver The solver now conditionally uses a differentiable code path when m.opt.tolerance is set to 0. The key insight came from a comment by a maintainer in the issue thread: in the C++ version of MuJoCo, setting tolerance = 0 is the established idiom for forcing a fixed number of solver iterations. To maintain consistency with the C++ API and the expectations of the existing user base, we adopted this same convention. Instead of introducing a new API option, the code now checks for m.opt.tolerance == 0. If true, it uses the project's existing _while_loop_scan function, which is a jax.lax.scan-based implementation that is safe for reverse-mode autodiff. If false, it continues to use the more performant jax.lax.while_loop. Code Change in `mjx/mujoco/mjx/_src/solver.py`: 1 if m.opt.iterations == 1: 2 ctx = body(ctx) 3 else: 4 # if tolerance is 0, scan is differentiable. otherwise, use while_loop. 5 ctx = jax.lax.cond( 6 m.opt.tolerance == 0, 7 lambda c: _while_loop_scan(cond, body, c, m.opt.iterations), 8 lambda c: jax.lax.while_loop(cond, body, c), 9 ctx, 10 ) 2. Fixing Standard Library Name Collisions During testing, we discovered that running test scripts directly from the file path caused ImportError and AttributeError exceptions. This was because two files in the mjx/_src directory had the same names as standard Python libraries: * math.py * dataclasses.py This "shadowing" of standard libraries is a latent bug that makes the codebase fragile and difficult for contributors to work with. We have fixed this by renaming the files and updating all corresponding imports. * math.py -> mjx_math.py * dataclasses.py -> mjx_dataclasses.py Behavior Before Change - Attempting to compute gradients through mjx.step or mjx.solve would raise a ValueError: Reverse-mode differentiation does not work for lax.while_loop... if m.opt.iterations > 1. - Running individual test files from the command line would fail with import errors due to standard library name collisions. Reason for Changes The primary reason is to enable gradient-based optimization through the MJX solver, a key feature for many users in robotics and machine learning. The secondary changes were made to fix critical, latent bugs in the project's structure that hindered development and testing, thereby improving the overall health and maintainability of the codebase. Test Coverage A new test, test_solver_differentiable, has been added to mjx/mujoco/mjx/_src/solver_test.py. This test verifies the fix by creating a simple pendulum model, setting m.opt.tolerance = 0, and asserting that the gradient of the acceleration with respect to control inputs is non-zero. All existing tests, plus the new test, now pass successfully. The final command used to run the tests successfully was: 1 export PYTHONPATH=/Users/devanshvarshney/mujoco/mjx:$PYTHONPATH && venv/bin/python /Users/devanshvarshney/mujoco/mjx/mujoco/mjx/_src/solver_test.py The output confirms that all 14 tests, including our new one, passed: 1 ... 2 [ RUN ] SolverTest.test_solver_differentiable 3 I0706 16:48:37.037750 8581750528 io.py:152] Using JAX default device: TFRT_CPU_0. 4 I0706 16:48:37.037802 8581750528 io.py:39] No CUDA GPU devices found in jax.devices("cuda"). 5 I0706 16:48:37.040613 8581750528 io.py:152] Using JAX default device: TFRT_CPU_0. 6 I0706 16:48:37.040644 8581750528 io.py:39] No CUDA GPU devices found in jax.devices("cuda"). 7 [ OK ] SolverTest.test_solver_differentiable 8 ... 9 ---------------------------------------------------------------------- 10 Ran 14 tests in 12.173s 11 12 OK Additional Information A Note on the Name Collision Bugs: It is important to clarify why we encountered the math.py and dataclasses.py import errors when they may not have affected other developers. The errors were triggered because we ran the test script directly from its file path. This action places the script's directory (.../_src/) at the top of Python's import search path, causing it to load the project's local math.py instead of the standard library version. While this was caused by our specific testing method, the underlying issue—shadowing a standard library module—is a latent bug in the project's structure. A robust project should not fail when a developer runs a single test file, which is a common practice. Therefore, renaming these files is a crucial fix that improves the project's stability and maintainability for all future contributors. Splitting the Pull Request: The changes in this pull request could potentially be split into two separate PRs: 1. The core fix for the differentiable solver (solver.py and solver_test.py). 2. The fix for the standard library name collisions (renaming math.py and dataclasses.py and updating all imports).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@varshneydevansh can you isolate this PR down to just the change to make the solver differentiable? Then if you'd like we can discuss module import conventions in an issue or different PR. Thanks!
ctx = jax.lax.while_loop(cond, body, ctx) | ||
# if tolerance is 0, scan is differentiable. otherwise, use while_loop. | ||
ctx = jax.lax.cond( | ||
m.opt.tolerance == 0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@varshneydevansh it's probably better to check this statically. It should produce better run-time performance. So something like this:
if m.opt.iterations == 1:
ctx = body(ctx)
elif m.opt.tolerance == 0.0:
ctx = _while_loop_scan(cond, body, ctx, m.opt.iterations)
else:
ctx = jax.lax.while_loop(cond, body, ctx)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I will do this currently fighting some UI threading async problem - https://gerrit.libreoffice.org/c/core/+/186822
The basics
The details
Resolves
Proposed Changes
This pull request introduces two main changes:
A. Differentiable Solver
The solver now conditionally uses a differentiable code path when
m.opt.tolerance is set to 0
.tolerance = 0
is the established idiom for forcing a fixed number of solver iterations.Instead of introducing a new API option, the code now checks for
m.opt.tolerance == 0
._while_loop_scan
function, which is ajax.lax.scan-based
implementation that is safe forreverse-mode autodiff
.jax.lax.while_loop
.B. Fixing Standard Library Name Collisions
During testing, we discovered that running test scripts directly from the file path caused
ImportError
andAttributeError
exceptions. This was because two files in themjx/_src
directory had the same names as standard Python libraries:math.py
dataclasses.py
This "shadowing" of standard libraries is a latent bug that makes the codebase fragile and difficult for contributors to work with. We have fixed this by renaming the files and updating all corresponding imports.
math.py -> mjx_math.py
dataclasses.py -> mjx_dataclasses.py
Behavior Before Change
mjx.step
ormjx.solve
would raise a `ValueError: Reverse-mode differentiation does not work for lax.while_loop or lax.fori_loop with dynamic start/stop values'.Reason for Changes
Test Coverage
test_solver_differentiable
, has been added tomjx/mujoco/mjx/_src/solver_test.py
.m.opt.tolerance = 0
, and asserting that the gradient of the acceleration with respect to control inputs is non-zero.All existing tests, plus the new test, now pass successfully.
The final command used to run the tests successfully was:
export PYTHONPATH=/Users/devanshvarshney/mujoco/mjx:$PYTHONPATH && venv/bin/python /Users/devanshvarshney/mujoco/mjx/mujoco/mjx/_src/solver_test.py
The output confirms that all 14 tests, including our new one, passed:
Additional Information
Intuition & why?
The User's Goal: "Gradient-based trajectory optimization." This is a fancy term for a simple idea: We have a simulation (like a robot arm moving), and we want to tweak the inputs (the motor forces) to achieve a goal (like touching a specific point). To do this efficiently, we need to know how a small change in input affects the output. This "how" is the gradient.
The Software's Constraint: The error message says it all:
ValueError: Reverse-mode differentiation does not work for lax.while_loop ... with dynamic start/stop values
.This is the key.
while_loop
with a dynamic condition (e.g., while error > 0.001) is the opposite of static. The number of times it will run is unknown before we actually run it. JAX cannot build a fixed blueprint for a process with an unknown number of steps.A Real-World Analogy: Building Intuition
Imagine we have two ways to bake a cake:
The Static Recipe (
fori_loop
): This recipe says, "Mix for 3 minutes, then bake for exactly 30 minutes." The steps and their durations are fixed and known before we start. If the cake comes out badly, we can easily calculate the "gradient." For example, we can reason: "If I had baked it for one more minute, it would be 5% less gooey."The Dynamic Recipe (
while_loop
): This recipe says, "Mix until the batter is smooth, then bake until a toothpick comes out clean." The number of steps is dynamic. How long do we mix? It depends on the batter. How long do we bake? It depends on our oven and how often we check. If the cake is bad, it's very hard to reason backward. The process wasn't a fixed set of operations, so we can't easily calculate the "gradient" of how one change would affect the outcome.JAX is like the first type of baker. It needs a static recipe to do its work. The
MuJoCo
solver was using a dynamic recipe, which is efficient for simulation but bad for calculating gradients.A Note on the Name Collision Bugs:
It is important to clarify why we encountered the
math.py
anddataclasses.py
import errors when they may not have affected other developers.(.../_src/) at the top of Python's import search path
, causing it to load the project's local math.py instead of the standard library version.Splitting the Pull Request:
solver.py
andsolver_test.py
).math.py
anddataclasses.py
and updating all imports).