Skip to content

Rewrite scalar solve to division #1453

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

jessegrabowski
Copy link
Member

@jessegrabowski jessegrabowski commented Jun 7, 2025

Description

Small optimization that rewrites solve(a,b) -> b / a when the core shape of a is (1, 1). This avoids calling a LAPACK routine in a case where it's simply not necessary.

This came up in the gradient of minimize, for cases where the function being minimized has only one input. The L_op in that case requires a bunch of linear algebra, but it can all be rewritten away when we're just dealing with scalars.

I also tweaked rewrite_cholesky_diag_to_sqrt_diag to apply to the (1, 1) case -- this allows cho_solve to be rewritten to just b / a ** 2, without any linalg calls. The only bummer is that the sqr and sqrt don't cancel. I get this graph:

Here is the graph, including the new sqr/sqrt cancellation:

import pytensor.tensor as pt
import pytensor
x = pt.tensor('x', shape=(1, 1))
c_and_lower = pt.linalg.cholesky(x), True
b = pt.tensor('b', shape=(None,))

f = pytensor.function([x, b], pt.linalg.cho_solve(c_and_lower, b))

f.dprint()

Composite{(i1 / abs(i0))} [id A] 2
 ├─ Squeeze{axis=1} [id B] 1
 │  └─ x [id C]
 └─ DimShuffle{order=[x]} [id D] 0
    └─ b [id E]

Inner graphs:

Composite{(i1 / abs(i0))} [id A]
 ← true_div [id F] 'o0'
    ├─ i1 [id G]
    └─ Abs [id H]
       └─ i0 [id I]

Related Issue

  • Closes #
  • Related to #

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pytensor--1453.org.readthedocs.build/en/1453/

Copilot

This comment was marked as outdated.

@jessegrabowski jessegrabowski added enhancement New feature or request graph rewriting linalg Linear algebra labels Jun 7, 2025
@ricardoV94
Copy link
Member

ricardoV94 commented Jun 7, 2025

Re sqr(sqrt) we should perhaps canonicalize to power, and have the pow(pow(x, a1), a2) -> pow(x, a1*a2) rewrite.

However we may need to introduce some nan switches if x < 0, depending on the values of a (and or refuse to rewrite if either is unknown)

@jessegrabowski
Copy link
Member Author

Oh I didn't even think about the domain issue. That's a good reason we don't do that rewrite.

@ricardoV94
Copy link
Member

ricardoV94 commented Jun 8, 2025

Oh I didn't even think about the domain issue. That's a good reason we don't do that rewrite.

We handle log(exp) and exp(log), fine, it's just that depending on the order the rewrite has a nan switch or not

@jessegrabowski
Copy link
Member Author

Common complex L

@jessegrabowski jessegrabowski force-pushed the scalar-solve-rewrite branch 2 times, most recently from a2fe5f3 to f998b60 Compare June 23, 2025 10:35
@jessegrabowski
Copy link
Member Author

jessegrabowski commented Jun 23, 2025

We had a test that assumed Cholesky of a 1x1 matrix should not be rewritten to sqrt(diag(x)), which is obviously wrong. I'm not sure why we had that, but I added a check for that case to the cholesky diag rewrite, and removed the test.

I also added the sqrt(sqr(x)) rewrites, so I did some git history cleanup. This PR should be a rebase merge with two commits, one for each feature.

Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR introduces optimizations for simplifying scalar operations and small linear algebra expressions.

  • Adds rewrites to cancel nested sqrt/sqr operations.
  • Optimizes solve, solve_triangular, and cho_solve on 1×1 matrices to simple division.
  • Updates tests to cover these new rewrites.

Reviewed Changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.

File Description
tests/tensor/rewriting/test_math.py Added TestSqrSqrt to verify sqrt/sqr cancellation rules.
tests/tensor/rewriting/test_linalg.py Added test_scalar_solve_to_division_rewrite for 1×1 solves.
pytensor/tensor/rewriting/math.py Defined local_sqrt_sqr rewrite to eliminate nested ops.
pytensor/tensor/rewriting/linalg.py Added 1×1 choleskysqrt rewrite and scalar_solve_to_divison.

Copy link

codecov bot commented Jun 23, 2025

Codecov Report

Attention: Patch coverage is 89.36170% with 5 lines in your changes missing coverage. Please review.

Project coverage is 81.99%. Comparing base (236e50d) to head (8ebf636).

Files with missing lines Patch % Lines
pytensor/tensor/rewriting/linalg.py 88.46% 2 Missing and 1 partial ⚠️
pytensor/tensor/rewriting/math.py 90.47% 1 Missing and 1 partial ⚠️

❌ Your patch check has failed because the patch coverage (89.36%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main    #1453   +/-   ##
=======================================
  Coverage   81.98%   81.99%           
=======================================
  Files         231      231           
  Lines       52192    52239   +47     
  Branches     9185     9197   +12     
=======================================
+ Hits        42790    42832   +42     
- Misses       7094     7097    +3     
- Partials     2308     2310    +2     
Files with missing lines Coverage Δ
pytensor/tensor/rewriting/math.py 89.50% <90.47%> (+0.01%) ⬆️
pytensor/tensor/rewriting/linalg.py 92.04% <88.46%> (-0.24%) ⬇️
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request graph rewriting linalg Linear algebra
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants