Skip to content

Commit 6554683

Browse files
Dhruvanshu-JoshiricardoV94
authored andcommitted
Add logprob for discrete minimum of IID variables
Co-authored-by: Dhruvanshu-Joshi <[email protected]>
1 parent dfc4788 commit 6554683

File tree

3 files changed

+116
-30
lines changed

3 files changed

+116
-30
lines changed

pymc/logprob/order.py

Lines changed: 50 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,7 @@
4141
from pytensor.graph.basic import Node
4242
from pytensor.graph.fg import FunctionGraph
4343
from pytensor.graph.rewriting.basic import node_rewriter
44-
from pytensor.scalar.basic import Mul
45-
from pytensor.tensor.basic import get_underlying_scalar_constant_value
4644
from pytensor.tensor.elemwise import Elemwise
47-
from pytensor.tensor.exceptions import NotScalarConstantError
4845
from pytensor.tensor.math import Max
4946
from pytensor.tensor.random.op import RandomVariable
5047
from pytensor.tensor.variable import TensorVariable
@@ -56,6 +53,7 @@
5653
_logprob_helper,
5754
)
5855
from pymc.logprob.rewriting import measurable_ir_rewrites_db
56+
from pymc.logprob.utils import find_negated_var
5957
from pymc.math import logdiffexp
6058
from pymc.pytensorf import constant_fold
6159

@@ -168,6 +166,13 @@ class MeasurableMaxNeg(Max):
168166
MeasurableVariable.register(MeasurableMaxNeg)
169167

170168

