From 31acc6f94f85770869918936ea245390cbb9148d Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 20 Jun 2025 17:07:32 -0400 Subject: [PATCH 1/5] add function_to_mermaid --- pytensor/mermaid.py | 63 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 pytensor/mermaid.py diff --git a/pytensor/mermaid.py b/pytensor/mermaid.py new file mode 100644 index 0000000000..227ecca3b0 --- /dev/null +++ b/pytensor/mermaid.py @@ -0,0 +1,63 @@ +from pytensor.d3viz.formatting import PyDotFormatter + + +def function_to_mermaid(fn): + formatter = PyDotFormatter() + dot = formatter(fn) + + nodes = dot.get_nodes() + edges = dot.get_edges() + + mermaid_lines = ["graph TD"] + mermaid_lines.append("%% Nodes:") + for node in nodes: + name = node.get_name() + label = node.get_label() + shape = node.get_shape() + + if label.endswith("."): + label = f"{label}0" + + if shape == "box": + shape = "rect" + else: + shape = "rounded" + + mermaid_lines.extend( + [ + f'{name}["{label}"]', + f"{name}@{{ shape: {shape} }}", + ] + ) + + fillcolor = node.get_fillcolor() + if fillcolor is not None and not fillcolor.startswith("#"): + fillcolor = _color_to_hex(fillcolor) + mermaid_lines.append(f"style {name} fill:{fillcolor}") + + mermaid_lines.append("%% Edges:") + for edge in edges: + source = edge.get_source() + target = edge.get_destination() + + mermaid_lines.append(f"{source} --> {target}") + + return "\n".join(mermaid_lines) + + +def _color_to_hex(color_name): + """Based on the colors in d3viz module.""" + return { + "limegreen": "#32CD32", + "SpringGreen": "#00FF7F", + "YellowGreen": "#9ACD32", + "dodgerblue": "#1E90FF", + "lightgrey": "#D3D3D3", + "yellow": "#FFFF00", + "cyan": "#00FFFF", + "magenta": "#FF00FF", + "red": "#FF0000", + "blue": "#0000FF", + "green": "#008000", + "grey": "#808080", + }.get(color_name) From 09ecb2d766adf44c6de7e31b83143492d5807856 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 20 Jun 2025 17:13:41 -0400 Subject: [PATCH 2/5] add space before edges section --- pytensor/mermaid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/mermaid.py b/pytensor/mermaid.py index 227ecca3b0..4509a983e5 100644 --- a/pytensor/mermaid.py +++ b/pytensor/mermaid.py @@ -35,7 +35,7 @@ def function_to_mermaid(fn): fillcolor = _color_to_hex(fillcolor) mermaid_lines.append(f"style {name} fill:{fillcolor}") - mermaid_lines.append("%% Edges:") + mermaid_lines.append("\n%% Edges:") for edge in edges: source = edge.get_source() target = edge.get_destination() From df5a4c9a76d9adc8215390154c0729cd97d8e7e7 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 20 Jun 2025 17:14:04 -0400 Subject: [PATCH 3/5] add test for diagram --- tests/test_mermaid.py | 63 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 tests/test_mermaid.py diff --git a/tests/test_mermaid.py b/tests/test_mermaid.py new file mode 100644 index 0000000000..91180cf757 --- /dev/null +++ b/tests/test_mermaid.py @@ -0,0 +1,63 @@ +from textwrap import dedent + +import pytest + +from pytensor import function +from pytensor import tensor as pt +from pytensor.mermaid import function_to_mermaid + + +@pytest.fixture +def sample_function(): + x = pt.dmatrix("x") + y = pt.dvector("y") + z = pt.dot(x, y) + z.name = "z" + return function([x, y], z) + + +def test_function_to_mermaid(sample_function): + diagram = function_to_mermaid(sample_function) + + assert ( + diagram + == dedent(""" + graph TD + %% Nodes: + n1["Shape_i"] + n1@{ shape: rounded } + style n1 fill:#00FFFF + n2["x"] + n2@{ shape: rect } + style n2 fill:#32CD32 + n2["x"] + n2@{ shape: rect } + style n2 fill:#32CD32 + n4["AllocEmpty"] + n4@{ shape: rounded } + n6["CGemv"] + n6@{ shape: rounded } + n7["1.0"] + n7@{ shape: rect } + style n7 fill:#00FF7F + n8["y"] + n8@{ shape: rect } + style n8 fill:#32CD32 + n9["0.0"] + n9@{ shape: rect } + style n9 fill:#00FF7F + n10["z"] + n10@{ shape: rect } + style n10 fill:#1E90FF + + %% Edges: + n2 --> n1 + n1 --> n4 + n4 --> n6 + n7 --> n6 + n2 --> n6 + n8 --> n6 + n9 --> n6 + n6 --> n10 + """).strip() + ) From b0f7ce23163a9ac993dde7ab357e273272365cac Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 20 Jun 2025 18:14:48 -0400 Subject: [PATCH 4/5] install pydot for tests --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 831ab5d1bc..c4a5df837e 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -196,7 +196,7 @@ jobs: if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro && pip install tensorflow-probability; fi if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi - pip install pytest-sphinx + pip install pytest-sphinx pydot pip install -e ./ micromamba list && pip freeze From 8004aef529a0364b277a6a703e0af35dcd49413d Mon Sep 17 00:00:00 2001 From: Will Dean Date: Sat, 21 Jun 2025 09:33:12 -0400 Subject: [PATCH 5/5] take scalar op name if Elemwise --- pytensor/d3viz/formatting.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pytensor/d3viz/formatting.py b/pytensor/d3viz/formatting.py index df39335c19..9198e77429 100644 --- a/pytensor/d3viz/formatting.py +++ b/pytensor/d3viz/formatting.py @@ -13,6 +13,7 @@ from pytensor.graph.basic import Apply, Constant, Variable, graph_inputs from pytensor.graph.fg import FunctionGraph from pytensor.printing import _try_pydot_import +from pytensor.tensor.elemwise import Elemwise class PyDotFormatter: @@ -291,6 +292,9 @@ def var_tag(var): def apply_label(node): """Return label of apply node.""" + if isinstance(node.op, Elemwise): + return node.op.scalar_op.__class__.__name__ + return node.op.__class__.__name__