diff --git a/pytensor/ifelse.py b/pytensor/ifelse.py index 1383cea263..ecc00dbe1a 100644 --- a/pytensor/ifelse.py +++ b/pytensor/ifelse.py @@ -477,7 +477,8 @@ def cond_make_inplace(fgraph, node): Reshape, Unbroadcast, pt.math.Dot, - pt.math.MaxAndArgmax, + pt.math.Max, + pt.math.Argmax, pt.subtensor.Subtensor, pt.subtensor.IncSubtensor, pt.basic.Alloc, diff --git a/pytensor/link/jax/dispatch/nlinalg.py b/pytensor/link/jax/dispatch/nlinalg.py index 2f92364379..81ff82ada2 100644 --- a/pytensor/link/jax/dispatch/nlinalg.py +++ b/pytensor/link/jax/dispatch/nlinalg.py @@ -2,7 +2,7 @@ from pytensor.link.jax.dispatch import jax_funcify from pytensor.tensor.blas import BatchedDot -from pytensor.tensor.math import Dot, MaxAndArgmax +from pytensor.tensor.math import Argmax, Dot, Max from pytensor.tensor.nlinalg import ( SVD, Det, @@ -104,18 +104,28 @@ def batched_dot(a, b): return batched_dot -@jax_funcify.register(MaxAndArgmax) -def jax_funcify_MaxAndArgmax(op, **kwargs): +@jax_funcify.register(Max) +def jax_funcify_Max(op, **kwargs): axis = op.axis - def maxandargmax(x, axis=axis): + def max(x): + max_res = jnp.max(x, axis) + + return max_res + + return max + + +@jax_funcify.register(Argmax) +def jax_funcify_Argmax(op, **kwargs): + axis = op.axis + + def argmax(x): if axis is None: axes = tuple(range(x.ndim)) else: axes = tuple(int(ax) for ax in axis) - max_res = jnp.max(x, axis) - # NumPy does not support multiple axes for argmax; this is a # work-around keep_axes = jnp.array( @@ -138,6 +148,6 @@ def maxandargmax(x, axis=axis): max_idx_res = jnp.argmax(reshaped_x, axis=-1).astype("int64") - return max_res, max_idx_res + return max_idx_res - return maxandargmax + return argmax diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 207ebd5cf2..fbbec2587c 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -44,7 +44,7 @@ ) from pytensor.scalar.basic import add as add_as from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise -from pytensor.tensor.math import MaxAndArgmax, MulWithoutZeros, Sum +from pytensor.tensor.math import Argmax, MulWithoutZeros, Sum from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad from pytensor.tensor.type import scalar @@ -985,8 +985,8 @@ def log_softmax_py_fn(x): return log_softmax -@numba_funcify.register(MaxAndArgmax) -def numba_funcify_MaxAndArgmax(op, node, **kwargs): +@numba_funcify.register(Argmax) +def numba_funcify_Argmax(op, node, **kwargs): axis = op.axis x_at = node.inputs[0] x_dtype = x_at.type.numpy_dtype @@ -996,8 +996,8 @@ def numba_funcify_MaxAndArgmax(op, node, **kwargs): if x_ndim == 0: @numba_basic.numba_njit(inline="always") - def maxandargmax(x): - return x, 0 + def argmax(x): + return 0 else: axes = tuple(int(ax) for ax in axis) @@ -1006,20 +1006,6 @@ def maxandargmax(x): # work-around keep_axes = tuple(i for i in range(x_ndim) if i not in axes) - reduce_max_py_fn = create_multiaxis_reducer( - scalar_maximum, - -np.inf, - axes, - x_ndim, - x_dtype, - return_scalar=False, - ) - reduce_max = jit_compile_reducer( - Apply(node.op, node.inputs, [node.outputs[0].clone()]), - reduce_max_py_fn, - reduce_to_scalar=False, - ) - reduced_x_ndim = x_ndim - len(axes) + 1 argmax_axis = create_axis_apply_fn( np.argmax, reduced_x_ndim - 1, reduced_x_ndim, np.int64 @@ -1030,9 +1016,7 @@ def maxandargmax(x): sl2 = slice(len(keep_axes), None) @numba_basic.numba_njit - def maxandargmax(x): - max_res = reduce_max(x) - + def argmax(x): # Not-reduced axes in front transposed_x = np.ascontiguousarray(np.transpose(x, reaxis_order)) kept_shape = transposed_x.shape[sl1] @@ -1048,6 +1032,6 @@ def maxandargmax(x): max_idx_res = argmax_axis(reshaped_x) - return max_res, max_idx_res + return max_idx_res - return maxandargmax + return argmax diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index 63a943e1f1..181e813f50 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -14,7 +14,6 @@ from pytensor.graph.replace import _vectorize_node from pytensor.link.c.op import COp from pytensor.link.c.params_type import ParamsType -from pytensor.link.c.type import Generic from pytensor.misc.safe_asarray import _asarray from pytensor.printing import pprint from pytensor.raise_op import Assert @@ -29,6 +28,7 @@ constant, stack, switch, + zeros_like, ) from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback from pytensor.tensor.elemwise import ( @@ -107,6 +107,14 @@ float64_atol = 1e-8 +def __getattr__(name): + if name == "MaxAndArgmax": + raise AttributeError( + "The class `MaxandArgmax` has been deprecated. Call `Max` and `Argmax` seperately as an alternative." + ) + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + def _get_atol_rtol(a, b): tiny = ("float16",) narrow = ("float32", "complex64") @@ -134,215 +142,6 @@ def _allclose(a, b, rtol=None, atol=None): return np.allclose(a, b, atol=atol_, rtol=rtol_) -class MaxAndArgmax(COp): - """ - Calculate the max and argmax over a given axis or over all axes. - - """ - - nin = 2 # tensor, axis - nout = 2 # max val, max idx - E_axis = "invalid axis" - params_type = Generic() - __props__ = ("axis",) - _f16_ok = True - - def __init__(self, axis): - assert isinstance(axis, tuple | list) - self.axis = tuple(axis) - - def get_params(self, node): - return self.axis - - def make_node(self, x): - x = as_tensor_variable(x) - - # Keep the original shapes for axes on which we do not perform the max/argmax. - all_axes = set(self.axis) - inputs = [x] - out_shape = tuple(s for i, s in enumerate(x.type.shape) if i not in all_axes) - outputs = [ - tensor(dtype=x.type.dtype, shape=out_shape, name="max"), - tensor(dtype="int64", shape=out_shape, name="argmax"), - ] - return Apply(self, inputs, outputs) - - def perform(self, node, inp, outs): - x = inp[0] - axes = self.axis - max, max_idx = outs - if axes is None: - axes = tuple(range(x.ndim)) - else: - axes = tuple(int(ax) for ax in axes) - max[0] = _asarray(np.max(x, axes), dtype=node.outputs[0].dtype) - # Numpy does not support multiple axes for argmax - # Work around - keep_axes = np.array([i for i in range(x.ndim) if i not in axes], dtype="int64") - # Not-reduced axes in front - transposed_x = np.transpose(x, np.concatenate((keep_axes, axes))) - kept_shape = transposed_x.shape[: len(keep_axes)] - reduced_shape = transposed_x.shape[len(keep_axes) :] - - # Numpy.prod returns 1.0 when arg is empty, so we cast it to int64 - # Otherwise reshape would complain citing float arg - new_shape = (*kept_shape, np.prod(reduced_shape, dtype="int64")) - reshaped_x = transposed_x.reshape(new_shape) - - max_idx[0] = _asarray(np.argmax(reshaped_x, axis=-1), dtype="int64") - - def c_code(self, node, name, inp, out, sub): - if len(self.axis) != 1 and len(self.axis) != node.inputs[0].ndim: - raise NotImplementedError( - "NumPy C-API can compute max and argmax only for 1 axis or for all axes." - ) - x = inp[0] - axis = sub["params"] - max, argmax = out - fail = sub["fail"] - ret = """ - #if PY_MAJOR_VERSION >= 3 - #ifndef PyInt_AS_LONG - #define PyInt_AS_LONG PyLong_AS_LONG - #endif - #endif - - int axis; - - if (PyTuple_GET_SIZE(%(axis)s) == PyArray_NDIM(%(x)s)) { - axis = NPY_MAXDIMS; - } else if(PyTuple_GET_SIZE(%(axis)s) == 1) { - PyObject* axis_object = PyTuple_GET_ITEM(%(axis)s, 0); - axis = (int)PyInt_AS_LONG(axis_object); - if (axis > PyArray_NDIM(%(x)s)-1 || axis < -PyArray_NDIM(%(x)s)) { - PyErr_SetString(PyExc_ValueError, - "MaxAndArgmax: bad axis argument"); - %(fail)s - } - } else { - PyErr_SetString(PyExc_NotImplementedError, - "MaxAndArgmax: NumPy C-API can compute max and argmax only for 1 axis or for all axes."); - %(fail)s - } - - Py_CLEAR(%(max)s); - Py_CLEAR(%(argmax)s);//todo pass them as out parameter. - - %(max)s = (PyArrayObject*)PyArray_Max(%(x)s, axis, NULL); - if (%(max)s == NULL) { - %(fail)s; - } - if (!PyArray_CheckExact(%(max)s)) { - %(max)s = (PyArrayObject*)PyArray_FromAny((PyObject*)%(max)s, NULL, 0, 0, NPY_ARRAY_ENSUREARRAY, NULL); - if(%(max)s == NULL){ - %(fail)s; - } - } - - %(argmax)s = (PyArrayObject*)PyArray_ArgMax(%(x)s, axis, NULL); - if (%(argmax)s == NULL) { - Py_CLEAR(%(max)s); - %(fail)s; - } - if (!PyArray_CheckExact(%(argmax)s)) { - %(argmax)s = (PyArrayObject*)PyArray_FromAny((PyObject*)%(argmax)s, NULL, 0, 0, NPY_ARRAY_ENSUREARRAY, NULL); - if(%(argmax)s == NULL){ - %(fail)s; - } - } - if (PyArray_TYPE(%(argmax)s) != NPY_INT64) { - PyObject * tmp = PyArray_Cast(%(argmax)s, NPY_INT64); - if (NULL == tmp){ - %(fail)s; - } - Py_DECREF(%(argmax)s); - %(argmax)s = (PyArrayObject*)tmp; - } - """ - return ret % locals() - - def c_code_cache_version(self): - return (5,) - - def infer_shape(self, fgraph, node, shapes): - ishape = shapes[0] - rval = tuple( - ishape[i] - for (i, b) in enumerate(node.inputs[0].type.broadcastable) - if i not in self.axis - ) - return [rval, rval] - - def R_op(self, inputs, eval_points): - if eval_points[0] is None: - return [None, None] - if len(self.axis) != 1: - raise ValueError("R_op supported for arg_max only for one axis!") - if self.axis[0] > 1: - raise ValueError("R_op supported for arg_max only when axis is 0 or 1") - if inputs[0].ndim != 2: - raise ValueError("R_op supported for arg_max only when input is a matrix") - max_vals, max_pos = self.make_node(*inputs).outputs - if self.axis[0] == 0: - return [eval_points[0][max_pos, arange(eval_points[0].shape[1])], None] - else: - return [eval_points[0][arange(eval_points[0].shape[0]), max_pos], None] - - def grad(self, inp, grads): - # The strict sense mathematical gradient of the maximum function is - # not calculated here for it is not defined at every point where some - # coordinates are identical. However, since the latter set has null - # Lebesgue measure, the result may be interpreted as weak gradient. - - # @note: This function should work correctly for L{vector}s. - # (x, y), (gz, gw) - # gz*dz/dx + gw*dw/dx, gz*dz/dy + gw*dw/dy - # gMax * dMax/dx + gArgMax * dArgMax/dx, - # gMax * dMax/daxis + gArgMax * dArgMax/daxis - # g_max has one less dimension than x, so you need to complete - # g_max to x's shape when axis=0 the broadcasting mechanism - # does it automatically - x = inp[0] - axis = as_tensor_variable(self.axis) - g_max, g_max_idx = grads - - g_max_disconnected = isinstance(g_max.type, DisconnectedType) - g_max_idx_disconnected = isinstance(g_max_idx.type, DisconnectedType) - - # if the op is totally disconnected, so are its inputs - if g_max_disconnected and g_max_idx_disconnected: - return [DisconnectedType()(), DisconnectedType()()] - - # if the max is disconnected but the argmax is not, - # the gradient on its inputs is zero - if g_max_disconnected: - return [x.zeros_like()] - if NoneConst.equals(axis): - axis_ = list(range(x.ndim)) - else: - axis_ = axis - xmax = max(x, axis_) - - # Raise the g_max and xmax to the same number of dim as the input. - pattern = [] - out_dim = 0 - if NoneConst.equals(axis): - # We are taking the max/argmax over all dimensions. - axis = None - for i in range(x.ndim): - if axis is None or i in axis.data: - pattern.append("x") - else: - pattern.append(out_dim) - out_dim += 1 - g_max_pad = DimShuffle(g_max.broadcastable, pattern)(g_max) - xmax_pad = DimShuffle(xmax.broadcastable, pattern)(xmax) - - # Set the grad to the correct position. - g_x = eq(xmax_pad, x) * g_max_pad - return (g_x,) - - class Argmax(COp): """ Calculate the argmax over a given axis or over all axes. @@ -359,7 +158,7 @@ class Argmax(COp): def __init__(self, axis): if axis is not None: axis = tuple(axis) - self.axis = tuple(axis) + self.axis = axis def get_params(self, node): if self.axis is not None and len(self.axis) == 1: @@ -395,7 +194,6 @@ def perform(self, node, inp, outs): (max_idx,) = outs if axes is None: axes = tuple(range(x.ndim)) - # Numpy does not support multiple axes for argmax # Work around keep_axes = np.array([i for i in range(x.ndim) if i not in axes], dtype="int64") @@ -403,7 +201,7 @@ def perform(self, node, inp, outs): transposed_x = np.transpose(x, np.concatenate((keep_axes, axes))) kept_shape = transposed_x.shape[: len(keep_axes)] reduced_shape = transposed_x.shape[len(keep_axes) :] - new_shape = (*kept_shape, np.prod(reduced_shape)) + new_shape = (*kept_shape, np.prod(reduced_shape, dtype="int64")) reshaped_x = transposed_x.reshape(new_shape) max_idx[0] = _asarray(np.argmax(reshaped_x, axis=-1), dtype="int64") @@ -470,6 +268,9 @@ def infer_shape(self, fgraph, node, shapes): ) return [rval] + def R_op(self, inputs, eval_points): + raise ValueError("Argmax is non-diifferentiable") + def grad(self, inp, grads): (x,) = inp @@ -477,7 +278,6 @@ def grad(self, inp, grads): @_vectorize_node.register(Argmax) -@_vectorize_node.register(MaxAndArgmax) def vectorize_argmax_node(op, node, batch_x): core_ndim = node.inputs[0].type.ndim batch_ndim = batch_x.type.ndim - core_ndim @@ -595,12 +395,24 @@ def max_and_argmax(a, axis=None, keepdims=False): """ # Check axis and convert it to a Python list of integers. - # Axis will be used as an op param of MaxAndArgmax. + # Axis will be used as an op param of Max and Argmax. a = as_tensor_variable(a) + + is_axis_empty = False + if axis == (): + is_axis_empty = True + axis = check_and_normalize_axes(a, axis) - if len(axis) == 0: - axis = list(range(a.type.ndim)) - out, argout = MaxAndArgmax(axis)(a) + + if len(axis) == 0 and not is_axis_empty: + axis = None + + out = Max(axis)(a) + + if not is_axis_empty: + argout = Argmax(axis)(a) + else: + argout = zeros_like(a, dtype="int64") if keepdims: out = makeKeepDims(a, out, axis) @@ -654,6 +466,74 @@ def clone(self, **kwargs): axis = kwargs.get("axis", self.axis) return type(self)(axis=axis) + def grad(self, inp, grads): + # The strict sense mathematical gradient of the maximum function is + # not calculated here for it is not defined at every point where some + # coordinates are identical. However, since the latter set has null + # Lebesgue measure, the result may be interpreted as weak gradient. + + # @note: This function should work correctly for L{vector}s. + # (x, y), (gz, gw) + # gz*dz/dx + gw*dw/dx, gz*dz/dy + gw*dw/dy + # gMax * dMax/dx + gArgMax * dArgMax/dx, + # gMax * dMax/daxis + gArgMax * dArgMax/daxis + # g_max has one less dimension than x, so you need to complete + # g_max to x's shape when axis=0 the broadcasting mechanism + # does it automatically + x = inp[0] + if self.axis is None: + self.axis = tuple(range(x.ndim)) + axis = as_tensor_variable(self.axis) + (g_max,) = grads + + g_max_disconnected = isinstance(g_max.type, DisconnectedType) + + # if the op is totally disconnected, so are its inputs + if g_max_disconnected: + return [DisconnectedType()()] + + # if NoneConst.equals(axis): + if axis is None: + axis_ = list(range(x.ndim)) + else: + axis_ = axis + xmax = max(x, axis_) + + # Raise the g_max and xmax to the same number of dim as the input. + pattern = [] + out_dim = 0 + if NoneConst.equals(axis): + # We are taking the max/argmax over all dimensions. + axis = None + for i in range(x.ndim): + if axis is None or i in axis.data: + pattern.append("x") + else: + pattern.append(out_dim) + out_dim += 1 + g_max_pad = DimShuffle(g_max.broadcastable, pattern)(g_max) + xmax_pad = DimShuffle(xmax.broadcastable, pattern)(xmax) + + # Set the grad to the correct position. + g_x = eq(xmax_pad, x) * g_max_pad + return (g_x,) + + def R_op(self, inputs, eval_points): + if eval_points[0] is None: + return [None, None] + if len(self.axis) != 1: + raise ValueError("R_op supported for arg_max only for one axis!") + if self.axis[0] > 1: + raise ValueError("R_op supported for arg_max only when axis is 0 or 1") + if inputs[0].ndim != 2: + raise ValueError("R_op supported for arg_max only when input is a matrix") + max_pos = Argmax(self.axis).make_node(*inputs).outputs + # print(eval_points[0].eval()) + if self.axis[0] == 0: + return [eval_points[0][max_pos, arange(eval_points[0].shape[1])], None] + else: + return [eval_points[0][arange(eval_points[0].shape[0]), max_pos], None] + class Min(NonZeroDimsCAReduce): nfunc_spec = ("min", 1, 1) @@ -685,16 +565,6 @@ def max(x, axis=None, keepdims=False): We return an error as numpy when we reduce a dim with a shape of 0. """ - - # We have a choice of implementing this call with the - # CAReduce op or the MaxAndArgmax op. - - # MaxAndArgmax supports grad and Rop, so we prefer to use that. - # CAReduce is faster, but optimizations will replace MaxAndArgmax[0] - # with CAReduce at compile time, so at this stage the important - # thing is supporting all user interface features, not speed. - # Some cases can be implemented only with CAReduce. - out = max_and_argmax(x, axis)[0] if keepdims: diff --git a/pytensor/tensor/rewriting/uncanonicalize.py b/pytensor/tensor/rewriting/uncanonicalize.py index 15a316c5a0..a44870ded2 100644 --- a/pytensor/tensor/rewriting/uncanonicalize.py +++ b/pytensor/tensor/rewriting/uncanonicalize.py @@ -35,31 +35,12 @@ from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter from pytensor.tensor.basic import Alloc, alloc, constant from pytensor.tensor.elemwise import CAReduce, DimShuffle -from pytensor.tensor.math import Argmax, Max, MaxAndArgmax, Min, neg +from pytensor.tensor.math import Min, neg from pytensor.tensor.rewriting.basic import register_uncanonicalize from pytensor.tensor.shape import Reshape, reshape from pytensor.tensor.subtensor import Subtensor -@register_uncanonicalize -@node_rewriter([MaxAndArgmax]) -def local_max_and_argmax(fgraph, node): - """ - If we don't use the argmax, change it to a max only. - """ - if isinstance(node.op, MaxAndArgmax): - axis = node.op.axis - if len(fgraph.clients[node.outputs[1]]) == 0: - new = Max(axis)(node.inputs[0]) - copy_stack_trace(node.outputs[0], new) - return [new, None] - - if len(fgraph.clients[node.outputs[0]]) == 0: - new = Argmax(axis)(node.inputs[0]) - copy_stack_trace(node.outputs[0], new) - return [None, new] - - @register_uncanonicalize @node_rewriter([neg]) def local_max_to_min(fgraph, node): @@ -71,7 +52,7 @@ def local_max_to_min(fgraph, node): Notes ----- We don't need an opt that will do the reverse as by default - the interface put only MaxAndArgmax into the graph. + the interface put only Max into the graph. """ if node.op == neg and node.inputs[0].owner: diff --git a/tests/link/jax/test_nlinalg.py b/tests/link/jax/test_nlinalg.py index 3a64fda364..2175670ee6 100644 --- a/tests/link/jax/test_nlinalg.py +++ b/tests/link/jax/test_nlinalg.py @@ -11,7 +11,7 @@ from pytensor.link.jax import JAXLinker from pytensor.tensor import blas as pt_blas from pytensor.tensor import nlinalg as pt_nlinalg -from pytensor.tensor.math import MaxAndArgmax, maximum +from pytensor.tensor.math import Argmax, Max, maximum from pytensor.tensor.math import max as pt_max from pytensor.tensor.type import dvector, matrix, scalar, tensor3, vector from tests.link.jax.test_basic import compare_jax_and_py @@ -88,7 +88,8 @@ def test_jax_basic_multiout_omni(): # Test that a single output of a multi-output `Op` can be used as input to # another `Op` x = dvector() - mx, amx = MaxAndArgmax([0])(x) + mx = Max([0])(x) + amx = Argmax([0])(x) out = mx * amx out_fg = FunctionGraph([x], [out]) compare_jax_and_py(out_fg, [np.r_[1, 2]]) diff --git a/tests/link/numba/test_elemwise.py b/tests/link/numba/test_elemwise.py index b8c131ead6..8bbbe164fc 100644 --- a/tests/link/numba/test_elemwise.py +++ b/tests/link/numba/test_elemwise.py @@ -552,8 +552,53 @@ def test_LogSoftmax(x, axis, exc): ), ], ) -def test_MaxAndArgmax(x, axes, exc): - g = ptm.MaxAndArgmax(axes)(x) +def test_Max(x, axes, exc): + g = ptm.Max(axes)(x) + + if isinstance(g, list): + g_fg = FunctionGraph(outputs=g) + else: + g_fg = FunctionGraph(outputs=[g]) + + cm = contextlib.suppress() if exc is None else pytest.warns(exc) + with cm: + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, SharedVariable | Constant) + ], + ) + + +@pytest.mark.parametrize( + "x, axes, exc", + [ + ( + set_test_value(pt.dscalar(), np.array(0.0, dtype="float64")), + [], + None, + ), + ( + set_test_value(pt.dvector(), rng.random(size=(3,)).astype("float64")), + [0], + None, + ), + ( + set_test_value(pt.dmatrix(), rng.random(size=(3, 2)).astype("float64")), + [0], + None, + ), + ( + set_test_value(pt.dmatrix(), rng.random(size=(3, 2)).astype("float64")), + [0, 1], + None, + ), + ], +) +def test_Argmax(x, axes, exc): + g = ptm.Argmax(axes)(x) if isinstance(g, list): g_fg = FunctionGraph(outputs=g) diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 84322989bf..29c07456b5 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -38,7 +38,7 @@ from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.math import ( Dot, - MaxAndArgmax, + Max, Prod, Sum, _conj, @@ -3734,8 +3734,8 @@ def check_max_log_sum_exp(x, axis, dimshuffle_op=None): return # In mode FAST_COMPILE, the rewrites don't replace the - # `MaxAndArgmax` `Op`. - if isinstance(node.op, MaxAndArgmax): + # `Max` `Op`. + if isinstance(node.op, Max): return # TODO FIXME: Refactor this test so that it makes a direct assertion and diff --git a/tests/tensor/rewriting/test_uncanonicalize.py b/tests/tensor/rewriting/test_uncanonicalize.py index d36447ac20..9d5011b6db 100644 --- a/tests/tensor/rewriting/test_uncanonicalize.py +++ b/tests/tensor/rewriting/test_uncanonicalize.py @@ -9,8 +9,6 @@ from pytensor.graph.rewriting.basic import out2in from pytensor.link.basic import PerformLinker from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise -from pytensor.tensor.math import MaxAndArgmax, max_and_argmax -from pytensor.tensor.math import max as pt_max from pytensor.tensor.math import min as pt_min from pytensor.tensor.rewriting.uncanonicalize import ( local_alloc_dimshuffle, @@ -23,67 +21,12 @@ from tests.link.test_link import make_function -class TestMaxAndArgmax: - def test_optimization(self): - # If we use only the max output, we should replace this op with - # a faster one. - mode = pytensor.compile.mode.get_default_mode().including( - "canonicalize", "fast_run" - ) - - for axis in [0, 1, -1]: - n = matrix() - - f = function([n], max_and_argmax(n, axis)[0], mode=mode) - topo = f.maker.fgraph.toposort() - assert len(topo) == 1 - assert isinstance(topo[0].op, CAReduce) - - f = function([n], max_and_argmax(n, axis), mode=mode) - topo = f.maker.fgraph.toposort() - assert len(topo) == 1 - assert isinstance(topo[0].op, MaxAndArgmax) - - class TestMinMax: def setup_method(self): self.mode = pytensor.compile.mode.get_default_mode().including( "canonicalize", "fast_run" ) - def test_optimization_max(self): - data = np.asarray(np.random.random((2, 3)), dtype=config.floatX) - n = matrix() - - for axis in [0, 1, -1]: - f = function([n], pt_max(n, axis), mode=self.mode) - topo = f.maker.fgraph.toposort() - assert len(topo) == 1 - assert isinstance(topo[0].op, CAReduce) - f(data) - - f = function([n], pt_max(-n, axis), mode=self.mode) - topo = f.maker.fgraph.toposort() - assert len(topo) == 2 - assert isinstance(topo[0].op, Elemwise) - assert isinstance(topo[0].op.scalar_op, ps.Neg) - assert isinstance(topo[1].op, CAReduce) - f(data) - - f = function([n], -pt_max(n, axis), mode=self.mode) - topo = f.maker.fgraph.toposort() - assert len(topo) == 2 - assert isinstance(topo[0].op, CAReduce) - assert isinstance(topo[1].op, Elemwise) - assert isinstance(topo[1].op.scalar_op, ps.Neg) - f(data) - - f = function([n], -pt_max(-n, axis), mode=self.mode) - topo = f.maker.fgraph.toposort() - assert len(topo) == 1 - assert isinstance(topo[0].op, CAReduce) # min - f(data) - def test_optimization_min(self): data = np.asarray(np.random.random((2, 3)), dtype=config.floatX) n = matrix() diff --git a/tests/tensor/test_math.py b/tests/tensor/test_math.py index e346348406..6b6f8def13 100644 --- a/tests/tensor/test_math.py +++ b/tests/tensor/test_math.py @@ -11,6 +11,7 @@ from numpy.testing import assert_array_equal from scipy.special import logsumexp as scipy_logsumexp +import pytensor import pytensor.scalar as ps from pytensor.compile.debugmode import DebugMode from pytensor.compile.function import function @@ -39,7 +40,7 @@ from pytensor.tensor.math import ( Argmax, Dot, - MaxAndArgmax, + Max, Mean, Prod, ProdWithoutZeros, @@ -760,11 +761,12 @@ def test_isnan(): class TestMaxAndArgmax: def setup_method(self): - MaxAndArgmax.debug = 0 + Max.debug = 0 + Argmax.debug = 0 def test_basic(self): - n = as_tensor_variable(5.0) - v, i = eval_outputs(max_and_argmax(n)) + n = as_tensor_variable(5) + v, i = eval_outputs(max_and_argmax(n, axis=())) assert v == 5.0 assert i == 0 assert i.dtype == "int64" @@ -1030,31 +1032,45 @@ def test_vectorize(self, core_axis, batch_axis): x = tensor(shape=(5, 5, 5, 5)) batch_x = tensor(shape=(3, 5, 5, 5, 5)) - # Test MaxAndArgmax - max_x, argmax_x = max_and_argmax(x, axis=core_axis) - node = max_x.owner - assert isinstance(node.op, MaxAndArgmax) - - new_node = vectorize_node(node, batch_x) - assert isinstance(new_node.op, MaxAndArgmax) - assert new_node.op.axis == batch_axis + argmax_x = argmax(x, axis=core_axis) - # Test Argmax - # Argmax is not user-facing, so we have to create it manually - node = Argmax(axis=node.op.axis).make_node(x) + arg_max_node = argmax_x.owner + new_node = vectorize_node(arg_max_node, batch_x) - new_node = vectorize_node(node, batch_x) assert isinstance(new_node.op, Argmax) assert new_node.op.axis == batch_axis + def test_max_empty_axis(self): + x = np.random.normal(size=(2, 3, 5, 7)) + axis = () + + non_axis = tuple(i for i in range(x.ndim) if i not in axis) + shape_axis = tuple(x.shape[dim] for dim in axis) + shape_non_axis = tuple(x.shape[dim] for dim in non_axis) + x_transposed = x.transpose(*axis, *non_axis) + + x_axis_raveled = x_transposed.reshape( + np.prod(shape_axis, dtype=int), np.prod(shape_non_axis, dtype=int) + ) + max_x = max_and_argmax(x, axis=axis)[0].eval() + argmax_x = max_and_argmax(x, axis=axis)[1].eval() + + raveled_max = x_axis_raveled[ + argmax_x.ravel(), np.arange(np.prod(shape_non_axis, dtype=int)) + ] + indirect_max = raveled_max.reshape(shape_non_axis) + + np.testing.assert_allclose(max_x, x.max(axis=axis)) + np.testing.assert_allclose(indirect_max, x.max(axis=axis)) + class TestArgminArgmax: def setup_method(self): - MaxAndArgmax.debug = 0 + Argmax.debug = 0 def test_scalar(self): for fct in [argmin, argmax]: - n = as_tensor_variable(5.0) + n = as_tensor_variable([5.0]) i = eval_outputs(fct(n)) assert i == 0 v = eval_outputs(fct(n).shape) @@ -1212,7 +1228,7 @@ def test_bool(self): class TestMinMax: def setup_method(self): - MaxAndArgmax.debug = 0 + Max.debug = 0 def test_scalar(self): for fct in [max, min]: @@ -1379,6 +1395,7 @@ def _grad_list(self): # check_grad_max(data, eval_outputs(grad(max_and_argmax(n, # axis=1)[0], n)),axis=1) + @pytest.mark.xfail(reason="Fails due to #770") def test_uint(self): for dtype in ("uint8", "uint16", "uint32", "uint64"): itype = np.iinfo(dtype) @@ -1404,6 +1421,14 @@ def test_bool(self): assert np.all(i) +def test_MaxAndArgmax_deprecated(): + with pytest.raises( + AttributeError, + match="The class `MaxandArgmax` has been deprecated. Call `Max` and `Argmax` seperately as an alternative.", + ): + pytensor.tensor.math.MaxAndArgmax + + rng = np.random.default_rng(seed=utt.fetch_seed()) TestClip1 = makeTester( name="ClipTester", @@ -2572,27 +2597,50 @@ def test_Mean(self): [adtens3], [Mean(aiscal_val)(adtens3)], [adtens3_val], Mean ) - def test_MaxAndArgmax(self): + def test_Max(self): + adtens3 = dtensor3() + adtens3_val = random(4, 5, 3) + self._compile_and_check( + [adtens3], max_and_argmax(adtens3, None), [adtens3_val], Max + ) + + self._compile_and_check( + [adtens3], max_and_argmax(adtens3, 0), [adtens3_val], Max + ) + + self._compile_and_check( + [adtens3], max_and_argmax(adtens3, 1), [adtens3_val], Max + ) + + self._compile_and_check( + [adtens3], max_and_argmax(adtens3, 2), [adtens3_val], Max + ) + + self._compile_and_check( + [adtens3], max_and_argmax(adtens3, [0, 1, 2]), [adtens3_val], Max + ) + + def test_Argmax(self): adtens3 = dtensor3() adtens3_val = random(4, 5, 3) self._compile_and_check( - [adtens3], max_and_argmax(adtens3, None), [adtens3_val], MaxAndArgmax + [adtens3], max_and_argmax(adtens3, None), [adtens3_val], Argmax ) self._compile_and_check( - [adtens3], max_and_argmax(adtens3, 0), [adtens3_val], MaxAndArgmax + [adtens3], max_and_argmax(adtens3, 0), [adtens3_val], Argmax ) self._compile_and_check( - [adtens3], max_and_argmax(adtens3, 1), [adtens3_val], MaxAndArgmax + [adtens3], max_and_argmax(adtens3, 1), [adtens3_val], Argmax ) self._compile_and_check( - [adtens3], max_and_argmax(adtens3, 2), [adtens3_val], MaxAndArgmax + [adtens3], max_and_argmax(adtens3, 2), [adtens3_val], Argmax ) self._compile_and_check( - [adtens3], max_and_argmax(adtens3, [0, 1, 2]), [adtens3_val], MaxAndArgmax + [adtens3], max_and_argmax(adtens3, [0, 1, 2]), [adtens3_val], Argmax ) def test_Dot(self):