|
| 1 | +# Copyright (c) ONNX Project Contributors |
| 2 | +# |
| 3 | +# SPDX-License-Identifier: Apache-2.0 |
| 4 | +"""Output ONNX spec in YAML format. |
| 5 | +
|
| 6 | +Usage: |
| 7 | +
|
| 8 | + python spec_to_yaml.py --output onnx-spec/defs |
| 9 | +""" |
| 10 | + |
| 11 | +from __future__ import annotations |
| 12 | + |
| 13 | +import argparse |
| 14 | +import enum |
| 15 | +import pathlib |
| 16 | +from collections.abc import Iterable |
| 17 | +from typing import Any |
| 18 | + |
| 19 | +from ruamel.yaml import YAML |
| 20 | + |
| 21 | +import onnx |
| 22 | + |
| 23 | + |
| 24 | +def dump_onnx_object( |
| 25 | + onnx_obj: onnx.defs.OpSchema |
| 26 | + | onnx.defs.OpSchema.Attribute |
| 27 | + | onnx.defs.OpSchema.FormalParameter |
| 28 | + | onnx.defs.OpSchema.TypeConstraintParam, |
| 29 | +) -> dict[str, Any]: |
| 30 | + res = {} |
| 31 | + for attr in dir(onnx_obj): |
| 32 | + if attr.startswith("_"): |
| 33 | + continue |
| 34 | + value = getattr(onnx_obj, attr) |
| 35 | + if isinstance(value, enum.EnumType) or "nanobind" in str(type(value)): |
| 36 | + continue |
| 37 | + if attr == "default_value" and isinstance( |
| 38 | + onnx_obj, onnx.defs.OpSchema.Attribute |
| 39 | + ): |
| 40 | + value = onnx.helper.get_attribute_value(value) |
| 41 | + value = dump_value(value) |
| 42 | + if not value: |
| 43 | + continue |
| 44 | + res[attr] = value |
| 45 | + return res |
| 46 | + |
| 47 | + |
| 48 | +def dump_enum(value: enum.Enum) -> str | None: |
| 49 | + for member in type(value): |
| 50 | + if member == value: |
| 51 | + if member.name == "Unknown": |
| 52 | + return None |
| 53 | + return member.name |
| 54 | + raise RuntimeError(f"Unhandled type {type(value)}") |
| 55 | + |
| 56 | + |
| 57 | +def dump_value(value: Any): # noqa: PLR0911 |
| 58 | + match value: |
| 59 | + case None: |
| 60 | + return None |
| 61 | + case ( |
| 62 | + onnx.defs.OpSchema() |
| 63 | + | onnx.defs.OpSchema.Attribute() |
| 64 | + | onnx.defs.OpSchema.FormalParameter() |
| 65 | + | onnx.defs.OpSchema.TypeConstraintParam() |
| 66 | + ): |
| 67 | + return dump_onnx_object(value) |
| 68 | + case onnx.FunctionProto(): |
| 69 | + return onnx.printer.to_text(value) |
| 70 | + case enum.Enum(): |
| 71 | + return dump_enum(value) |
| 72 | + case dict(): |
| 73 | + return {k: dump_value(v) for k, v in value.items()} |
| 74 | + case float() | int() | str(): |
| 75 | + return value |
| 76 | + case Iterable(): |
| 77 | + return type(value)(dump_value(v) for v in value) # type: ignore |
| 78 | + |
| 79 | + raise RuntimeError(f"Unhandled type {type(value)}") |
| 80 | + |
| 81 | + |
| 82 | +def main(): |
| 83 | + parser = argparse.ArgumentParser(description="Output ONNX spec in YAML format.") |
| 84 | + parser.add_argument("--output", help="Output directory", required=True) |
| 85 | + args = parser.parse_args() |
| 86 | + |
| 87 | + schemas = onnx.defs.get_all_schemas_with_history() |
| 88 | + yaml = YAML() |
| 89 | + yaml.indent(mapping=2, sequence=4, offset=2) |
| 90 | + |
| 91 | + latest_versions: dict = {} |
| 92 | + for schema in schemas: |
| 93 | + if schema.name in latest_versions: |
| 94 | + latest_versions[schema.name] = max( |
| 95 | + latest_versions[schema.name], schema.since_version |
| 96 | + ) |
| 97 | + else: |
| 98 | + latest_versions[schema.name] = schema.since_version |
| 99 | + for schema in schemas: |
| 100 | + schema_dict = dump_value(schema) |
| 101 | + domain = schema.domain or "ai.onnx" |
| 102 | + outdir = pathlib.Path(args.output) / domain |
| 103 | + if latest_versions[schema.name] != schema.since_version: |
| 104 | + outdir = outdir / "old" |
| 105 | + else: |
| 106 | + outdir = outdir / "latest" |
| 107 | + outdir.mkdir(parents=True, exist_ok=True) |
| 108 | + path = outdir / f"{schema.name}-{schema.since_version}.yaml" |
| 109 | + with open(path, "w", encoding="utf-8") as f: |
| 110 | + print(f"Writing {path}") |
| 111 | + yaml.dump(schema_dict, f) |
| 112 | + |
| 113 | + |
| 114 | +if __name__ == "__main__": |
| 115 | + main() |
0 commit comments