Skip to content

Commit afbddcb

Browse files
fix: div return type for integer arguments
1 parent bebf52f commit afbddcb

File tree

2 files changed

+37
-1
lines changed

2 files changed

+37
-1
lines changed

onnx2torch/node_converters/binary_math_operations.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,19 @@
1515
from onnx2torch.utils.common import old_style_broadcast
1616
from onnx2torch.utils.common import onnx_mapping_from_node
1717

18+
19+
def _onnx_div(first: torch.Tensor, second: torch.Tensor) -> torch.Tensor:
20+
if first.is_floating_point() or second.is_floating_point(): # float division
21+
return torch.div(first, second)
22+
23+
return torch.div(first, second, rounding_mode='trunc') # integer division
24+
25+
1826
_TORCH_FUNCTION_FROM_ONNX_TYPE = {
1927
'Add': torch.add,
2028
'Sub': torch.sub,
2129
'Mul': torch.mul,
22-
'Div': torch.div,
30+
'Div': _onnx_div,
2331
}
2432

2533

tests/node_converters/binary_operations_test.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,31 @@ def test_math_binary_operation(op_type: str) -> None: # pylint: disable=missing
3030

3131
model = make_model_from_nodes(nodes=node, initializers=initializers, inputs_example=test_inputs)
3232
check_onnx_model(model, test_inputs)
33+
34+
35+
@pytest.mark.parametrize(
36+
'x, y',
37+
[
38+
(1, 2),
39+
(1, 5),
40+
(5, 30),
41+
(-1, 2),
42+
(-1, 5),
43+
(5, -30),
44+
(5, 2),
45+
(-5, 2),
46+
],
47+
)
48+
def test_div_operation(x: int, y: int) -> None: # pylint: disable=missing-function-docstring
49+
x_ = np.array(x, dtype=np.int64) # pylint: disable=invalid-name
50+
y_ = np.array(y, dtype=np.int64) # pylint: disable=invalid-name
51+
test_inputs = {'x': x_, 'y': y_}
52+
53+
node = onnx.helper.make_node(
54+
op_type='Div',
55+
inputs=['x', 'y'],
56+
outputs=['z'],
57+
)
58+
59+
model = make_model_from_nodes(nodes=node, initializers={}, inputs_example=test_inputs)
60+
check_onnx_model(model, test_inputs)

0 commit comments

Comments
 (0)