@@ -414,11 +414,12 @@ def version_10(cls, ctx, node, **kwargs):
414
414
ctx .get_dtype (node .input [1 ]) in [onnx_pb .TensorProto .INT8 , onnx_pb .TensorProto .UINT8 ] and
415
415
ctx .get_dtype (node .output [0 ]) == onnx_pb .TensorProto .INT32 ):
416
416
node .type = "MatMulInteger"
417
- zero_point_node_a = ctx .make_const (utils .make_name ("zero_point_a" ),
418
- np .zeros (1 , dtype = utils .map_onnx_to_numpy_type (ctx .get_dtype (node .input [0 ]))))
419
- zero_point_node_b = ctx .make_const (utils .make_name ("zero_point_b" ),
420
- np .zeros (1 , dtype = utils .map_onnx_to_numpy_type (ctx .get_dtype (node .input [1 ]))))
421
- ctx .replace_inputs (node , [node .input [0 ], node .input [1 ], zero_point_node_a .output [0 ], zero_point_node_b .output [0 ]])
417
+ zpdata_a = np .zeros (1 , dtype = utils .map_onnx_to_numpy_type (ctx .get_dtype (node .input [0 ])))
418
+ zero_point_node_a = ctx .make_const (utils .make_name ("zero_point_a" ), zpdata_a )
419
+ zpdata_b = np .zeros (1 , dtype = utils .map_onnx_to_numpy_type (ctx .get_dtype (node .input [1 ])))
420
+ zero_point_node_b = ctx .make_const (utils .make_name ("zero_point_b" ), zpdata_b )
421
+ ctx .replace_inputs (node , [node .input [0 ], node .input [1 ],
422
+ zero_point_node_a .output [0 ], zero_point_node_b .output [0 ]])
422
423
cls .version_1 (ctx , node , ** kwargs )
423
424
424
425
@tf_op ("Erf" )
0 commit comments