|
21 | 21 | from typing import Any, cast
|
22 | 22 |
|
23 | 23 | from pytensor import function
|
24 |
| -from pytensor.graph import Apply |
25 | 24 | from pytensor.graph.basic import ancestors, walk
|
26 | 25 | from pytensor.scalar.basic import Cast
|
27 | 26 | from pytensor.tensor.elemwise import Elemwise
|
28 |
| -from pytensor.tensor.random.op import RandomVariable |
29 | 27 | from pytensor.tensor.shape import Shape
|
30 | 28 | from pytensor.tensor.variable import TensorVariable
|
31 | 29 |
|
@@ -240,42 +238,32 @@ class ModelGraph:
|
240 | 238 | def __init__(self, model):
|
241 | 239 | self.model = model
|
242 | 240 | self._all_var_names = get_default_varnames(self.model.named_vars, include_transformed=False)
|
| 241 | + self._all_vars = {model[var_name] for var_name in self._all_var_names} |
243 | 242 | self.var_list = self.model.named_vars.values()
|
244 | 243 |
|
245 | 244 | def get_parent_names(self, var: TensorVariable) -> set[VarName]:
|
246 |
| - if var.owner is None or var.owner.inputs is None: |
| 245 | + if var.owner is None: |
247 | 246 | return set()
|
248 | 247 |
|
249 |
| - def _filter_non_parameter_inputs(var): |
250 |
| - node = var.owner |
251 |
| - if isinstance(node.op, Shape): |
252 |
| - # Don't show shape-related dependencies |
253 |
| - return [] |
254 |
| - if isinstance(node.op, RandomVariable): |
255 |
| - # Filter out rng and size parameters or RandomVariable nodes |
256 |
| - return node.op.dist_params(node) |
257 |
| - else: |
258 |
| - # Otherwise return all inputs |
259 |
| - return node.inputs |
260 |
| - |
261 |
| - blockers = set(self.model.named_vars) |
| 248 | + named_vars = self._all_vars |
262 | 249 |
|
263 | 250 | def _expand(x):
|
264 |
| - nonlocal blockers |
265 |
| - if x.name in blockers: |
| 251 | + if x in named_vars: |
| 252 | + # Don't go beyond named_vars |
266 | 253 | return [x]
|
267 |
| - if isinstance(x.owner, Apply): |
268 |
| - return reversed(_filter_non_parameter_inputs(x)) |
269 |
| - return [] |
270 |
| - |
271 |
| - parents = set() |
272 |
| - for x in walk(nodes=_filter_non_parameter_inputs(var), expand=_expand): |
273 |
| - # Only consider nodes that are in the named model variables. |
274 |
| - vname = getattr(x, "name", None) |
275 |
| - if isinstance(vname, str) and vname in self._all_var_names: |
276 |
| - parents.add(VarName(vname)) |
277 |
| - |
278 |
| - return parents |
| 254 | + if x.owner is None: |
| 255 | + return [] |
| 256 | + if isinstance(x.owner.op, Shape): |
| 257 | + # Don't propagate shape-related dependencies |
| 258 | + return [] |
| 259 | + # Continue walking the graph through the inputs |
| 260 | + return x.owner.inputs |
| 261 | + |
| 262 | + return { |
| 263 | + VarName(ancestor.name) |
| 264 | + for ancestor in walk(nodes=var.owner.inputs, expand=_expand) |
| 265 | + if ancestor in named_vars |
| 266 | + } |
279 | 267 |
|
280 | 268 | def vars_to_plot(self, var_names: Iterable[VarName] | None = None) -> list[VarName]:
|
281 | 269 | if var_names is None:
|
|
0 commit comments