Skip to content

Add function to mermaid diagram #1490

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

williambdean
Copy link
Contributor

@williambdean williambdean commented Jun 20, 2025

Description

I explored using variables directly with is definitely doable. Some helpful functions where:

def _get_edges(variable):
    if variable.owner is None:
        return

    yield from ((variable, input_var) for input_var in variable.owner.inputs)

    for variable in variable.owner.inputs:
        yield from get_edges(variable)


def get_edges(variable):
    return list(_get_edges(variable))


def get_nodes(variable=None):
    edges = get_edges(variable)

    nodes = set()
    for child, parent in edges:
        nodes.add(child)
        nodes.add(parent)

    return nodes

Using the pydot formatter which already exists in pytensor, can provide some version of this already.

Some examples of this:

import pytensor
import pytensor.tensor as pt
from pytensor.mermaid import function_to_mermaid

alpha = pt.scalar("alpha")
beta = pt.vector("beta")
noise = pt.scalar("noise")

X = pt.matrix("X")

y = pt.dot(X, beta) + alpha + noise
y.name = "y"

fn = pytensor.function([X, alpha, beta, noise], y)
mermaid_code = function_to_mermaid(fn)

print(mermaid_code)
graph TD
%% Nodes:
n1["DimShuffle"]
n1@{ shape: rounded }
n2["noise"]
n2@{ shape: rect }
style n2 fill:#32CD32
n4["DimShuffle"]
n4@{ shape: rounded }
n5["alpha"]
n5@{ shape: rect }
style n5 fill:#32CD32
n7["Shape_i"]
n7@{ shape: rounded }
style n7 fill:#00FFFF
n8["X"]
n8@{ shape: rect }
style n8 fill:#32CD32
n8["X"]
n8@{ shape: rect }
style n8 fill:#32CD32
n10["AllocEmpty"]
n10@{ shape: rounded }
n12["CGemv"]
n12@{ shape: rounded }
n13["1.0"]
n13@{ shape: rect }
style n13 fill:#00FF7F
n14["beta"]
n14@{ shape: rect }
style n14 fill:#32CD32
n15["0.0"]
n15@{ shape: rect }
style n15 fill:#00FF7F
n17["Elemwise"]
n17@{ shape: rounded }
n18["y"]
n18@{ shape: rect }
style n18 fill:#1E90FF

%% Edges:
n2 --> n1
n5 --> n4
n8 --> n7
n7 --> n10
n10 --> n12
n13 --> n12
n8 --> n12
n14 --> n12
n15 --> n12
n12 --> n17
n4 --> n17
n1 --> n17
n17 --> n18
Loading
import pytensor.tensor as pt
import pytensor
from pytensor.mermaid import function_to_mermaid

x, y, z = pt.scalars('xyz')
e = x * y
op = pytensor.compile.builders.OpFromGraph([x, y], [e])
e2 = op(x, y) + z
op2 = pytensor.compile.builders.OpFromGraph([x, y, z], [e2])
e3 = op2(x, y, z) + z
f = pytensor.function([x, y, z], [e3])

print(function_to_mermaid(f))
graph TD
%% Nodes:
n1["OpFromGraph"]
n1@{ shape: rounded }
n2["x"]
n2@{ shape: rect }
style n2 fill:#32CD32
n3["y"]
n3@{ shape: rect }
style n3 fill:#32CD32
n4["z"]
n4@{ shape: rect }
style n4 fill:#32CD32
n4["z"]
n4@{ shape: rect }
style n4 fill:#32CD32
n6["Elemwise"]
n6@{ shape: rounded }
n7["dscalar"]
n7@{ shape: rect }
style n7 fill:#1E90FF

%% Edges:
n2 --> n1
n3 --> n1
n4 --> n1
n1 --> n6
n4 --> n6
n6 --> n7
Loading

Related Issue

  • Closes #
  • Related to #

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

@williambdean
Copy link
Contributor Author

Seeing that tests are failing because of no pydot. Might be able to mock that behavior...

@ricardoV94
Copy link
Member

Rather install and test properly. Elemwise isn't a great name though, we have to check the function that extracts the name It should use str(op)

return "\n".join(mermaid_lines)


def _color_to_hex(color_name):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this a function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mermaid needs hexcolors. the pydotformatter has strings names for colors

@williambdean
Copy link
Contributor Author

Rather install and test properly.

Do you like the route of using pydot? Don't think it would be too hard to implement a custom formatter. We just need a light graph representation.

Elemwise isn't a great name though, we have to check the function that extracts the name It should use str(op)

Sure. Shall I change for the formatter then or go a different route?

@ricardoV94
Copy link
Member

ricardoV94 commented Jun 20, 2025

I think pydot is overkill but I wouldn't mock if you're using it. Your original function was fine except it will iterate some edges repeatedly.

I'm sure you can repurpose FunctionGraph or something from graph.basic

@williambdean
Copy link
Contributor Author

Your original function was fine except it will iterate some edges repeatedly.

Can you explain the case that is missing or example that will fail

I'm sure you can repurpose FunctionGraph or something from graph.basic

I will explore. Is dprint logic general enough to support?

@ricardoV94
Copy link
Member

ricardoV94 commented Jun 21, 2025

Can you explain the case that is missing or example that will fail

Don't know if it will fail, but a graph like this, you may end up navigating x -> exp_x twice:

x = pt.scalar("x")
y = pt.exp(x)
out = y + y * 2

I will explore. Is dprint logic general enough to support?

Not sure what you mean. I meant that in those modules you have utilities to iterate over a graph. It's unlikely you have to invent something new for that goal.

@williambdean
Copy link
Contributor Author

williambdean commented Jun 21, 2025

The d3viz based on

x = pt.scalar("x")
y = pt.exp(x)
out = y + y * 2

looks like (I constructed this from hand):

graph TD
A["x"]
style A fill:#32CD32
B["Elemwise"]
B@{ shape: rounded }
style B fill:#FF00FF
C["dscalar"]
style C fill:#1E90FF

A --> B
B --> C
Loading

Is that as expected? If not, the PydotFormatter is not correct

@williambdean
Copy link
Contributor Author

Recent change has it looking like

        graph TD
        %% Nodes:
        n1["Composite"]
        n1@{ shape: rounded }
        n2["x"]
        n2@{ shape: rect }
        style n2 fill:#32CD32
        n3["out"]
        n3@{ shape: rect }
        style n3 fill:#1E90FF

        %% Edges:
        n2 --> n1
        n1 --> n3

Loading

@ricardoV94
Copy link
Member

Seems like you are compiling/rewriting the graph, otherwise the Composite wouldn't be introduced

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants