-
Notifications
You must be signed in to change notification settings - Fork 135
Rewrite scalar solve to division #1453
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Rewrite scalar solve to division #1453
Conversation
Re sqr(sqrt) we should perhaps canonicalize to power, and have the However we may need to introduce some nan switches if x < 0, depending on the values of a (and or refuse to rewrite if either is unknown) |
Oh I didn't even think about the domain issue. That's a good reason we don't do that rewrite. |
We handle log(exp) and exp(log), fine, it's just that depending on the order the rewrite has a nan switch or not |
Common complex L |
a2fe5f3
to
f998b60
Compare
We had a test that assumed Cholesky of a 1x1 matrix should not be rewritten to I also added the sqrt(sqr(x)) rewrites, so I did some git history cleanup. This PR should be a rebase merge with two commits, one for each feature. |
11c1c22
to
a3f2543
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces optimizations for simplifying scalar operations and small linear algebra expressions.
- Adds rewrites to cancel nested
sqrt
/sqr
operations. - Optimizes
solve
,solve_triangular
, andcho_solve
on 1×1 matrices to simple division. - Updates tests to cover these new rewrites.
Reviewed Changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.
File | Description |
---|---|
tests/tensor/rewriting/test_math.py | Added TestSqrSqrt to verify sqrt /sqr cancellation rules. |
tests/tensor/rewriting/test_linalg.py | Added test_scalar_solve_to_division_rewrite for 1×1 solves. |
pytensor/tensor/rewriting/math.py | Defined local_sqrt_sqr rewrite to eliminate nested ops. |
pytensor/tensor/rewriting/linalg.py | Added 1×1 cholesky → sqrt rewrite and scalar_solve_to_divison . |
a3f2543
to
8ebf636
Compare
Codecov ReportAttention: Patch coverage is
❌ Your patch check has failed because the patch coverage (89.36%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #1453 +/- ##
=======================================
Coverage 81.98% 81.99%
=======================================
Files 231 231
Lines 52192 52239 +47
Branches 9185 9197 +12
=======================================
+ Hits 42790 42832 +42
- Misses 7094 7097 +3
- Partials 2308 2310 +2
🚀 New features to boost your workflow:
|
Description
Small optimization that rewrites
solve(a,b) -> b / a
when the core shape ofa
is(1, 1)
. This avoids calling a LAPACK routine in a case where it's simply not necessary.This came up in the gradient of
minimize
, for cases where the function being minimized has only one input. The L_op in that case requires a bunch of linear algebra, but it can all be rewritten away when we're just dealing with scalars.I also tweaked
rewrite_cholesky_diag_to_sqrt_diag
to apply to the (1, 1) case -- this allowscho_solve
to be rewritten to just b / a ** 2, without any linalg calls.The only bummer is that thesqr
andsqrt
don't cancel. I get this graph:Here is the graph, including the new sqr/sqrt cancellation:
Related Issue
Checklist
Type of change
📚 Documentation preview 📚: https://pytensor--1453.org.readthedocs.build/en/1453/