169+
class MeasurableDiscreteMaxNeg(Max):
170+
"""A placeholder used to specify a log-likelihood for sub-graphs of negative maxima of discrete variables"""
171+
172+
173+
MeasurableVariable.register(MeasurableDiscreteMaxNeg)
174+
175+
171176
@node_rewriter(tracks=[Max])
172177
def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> Optional[list[TensorVariable]]:
173178
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
@@ -180,37 +185,20 @@ def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> Optional[list[
180185

181186
base_var = node.inputs[0]
182187

183-
if base_var.owner is None:
184-
return None
185-
186-
if not rv_map_feature.request_measurable(node.inputs):
187-
return None
188-
189188
# Min is the Max of the negation of the same distribution. Hence, op must be Elemwise
190-
if not isinstance(base_var.owner.op, Elemwise):
189+
if not (base_var.owner is not None and isinstance(base_var.owner.op, Elemwise)):
191190
return None
192191

192+
base_rv = find_negated_var(base_var)
193+
193194
# negation is rv * (-1). Hence the scalar_op must be Mul
194-
try:
195-
if not (
196-
isinstance(base_var.owner.op.scalar_op, Mul)
197-
and len(base_var.owner.inputs) == 2
198-
and get_underlying_scalar_constant_value(base_var.owner.inputs[1]) == -1
199-
):
200-
return None
201-
except NotScalarConstantError:
195+
if base_rv is None:
202196
return None
203197

204-
base_rv = base_var.owner.inputs[0]
205-
206198
# Non-univariate distributions and non-RVs must be rejected
207199
if not (isinstance(base_rv.owner.op, RandomVariable) and base_rv.owner.op.ndim_supp == 0):
208200
return None
209201

210-
# TODO: We are currently only supporting continuous rvs
211-
if isinstance(base_rv.owner.op, RandomVariable) and base_rv.owner.op.dtype.startswith("int"):
212-
return None
213-
214202
# univariate i.i.d. test which also rules out other distributions
215203
for params in base_rv.owner.inputs[3:]:
216204
if params.type.ndim != 0:
@@ -222,11 +210,16 @@ def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> Optional[list[
222210
if axis != base_var_dims:
223211
return None
224212

225-
measurable_min = MeasurableMaxNeg(list(axis))
226-
min_rv_node = measurable_min.make_node(base_var)
227-
min_rv = min_rv_node.outputs
213+
if not rv_map_feature.request_measurable([base_rv]):
214+
return None
228215

229-
return min_rv
216+
# distinguish measurable discrete and continuous (because logprob is different)
217+
if base_rv.owner.op.dtype.startswith("int"):
218+
measurable_min = MeasurableDiscreteMaxNeg(list(axis))
219+
else:
220+
measurable_min = MeasurableMaxNeg(list(axis))
221+
222+
return measurable_min.make_node(base_rv).outputs
230223

231224

232225
measurable_ir_rewrites_db.register(
@@ -238,14 +231,13 @@ def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> Optional[list[
238231

239232

240233
@_logprob.register(MeasurableMaxNeg)
241-
def max_neg_logprob(op, values, base_var, **kwargs):
234+
def max_neg_logprob(op, values, base_rv, **kwargs):
242235
r"""Compute the log-likelihood graph for the `Max` operation.
243236
The formula that we use here is :
244237
\ln(f_{(n)}(x)) = \ln(n) + (n-1) \ln(1 - F(x)) + \ln(f(x))
245238
where f(x) represents the p.d.f and F(x) represents the c.d.f of the distribution respectively.
246239
"""
247240
(value,) = values
248-
base_rv = base_var.owner.inputs[0]
249241

250242
logprob = _logprob_helper(base_rv, -value)
251243
logcdf = _logcdf_helper(base_rv, -value)
@@ -254,3 +246,31 @@ def max_neg_logprob(op, values, base_var, **kwargs):
254246
logprob = (n - 1) * pt.math.log(1 - pt.math.exp(logcdf)) + logprob + pt.math.log(n)
255247

256248
return logprob
249+
250+
251+
@_logprob.register(MeasurableDiscreteMaxNeg)
252+
def discrete_max_neg_logprob(op, values, base_rv, **kwargs):
253+
r"""Compute the log-likelihood graph for the `Max` operation.
254+
255+
The formula that we use here is :
256+
.. math::
257+
\ln(P_{(n)}(x)) = \ln((1 - F(x - 1))^n - (1 - F(x))^n)
258+
where $P_{(n)}(x)$ represents the p.m.f of the maximum statistic and $F(x)$ represents the c.d.f of the i.i.d. variables.
259+
"""
260+
261+
(value,) = values
262+
263+
# The cdf of a negative variable is the survival at the negated value
264+
logcdf = pt.log1mexp(_logcdf_helper(base_rv, -value))
265+
logcdf_prev = pt.log1mexp(_logcdf_helper(base_rv, -(value + 1)))
266+
267+
[n] = constant_fold([base_rv.size])
268+
269+
# Now we can use the same expression as the discrete max
270+
logprob = pt.where(
271+
pt.and_(pt.eq(logcdf, -pt.inf), pt.eq(logcdf_prev, -pt.inf)),
272+
-pt.inf,
273+
logdiffexp(n * logcdf_prev, n * logcdf),
274+
)
275+
276+
return logprob

pymc/logprob/utils.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@
4949
from pytensor.graph.op import HasInnerGraph
5050
from pytensor.link.c.type import CType
5151
from pytensor.raise_op import CheckAndRaise
52+
from pytensor.scalar.basic import Mul
53+
from pytensor.tensor.basic import get_underlying_scalar_constant_value
54+
from pytensor.tensor.elemwise import Elemwise
55+
from pytensor.tensor.exceptions import NotScalarConstantError
5256
from pytensor.tensor.random.op import RandomVariable
5357
from pytensor.tensor.variable import TensorVariable
5458

@@ -296,3 +300,26 @@ def diracdelta_logprob(op, values, *inputs, **kwargs):
296300
(const_value,) = inputs
297301
values, const_value = pt.broadcast_arrays(values, const_value)
298302
return pt.switch(pt.isclose(values, const_value, rtol=op.rtol, atol=op.atol), 0.0, -np.inf)
303+
304+
305+
def find_negated_var(var):
306+
"""Return a variable that is being multiplied by -1 or None otherwise."""
307+
308+
if (
309+
not (var.owner)
310+
and isinstance(var.owner.op, Elemwise)
311+
and isinstance(var.owner.op.scalar_op, Mul)
312+
):
313+
return None
314+
if len(var.owner.inputs) != 2:
315+
return None
316+
317+
inputs = var.owner.inputs
318+
for mul_var, mul_const in (inputs, reversed(inputs)):
319+
try:
320+
if get_underlying_scalar_constant_value(mul_const) == -1:
321+
return mul_var
322+
except NotScalarConstantError:
323+
continue
324+
325+
return None

tests/logprob/test_order.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import re
3838

3939
import numpy as np
40+
import pytensor
4041
import pytensor.tensor as pt
4142
import pytest
4243
import scipy.stats as sp
@@ -254,3 +255,41 @@ def test_max_discrete(mu, size, value, axis):
254255
(x_max_logprob.eval({x_max_value: test_value})),
255256
rtol=1e-06,
256257
)
258+
259+
260+
@pytest.mark.parametrize(
261+
"mu, n, test_value, axis",
262+
[(2, 3, 1, -1), (2, 3, 1, 0), (1, 2, 2, None), (0, 4, 0, 0)],
263+
)
264+
def test_min_discrete(mu, n, test_value, axis):
265+
x = pm.Poisson.dist(name="x", mu=mu, size=(n,))
266+
x_min = pt.min(x, axis=axis)
267+
x_min_value = pt.scalar("x_min_value")
268+
x_min_logprob = logp(x_min, x_min_value)
269+
270+
sf_before = 1 - sp.poisson(mu).cdf(test_value - 1)
271+
sf = 1 - sp.poisson(mu).cdf(test_value)
272+
273+
expected_logp = np.log(sf_before**n - sf**n)
274+
275+
np.testing.assert_allclose(
276+
x_min_logprob.eval({x_min_value: test_value}),
277+
expected_logp,
278+
rtol=1e-06,
279+
)
280+
281+
282+
def test_min_max_bernoulli():
283+
p = 0.7
284+
q = 1 - p
285+
n = 3
286+
x = pm.Bernoulli.dist(name="x", p=p, shape=(n,))
287+
value = pt.scalar("value", dtype=int)
288+
289+
max_logp_fn = pytensor.function([value], pm.logp(pt.max(x), value))
290+
np.testing.assert_allclose(max_logp_fn(0), np.log(q**n))
291+
np.testing.assert_allclose(max_logp_fn(1), np.log(1 - q**n))
292+
293+
min_logp_fn = pytensor.function([value], pm.logp(pt.min(x), value))
294+
np.testing.assert_allclose(min_logp_fn(1), np.log(p**n))
295+
np.testing.assert_allclose(min_logp_fn(0), np.log(1 - p**n))

0 commit comments

Comments
 (0)