Skip to content

Commit 5f57001

Browse files
cyyeverjustinchuby
andauthored
Add a tool to dump ONNX schema into yaml files (#7480)
--------- Signed-off-by: cyy <[email protected]> Signed-off-by: Yuanyuan Chen <[email protected]> Co-authored-by: Justin Chu <[email protected]>
1 parent 3761d3d commit 5f57001

File tree

1 file changed

+115
-0
lines changed

1 file changed

+115
-0
lines changed

tools/spec_to_yaml.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# Copyright (c) ONNX Project Contributors
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
"""Output ONNX spec in YAML format.
5+
6+
Usage:
7+
8+
python spec_to_yaml.py --output onnx-spec/defs
9+
"""
10+
11+
from __future__ import annotations
12+
13+
import argparse
14+
import enum
15+
import pathlib
16+
from collections.abc import Iterable
17+
from typing import Any
18+
19+
from ruamel.yaml import YAML
20+
21+
import onnx
22+
23+
24+
def dump_onnx_object(
25+
onnx_obj: onnx.defs.OpSchema
26+
| onnx.defs.OpSchema.Attribute
27+
| onnx.defs.OpSchema.FormalParameter
28+
| onnx.defs.OpSchema.TypeConstraintParam,
29+
) -> dict[str, Any]:
30+
res = {}
31+
for attr in dir(onnx_obj):
32+
if attr.startswith("_"):
33+
continue
34+
value = getattr(onnx_obj, attr)
35+
if isinstance(value, enum.EnumType) or "nanobind" in str(type(value)):
36+
continue
37+
if attr == "default_value" and isinstance(
38+
onnx_obj, onnx.defs.OpSchema.Attribute
39+
):
40+
value = onnx.helper.get_attribute_value(value)
41+
value = dump_value(value)
42+
if not value:
43+
continue
44+
res[attr] = value
45+
return res
46+
47+
48+
def dump_enum(value: enum.Enum) -> str | None:
49+
for member in type(value):
50+
if member == value:
51+
if member.name == "Unknown":
52+
return None
53+
return member.name
54+
raise RuntimeError(f"Unhandled type {type(value)}")
55+
56+
57+
def dump_value(value: Any): # noqa: PLR0911
58+
match value:
59+
case None:
60+
return None
61+
case (
62+
onnx.defs.OpSchema()
63+
| onnx.defs.OpSchema.Attribute()
64+
| onnx.defs.OpSchema.FormalParameter()
65+
| onnx.defs.OpSchema.TypeConstraintParam()
66+
):
67+
return dump_onnx_object(value)
68+
case onnx.FunctionProto():
69+
return onnx.printer.to_text(value)
70+
case enum.Enum():
71+
return dump_enum(value)
72+
case dict():
73+
return {k: dump_value(v) for k, v in value.items()}
74+
case float() | int() | str():
75+
return value
76+
case Iterable():
77+
return type(value)(dump_value(v) for v in value) # type: ignore
78+
79+
raise RuntimeError(f"Unhandled type {type(value)}")
80+
81+
82+
def main():
83+
parser = argparse.ArgumentParser(description="Output ONNX spec in YAML format.")
84+
parser.add_argument("--output", help="Output directory", required=True)
85+
args = parser.parse_args()
86+
87+
schemas = onnx.defs.get_all_schemas_with_history()
88+
yaml = YAML()
89+
yaml.indent(mapping=2, sequence=4, offset=2)
90+
91+
latest_versions: dict = {}
92+
for schema in schemas:
93+
if schema.name in latest_versions:
94+
latest_versions[schema.name] = max(
95+
latest_versions[schema.name], schema.since_version
96+
)
97+
else:
98+
latest_versions[schema.name] = schema.since_version
99+
for schema in schemas:
100+
schema_dict = dump_value(schema)
101+
domain = schema.domain or "ai.onnx"
102+
outdir = pathlib.Path(args.output) / domain
103+
if latest_versions[schema.name] != schema.since_version:
104+
outdir = outdir / "old"
105+
else:
106+
outdir = outdir / "latest"
107+
outdir.mkdir(parents=True, exist_ok=True)
108+
path = outdir / f"{schema.name}-{schema.since_version}.yaml"
109+
with open(path, "w", encoding="utf-8") as f:
110+
print(f"Writing {path}")
111+
yaml.dump(schema_dict, f)
112+
113+
114+
if __name__ == "__main__":
115+
main()

0 commit comments

Comments
 (0)