Skip to content

Commit f6877ef

Browse files
ricardoV94aseyboldt
andcommitted
Add helper to build hessian vector product
Co-authored-by: Adrian Seyboldt <[email protected]>
1 parent d3bd1f1 commit f6877ef

File tree

3 files changed

+148
-0
lines changed

3 files changed

+148
-0
lines changed

doc/tutorial/gradients.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,16 @@ or, making use of the R-operator:
267267
>>> f([4, 4], [2, 2])
268268
array([ 4., 4.])
269269

270+
There is a builtin helper that uses the first method
271+
272+
>>> x = pt.dvector('x')
273+
>>> v = pt.dvector('v')
274+
>>> y = pt.sum(x ** 2)
275+
>>> Hv = pytensor.gradient.hessian_vector_product(y, x, v)
276+
>>> f = pytensor.function([x, v], Hv)
277+
>>> f([4, 4], [2, 2])
278+
array([ 4., 4.])
279+
270280

271281
Final Pointers
272282
==============

pytensor/gradient.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2052,6 +2052,75 @@ def hessian(cost, wrt, consider_constant=None, disconnected_inputs="raise"):
20522052
return as_list_or_tuple(using_list, using_tuple, hessians)
20532053

20542054

2055+
def hessian_vector_product(cost, wrt, p, **grad_kwargs):
2056+
"""Return the expression of the Hessian times a vector p.
2057+
2058+
Parameters
2059+
----------
2060+
cost: Scalar (0-dimensional) variable.
2061+
wrt: Vector (1-dimensional tensor) 'Variable' or list of Vectors
2062+
p: Vector (1-dimensional tensor) 'Variable' or list of Vectors
2063+
Each vector will be used for the hessp wirt to exach input variable
2064+
**grad_kwargs:
2065+
Keyword arguments passed to `grad` function.
2066+
2067+
Returns
2068+
-------
2069+
:class:` Vector or list of Vectors
2070+
The Hessian times p of the `cost` with respect to (elements of) `wrt`.
2071+
2072+
Examples
2073+
--------
2074+
2075+
.. testcode::
2076+
2077+
import numpy as np
2078+
from scipy.optimize import minimize
2079+
2080+
from pytensor import function
2081+
from pytensor.tensor import vector
2082+
from pytensor.gradient import jacobian, hessian_vector_product
2083+
2084+
x = vector('x')
2085+
p = vector('p')
2086+
2087+
rosen = (100 * (x[1:] - x[:-1] ** 2) ** 2 + (1 - x[:-1]) ** 2).sum()
2088+
rosen_hessp = hessian_vector_product(rosen, x, p)
2089+
rosen_jac = jacobian(rosen, x)
2090+
2091+
rosen_fn = function([x], rosen)
2092+
rosen_jac_fn = function([x], rosen_jac)
2093+
rosen_hessp_fn = function([x, p], rosen_hessp)
2094+
x0 = np.array([1.3, 0.7, 0.8, 1.9, 1.2])
2095+
res = minimize(
2096+
rosen_fn,
2097+
x0,
2098+
method="Newton-CG",
2099+
jac=rosen_jac_fn,
2100+
hessp=rosen_hessp_fn,
2101+
options={"xtol": 1e-8},
2102+
)
2103+
assert res.success
2104+
np.testing.assert_allclose(res.x, np.ones_like(x0))
2105+
2106+
2107+
"""
2108+
wrt_list = wrt if isinstance(wrt, Sequence) else [wrt]
2109+
p_list = p if isinstance(p, Sequence) else [p]
2110+
grad_wrt_list = grad(cost, wrt=wrt_list, **grad_kwargs)
2111+
hessian_cost = pytensor.tensor.add(
2112+
*[
2113+
(grad_wrt * p).sum()
2114+
for grad_wrt, p in zip(grad_wrt_list, p_list, strict=True)
2115+
]
2116+
)
2117+
Hp_list = grad(hessian_cost, wrt=wrt_list, **grad_kwargs)
2118+
2119+
if isinstance(wrt, Variable):
2120+
return Hp_list[0]
2121+
return Hp_list
2122+
2123+
20552124
def _is_zero(x):
20562125
"""
20572126
Returns 'yes', 'no', or 'maybe' indicating whether x

tests/test_gradient.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import numpy as np
44
import pytest
5+
from scipy.optimize import rosen_hess_prod
56

67
import pytensor
78
import pytensor.tensor.basic as ptb
@@ -22,6 +23,7 @@
2223
grad_scale,
2324
grad_undefined,
2425
hessian,
26+
hessian_vector_product,
2527
jacobian,
2628
subgraph_grad,
2729
zero_grad,
@@ -1081,3 +1083,70 @@ def test_jacobian_disconnected_inputs():
10811083
func_s = pytensor.function([s2], jacobian_s)
10821084
val = np.array(1.0).astype(pytensor.config.floatX)
10831085
assert np.allclose(func_s(val), np.zeros(1))
1086+
1087+
1088+
class TestHessianVectorProdudoct:
1089+
def test_rosen(self):
1090+
x = vector("x", dtype="float64")
1091+
rosen = (100 * (x[1:] - x[:-1] ** 2) ** 2 + (1 - x[:-1]) ** 2).sum()
1092+
p = vector("p")
1093+
1094+
rosen_hess_prod_pt = hessian_vector_product(rosen, wrt=x, p=p)
1095+
1096+
x_test = 0.1 * np.arange(9)
1097+
p_test = 0.5 * np.arange(9)
1098+
np.testing.assert_allclose(
1099+
rosen_hess_prod_pt.eval({x: x_test, p: p_test}),
1100+
rosen_hess_prod(x_test, p_test),
1101+
)
1102+
1103+
def test_multiple_wrt(self):
1104+
x = vector("x", dtype="float64")
1105+
y = vector("y", dtype="float64")
1106+
p_x = vector("p_x", dtype="float64")
1107+
p_y = vector("p_y", dtype="float64")
1108+
1109+
cost = (x**2 - y**2).sum()
1110+
hessp_x, hessp_y = hessian_vector_product(cost, wrt=[x, y], p=[p_x, p_y])
1111+
1112+
hessp_fn = pytensor.function([x, y, p_x, p_y], [hessp_x, hessp_y])
1113+
test = {
1114+
# x, y don't matter
1115+
"x": np.random.normal(size=(3,)),
1116+
"y": np.random.normal(size=(3,)),
1117+
"p_x": [1, 2, 3],
1118+
"p_y": [3, 2, 1],
1119+
}
1120+
hessp_x_eval, hessp_y_eval = hessp_fn(**test)
1121+
np.testing.assert_allclose(hessp_x_eval, [2, 4, 6])
1122+
np.testing.assert_allclose(hessp_y_eval, [-6, -4, -2])
1123+
1124+
def test_doc_example(self):
1125+
import numpy as np
1126+
from scipy.optimize import minimize
1127+
1128+
from pytensor import function
1129+
from pytensor.gradient import hessian_vector_product, jacobian
1130+
from pytensor.tensor import vector
1131+
1132+
x = vector("x")
1133+
p = vector("p")
1134+
1135+
rosen = (100 * (x[1:] - x[:-1] ** 2) ** 2 + (1 - x[:-1]) ** 2).sum()
1136+
rosen_hessp = hessian_vector_product(rosen, x, p)
1137+
rosen_jac = jacobian(rosen, x)
1138+
1139+
rosen_fn = function([x], rosen)
1140+
rosen_jac_fn = function([x], rosen_jac)
1141+
rosen_hessp_fn = function([x, p], rosen_hessp)
1142+
x0 = np.array([1.3, 0.7, 0.8, 1.9, 1.2])
1143+
res = minimize(
1144+
rosen_fn,
1145+
x0,
1146+
method="Newton-CG",
1147+
jac=rosen_jac_fn,
1148+
hessp=rosen_hessp_fn,
1149+
options={"xtol": 1e-8},
1150+
)
1151+
assert res.success
1152+
np.testing.assert_allclose(res.x, np.ones_like(x0))

0 commit comments

Comments
 (0)