Skip to content

Fix bug with pickling PointFunc #7858

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 20, 2025

Conversation

jessegrabowski
Copy link
Member

@jessegrabowski jessegrabowski commented Jul 20, 2025

Description

Fixes a bug when unpickling certain step samplers (BinaryMetropolis, BinaryGibbsMetropolis, and CategoricalGibbsMetropolis) when mp_ctx is spawn or forkserver

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

@@ -425,3 +425,20 @@ def test_sampling_state(step_method, model_fn):
assert equal_sampling_states(final_state1, final_state2)
assert equal_dataclass_values(sample1, sample2)
assert equal_dataclass_values(stat1, stat2)


def test_binary_gibbs_with_spawn():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this specific to binary gibbs? That's the only one using PointFunc?

We should test spawning a PointFunc instead or this test won't be a regression test if binary gibbs changes to not using PointFunc

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no idea why it's these functions. model.compile_xxx returns a PointFunc, so everyone uses them, but only these 3 samplers trigger the error (I tried most of them)

Copy link
Member

@ricardoV94 ricardoV94 Jul 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most samplers build their own function on raveled inputs and don't use the wrapper PointFunc for performance

assert "shape=(?,)" in captured.out


def test_pickle_point_func():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to be sure this failed before?

Copy link
Member Author

@jessegrabowski jessegrabowski Jul 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep

FAILED                         [100%]
tests/test_pytensorf.py:769 (test_pickle_point_func)
def test_pickle_point_func():
        """
        Regression test for https://github.com/pymc-devs/pymc/issues/7857
        """
        import cloudpickle
    
        x, y = pt.vectors("x", "y")
        outs = x * 2 + y**2
        f = compile([x, y], outs)
    
        point_f = PointFunc(f)
        point_f_pickled = cloudpickle.dumps(point_f)
>       point_f_unpickled = cloudpickle.loads(point_f_pickled)
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

test_pytensorf.py:782: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
../pymc/pytensorf.py:573: in __getattr__
    return getattr(self.f, item)
                   ^^^^^^
../pymc/pytensorf.py:573: in __getattr__
    return getattr(self.f, item)
                   ^^^^^^
../pymc/pytensorf.py:573: in __getattr__
    return getattr(self.f, item)
                   ^^^^^^
E   RecursionError: maximum recursion depth exceeded
!!! Recursion detected (same locals & position)

Copy link

codecov bot commented Jul 20, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 92.98%. Comparing base (ae43026) to head (55db526).
Report is 7 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #7858      +/-   ##
==========================================
+ Coverage   92.92%   92.98%   +0.06%     
==========================================
  Files         107      108       +1     
  Lines       18299    18328      +29     
==========================================
+ Hits        17004    17043      +39     
+ Misses       1295     1285      -10     
Files with missing lines Coverage Δ
pymc/pytensorf.py 90.02% <100.00%> (ø)

... and 20 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ricardoV94 ricardoV94 merged commit 9629256 into pymc-devs:main Jul 20, 2025
25 checks passed
@ricardoV94 ricardoV94 changed the title Check attributes have been set before calling getattr in PointFunc Fix bug with pickling PointFunc Jul 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Cannot unpickle certain step samplers when mp_ctx=spawn or forkserver
2 participants