Skip to content

Commit 3f30b55

Browse files
committed
Model graph: more robust handling of RV -> Observed dependencies
1 parent 7a2bd74 commit 3f30b55

File tree

2 files changed

+33
-32
lines changed

2 files changed

+33
-32
lines changed

pymc/model_graph.py

Lines changed: 23 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@
2222

2323
from pytensor import function
2424
from pytensor.graph.basic import ancestors, walk
25-
from pytensor.scalar.basic import Cast
26-
from pytensor.tensor.elemwise import Elemwise
2725
from pytensor.tensor.shape import Shape
2826
from pytensor.tensor.variable import TensorVariable
2927

@@ -299,35 +297,28 @@ def make_compute_graph(
299297
self, var_names: Iterable[VarName] | None = None
300298
) -> dict[VarName, set[VarName]]:
301299
"""Get map of var_name -> set(input var names) for the model."""
302-
input_map: dict[VarName, set[VarName]] = defaultdict(set)
303-
304-
for var_name in self.vars_to_plot(var_names):
305-
var = self.model[var_name]
306-
parent_name = self.get_parent_names(var)
307-
input_map[var_name] = input_map[var_name].union(parent_name)
308-
309-
if var in self.model.observed_RVs:
310-
obs_node = self.model.rvs_to_values[var]
311-
312-
# loop created so that the elif block can go through this again
313-
# and remove any intermediate ops, notably dtype casting, to observations
314-
while True:
315-
obs_name = obs_node.name
316-
if obs_name and obs_name != var_name:
317-
input_map[var_name] = input_map[var_name].difference({obs_name})
318-
input_map[obs_name] = input_map[obs_name].union({var_name})
319-
break
320-
elif (
321-
# for cases where observations are cast to a certain dtype
322-
# see issue 5795: https://github.com/pymc-devs/pymc/issues/5795
323-
obs_node.owner
324-
and isinstance(obs_node.owner.op, Elemwise)
325-
and isinstance(obs_node.owner.op.scalar_op, Cast)
326-
):
327-
# we can retrieve the observation node by going up the graph
328-
obs_node = obs_node.owner.inputs[0]
329-
else:
330-
break
300+
model = self.model
301+
named_vars = self._all_vars
302+
input_map: dict[str, set[str]] = defaultdict(set)
303+
304+
var_names_to_plot = self.vars_to_plot(var_names)
305+
for var_name in var_names_to_plot:
306+
parent_names = self.get_parent_names(model[var_name])
307+
input_map[var_name].update(parent_names)
308+
309+
for var_name in var_names_to_plot:
310+
if (var := model[var_name]) in model.observed_RVs:
311+
# Make observed `Data` variables flow from the observed RV, and not the other way around
312+
# (In the generative graph they usually inform shape of the observed RV)
313+
# We have to iterate over the ancestors of the observed values because there can be
314+
# deterministic operations in between the `Data` variable and the observed value.
315+
obs_var = model.rvs_to_values[var]
316+
for ancestor in ancestors([obs_var]):
317+
if ancestor not in named_vars:
318+
continue
319+
obs_name = ancestor.name
320+
input_map[var_name].discard(obs_name)
321+
input_map[obs_name].add(var_name)
331322

332323
return input_map
333324

@@ -348,7 +339,7 @@ def get_plates(
348339
plates = defaultdict(set)
349340

350341
# TODO: Evaluate all RV shapes at once
351-
# This should help find discrepencies, and
342+
# This should help find discrepancies, and
352343
# avoids unnecessary function compiles for determining labels.
353344
dim_lengths: dict[str, int] = {
354345
dim_name: fast_eval(value).item() for dim_name, value in self.model.dim_lengths.items()

tests/test_model_graph.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,16 @@ def test_model_graph_with_intermediate_named_variables():
541541
assert ModelGraph(m3).make_compute_graph() == {"C": set(), "D": set(), "E": {"C"}}
542542

543543

544+
def test_model_graph_complex_observed_dependency():
545+
with pm.Model() as model:
546+
x = pm.Data("x", [0])
547+
y = pm.Data("y", [0])
548+
observed = pt.exp(x) + pt.log(y)
549+
pm.Normal("obs", mu=0, observed=observed)
550+
551+
assert ModelGraph(model).make_compute_graph() == {"obs": set(), "x": {"obs"}, "y": {"obs"}}
552+
553+
544554
@pytest.fixture
545555
def simple_model() -> pm.Model:
546556
with pm.Model() as model:

0 commit comments

Comments
 (0)