diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 6aae8efab3..8317d2be63 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -608,6 +608,9 @@ def identity(node: ir.Node, op, state: OptimizerState) -> ReturnValue: input = node.inputs[0] output = node.outputs[0] if input is not None and output is not None: + input.shape = _merge_shapes(input.shape, output.shape) + if input.type is None: + input.type = output.type state.set_sym_value(output, input) return None