|
44 | 44 | from pytensor.graph.basic import equal_computations
|
45 | 45 |
|
46 | 46 | from pymc.distributions.continuous import Cauchy, ChiSquared
|
| 47 | +from pymc.distributions.discrete import Bernoulli |
47 | 48 | from pymc.logprob.basic import conditional_logp, icdf, logcdf, logp
|
48 | 49 | from pymc.logprob.transforms import (
|
49 | 50 | ArccoshTransform,
|
@@ -680,18 +681,51 @@ def test_multivariate_rv_transform(shift, scale):
|
680 | 681 | )
|
681 | 682 |
|
682 | 683 |
|
683 |
| -def test_discrete_rv_unary_transform_fails(): |
| 684 | +def test_not_implemented_discrete_rv_transform(): |
684 | 685 | y_rv = pt.exp(pt.random.poisson(1))
|
685 | 686 | with pytest.raises(RuntimeError, match="could not be derived"):
|
686 | 687 | conditional_logp({y_rv: y_rv.clone()})
|
687 | 688 |
|
688 |
| - |
689 |
| -def test_discrete_rv_multinary_transform_fails(): |
690 |
| - y_rv = 5 + pt.random.poisson(1) |
| 689 | + y_rv = 5 * pt.random.poisson(1) |
691 | 690 | with pytest.raises(RuntimeError, match="could not be derived"):
|
692 | 691 | conditional_logp({y_rv: y_rv.clone()})
|
693 | 692 |
|
694 | 693 |
|
| 694 | +def test_negated_discrete_rv_transform(): |
| 695 | + p = 0.7 |
| 696 | + rv = -Bernoulli.dist(p=p) |
| 697 | + vv = rv.type() |
| 698 | + logp_fn = pytensor.function([vv], logp(rv, vv)) |
| 699 | + |
| 700 | + # A negated Bernoulli has pmf {p if x == -1; 1-p if x == 0; 0 otherwise} |
| 701 | + assert logp_fn(-2) == -np.inf |
| 702 | + np.testing.assert_allclose(logp_fn(-1), np.log(p)) |
| 703 | + np.testing.assert_allclose(logp_fn(0), np.log(1 - p)) |
| 704 | + assert logp_fn(1) == -np.inf |
| 705 | + |
| 706 | + # Logcdf and icdf not supported yet |
| 707 | + for func in (logcdf, icdf): |
| 708 | + with pytest.raises(NotImplementedError): |
| 709 | + func(rv, 0) |
| 710 | + |
| 711 | + |
| 712 | +def test_shifted_discrete_rv_transform(): |
| 713 | + p = 0.7 |
| 714 | + rv = Bernoulli.dist(p=p) + 5 |
| 715 | + vv = rv.type() |
| 716 | + rv_logp_fn = pytensor.function([vv], logp(rv, vv)) |
| 717 | + |
| 718 | + assert rv_logp_fn(4) == -np.inf |
| 719 | + np.testing.assert_allclose(rv_logp_fn(5), np.log(1 - p)) |
| 720 | + np.testing.assert_allclose(rv_logp_fn(6), np.log(p)) |
| 721 | + assert rv_logp_fn(7) == -np.inf |
| 722 | + |
| 723 | + # Logcdf and icdf not supported yet |
| 724 | + for func in (logcdf, icdf): |
| 725 | + with pytest.raises(NotImplementedError): |
| 726 | + func(rv, 0) |
| 727 | + |
| 728 | + |
695 | 729 | @pytest.mark.xfail(reason="Check not implemented yet")
|
696 | 730 | def test_invalid_broadcasted_transform_rv_fails():
|
697 | 731 | loc = pt.vector("loc")
|
|
0 commit comments