diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 62c28894c0..4ff253f374 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -491,9 +491,7 @@ def cast(node: ir.Node, op, state: OptimizerState) -> ReturnValue: # should handle this. Only the optimization to eliminate redundant Cast ops # should be needed here. - input_shape = input.shape - if input_shape is not None: - output.shape = input_shape.copy() + output.shape = _merge_shapes(output.shape, input.shape) input_dtype = _get_input_element_type(node, 0) output_dtype = _get_int_attribute(node, "to", None)