Skip to content

Commit c86763f

Browse files
ricardoV94aseyboldt
andcommitted
Add helper to build hessian vector product
Co-authored-by: Adrian Seyboldt <[email protected]>
1 parent 94a055d commit c86763f

File tree

3 files changed

+126
-0
lines changed

3 files changed

+126
-0
lines changed

doc/tutorial/gradients.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,16 @@ or, making use of the R-operator:
267267
>>> f([4, 4], [2, 2])
268268
array([ 4., 4.])
269269

270+
There is a builtin helper that uses the first method
271+
272+
>>> x = pt.dvector('x')
273+
>>> v = pt.dvector('v')
274+
>>> y = pt.sum(x ** 2)
275+
>>> Hv = pytensor.gradient.hessian_vector_product(y, x, v)
276+
>>> f = pytensor.function([x, v], Hv)
277+
>>> f([4, 4], [2, 2])
278+
array([ 4., 4.])
279+
270280

271281
Final Pointers
272282
==============

pytensor/gradient.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2052,6 +2052,83 @@ def hessian(cost, wrt, consider_constant=None, disconnected_inputs="raise"):
20522052
return as_list_or_tuple(using_list, using_tuple, hessians)
20532053

20542054

2055+
def hessian_vector_product(cost, wrt, p, **grad_kwargs):
2056+
"""Return the expression of the Hessian times a vector p.
2057+
2058+
Notes
2059+
-----
2060+
This function uses backward autodiff twice to obtain the desired expression.
2061+
You may want to manually build the equivalent expression by combining backward
2062+
followed by forward (if all Ops support it) autodiff.
2063+
See {ref}`docs/_tutcomputinggrads#Hessian-times-a-Vector` for how to do this.
2064+
2065+
Parameters
2066+
----------
2067+
cost: Scalar (0-dimensional) variable.
2068+
wrt: Vector (1-dimensional tensor) 'Variable' or list of Vectors
2069+
p: Vector (1-dimensional tensor) 'Variable' or list of Vectors
2070+
Each vector will be used for the hessp wirt to exach input variable
2071+
**grad_kwargs:
2072+
Keyword arguments passed to `grad` function.
2073+
2074+
Returns
2075+
-------
2076+
:class:` Vector or list of Vectors
2077+
The Hessian times p of the `cost` with respect to (elements of) `wrt`.
2078+
2079+
Examples
2080+
--------
2081+
2082+
>>> import numpy as np
2083+
>>> from scipy.optimize import minimize
2084+
>>> from pytensor import function
2085+
>>> from pytensor.tensor import vector
2086+
>>> from pytensor.gradient import grad, hessian_vector_product
2087+
>>>
2088+
>>> x = vector('x')
2089+
>>> p = vector('p')
2090+
>>>
2091+
>>> rosen = (100 * (x[1:] - x[:-1] ** 2) ** 2 + (1 - x[:-1]) ** 2).sum()
2092+
>>> rosen_jac = grad(rosen, x)
2093+
>>> rosen_hessp = hessian_vector_product(rosen, x, p)
2094+
>>>
2095+
>>> rosen_fn = function([x], rosen)
2096+
>>> rosen_jac_fn = function([x], rosen_jac)
2097+
>>> rosen_hessp_fn = function([x, p], rosen_hessp)
2098+
>>> x0 = np.array([1.3, 0.7, 0.8, 1.9, 1.2])
2099+
>>> res = minimize(
2100+
... rosen_fn,
2101+
... x0,
2102+
... method="Newton-CG",
2103+
... jac=rosen_jac_fn,
2104+
... hessp=rosen_hessp_fn,
2105+
... options={"xtol": 1e-8, "disp": True},
2106+
... )
2107+
Optimization terminated successfully.
2108+
Current function value: 0.000000
2109+
Iterations: 24
2110+
Function evaluations: 33
2111+
Gradient evaluations: 33
2112+
Hessian evaluations: 66
2113+
>>> res.x
2114+
array([1. , 1. , 1. , 0.99999999, 0.99999999])
2115+
"""
2116+
wrt_list = wrt if isinstance(wrt, Sequence) else [wrt]
2117+
p_list = p if isinstance(p, Sequence) else [p]
2118+
grad_wrt_list = grad(cost, wrt=wrt_list, **grad_kwargs)
2119+
hessian_cost = pytensor.tensor.add(
2120+
*[
2121+
(grad_wrt * p).sum()
2122+
for grad_wrt, p in zip(grad_wrt_list, p_list, strict=True)
2123+
]
2124+
)
2125+
Hp_list = grad(hessian_cost, wrt=wrt_list, **grad_kwargs)
2126+
2127+
if isinstance(wrt, Variable):
2128+
return Hp_list[0]
2129+
return Hp_list
2130+
2131+
20552132
def _is_zero(x):
20562133
"""
20572134
Returns 'yes', 'no', or 'maybe' indicating whether x

tests/test_gradient.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import numpy as np
44
import pytest
5+
from scipy.optimize import rosen_hess_prod
56

67
import pytensor
78
import pytensor.tensor.basic as ptb
@@ -22,6 +23,7 @@
2223
grad_scale,
2324
grad_undefined,
2425
hessian,
26+
hessian_vector_product,
2527
jacobian,
2628
subgraph_grad,
2729
zero_grad,
@@ -1081,3 +1083,40 @@ def test_jacobian_disconnected_inputs():
10811083
func_s = pytensor.function([s2], jacobian_s)
10821084
val = np.array(1.0).astype(pytensor.config.floatX)
10831085
assert np.allclose(func_s(val), np.zeros(1))
1086+
1087+
1088+
class TestHessianVectorProdudoct:
1089+
def test_rosen(self):
1090+
x = vector("x", dtype="float64")
1091+
rosen = (100 * (x[1:] - x[:-1] ** 2) ** 2 + (1 - x[:-1]) ** 2).sum()
1092+
1093+
p = vector("p", dtype="float64")
1094+
rosen_hess_prod_pt = hessian_vector_product(rosen, wrt=x, p=p)
1095+
1096+
x_test = 0.1 * np.arange(9)
1097+
p_test = 0.5 * np.arange(9)
1098+
np.testing.assert_allclose(
1099+
rosen_hess_prod_pt.eval({x: x_test, p: p_test}),
1100+
rosen_hess_prod(x_test, p_test),
1101+
)
1102+
1103+
def test_multiple_wrt(self):
1104+
x = vector("x", dtype="float64")
1105+
y = vector("y", dtype="float64")
1106+
p_x = vector("p_x", dtype="float64")
1107+
p_y = vector("p_y", dtype="float64")
1108+
1109+
cost = (x**2 - y**2).sum()
1110+
hessp_x, hessp_y = hessian_vector_product(cost, wrt=[x, y], p=[p_x, p_y])
1111+
1112+
hessp_fn = pytensor.function([x, y, p_x, p_y], [hessp_x, hessp_y])
1113+
test = {
1114+
# x, y don't matter
1115+
"x": np.full((3,), np.nan),
1116+
"y": np.full((3,), np.nan),
1117+
"p_x": [1, 2, 3],
1118+
"p_y": [3, 2, 1],
1119+
}
1120+
hessp_x_eval, hessp_y_eval = hessp_fn(**test)
1121+
np.testing.assert_allclose(hessp_x_eval, [2, 4, 6])
1122+
np.testing.assert_allclose(hessp_y_eval, [-6, -4, -2])

0 commit comments

Comments
 (0)