22
22
23
23
from pytensor import function
24
24
from pytensor .graph .basic import ancestors , walk
25
- from pytensor .scalar .basic import Cast
26
- from pytensor .tensor .elemwise import Elemwise
27
25
from pytensor .tensor .shape import Shape
28
26
from pytensor .tensor .variable import TensorVariable
29
27
@@ -299,35 +297,28 @@ def make_compute_graph(
299
297
self , var_names : Iterable [VarName ] | None = None
300
298
) -> dict [VarName , set [VarName ]]:
301
299
"""Get map of var_name -> set(input var names) for the model."""
302
- input_map : dict [VarName , set [VarName ]] = defaultdict (set )
303
-
304
- for var_name in self .vars_to_plot (var_names ):
305
- var = self .model [var_name ]
306
- parent_name = self .get_parent_names (var )
307
- input_map [var_name ] = input_map [var_name ].union (parent_name )
308
-
309
- if var in self .model .observed_RVs :
310
- obs_node = self .model .rvs_to_values [var ]
311
-
312
- # loop created so that the elif block can go through this again
313
- # and remove any intermediate ops, notably dtype casting, to observations
314
- while True :
315
- obs_name = obs_node .name
316
- if obs_name and obs_name != var_name :
317
- input_map [var_name ] = input_map [var_name ].difference ({obs_name })
318
- input_map [obs_name ] = input_map [obs_name ].union ({var_name })
319
- break
320
- elif (
321
- # for cases where observations are cast to a certain dtype
322
- # see issue 5795: https://github.com/pymc-devs/pymc/issues/5795
323
- obs_node .owner
324
- and isinstance (obs_node .owner .op , Elemwise )
325
- and isinstance (obs_node .owner .op .scalar_op , Cast )
326
- ):
327
- # we can retrieve the observation node by going up the graph
328
- obs_node = obs_node .owner .inputs [0 ]
329
- else :
330
- break
300
+ model = self .model
301
+ named_vars = self ._all_vars
302
+ input_map : dict [str , set [str ]] = defaultdict (set )
303
+
304
+ var_names_to_plot = self .vars_to_plot (var_names )
305
+ for var_name in var_names_to_plot :
306
+ parent_names = self .get_parent_names (model [var_name ])
307
+ input_map [var_name ].update (parent_names )
308
+
309
+ for var_name in var_names_to_plot :
310
+ if (var := model [var_name ]) in model .observed_RVs :
311
+ # Make observed `Data` variables flow from the observed RV, and not the other way around
312
+ # (In the generative graph they usually inform shape of the observed RV)
313
+ # We have to iterate over the ancestors of the observed values because there can be
314
+ # deterministic operations in between the `Data` variable and the observed value.
315
+ obs_var = model .rvs_to_values [var ]
316
+ for ancestor in ancestors ([obs_var ]):
317
+ if ancestor not in named_vars :
318
+ continue
319
+ obs_name = ancestor .name
320
+ input_map [var_name ].discard (obs_name )
321
+ input_map [obs_name ].add (var_name )
331
322
332
323
return input_map
333
324
@@ -348,7 +339,7 @@ def get_plates(
348
339
plates = defaultdict (set )
349
340
350
341
# TODO: Evaluate all RV shapes at once
351
- # This should help find discrepencies , and
342
+ # This should help find discrepancies , and
352
343
# avoids unnecessary function compiles for determining labels.
353
344
dim_lengths : dict [str , int ] = {
354
345
dim_name : fast_eval (value ).item () for dim_name , value in self .model .dim_lengths .items ()
0 commit comments