41
41
from pytensor .graph .basic import Node
42
42
from pytensor .graph .fg import FunctionGraph
43
43
from pytensor .graph .rewriting .basic import node_rewriter
44
- from pytensor .scalar .basic import Mul
45
- from pytensor .tensor .basic import get_underlying_scalar_constant_value
46
44
from pytensor .tensor .elemwise import Elemwise
47
- from pytensor .tensor .exceptions import NotScalarConstantError
48
45
from pytensor .tensor .math import Max
49
46
from pytensor .tensor .random .op import RandomVariable
50
47
from pytensor .tensor .variable import TensorVariable
56
53
_logprob_helper ,
57
54
)
58
55
from pymc .logprob .rewriting import measurable_ir_rewrites_db
56
+ from pymc .logprob .utils import find_negated_var
59
57
from pymc .math import logdiffexp
60
58
from pymc .pytensorf import constant_fold
61
59
@@ -168,6 +166,13 @@ class MeasurableMaxNeg(Max):
168
166
MeasurableVariable .register (MeasurableMaxNeg )
169
167
170
168
169
+ class MeasurableDiscreteMaxNeg (Max ):
170
+ """A placeholder used to specify a log-likelihood for sub-graphs of negative maxima of discrete variables"""
171
+
172
+
173
+ MeasurableVariable .register (MeasurableDiscreteMaxNeg )
174
+
175
+
171
176
@node_rewriter (tracks = [Max ])
172
177
def find_measurable_max_neg (fgraph : FunctionGraph , node : Node ) -> Optional [list [TensorVariable ]]:
173
178
rv_map_feature = getattr (fgraph , "preserve_rv_mappings" , None )
@@ -180,37 +185,20 @@ def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> Optional[list[
180
185
181
186
base_var = node .inputs [0 ]
182
187
183
- if base_var .owner is None :
184
- return None
185
-
186
- if not rv_map_feature .request_measurable (node .inputs ):
187
- return None
188
-
189
188
# Min is the Max of the negation of the same distribution. Hence, op must be Elemwise
190
- if not isinstance (base_var .owner .op , Elemwise ):
189
+ if not ( base_var . owner is not None and isinstance (base_var .owner .op , Elemwise ) ):
191
190
return None
192
191
192
+ base_rv = find_negated_var (base_var )
193
+
193
194
# negation is rv * (-1). Hence the scalar_op must be Mul
194
- try :
195
- if not (
196
- isinstance (base_var .owner .op .scalar_op , Mul )
197
- and len (base_var .owner .inputs ) == 2
198
- and get_underlying_scalar_constant_value (base_var .owner .inputs [1 ]) == - 1
199
- ):
200
- return None
201
- except NotScalarConstantError :
195
+ if base_rv is None :
202
196
return None
203
197
204
- base_rv = base_var .owner .inputs [0 ]
205
-
206
198
# Non-univariate distributions and non-RVs must be rejected
207
199
if not (isinstance (base_rv .owner .op , RandomVariable ) and base_rv .owner .op .ndim_supp == 0 ):
208
200
return None
209
201
210
- # TODO: We are currently only supporting continuous rvs
211
- if isinstance (base_rv .owner .op , RandomVariable ) and base_rv .owner .op .dtype .startswith ("int" ):
212
- return None
213
-
214
202
# univariate i.i.d. test which also rules out other distributions
215
203
for params in base_rv .owner .inputs [3 :]:
216
204
if params .type .ndim != 0 :
@@ -222,11 +210,16 @@ def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> Optional[list[
222
210
if axis != base_var_dims :
223
211
return None
224
212
225
- measurable_min = MeasurableMaxNeg (list (axis ))
226
- min_rv_node = measurable_min .make_node (base_var )
227
- min_rv = min_rv_node .outputs
213
+ if not rv_map_feature .request_measurable ([base_rv ]):
214
+ return None
228
215
229
- return min_rv
216
+ # distinguish measurable discrete and continuous (because logprob is different)
217
+ if base_rv .owner .op .dtype .startswith ("int" ):
218
+ measurable_min = MeasurableDiscreteMaxNeg (list (axis ))
219
+ else :
220
+ measurable_min = MeasurableMaxNeg (list (axis ))
221
+
222
+ return measurable_min .make_node (base_rv ).outputs
230
223
231
224
232
225
measurable_ir_rewrites_db .register (
@@ -238,14 +231,13 @@ def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> Optional[list[
238
231
239
232
240
233
@_logprob .register (MeasurableMaxNeg )
241
- def max_neg_logprob (op , values , base_var , ** kwargs ):
234
+ def max_neg_logprob (op , values , base_rv , ** kwargs ):
242
235
r"""Compute the log-likelihood graph for the `Max` operation.
243
236
The formula that we use here is :
244
237
\ln(f_{(n)}(x)) = \ln(n) + (n-1) \ln(1 - F(x)) + \ln(f(x))
245
238
where f(x) represents the p.d.f and F(x) represents the c.d.f of the distribution respectively.
246
239
"""
247
240
(value ,) = values
248
- base_rv = base_var .owner .inputs [0 ]
249
241
250
242
logprob = _logprob_helper (base_rv , - value )
251
243
logcdf = _logcdf_helper (base_rv , - value )
@@ -254,3 +246,31 @@ def max_neg_logprob(op, values, base_var, **kwargs):
254
246
logprob = (n - 1 ) * pt .math .log (1 - pt .math .exp (logcdf )) + logprob + pt .math .log (n )
255
247
256
248
return logprob
249
+
250
+
251
+ @_logprob .register (MeasurableDiscreteMaxNeg )
252
+ def discrete_max_neg_logprob (op , values , base_rv , ** kwargs ):
253
+ r"""Compute the log-likelihood graph for the `Max` operation.
254
+
255
+ The formula that we use here is :
256
+ .. math::
257
+ \ln(P_{(n)}(x)) = \ln((1 - F(x - 1))^n - (1 - F(x))^n)
258
+ where $P_{(n)}(x)$ represents the p.m.f of the maximum statistic and $F(x)$ represents the c.d.f of the i.i.d. variables.
259
+ """
260
+
261
+ (value ,) = values
262
+
263
+ # The cdf of a negative variable is the survival at the negated value
264
+ logcdf = pt .log1mexp (_logcdf_helper (base_rv , - value ))
265
+ logcdf_prev = pt .log1mexp (_logcdf_helper (base_rv , - (value + 1 )))
266
+
267
+ [n ] = constant_fold ([base_rv .size ])
268
+
269
+ # Now we can use the same expression as the discrete max
270
+ logprob = pt .where (
271
+ pt .and_ (pt .eq (logcdf , - pt .inf ), pt .eq (logcdf_prev , - pt .inf )),
272
+ - pt .inf ,
273
+ logdiffexp (n * logcdf_prev , n * logcdf ),
274
+ )
275
+
276
+ return logprob
0 commit comments