Skip to content

Commit 9629256

Browse files
jessegrabowskiricardoV94
authored andcommitted
Add dprint method to PointFunc and pickle regression test
1 parent f34eb26 commit 9629256

File tree

2 files changed

+30
-5
lines changed

2 files changed

+30
-5
lines changed

pymc/pytensorf.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -567,10 +567,8 @@ def __init__(self, f):
567567
def __call__(self, state):
568568
return self.f(**state)
569569

570-
def __getattr__(self, item):
571-
"""Allow access to the original function attributes."""
572-
# This is only reached if `__getattribute__` fails.
573-
return getattr(self.f, item)
570+
def dprint(self, **kwrags):
571+
return self.f.dprint(**kwrags)
574572

575573

576574
class CallableTensor:

tests/test_pytensorf.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -746,7 +746,7 @@ def test_hessian_sign_change_warning(func):
746746
assert equal_computations([res_neg], [-res])
747747

748748

749-
def test_point_func():
749+
def test_point_func(capsys):
750750
x, y = pt.vectors("x", "y")
751751
outs = x * 2 + y**2
752752
f = compile([x, y], outs)
@@ -758,3 +758,30 @@ def test_point_func():
758758
dprint_res = point_f.dprint(file="str")
759759
expected_dprint_res = point_f.f.dprint(file="str")
760760
assert dprint_res == expected_dprint_res
761+
762+
point_f.dprint(print_shape=True)
763+
captured = capsys.readouterr()
764+
765+
# The shape=(?,) arises because the inputs are dvector. This checks that the dprint works, and the print_shape
766+
# kwargs was correctly forwarded
767+
assert "shape=(?,)" in captured.out
768+
769+
770+
def test_pickle_point_func():
771+
"""
772+
Regression test for https://github.com/pymc-devs/pymc/issues/7857
773+
"""
774+
import cloudpickle
775+
776+
x, y = pt.vectors("x", "y")
777+
outs = x * 2 + y**2
778+
f = compile([x, y], outs)
779+
780+
point_f = PointFunc(f)
781+
point_f_pickled = cloudpickle.dumps(point_f)
782+
point_f_unpickled = cloudpickle.loads(point_f_pickled)
783+
784+
# Check that the function survived the round-trip
785+
np.testing.assert_allclose(
786+
point_f_unpickled({"y": [3], "x": [2]}), point_f({"y": [3], "x": [2]})
787+
)

0 commit comments

Comments
 (0)