@@ -746,7 +746,7 @@ def test_hessian_sign_change_warning(func):
746
746
assert equal_computations ([res_neg ], [- res ])
747
747
748
748
749
- def test_point_func ():
749
+ def test_point_func (capsys ):
750
750
x , y = pt .vectors ("x" , "y" )
751
751
outs = x * 2 + y ** 2
752
752
f = compile ([x , y ], outs )
@@ -758,3 +758,30 @@ def test_point_func():
758
758
dprint_res = point_f .dprint (file = "str" )
759
759
expected_dprint_res = point_f .f .dprint (file = "str" )
760
760
assert dprint_res == expected_dprint_res
761
+
762
+ point_f .dprint (print_shape = True )
763
+ captured = capsys .readouterr ()
764
+
765
+ # The shape=(?,) arises because the inputs are dvector. This checks that the dprint works, and the print_shape
766
+ # kwargs was correctly forwarded
767
+ assert "shape=(?,)" in captured .out
768
+
769
+
770
+ def test_pickle_point_func ():
771
+ """
772
+ Regression test for https://github.com/pymc-devs/pymc/issues/7857
773
+ """
774
+ import cloudpickle
775
+
776
+ x , y = pt .vectors ("x" , "y" )
777
+ outs = x * 2 + y ** 2
778
+ f = compile ([x , y ], outs )
779
+
780
+ point_f = PointFunc (f )
781
+ point_f_pickled = cloudpickle .dumps (point_f )
782
+ point_f_unpickled = cloudpickle .loads (point_f_pickled )
783
+
784
+ # Check that the function survived the round-trip
785
+ np .testing .assert_allclose (
786
+ point_f_unpickled ({"y" : [3 ], "x" : [2 ]}), point_f ({"y" : [3 ], "x" : [2 ]})
787
+ )
0 commit comments