From 44f67e0e07f8961b3736244a5cc5b5ef93029975 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 5 Jun 2025 14:08:07 +0200 Subject: [PATCH] Create test for mismatch between C and python Psi implementation --- pytensor/scalar/math.py | 70 +++++++++++++++++---------------------- tests/scalar/test_math.py | 34 +++++++++++++++++++ 2 files changed, 65 insertions(+), 39 deletions(-) diff --git a/pytensor/scalar/math.py b/pytensor/scalar/math.py index 86029e626f..ccfaff0ae9 100644 --- a/pytensor/scalar/math.py +++ b/pytensor/scalar/math.py @@ -378,53 +378,42 @@ def L_op(self, inputs, outputs, grads): def c_support_code(self, **kwargs): return """ - // For GPU support - #ifdef WITHIN_KERNEL - #define DEVICE WITHIN_KERNEL - #else - #define DEVICE - #endif - - #ifndef ga_double - #define ga_double double - #endif - #ifndef _PSIFUNCDEFINED #define _PSIFUNCDEFINED - DEVICE double _psi(ga_double x) { + double _psi(double x) { - /*taken from - Bernardo, J. M. (1976). Algorithm AS 103: - Psi (Digamma) Function. Applied Statistics. 25 (3), 315-317. - http://www.uv.es/~bernardo/1976AppStatist.pdf */ + /*taken from + Bernardo, J. M. (1976). Algorithm AS 103: + Psi (Digamma) Function. Applied Statistics. 25 (3), 315-317. + http://www.uv.es/~bernardo/1976AppStatist.pdf */ - ga_double y, R, psi_ = 0; - ga_double S = 1.0e-5; - ga_double C = 8.5; - ga_double S3 = 8.333333333e-2; - ga_double S4 = 8.333333333e-3; - ga_double S5 = 3.968253968e-3; - ga_double D1 = -0.5772156649; + double y, R, psi_ = 0; + double S = 1.0e-5; + double C = 8.5; + double S3 = 8.333333333e-2; + double S4 = 8.333333333e-3; + double S5 = 3.968253968e-3; + double D1 = -0.5772156649; - y = x; + y = x; - if (y <= 0.0) - return psi_; + if (y <= 0.0) + return psi_; - if (y <= S) - return D1 - 1.0/y; + if (y <= S) + return D1 - 1.0/y; - while (y < C) { - psi_ = psi_ - 1.0 / y; - y = y + 1; - } + while (y < C) { + psi_ = psi_ - 1.0 / y; + y = y + 1; + } - R = 1.0 / y; - psi_ = psi_ + log(y) - .5 * R ; - R= R*R; - psi_ = psi_ - R * (S3 - R * (S4 - R * S5)); + R = 1.0 / y; + psi_ = psi_ + log(y) - .5 * R ; + R= R*R; + psi_ = psi_ - R * (S3 - R * (S4 - R * S5)); - return psi_; + return psi_; } #endif """ @@ -433,10 +422,13 @@ def c_code(self, node, name, inp, out, sub): (x,) = inp (z,) = out if node.inputs[0].type in float_types: - return f"""{z} = - _psi({x});""" + dtype = "npy_" + node.outputs[0].dtype + return f"{z} = ({dtype}) _psi({x});" raise NotImplementedError("only floating point is implemented") + def c_code_cache_version(self): + return (1,) + psi = Psi(upgrade_to_float, name="psi") diff --git a/tests/scalar/test_math.py b/tests/scalar/test_math.py index f4a9f2d414..63373577c1 100644 --- a/tests/scalar/test_math.py +++ b/tests/scalar/test_math.py @@ -2,6 +2,7 @@ import numpy as np import pytest +import scipy import scipy.special as sp import pytensor.tensor as pt @@ -19,6 +20,7 @@ gammal, gammau, hyp2f1, + psi, ) from tests.link.test_link import make_function @@ -149,3 +151,35 @@ def test_scalarloop_grad_mixed_dtypes(op, scalar_loop_grads): (var.owner and isinstance(var.owner.op, ScalarLoop)) for var in ancestors(grad) ) + + +@pytest.mark.parametrize( + "linker", + [ + "py", + pytest.param( + "c", + marks=pytest.mark.xfail( + reason="C implementation does not support negative inputs" + ), + ), + ], +) +def test_psi(linker): + x = float64("x") + out = psi(x) + + fn = function([x], out, mode=Mode(linker=linker, optimizer="fast_run")) + fn.dprint() + + x_test = np.float64(0.5) + np.testing.assert_allclose( + fn(x_test), + scipy.special.psi(x_test), + strict=True, + ) + np.testing.assert_allclose( + fn(-x_test), + scipy.special.psi(-x_test), + strict=True, + )