Skip to content

Commit 01bca17

Browse files
authored
Update math.py
1 parent 1c3af1f commit 01bca17

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

tf2onnx/onnx_opset/math.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -411,13 +411,13 @@ def version_1(cls, ctx, node, **kwargs):
411411
@classmethod
412412
def version_10(cls, ctx, node, **kwargs):
413413
if (ctx.get_dtype(node.input[0]) in [onnx_pb.TensorProto.INT8, onnx_pb.TensorProto.UINT8] and
414-
ctx.get_dtype(node.input[1]) in [onnx_pb.TensorProto.INT8, onnx_pb.TensorProto.UINT8] and
415-
ctx.get_dtype(node.output[0]) == onnx_pb.TensorProto.INT32):
414+
ctx.get_dtype(node.input[1]) in [onnx_pb.TensorProto.INT8, onnx_pb.TensorProto.UINT8] and
415+
ctx.get_dtype(node.output[0]) == onnx_pb.TensorProto.INT32):
416416
node.type = "MatMulInteger"
417417
zero_point_node_a = ctx.make_const(utils.make_name("zero_point_a"),
418-
np.zeros(1, dtype=utils.map_onnx_to_numpy_type(ctx.get_dtype(node.input[0]))))
418+
np.zeros(1, dtype=utils.map_onnx_to_numpy_type(ctx.get_dtype(node.input[0]))))
419419
zero_point_node_b = ctx.make_const(utils.make_name("zero_point_b"),
420-
np.zeros(1, dtype=utils.map_onnx_to_numpy_type(ctx.get_dtype(node.input[1]))))
420+
np.zeros(1, dtype=utils.map_onnx_to_numpy_type(ctx.get_dtype(node.input[1]))))
421421
ctx.replace_inputs(node, [node.input[0], node.input[1], zero_point_node_a.output[0], zero_point_node_b.output[0]])
422422
cls.version_1(ctx, node, **kwargs)
423423

0 commit comments

Comments
 (0)