Skip to content

Commit dbef0ca

Browse files
TF: add onnx_export option and onnx_comp_floor_div (#1453)
1 parent 52abd3a commit dbef0ca

File tree

4 files changed

+39
-9
lines changed

4 files changed

+39
-9
lines changed

returnn/tensor/_dim_extra.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1182,8 +1182,12 @@ def _bin_op_tf(a, b):
11821182
elif kind == "mul":
11831183
return a * b
11841184
elif kind in ("floordiv", "truediv"): # truediv assumes there is no remainder
1185+
if util.is_onnx_export_global():
1186+
return tf_util.onnx_compat_floor_div(a, b)
11851187
return a // b
11861188
elif kind == "ceildiv":
1189+
if util.is_onnx_export_global():
1190+
return -tf_util.onnx_compat_floor_div(-a, b)
11871191
return -(-a // b)
11881192
else:
11891193
raise ValueError("unknown op kind %r" % op.kind)

returnn/tf/frontend_low_level/_backend.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import tensorflow as tf
99

1010
import returnn.tf.compat as tf_compat
11-
from returnn.util.basic import NotSpecified
11+
from returnn.util.basic import NotSpecified, is_onnx_export_global
1212
from returnn.tensor import Tensor, Dim
1313
from returnn.tf.util import basic as tf_util
1414

@@ -132,14 +132,17 @@ def combine_raw(a: tf.Tensor, kind: str, b: tf.Tensor) -> tf.Tensor:
132132
:return: a `kind` b
133133
"""
134134
assert a.shape.ndims == b.shape.ndims or a.shape.ndims == 0 or b.shape.ndims == 0
135-
kind = {
136-
"sub": "subtract",
137-
"mul": "multiply",
138-
}.get(kind, kind)
139-
op = getattr(tf, kind, None) # e.g. tf.add
140-
# In tf v2, some ops like floordiv or mod exist in the tf.math namespace instead
141-
if op is None:
142-
op = getattr(tf.math, kind)
135+
if kind == "floordiv" and is_onnx_export_global():
136+
op = tf_util.onnx_compat_floor_div
137+
else:
138+
kind = {
139+
"sub": "subtract",
140+
"mul": "multiply",
141+
}.get(kind, kind)
142+
op = getattr(tf, kind, None) # e.g. tf.add
143+
# In tf v2, some ops like floordiv or mod exist in the tf.math namespace instead
144+
if op is None:
145+
op = getattr(tf.math, kind)
143146
with tf_util.same_control_flow_ctx([a, b]):
144147
return op(a, b)
145148

returnn/tf/util/basic.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7820,3 +7820,14 @@ def is_axis_from_description_recurrent(axis, network, data):
78207820
if axis == single_step_dim:
78217821
return True
78227822
return False
7823+
7824+
7825+
def onnx_compat_floor_div(a: tf.Tensor, b: tf.Tensor) -> tf.Tensor:
7826+
"""
7827+
:param a:
7828+
:param b:
7829+
:return: for onnx export compatible floor_divide
7830+
"""
7831+
# https://github.com/onnx/tensorflow-onnx/issues/2174
7832+
abs_a, abs_b = tf.abs(a), tf.abs(b)
7833+
return tf.where(a * b >= 0, a // b, -abs_a // abs_b - tf.cast(abs_a % abs_b != 0, dtype=a.dtype))

returnn/util/basic.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3627,6 +3627,18 @@ def get_global_inf_value() -> float:
36273627
return config.float("inf_value", _default_global_inf_value)
36283628

36293629

3630+
def is_onnx_export_global() -> bool:
3631+
"""
3632+
:return: False by default. If 'onnx_export' is set in the config, that value is used.
3633+
"""
3634+
from returnn.config import get_global_config
3635+
3636+
config = get_global_config(raise_exception=False)
3637+
if not config:
3638+
return False
3639+
return config.bool("onnx_export", False)
3640+
3641+
36303642
# See :func:`maybe_restart_returnn_with_atfork_patch` below for why you might want to use this.
36313643
_c_code_patch_atfork = """
36323644
#define _GNU_SOURCE

0 commit comments

Comments
 (0)