Skip to content

Commit 45ac43c

Browse files
ricardoV94aseyboldt
andcommitted
Add helper to build hessian product
Co-authored-by: Adrian Seyboldt <[email protected]>
1 parent d3bd1f1 commit 45ac43c

File tree

2 files changed

+114
-0
lines changed

2 files changed

+114
-0
lines changed

pytensor/gradient.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2052,6 +2052,72 @@ def hessian(cost, wrt, consider_constant=None, disconnected_inputs="raise"):
20522052
return as_list_or_tuple(using_list, using_tuple, hessians)
20532053

20542054

2055+
def hessian_prod(cost, wrt, p, **grad_kwargs):
2056+
"""Return the expression of the Hessian times a vector p.
2057+
2058+
Parameters
2059+
----------
2060+
cost: Scalar (0-dimensional) variable.
2061+
wrt: Vector (1-dimensional tensor) 'Variable' or list of Vectors
2062+
p: Vector (1-dimensional tensor) 'Variable' or list of Vectors
2063+
Each vector will be used for the hessp wirt to exach input variable
2064+
**grad_kwargs:
2065+
Keyword arguments passed to `grad` function.
2066+
2067+
Returns
2068+
-------
2069+
:class:` Vector or list of Vectors
2070+
The Hessian times p of the `cost` with respect to (elements of) `wrt`.
2071+
2072+
Examples
2073+
--------
2074+
2075+
.. testcode::
2076+
2077+
import numpy as np
2078+
from scipy.optimize import minimize
2079+
2080+
from pytensor import function
2081+
from pytensor.tensor import vector
2082+
from pytensor.gradient import jacobian, hessian_prod
2083+
2084+
x = vector('x')
2085+
p = vector('p')
2086+
2087+
rosen = (100 * (x[1:] - x[:-1] ** 2) ** 2 + (1 - x[:-1]) ** 2).sum()
2088+
rosen_hessp = hessian_prod(rosen, x, p)
2089+
rosen_jac = jacobian(rosen, x)
2090+
2091+
rosen_fn = function([x], rosen)
2092+
rosen_jac_fn = function([x], rosen_jac)
2093+
rosen_hessp_fn = function([x, p], rosen_hessp)
2094+
x0 = np.array([1.3, 0.7, 0.8, 1.9, 1.2])
2095+
res = minimize(
2096+
rosen_fn,
2097+
x0,
2098+
method="Newton-CG",
2099+
jac=rosen_jac_fn,
2100+
hessp=rosen_hessp_fn,
2101+
options={"xatol": 1e-8},
2102+
)
2103+
assert res.success
2104+
np.testing.assert_allclose(res.x, np.ones_like(x0), atol=1e-3)
2105+
2106+
2107+
"""
2108+
wrt_list = wrt if isinstance(wrt, Sequence) else [wrt]
2109+
p_list = p if isinstance(p, Sequence) else [p]
2110+
grad_wrt_list = grad(cost, wrt=wrt_list, **grad_kwargs)
2111+
Hp_list = [
2112+
grad(grad_wrt @ p, wrt=wrt, **grad_kwargs)
2113+
for grad_wrt, p in zip(grad_wrt_list, p_list, strict=True)
2114+
]
2115+
2116+
if isinstance(wrt, Variable):
2117+
return Hp_list[0]
2118+
return Hp_list
2119+
2120+
20552121
def _is_zero(x):
20562122
"""
20572123
Returns 'yes', 'no', or 'maybe' indicating whether x

tests/test_gradient.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import numpy as np
44
import pytest
5+
from scipy.optimize import rosen_hess_prod
56

67
import pytensor
78
import pytensor.tensor.basic as ptb
@@ -22,6 +23,7 @@
2223
grad_scale,
2324
grad_undefined,
2425
hessian,
26+
hessian_prod,
2527
jacobian,
2628
subgraph_grad,
2729
zero_grad,
@@ -1081,3 +1083,49 @@ def test_jacobian_disconnected_inputs():
10811083
func_s = pytensor.function([s2], jacobian_s)
10821084
val = np.array(1.0).astype(pytensor.config.floatX)
10831085
assert np.allclose(func_s(val), np.zeros(1))
1086+
1087+
1088+
def test_hessp():
1089+
x = vector("x", dtype="float64")
1090+
rosen = (100 * (x[1:] - x[:-1] ** 2) ** 2 + (1 - x[:-1]) ** 2).sum()
1091+
p = vector("p")
1092+
1093+
rosen_hess_prod_pt = hessian_prod(rosen, wrt=x, p=p)
1094+
1095+
x_test = 0.1 * np.arange(9)
1096+
p_test = 0.5 * np.arange(9)
1097+
np.testing.assert_allclose(
1098+
rosen_hess_prod_pt.eval({x: x_test, p: p_test}),
1099+
rosen_hess_prod(x_test, p_test),
1100+
)
1101+
1102+
1103+
def test_hessp_example():
1104+
import numpy as np
1105+
from scipy.optimize import minimize
1106+
1107+
from pytensor import function
1108+
from pytensor.gradient import hessian_prod, jacobian
1109+
from pytensor.tensor import vector
1110+
1111+
x = vector("x")
1112+
p = vector("p")
1113+
1114+
rosen = (100 * (x[1:] - x[:-1] ** 2) ** 2 + (1 - x[:-1]) ** 2).sum()
1115+
rosen_hessp = hessian_prod(rosen, x, p)
1116+
rosen_jac = jacobian(rosen, x)
1117+
1118+
rosen_fn = function([x], rosen)
1119+
rosen_jac_fn = function([x], rosen_jac)
1120+
rosen_hessp_fn = function([x, p], rosen_hessp)
1121+
x0 = np.array([1.3, 0.7, 0.8, 1.9, 1.2])
1122+
res = minimize(
1123+
rosen_fn,
1124+
x0,
1125+
method="Newton-CG",
1126+
jac=rosen_jac_fn,
1127+
hessp=rosen_hessp_fn,
1128+
options={"xatol": 1e-8},
1129+
)
1130+
assert res.success
1131+
np.testing.assert_allclose(res.x, np.ones_like(x0), atol=1e-3)

0 commit comments

Comments
 (0)