@@ -411,13 +411,13 @@ def version_1(cls, ctx, node, **kwargs):
411
411
@classmethod
412
412
def version_10 (cls , ctx , node , ** kwargs ):
413
413
if (ctx .get_dtype (node .input [0 ]) in [onnx_pb .TensorProto .INT8 , onnx_pb .TensorProto .UINT8 ] and
414
- ctx .get_dtype (node .input [1 ]) in [onnx_pb .TensorProto .INT8 , onnx_pb .TensorProto .UINT8 ] and
415
- ctx .get_dtype (node .output [0 ]) == onnx_pb .TensorProto .INT32 ):
414
+ ctx .get_dtype (node .input [1 ]) in [onnx_pb .TensorProto .INT8 , onnx_pb .TensorProto .UINT8 ] and
415
+ ctx .get_dtype (node .output [0 ]) == onnx_pb .TensorProto .INT32 ):
416
416
node .type = "MatMulInteger"
417
417
zero_point_node_a = ctx .make_const (utils .make_name ("zero_point_a" ),
418
- np .zeros (1 , dtype = utils .map_onnx_to_numpy_type (ctx .get_dtype (node .input [0 ]))))
418
+ np .zeros (1 , dtype = utils .map_onnx_to_numpy_type (ctx .get_dtype (node .input [0 ]))))
419
419
zero_point_node_b = ctx .make_const (utils .make_name ("zero_point_b" ),
420
- np .zeros (1 , dtype = utils .map_onnx_to_numpy_type (ctx .get_dtype (node .input [1 ]))))
420
+ np .zeros (1 , dtype = utils .map_onnx_to_numpy_type (ctx .get_dtype (node .input [1 ]))))
421
421
ctx .replace_inputs (node , [node .input [0 ], node .input [1 ], zero_point_node_a .output [0 ], zero_point_node_b .output [0 ]])
422
422
cls .version_1 (ctx , node , ** kwargs )
423
423
0 commit comments