Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
"packaging",
"protobuf",
)
ONNX_IR = "onnx_ir==0.1.7"
ONNX_IR = "onnx_ir==0.1.9"
ONNX_IR_MAIN = "git+https://github.com/onnx/ir-py.git@main#egg=onnx_ir"


Expand Down
154 changes: 2 additions & 152 deletions onnxscript/ir/__init__.py
Original file line number Diff line number Diff line change
@@ -1,154 +1,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""In-memory intermediate representation for ONNX graphs."""

__all__ = [
# Modules
"serde",
"traversal",
"convenience",
"external_data",
"tape",
# IR classes
"Tensor",
"ExternalTensor",
"StringTensor",
"LazyTensor",
"SymbolicDim",
"Shape",
"TensorType",
"OptionalType",
"SequenceType",
"SparseTensorType",
"TypeAndShape",
"Value",
"Attr",
"RefAttr",
"Node",
"Function",
"Graph",
"GraphView",
"Model",
# Constructors
"AttrFloat32",
"AttrFloat32s",
"AttrGraph",
"AttrGraphs",
"AttrInt64",
"AttrInt64s",
"AttrSparseTensor",
"AttrSparseTensors",
"AttrString",
"AttrStrings",
"AttrTensor",
"AttrTensors",
"AttrTypeProto",
"AttrTypeProtos",
"Input",
# Protocols
"ArrayCompatible",
"DLPackCompatible",
"TensorProtocol",
"ValueProtocol",
"ModelProtocol",
"NodeProtocol",
"GraphProtocol",
"GraphViewProtocol",
"AttributeProtocol",
"ReferenceAttributeProtocol",
"SparseTensorProtocol",
"SymbolicDimProtocol",
"ShapeProtocol",
"TypeProtocol",
"MapTypeProtocol",
"FunctionProtocol",
# Enums
"AttributeType",
"DataType",
# Types
"OperatorIdentifier",
# Protobuf compatible types
"TensorProtoTensor",
# Conversion functions
"from_proto",
"from_onnx_text",
"to_proto",
# Convenience constructors
"tensor",
"node",
# Pass infrastructure
"passes",
# IO
"load",
"save",
]

from onnx_ir import (
ArrayCompatible,
Attr,
AttrFloat32,
AttrFloat32s,
AttrGraph,
AttrGraphs,
AttributeProtocol,
AttributeType,
AttrInt64,
AttrInt64s,
AttrSparseTensor,
AttrSparseTensors,
AttrString,
AttrStrings,
AttrTensor,
AttrTensors,
AttrTypeProto,
AttrTypeProtos,
DataType,
DLPackCompatible,
ExternalTensor,
Function,
FunctionProtocol,
Graph,
GraphProtocol,
GraphView,
GraphViewProtocol,
Input,
LazyTensor,
MapTypeProtocol,
Model,
ModelProtocol,
Node,
NodeProtocol,
OperatorIdentifier,
OptionalType,
RefAttr,
ReferenceAttributeProtocol,
SequenceType,
Shape,
ShapeProtocol,
SparseTensorProtocol,
SparseTensorType,
StringTensor,
SymbolicDim,
SymbolicDimProtocol,
Tensor,
TensorProtocol,
TensorProtoTensor,
TensorType,
TypeAndShape,
TypeProtocol,
Value,
ValueProtocol,
convenience,
external_data,
from_onnx_text,
from_proto,
load,
node,
passes,
save,
serde,
tape,
tensor,
to_proto,
traversal,
)
# pylint: disable=wildcard-import,unused-wildcard-import
from onnx_ir import * # type: ignore # noqa: F403
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@

class Bfloat16ConversionTest(unittest.TestCase):
def setUp(self) -> None:
self.v0 = ir.Input(name="v0", shape=ir.Shape([2, 3, 4]))
self.v0 = ir.val(name="v0", shape=ir.Shape([2, 3, 4]))
self.v0.dtype = ir.DataType.BFLOAT16
self.v1 = ir.Input(name="v1", shape=ir.Shape([2, 3, 4]))
self.v1 = ir.val(name="v1", shape=ir.Shape([2, 3, 4]))
self.v1.dtype = ir.DataType.BFLOAT16
self.v2 = ir.Input(name="v2", shape=ir.Shape([2, 3, 4]))
self.v2 = ir.val(name="v2", shape=ir.Shape([2, 3, 4]))
self.v2.dtype = ir.DataType.BFLOAT16

self.add_node = ir.Node("", "Add", inputs=(self.v0, self.v1), num_outputs=1)
Expand Down
10 changes: 5 additions & 5 deletions onnxscript/rewriter/rules/common/_basic_rules_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,14 +421,14 @@ def _convert_shape(shape, name):
if isinstance(shape, np.ndarray):
shape = tape.initializer(ir.Tensor(shape, name=name))
elif isinstance(shape, (list, tuple)):
shape = ir.Input(name, ir.Shape(shape), ir.TensorType(ir.DataType.INT64))
shape = ir.val(name, ir.DataType.INT64, ir.Shape(shape))
tape.graph_like.inputs.append(shape)
else:
raise TypeError(f"Unsupported type {type(shape)} for shape.")
return shape

x = ir.Input("X", ir.Shape(input_shape), ir.TensorType(ir.DataType.FLOAT))
y = ir.Input("Y", type=ir.TensorType(ir.DataType.FLOAT))
x = ir.val("X", ir.DataType.FLOAT, ir.Shape(input_shape))
y = ir.val("Y", ir.DataType.FLOAT)
tape = ir.tape.Tape(ir.Graph([x], [y], nodes=[], opset_imports={"": 20}))

# Build the graph.
Expand Down Expand Up @@ -554,8 +554,8 @@ def test_unsupported_reshape_reshape(self, shape2, error_msg):
class Flatten2ReshapeTest(unittest.TestCase):
@staticmethod
def create_model(input_shape, axis=1):
x = ir.Input("X", ir.Shape(input_shape), ir.TensorType(ir.DataType.FLOAT))
y = ir.Input("Y", type=ir.TensorType(ir.DataType.FLOAT))
x = ir.val("X", ir.DataType.FLOAT, ir.Shape(input_shape))
y = ir.val("Y", ir.DataType.FLOAT)
tape = ir.tape.Tape(ir.Graph([x], [y], nodes=[], opset_imports={"": 20}))

# Build the graph.
Expand Down
8 changes: 4 additions & 4 deletions onnxscript/rewriter/rules/common/_fuse_pad_into_conv_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,13 @@ def build_model(

# Register operations in the tape
idtype = ir.DataType.UINT8 if op_type == "ConvInteger" else ir.DataType.FLOAT
x = ir.Input("X", shape=input_shape, type=ir.TensorType(idtype))
x = ir.val("X", shape=input_shape, type=ir.TensorType(idtype))
y = tape.op("Pad", inputs=[x, *pad_inputs], attributes=pad_attributes)
y = tape.op(
op_type,
inputs=[y, self.get_conv_weights(weight_shape, tape)],
attributes=conv_attributes,
output=ir.Input("Y", shape=output_shape, type=ir.TensorType(x.dtype)),
output=ir.val("Y", shape=output_shape, type=ir.TensorType(x.dtype)),
)
if op_type == "ConvInteger":
y.dtype = ir.DataType.INT32
Expand Down Expand Up @@ -290,12 +290,12 @@ def build_model(
raise ValueError(f"Unsupported type for pad input ({x}): {type(x)}.")

# Register operations in the tape
x = ir.Input("X", shape=input_shape, type=ir.TensorType(ir.DataType.FLOAT))
x = ir.val("X", shape=input_shape, type=ir.TensorType(ir.DataType.FLOAT))
y = tape.op(
"Conv",
inputs=[x, *conv_inputs],
attributes=conv_attributes,
output=ir.Input("Y", shape=output_shape, type=x.type),
output=ir.val("Y", shape=output_shape, type=x.type),
)

# Build the model
Expand Down
8 changes: 4 additions & 4 deletions onnxscript/rewriter/rules/common/_matmul_add_to_gemm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ def get_test_model(
bias_shape = weight_shape[0] if transB else weight_shape[-1]
output_shape = ir.Shape(("?",) * input_shape.rank())

x = ir.Input("X", shape=input_shape, type=ir.TensorType(ir.DataType.FLOAT))
x = ir.val("X", shape=input_shape, type=ir.TensorType(ir.DataType.FLOAT))

if weight_as_inputs:
w = ir.Input("W", shape=weight_shape, type=ir.TensorType(ir.DataType.FLOAT))
w = ir.val("W", shape=weight_shape, type=ir.TensorType(ir.DataType.FLOAT))
inputs.append(w)
else:
w = ir.tensor(
Expand All @@ -58,7 +58,7 @@ def get_test_model(
w = tape.initializer(w)

if bias_as_inputs:
b = ir.Input(
b = ir.val(
"B", shape=ir.Shape([bias_shape]), type=ir.TensorType(ir.DataType.FLOAT)
)
inputs.append(b)
Expand All @@ -77,7 +77,7 @@ def get_test_model(
y = tape.op(
"Add",
inputs=[y, b],
output=ir.Input("Y", shape=output_shape, type=ir.TensorType(ir.DataType.FLOAT)),
output=ir.val("Y", shape=output_shape, type=ir.TensorType(ir.DataType.FLOAT)),
)

# Build the model
Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ classifiers = [
dependencies = [
"ml_dtypes",
"numpy",
"onnx_ir>=0.1.7,<2", # Expect onnx_ir to have a breaking change in 2.0. If not, extend this range.
"onnx_ir>=0.1.9,<2", # Expect onnx_ir to have a breaking change in 2.0. If not, extend this range.
"onnx>=1.16",
"packaging",
"typing_extensions>=4.10",
Expand All @@ -41,7 +41,6 @@ onnxscript = ["py.typed"]
onnx = ["py.typed"]

[tool.pytest.ini_options]
filterwarnings = ["ignore::UserWarning", "ignore::DeprecationWarning"]
addopts = "-rsfEX --tb=short --color=yes"

[tool.mypy]
Expand Down
Loading