Skip to content

Commit dfc4788

Browse files
committed
Support discrete negation and addition
1 parent 94020c9 commit dfc4788

File tree

2 files changed

+58
-7
lines changed

2 files changed

+58
-7
lines changed

pymc/logprob/transforms.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,11 @@
117117
_logprob_helper,
118118
)
119119
from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db
120-
from pymc.logprob.utils import CheckParameterValue, check_potential_measurability
120+
from pymc.logprob.utils import (
121+
CheckParameterValue,
122+
check_potential_measurability,
123+
find_negated_var,
124+
)
121125

122126

123127
class Transform(abc.ABC):
@@ -229,6 +233,10 @@ def measurable_transform_logcdf(op: MeasurableTransform, value, *inputs, **kwarg
229233
other_inputs = list(inputs)
230234
measurable_input = other_inputs.pop(op.measurable_input_idx)
231235

236+
# Do not apply rewrite to discrete variables
237+
if measurable_input.type.dtype.startswith("int"):
238+
raise NotImplementedError("logcdf of transformed discrete variables not implemented")
239+
232240
backward_value = op.transform_elemwise.backward(value, *other_inputs)
233241

234242
# Fail if transformation is not injective
@@ -273,6 +281,10 @@ def measurable_transform_icdf(op: MeasurableTransform, value, *inputs, **kwargs)
273281
other_inputs = list(inputs)
274282
measurable_input = other_inputs.pop(op.measurable_input_idx)
275283

284+
# Do not apply rewrite to discrete variables
285+
if measurable_input.type.dtype.startswith("int"):
286+
raise NotImplementedError("icdf of transformed discrete variables not implemented")
287+
276288
if isinstance(op.scalar_op, MONOTONICALLY_INCREASING_OPS):
277289
pass
278290
elif isinstance(op.scalar_op, MONOTONICALLY_DECREASING_OPS):
@@ -429,10 +441,15 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[li
429441
return None
430442

431443
[measurable_input] = measurable_inputs
444+
[measurable_output] = node.outputs
432445

433-
# Do not apply rewrite to discrete variables
446+
# Do not apply rewrite to discrete variables except for their addition and negation
434447
if measurable_input.type.dtype.startswith("int"):
435-
return None
448+
if not (find_negated_var(measurable_output) or isinstance(node.op.scalar_op, Add)):
449+
return None
450+
# Do not allow rewrite if output is cast to a float, because we don't have meta-info on the type of the MeasurableVariable
451+
if not measurable_output.type.dtype.startswith("int"):
452+
return None
436453

437454
# Check that other inputs are not potentially measurable, in which case this rewrite
438455
# would be invalid

tests/logprob/test_transforms.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from pytensor.graph.basic import equal_computations
4545

4646
from pymc.distributions.continuous import Cauchy, ChiSquared
47+
from pymc.distributions.discrete import Bernoulli
4748
from pymc.logprob.basic import conditional_logp, icdf, logcdf, logp
4849
from pymc.logprob.transforms import (
4950
ArccoshTransform,
@@ -680,18 +681,51 @@ def test_multivariate_rv_transform(shift, scale):
680681
)
681682

682683

683-
def test_discrete_rv_unary_transform_fails():
684+
def test_not_implemented_discrete_rv_transform():
684685
y_rv = pt.exp(pt.random.poisson(1))
685686
with pytest.raises(RuntimeError, match="could not be derived"):
686687
conditional_logp({y_rv: y_rv.clone()})
687688

688-
689-
def test_discrete_rv_multinary_transform_fails():
690-
y_rv = 5 + pt.random.poisson(1)
689+
y_rv = 5 * pt.random.poisson(1)
691690
with pytest.raises(RuntimeError, match="could not be derived"):
692691
conditional_logp({y_rv: y_rv.clone()})
693692

694693

694+
def test_negated_discrete_rv_transform():
695+
p = 0.7
696+
rv = -Bernoulli.dist(p=p)
697+
vv = rv.type()
698+
logp_fn = pytensor.function([vv], logp(rv, vv))
699+
700+
# A negated Bernoulli has pmf {p if x == -1; 1-p if x == 0; 0 otherwise}
701+
assert logp_fn(-2) == -np.inf
702+
np.testing.assert_allclose(logp_fn(-1), np.log(p))
703+
np.testing.assert_allclose(logp_fn(0), np.log(1 - p))
704+
assert logp_fn(1) == -np.inf
705+
706+
# Logcdf and icdf not supported yet
707+
for func in (logcdf, icdf):
708+
with pytest.raises(NotImplementedError):
709+
func(rv, 0)
710+
711+
712+
def test_shifted_discrete_rv_transform():
713+
p = 0.7
714+
rv = Bernoulli.dist(p=p) + 5
715+
vv = rv.type()
716+
rv_logp_fn = pytensor.function([vv], logp(rv, vv))
717+
718+
assert rv_logp_fn(4) == -np.inf
719+
np.testing.assert_allclose(rv_logp_fn(5), np.log(1 - p))
720+
np.testing.assert_allclose(rv_logp_fn(6), np.log(p))
721+
assert rv_logp_fn(7) == -np.inf
722+
723+
# Logcdf and icdf not supported yet
724+
for func in (logcdf, icdf):
725+
with pytest.raises(NotImplementedError):
726+
func(rv, 0)
727+
728+
695729
@pytest.mark.xfail(reason="Check not implemented yet")
696730
def test_invalid_broadcasted_transform_rv_fails():
697731
loc = pt.vector("loc")

0 commit comments

Comments
 (0)