Skip to content

Commit 7a2bd74

Browse files
committed
Model graph: don't define dependencies based on variable names
1 parent f639a5a commit 7a2bd74

File tree

2 files changed

+36
-33
lines changed

2 files changed

+36
-33
lines changed

pymc/model_graph.py

Lines changed: 18 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,9 @@
2121
from typing import Any, cast
2222

2323
from pytensor import function
24-
from pytensor.graph import Apply
2524
from pytensor.graph.basic import ancestors, walk
2625
from pytensor.scalar.basic import Cast
2726
from pytensor.tensor.elemwise import Elemwise
28-
from pytensor.tensor.random.op import RandomVariable
2927
from pytensor.tensor.shape import Shape
3028
from pytensor.tensor.variable import TensorVariable
3129

@@ -240,42 +238,32 @@ class ModelGraph:
240238
def __init__(self, model):
241239
self.model = model
242240
self._all_var_names = get_default_varnames(self.model.named_vars, include_transformed=False)
241+
self._all_vars = {model[var_name] for var_name in self._all_var_names}
243242
self.var_list = self.model.named_vars.values()
244243

245244
def get_parent_names(self, var: TensorVariable) -> set[VarName]:
246-
if var.owner is None or var.owner.inputs is None:
245+
if var.owner is None:
247246
return set()
248247

249-
def _filter_non_parameter_inputs(var):
250-
node = var.owner
251-
if isinstance(node.op, Shape):
252-
# Don't show shape-related dependencies
253-
return []
254-
if isinstance(node.op, RandomVariable):
255-
# Filter out rng and size parameters or RandomVariable nodes
256-
return node.op.dist_params(node)
257-
else:
258-
# Otherwise return all inputs
259-
return node.inputs
260-
261-
blockers = set(self.model.named_vars)
248+
named_vars = self._all_vars
262249

263250
def _expand(x):
264-
nonlocal blockers
265-
if x.name in blockers:
251+
if x in named_vars:
252+
# Don't go beyond named_vars
266253
return [x]
267-
if isinstance(x.owner, Apply):
268-
return reversed(_filter_non_parameter_inputs(x))
269-
return []
270-
271-
parents = set()
272-
for x in walk(nodes=_filter_non_parameter_inputs(var), expand=_expand):
273-
# Only consider nodes that are in the named model variables.
274-
vname = getattr(x, "name", None)
275-
if isinstance(vname, str) and vname in self._all_var_names:
276-
parents.add(VarName(vname))
277-
278-
return parents
254+
if x.owner is None:
255+
return []
256+
if isinstance(x.owner.op, Shape):
257+
# Don't propagate shape-related dependencies
258+
return []
259+
# Continue walking the graph through the inputs
260+
return x.owner.inputs
261+
262+
return {
263+
VarName(ancestor.name)
264+
for ancestor in walk(nodes=var.owner.inputs, expand=_expand)
265+
if ancestor in named_vars
266+
}
279267

280268
def vars_to_plot(self, var_names: Iterable[VarName] | None = None) -> list[VarName]:
281269
if var_names is None:

tests/test_model_graph.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,7 @@ class TestVariableSelection:
470470
[
471471
(["c"], ["a", "b", "c"], {"c": {"a", "b"}, "a": set(), "b": set()}),
472472
(
473-
["L"],
473+
["L", "obs"],
474474
["pred", "obs", "L", "intermediate", "a", "b"],
475475
{
476476
"pred": {"intermediate"},
@@ -516,14 +516,29 @@ def test_model_graph_with_intermediate_named_variables():
516516
with pm.Model() as m1:
517517
a = pm.Normal("a", 0, 1, shape=3)
518518
pm.Normal("b", a.mean(axis=-1), 1)
519-
assert dict(ModelGraph(m1).make_compute_graph()) == {"a": set(), "b": {"a"}}
519+
assert ModelGraph(m1).make_compute_graph() == {"a": set(), "b": {"a"}}
520520

521521
with pm.Model() as m2:
522522
a = pm.Normal("a", 0, 1)
523523
b = a + 1
524524
b.name = "b"
525525
pm.Normal("c", b, 1)
526-
assert dict(ModelGraph(m2).make_compute_graph()) == {"a": set(), "c": {"a"}}
526+
assert ModelGraph(m2).make_compute_graph() == {"a": set(), "c": {"a"}}
527+
528+
# Regression test for https://github.com/pymc-devs/pymc/issues/7397
529+
with pm.Model() as m3:
530+
data = pt.as_tensor_variable(
531+
np.ones((5, 3)),
532+
name="C",
533+
)
534+
# C has the same name as `data` variable
535+
# This used to be wrongly picked up as a dependency
536+
C = pm.Deterministic("C", data)
537+
# D depends on a variable called `C` but this is not really one in the model
538+
D = pm.Deterministic("D", data)
539+
# This actually depends on the model variable `C`
540+
E = pm.Deterministic("E", C)
541+
assert ModelGraph(m3).make_compute_graph() == {"C": set(), "D": set(), "E": {"C"}}
527542

528543

529544
@pytest.fixture

0 commit comments

Comments
 (0)