diff --git a/onnxscript/rewriter/redundant_scatter_nd.py b/onnxscript/rewriter/redundant_scatter_nd.py index 1ba6477f52..4d96360cd7 100644 --- a/onnxscript/rewriter/redundant_scatter_nd.py +++ b/onnxscript/rewriter/redundant_scatter_nd.py @@ -24,10 +24,6 @@ from onnxscript.rewriter import pattern as orp -def fail(*args): - return onnxscript.rewriter.MatchResult().fail(*args) - - class ScatterAllDynamic(orp.RewriteRuleClassBase): def pattern(self, op, data, axis, transposed_data, updates): # Construct update-indices spanning an entire axis: @@ -41,24 +37,26 @@ def pattern(self, op, data, axis, transposed_data, updates): def check(self, context, data, axis, transposed_data, **_): # Check that updated-indices represent the full range of the first dimension of the transposed data. # That is: check that the data.shape[axis] matches transposed_data.shape[0]. + result = onnxscript.rewriter.MatchResult() axis_value = ir_utils.get_singleton_value(axis) if not isinstance(axis_value, int): - return fail("Axis value must be a constant integer.", axis) + return result.fail("Axis value must be a constant integer.", axis) shape: ir.Shape | None = data.shape if shape is None: - return fail("Data shape is not statically known.", data) + return result.fail("Data shape is not statically known.", data) updated_dim_value = shape[axis_value] transposed_data_shape: ir.Shape | None = transposed_data.shape if transposed_data_shape is None: - return fail("Transposed data shape is not statically known.", transposed_data) + return result.fail( + "Transposed data shape is not statically known.", transposed_data + ) actual_dim_value = transposed_data_shape[0] if updated_dim_value != actual_dim_value: # The first dimension of the transposed data does not match the updated dimension, # so we cannot apply this rule. - return fail( + return result.fail( "The first dimension of the transposed data does not match the updated dimension.", - data, - transposed_data, + [data, transposed_data], ) return True @@ -81,20 +79,23 @@ def check(self, context, data, indices, updates, **_): """Check if the ScatterND is redundant due to static indices covering entire tensor.""" # To validate data can be replaced directly by updates, we need to check the following: # 1. they have the same shape + result = onnxscript.rewriter.MatchResult() if data.shape is None: - return fail("The value 'data' shape is not statically known.", data) + return result.fail("The value 'data' shape is not statically known.", data) if updates.shape is None: - return fail("The value 'updates' shape is not statically known.", updates) + return result.fail("The value 'updates' shape is not statically known.", updates) if data.shape != updates.shape: - return fail("The shape of 'data' and 'updates' are different.", data, updates) + return result.fail( + "The shape of 'data' and 'updates' are different.", [data, updates] + ) # 2. the indices is referring to the whole data, which is from 0 to data.shape[0] if indices.const_value is None: - return fail("The value 'indices' is not statically known.", indices) + return result.fail("The value 'indices' is not statically known.", indices) expected_indices = [[i] for i in range(data.shape[0])] actual_indices = indices.const_value.numpy().tolist() if actual_indices != expected_indices: - return fail("The 'indices' is not referring to the whole data.", indices) + return result.fail("The 'indices' is not referring to the whole data.", indices) return True