diff --git a/doc/tutorial/gradients.rst b/doc/tutorial/gradients.rst index f8b7f7ff98..35dc852c77 100644 --- a/doc/tutorial/gradients.rst +++ b/doc/tutorial/gradients.rst @@ -101,9 +101,12 @@ PyTensor implements the :func:`pytensor.gradient.jacobian` macro that does all that is needed to compute the Jacobian. The following text explains how to do it manually. +Using Scan +---------- + In order to manually compute the Jacobian of some function ``y`` with -respect to some parameter ``x`` we need to use `scan`. What we -do is to loop over the entries in ``y`` and compute the gradient of +respect to some parameter ``x`` we can use `scan`. +In this case, we loop over the entries in ``y`` and compute the gradient of ``y[i]`` with respect to ``x``. .. note:: @@ -111,8 +114,7 @@ do is to loop over the entries in ``y`` and compute the gradient of `scan` is a generic op in PyTensor that allows writing in a symbolic manner all kinds of recurrent equations. While creating symbolic loops (and optimizing them for performance) is a hard task, - effort is being done for improving the performance of `scan`. We - shall return to :ref:`scan` later in this tutorial. + efforts are being made to improving the performance of `scan`. >>> import pytensor >>> import pytensor.tensor as pt @@ -124,9 +126,9 @@ do is to loop over the entries in ``y`` and compute the gradient of array([[ 8., 0.], [ 0., 8.]]) -What we do in this code is to generate a sequence of integers from ``0`` to -``y.shape[0]`` using `pt.arange`. Then we loop through this sequence, and -at each step, we compute the gradient of element ``y[i]`` with respect to +This code generates a sequence of integers from ``0`` to +``y.shape[0]`` using `pt.arange`. Then it loops through this sequence, and +at each step, computes the gradient of element ``y[i]`` with respect to ``x``. `scan` automatically concatenates all these rows, generating a matrix which corresponds to the Jacobian. @@ -139,6 +141,31 @@ matrix which corresponds to the Jacobian. ``x`` anymore, while ``y[i]`` still is. +Using automatic vectorization +----------------------------- +An alternative way to build the Jacobian is to vectorize the graph that computes a single row or colum of the jacobian +We can use `Lop` or `Rop` (more about it below) to obtain the row or column of the jacobian and `vectorize_graph` +to vectorize it to the full jacobian matrix. + +>>> import pytensor +>>> import pytensor.tensor as pt +>>> from pytensor.gradient import Lop +>>> from pytensor.graph import vectorize_graph +>>> x = pt.dvector('x') +>>> y = x ** 2 +>>> row_cotangent = pt.dvector("row_cotangent") # Helper variable, it will be replaced during vectorization +>>> J_row = Lop(y, x, row_cotangent) +>>> J = vectorize_graph(J_row, replace={row_cotangent: pt.eye(x.size)}) +>>> f = pytensor.function([x], J) +>>> f([4, 4]) +array([[ 8., 0.], + [ 0., 8.]]) + +This avoids the overhead of scan, at the cost of higher memory usage if the jacobian expression has large intermediate operations. +Also, not all graphs are safely vectorizable (e.g., if different rows require intermediate operations of different sizes). +For these reasons `jacobian` uses scan by default. The behavior can be changed by setting `vectorize=True`. + + Computing the Hessian ===================== diff --git a/pytensor/gradient.py b/pytensor/gradient.py index 96a39e09d9..5924fd7fcb 100644 --- a/pytensor/gradient.py +++ b/pytensor/gradient.py @@ -11,7 +11,7 @@ import pytensor from pytensor.compile.ops import ViewOp from pytensor.configdefaults import config -from pytensor.graph import utils +from pytensor.graph import utils, vectorize_graph from pytensor.graph.basic import Apply, NominalVariable, Variable from pytensor.graph.null_type import NullType, null_type from pytensor.graph.op import get_test_values @@ -703,15 +703,15 @@ def grad( grad_dict[var] = g_var def handle_disconnected(var): - message = ( - "grad method was asked to compute the gradient " - "with respect to a variable that is not part of " - "the computational graph of the cost, or is used " - f"only by a non-differentiable operator: {var}" - ) if disconnected_inputs == "ignore": - pass + return elif disconnected_inputs == "warn": + message = ( + "grad method was asked to compute the gradient " + "with respect to a variable that is not part of " + "the computational graph of the cost, or is used " + f"only by a non-differentiable operator: {var}" + ) warnings.warn(message, stacklevel=2) elif disconnected_inputs == "raise": message = utils.get_variable_trace_string(var) @@ -2021,13 +2021,19 @@ def __str__(self): Exception args: {args_msg}""" -def jacobian(expression, wrt, consider_constant=None, disconnected_inputs="raise"): +def jacobian( + expression, + wrt, + consider_constant=None, + disconnected_inputs="raise", + vectorize=False, +): """ Compute the full Jacobian, row by row. Parameters ---------- - expression : Vector (1-dimensional) :class:`~pytensor.graph.basic.Variable` + expression :class:`~pytensor.graph.basic.Variable` Values that we are differentiating (that we want the Jacobian of) wrt : :class:`~pytensor.graph.basic.Variable` or list of Variables Term[s] with respect to which we compute the Jacobian @@ -2051,18 +2057,18 @@ def jacobian(expression, wrt, consider_constant=None, disconnected_inputs="raise output, then a zero variable is returned. The return value is of same type as `wrt`: a list/tuple or TensorVariable in all cases. """ + from pytensor.tensor.basic import eye + from pytensor.tensor.extra_ops import broadcast_to if not isinstance(expression, Variable): raise TypeError("jacobian expects a Variable as `expression`") - if expression.ndim > 1: - raise ValueError( - "jacobian expects a 1 dimensional variable as `expression`." - " If not use flatten to make it a vector" - ) - using_list = isinstance(wrt, list) using_tuple = isinstance(wrt, tuple) + grad_kwargs = { + "consider_constant": consider_constant, + "disconnected_inputs": disconnected_inputs, + } if isinstance(wrt, list | tuple): wrt = list(wrt) @@ -2070,43 +2076,55 @@ def jacobian(expression, wrt, consider_constant=None, disconnected_inputs="raise wrt = [wrt] if all(expression.type.broadcastable): - # expression is just a scalar, use grad - return as_list_or_tuple( - using_list, - using_tuple, - grad( - expression.squeeze(), - wrt, - consider_constant=consider_constant, - disconnected_inputs=disconnected_inputs, - ), + jacobian_matrices = grad(expression.squeeze(), wrt, **grad_kwargs) + + elif vectorize: + expression_flat = expression.ravel() + row_tangent = _float_ones_like(expression_flat).type("row_tangent") + jacobian_single_rows = Lop(expression.ravel(), wrt, row_tangent, **grad_kwargs) + + n_rows = expression_flat.size + jacobian_matrices = vectorize_graph( + jacobian_single_rows, + replace={row_tangent: eye(n_rows, dtype=row_tangent.dtype)}, ) + if disconnected_inputs != "raise": + # If the input is disconnected from the cost, `vectorize_graph` has no effect on the respective jacobian + # We have to broadcast the zeros explicitly here + for i, (jacobian_single_row, jacobian_matrix) in enumerate( + zip(jacobian_single_rows, jacobian_matrices, strict=True) + ): + if jacobian_single_row.ndim == jacobian_matrix.ndim: + jacobian_matrices[i] = broadcast_to( + jacobian_matrix, shape=(n_rows, *jacobian_matrix.shape) + ) - def inner_function(*args): - idx = args[0] - expr = args[1] - rvals = [] - for inp in args[2:]: - rval = grad( - expr[idx], - inp, - consider_constant=consider_constant, - disconnected_inputs=disconnected_inputs, + else: + + def inner_function(*args): + idx, expr, *wrt = args + return grad(expr[idx], wrt, **grad_kwargs) + + jacobian_matrices, updates = pytensor.scan( + inner_function, + sequences=pytensor.tensor.arange(expression.size), + non_sequences=[expression.ravel(), *wrt], + return_list=True, + ) + if updates: + raise ValueError( + "The scan used to build the jacobian matrices returned a list of updates" ) - rvals.append(rval) - return rvals - - # Computing the gradients does not affect the random seeds on any random - # generator used n expression (because during computing gradients we are - # just backtracking over old values. (rp Jan 2012 - if anyone has a - # counter example please show me) - jacobs, updates = pytensor.scan( - inner_function, - sequences=pytensor.tensor.arange(expression.shape[0]), - non_sequences=[expression, *wrt], - ) - assert not updates, "Scan has returned a list of updates; this should not happen." - return as_list_or_tuple(using_list, using_tuple, jacobs) + + if jacobian_matrices[0].ndim < (expression.ndim + wrt[0].ndim): + # There was some raveling or squeezing done prior to getting the jacobians + # Reshape into original shapes + jacobian_matrices = [ + jac_matrix.reshape((*expression.shape, *w.shape)) + for jac_matrix, w in zip(jacobian_matrices, wrt, strict=True) + ] + + return as_list_or_tuple(using_list, using_tuple, jacobian_matrices) def hessian(cost, wrt, consider_constant=None, disconnected_inputs="raise"): diff --git a/pytensor/graph/replace.py b/pytensor/graph/replace.py index 5092d55e6b..6cb46b6301 100644 --- a/pytensor/graph/replace.py +++ b/pytensor/graph/replace.py @@ -232,13 +232,13 @@ def vectorize_graph( def vectorize_graph( outputs: Sequence[Variable], replace: Mapping[Variable, Variable], -) -> Sequence[Variable]: ... +) -> list[Variable]: ... def vectorize_graph( outputs: Variable | Sequence[Variable], replace: Mapping[Variable, Variable], -) -> Variable | Sequence[Variable]: +) -> Variable | list[Variable]: """Vectorize outputs graph given mapping from old variables to expanded counterparts version. Expanded dimensions must be on the left. Behavior is similar to the functional `numpy.vectorize`. diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index b9e9c3164d..8225fd02ac 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -3081,6 +3081,10 @@ def flatten(x, ndim=1): else: dims = (-1,) + if len(dims) == _x.ndim: + # Nothing to ravel + return _x + x_reshaped = _x.reshape(dims) shape_kept_dims = _x.type.shape[: ndim - 1] bcast_new_dim = builtins.all(s == 1 for s in _x.type.shape[ndim - 1 :]) diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index 5e6271e170..b01a50e2fa 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -3867,35 +3867,22 @@ class TestInferShape(utt.InferShapeTester): def test_Flatten(self): atens3 = tensor3() atens3_val = random(4, 5, 3) - for ndim in (3, 2, 1): + for ndim in (2, 1): self._compile_and_check( [atens3], [flatten(atens3, ndim)], [atens3_val], Reshape, - excluding=["local_useless_reshape"], ) amat = matrix() amat_val = random(4, 5) - for ndim in (2, 1): - self._compile_and_check( - [amat], - [flatten(amat, ndim)], - [amat_val], - Reshape, - excluding=["local_useless_reshape"], - ) - - avec = vector() - avec_val = random(4) ndim = 1 self._compile_and_check( - [avec], - [flatten(avec, ndim)], - [avec_val], + [amat], + [flatten(amat, ndim)], + [amat_val], Reshape, - excluding=["local_useless_reshape"], ) def test_Eye(self): diff --git a/tests/test_gradient.py b/tests/test_gradient.py index 9673f8338e..89712c19dd 100644 --- a/tests/test_gradient.py +++ b/tests/test_gradient.py @@ -4,6 +4,7 @@ import pytensor import pytensor.tensor.basic as ptb +from pytensor import function from pytensor.configdefaults import config from pytensor.gradient import ( DisconnectedInputError, @@ -31,7 +32,7 @@ from pytensor.graph.null_type import NullType from pytensor.graph.op import Op from pytensor.scan.op import Scan -from pytensor.tensor.math import add, dot, exp, sigmoid, sqr, tanh +from pytensor.tensor.math import add, dot, exp, outer, sigmoid, sqr, tanh from pytensor.tensor.math import sum as pt_sum from pytensor.tensor.random import RandomStream from pytensor.tensor.type import ( @@ -940,139 +941,207 @@ def test_undefined_grad_opt(): ) -def test_jacobian_vector(): - x = vector() - y = x * 2 - rng = np.random.default_rng(seed=utt.fetch_seed()) - - # test when the jacobian is called with a tensor as wrt - Jx = jacobian(y, x) - f = pytensor.function([x], Jx) - vx = rng.uniform(size=(10,)).astype(pytensor.config.floatX) - assert np.allclose(f(vx), np.eye(10) * 2) +@pytest.mark.parametrize("vectorize", [False, True], ids=lambda x: f"vectorize={x}") +class TestJacobian: + def test_jacobian_vector(self, vectorize): + x = vector() + y = x * 2 + rng = np.random.default_rng(seed=utt.fetch_seed()) + + # test when the jacobian is called with a tensor as wrt + Jx = jacobian(y, x, vectorize=vectorize) + f = function([x], Jx) + vx = rng.uniform(size=(10,)).astype(pytensor.config.floatX) + assert np.allclose(f(vx), np.eye(10) * 2) + + # test when the jacobian is called with a tuple as wrt + Jx = jacobian(y, (x,), vectorize=vectorize) + assert isinstance(Jx, tuple) + f = function([x], Jx[0]) + vx = rng.uniform(size=(10,)).astype(pytensor.config.floatX) + assert np.allclose(f(vx), np.eye(10) * 2) + + # test when the jacobian is called with a list as wrt + Jx = jacobian(y, [x], vectorize=vectorize) + assert isinstance(Jx, list) + f = function([x], Jx[0]) + vx = rng.uniform(size=(10,)).astype(pytensor.config.floatX) + assert np.allclose(f(vx), np.eye(10) * 2) + + # test when the jacobian is called with a list of two elements + z = vector() + y = x * z + Js = jacobian(y, [x, z], vectorize=vectorize) + f = function([x, z], Js) + vx = rng.uniform(size=(10,)).astype(pytensor.config.floatX) + vz = rng.uniform(size=(10,)).astype(pytensor.config.floatX) + vJs = f(vx, vz) + evx = np.zeros((10, 10)) + evz = np.zeros((10, 10)) + np.fill_diagonal(evx, vx) + np.fill_diagonal(evz, vz) + assert np.allclose(vJs[0], evz) + assert np.allclose(vJs[1], evx) + + def test_jacobian_matrix(self, vectorize): + x = matrix() + y = 2 * x.sum(axis=0) + rng = np.random.default_rng(seed=utt.fetch_seed()) + ev = np.zeros((10, 10, 10)) + for dx in range(10): + ev[dx, :, dx] = 2.0 + + # test when the jacobian is called with a tensor as wrt + Jx = jacobian(y, x, vectorize=vectorize) + f = function([x], Jx) + vx = rng.uniform(size=(10, 10)).astype(pytensor.config.floatX) + assert np.allclose(f(vx), ev) + + # test when the jacobian is called with a tuple as wrt + Jx = jacobian(y, (x,), vectorize=vectorize) + assert isinstance(Jx, tuple) + f = function([x], Jx[0]) + vx = rng.uniform(size=(10, 10)).astype(pytensor.config.floatX) + assert np.allclose(f(vx), ev) + + # test when the jacobian is called with a list as wrt + Jx = jacobian(y, [x], vectorize=vectorize) + assert isinstance(Jx, list) + f = function([x], Jx[0]) + vx = rng.uniform(size=(10, 10)).astype(pytensor.config.floatX) + assert np.allclose(f(vx), ev) + + # test when the jacobian is called with a list of two elements + z = matrix() + y = (x * z).sum(axis=1) + Js = jacobian(y, [x, z], vectorize=vectorize) + f = function([x, z], Js) + vx = rng.uniform(size=(10, 10)).astype(pytensor.config.floatX) + vz = rng.uniform(size=(10, 10)).astype(pytensor.config.floatX) + vJs = f(vx, vz) + evx = np.zeros((10, 10, 10)) + evz = np.zeros((10, 10, 10)) + for dx in range(10): + evx[dx, dx, :] = vx[dx, :] + evz[dx, dx, :] = vz[dx, :] + assert np.allclose(vJs[0], evz) + assert np.allclose(vJs[1], evx) + + def test_jacobian_scalar(self, vectorize): + x = scalar() + y = x * 2 + rng = np.random.default_rng(seed=utt.fetch_seed()) + + # test when the jacobian is called with a tensor as wrt + Jx = jacobian(y, x, vectorize=vectorize) + f = function([x], Jx) + vx = np.asarray(rng.uniform(), dtype=pytensor.config.floatX) + assert np.allclose(f(vx), 2) + + # test when input is a shape (1,) vector -- should still be treated as a scalar + Jx = jacobian(y[None], x) + f = function([x], Jx) + + # Ensure we hit the scalar grad case (doesn't use scan) + nodes = f.maker.fgraph.apply_nodes + assert not any(isinstance(node.op, Scan) for node in nodes) + + vx = np.asarray(rng.uniform(), dtype=pytensor.config.floatX) + assert np.allclose(f(vx), 2) + + # test when the jacobian is called with a tuple as wrt + Jx = jacobian(y, (x,), vectorize=vectorize) + assert isinstance(Jx, tuple) + f = function([x], Jx[0]) + vx = np.asarray(rng.uniform(), dtype=pytensor.config.floatX) + assert np.allclose(f(vx), 2) + + # test when the jacobian is called with a list as wrt + Jx = jacobian(y, [x], vectorize=vectorize) + assert isinstance(Jx, list) + f = function([x], Jx[0]) + vx = np.asarray(rng.uniform(), dtype=pytensor.config.floatX) + assert np.allclose(f(vx), 2) + + # test when the jacobian is called with a list of two elements + z = scalar() + y = x * z + Jx = jacobian(y, [x, z], vectorize=vectorize) + f = function([x, z], Jx) + vx = np.asarray(rng.uniform(), dtype=pytensor.config.floatX) + vz = np.asarray(rng.uniform(), dtype=pytensor.config.floatX) + vJx = f(vx, vz) + + assert np.allclose(vJx[0], vz) + assert np.allclose(vJx[1], vx) + + @pytest.mark.parametrize("square_jac", [False, True]) + def test_jacobian_matrix_expression(self, vectorize, square_jac): + x = vector("x", shape=(3,)) + y = outer(x, x) + if not square_jac: + y = y[:, 1:] + Jy_wrt_x = jacobian(y, wrt=x, vectorize=vectorize) + f = function([x], Jy_wrt_x) + x_test = np.arange(3, dtype=x.type.dtype) + res = f(x_test) + expected_res = np.array( + [ + # Jy[0]_wrt_x (y[0] = x[0] * x) + [[0, 0, 0], [1, 0, 0], [2, 0, 0]], + # Jy[1]_wrt_x (y[1] = x[1] * x) + [ + [1, 0, 0], + [0, 2, 0], + [0, 2, 1], + ], + # Jy[2]_wrt_x (y[2] = x[2] * x) + [ + [2, 0, 0], + [0, 2, 1], + [0, 0, 4], + ], + ] + ) + if not square_jac: + expected_res = expected_res[:, 1:, :] + np.testing.assert_allclose(res, expected_res) + + def test_jacobian_disconnected_inputs(self, vectorize): + # Test that disconnected inputs are properly handled by jacobian. + s1 = scalar("s1") + s2 = scalar("s2") + jacobian_s = jacobian(1 + s1, s2, disconnected_inputs="ignore") + func_s = function([s2], jacobian_s) + val = np.array(1.0, dtype=config.floatX) + np.testing.assert_allclose(func_s(val), np.zeros(1)) + + v1 = vector("v1") + v2 = vector("v2") + jacobian_v = jacobian( + 1 + v1, v2, disconnected_inputs="ignore", vectorize=vectorize + ) + func_v = function([v1, v2], jacobian_v, on_unused_input="ignore") + val = np.arange(4.0, dtype=pytensor.config.floatX) + np.testing.assert_allclose(func_v(val, val), np.zeros((4, 4))) + + m1 = matrix("m1") + m2 = matrix("m2") + jacobian_m = jacobian( + 1 + m1[1:, 2:], m2, disconnected_inputs="ignore", vectorize=vectorize + ) + func_v = function([m1, m2], jacobian_m, on_unused_input="ignore") + val = np.ones((4, 4), dtype=config.floatX) + np.testing.assert_allclose(func_v(val, val), np.zeros((3, 2, 4, 4))) - # test when the jacobian is called with a tuple as wrt - Jx = jacobian(y, (x,)) - assert isinstance(Jx, tuple) - f = pytensor.function([x], Jx[0]) - vx = rng.uniform(size=(10,)).astype(pytensor.config.floatX) - assert np.allclose(f(vx), np.eye(10) * 2) + def test_benchmark(self, vectorize, benchmark): + x = vector("x", shape=(3,)) + y = outer(x, x) - # test when the jacobian is called with a list as wrt - Jx = jacobian(y, [x]) - assert isinstance(Jx, list) - f = pytensor.function([x], Jx[0]) - vx = rng.uniform(size=(10,)).astype(pytensor.config.floatX) - assert np.allclose(f(vx), np.eye(10) * 2) + jac_y = jacobian(y, x, vectorize=vectorize) - # test when the jacobian is called with a list of two elements - z = vector() - y = x * z - Js = jacobian(y, [x, z]) - f = pytensor.function([x, z], Js) - vx = rng.uniform(size=(10,)).astype(pytensor.config.floatX) - vz = rng.uniform(size=(10,)).astype(pytensor.config.floatX) - vJs = f(vx, vz) - evx = np.zeros((10, 10)) - evz = np.zeros((10, 10)) - np.fill_diagonal(evx, vx) - np.fill_diagonal(evz, vz) - assert np.allclose(vJs[0], evz) - assert np.allclose(vJs[1], evx) - - -def test_jacobian_matrix(): - x = matrix() - y = 2 * x.sum(axis=0) - rng = np.random.default_rng(seed=utt.fetch_seed()) - ev = np.zeros((10, 10, 10)) - for dx in range(10): - ev[dx, :, dx] = 2.0 - - # test when the jacobian is called with a tensor as wrt - Jx = jacobian(y, x) - f = pytensor.function([x], Jx) - vx = rng.uniform(size=(10, 10)).astype(pytensor.config.floatX) - assert np.allclose(f(vx), ev) - - # test when the jacobian is called with a tuple as wrt - Jx = jacobian(y, (x,)) - assert isinstance(Jx, tuple) - f = pytensor.function([x], Jx[0]) - vx = rng.uniform(size=(10, 10)).astype(pytensor.config.floatX) - assert np.allclose(f(vx), ev) - - # test when the jacobian is called with a list as wrt - Jx = jacobian(y, [x]) - assert isinstance(Jx, list) - f = pytensor.function([x], Jx[0]) - vx = rng.uniform(size=(10, 10)).astype(pytensor.config.floatX) - assert np.allclose(f(vx), ev) - - # test when the jacobian is called with a list of two elements - z = matrix() - y = (x * z).sum(axis=1) - Js = jacobian(y, [x, z]) - f = pytensor.function([x, z], Js) - vx = rng.uniform(size=(10, 10)).astype(pytensor.config.floatX) - vz = rng.uniform(size=(10, 10)).astype(pytensor.config.floatX) - vJs = f(vx, vz) - evx = np.zeros((10, 10, 10)) - evz = np.zeros((10, 10, 10)) - for dx in range(10): - evx[dx, dx, :] = vx[dx, :] - evz[dx, dx, :] = vz[dx, :] - assert np.allclose(vJs[0], evz) - assert np.allclose(vJs[1], evx) - - -def test_jacobian_scalar(): - x = scalar() - y = x * 2 - rng = np.random.default_rng(seed=utt.fetch_seed()) - - # test when the jacobian is called with a tensor as wrt - Jx = jacobian(y, x) - f = pytensor.function([x], Jx) - vx = np.asarray(rng.uniform(), dtype=pytensor.config.floatX) - assert np.allclose(f(vx), 2) - - # test when input is a shape (1,) vector -- should still be treated as a scalar - Jx = jacobian(y[None], x) - f = pytensor.function([x], Jx) - - # Ensure we hit the scalar grad case (doesn't use scan) - nodes = f.maker.fgraph.apply_nodes - assert not any(isinstance(node.op, Scan) for node in nodes) - - vx = np.asarray(rng.uniform(), dtype=pytensor.config.floatX) - assert np.allclose(f(vx), 2) - - # test when the jacobian is called with a tuple as wrt - Jx = jacobian(y, (x,)) - assert isinstance(Jx, tuple) - f = pytensor.function([x], Jx[0]) - vx = np.asarray(rng.uniform(), dtype=pytensor.config.floatX) - assert np.allclose(f(vx), 2) - - # test when the jacobian is called with a list as wrt - Jx = jacobian(y, [x]) - assert isinstance(Jx, list) - f = pytensor.function([x], Jx[0]) - vx = np.asarray(rng.uniform(), dtype=pytensor.config.floatX) - assert np.allclose(f(vx), 2) - - # test when the jacobian is called with a list of two elements - z = scalar() - y = x * z - Jx = jacobian(y, [x, z]) - f = pytensor.function([x, z], Jx) - vx = np.asarray(rng.uniform(), dtype=pytensor.config.floatX) - vz = np.asarray(rng.uniform(), dtype=pytensor.config.floatX) - vJx = f(vx, vz) - - assert np.allclose(vJx[0], vz) - assert np.allclose(vJx[1], vx) + fn = function([x], jac_y, trust_input=True) + benchmark(fn, np.array([0, 1, 2], dtype=x.type.dtype)) def test_hessian(): @@ -1084,25 +1153,7 @@ def test_hessian(): assert np.allclose(f(vx), np.eye(10) * 2) -def test_jacobian_disconnected_inputs(): - # Test that disconnected inputs are properly handled by jacobian. - - v1 = vector() - v2 = vector() - jacobian_v = pytensor.gradient.jacobian(1 + v1, v2, disconnected_inputs="ignore") - func_v = pytensor.function([v1, v2], jacobian_v) - val = np.arange(4.0).astype(pytensor.config.floatX) - assert np.allclose(func_v(val, val), np.zeros((4, 4))) - - s1 = scalar() - s2 = scalar() - jacobian_s = pytensor.gradient.jacobian(1 + s1, s2, disconnected_inputs="ignore") - func_s = pytensor.function([s2], jacobian_s) - val = np.array(1.0).astype(pytensor.config.floatX) - assert np.allclose(func_s(val), np.zeros(1)) - - -class TestHessianVectorProdudoct: +class TestHessianVectorProduct: def test_rosen(self): x = vector("x", dtype="float64") rosen = (100 * (x[1:] - x[:-1] ** 2) ** 2 + (1 - x[:-1]) ** 2).sum()