-
Notifications
You must be signed in to change notification settings - Fork 447
Description
Describe the bug
Conversion of tf.strided_slice in some specific cases produces ONNX graph which doesn't work like the original one. For example I get incorrect ONNX graph when try to convert the next tensorflow code:
x = x[..., tf.newaxis, tf.newaxis]
In this particular case after tf2onnx conversion source StridedSlice operation is splitted into 2 ONNX nodes: Unsqueeze + Slice. Unsqueeze op in this case has 'axes' input equals to [1, 2], but according TF result it should have 'axes'=[2, 3].
I suppose that the problem is related to this block of code
System information
- Ubuntu 20.04.3 LTS
- Tensorflow 2.8.0
- Python 3.8.10 [GCC 9.3.0] on linux
To Reproduce
The next snippet illustrates the problem:
import tensorflow as tf
import tf2onnx
import numpy as np
import onnx
import onnxruntime as ort
class OrtModel:
def __init__(self, onnx_path, custom_ops=None):
so = ort.SessionOptions()
if custom_ops is not None:
so.register_custom_ops_library(custom_ops)
self._session = ort.InferenceSession(onnx_path, so)
self._model = onnx.load(onnx_path)
def predict(self, inputs, output_names=None):
return self._session.run(output_names, inputs)
def __call__(self, *args, **kwargs):
return self.predict(*args, **kwargs)
def save_onnx(self, onnx_path):
onnx.save(self._model, onnx_path)
def graph_with_strided_slice(x):
return x[..., tf.newaxis, tf.newaxis]
signature = [tf.TensorSpec(shape=[None, None], dtype=tf.float64)]
tf_func = tf.function(func=graph_with_strided_slice, input_signature=signature)
test_data = np.zeros(shape=(2, 0))
print("Test data (shape)", test_data.shape)
test_tensor = tf.convert_to_tensor(test_data)
print("Tensorflow inference result (shape):", tf_func(test_tensor).shape)
SAVED_MODEL_PATH = "model.onnx"
proto, _ = tf2onnx.convert.from_function(
function=tf_func,
input_signature=signature,
output_path=SAVED_MODEL_PATH,
opset=14
)
input_names = [x.name for x in proto.graph.input]
output_names = [x.name for x in proto.graph.output]
model = OrtModel(SAVED_MODEL_PATH)
onnx_res = model({name: val for name, val in zip(input_names, [test_data])}, output_names)
print("OnnxRuntime inference result (shape):", onnx_res[0].shape)
I get the next output after the snippet is completed:
Test data (shape) (2, 0)
Tensorflow inference result (shape): (2, 0, 1, 1)
OnnxRuntime inference result (shape): (2, 1, 1, 0)
Additional context
It seems I've managed to fix the issue with the next patch:
diff --git a/tf2onnx/onnx_opset/tensor.py b/tf2onnx/onnx_opset/tensor.py
index 4956422..d961777 100644
--- a/tf2onnx/onnx_opset/tensor.py
+++ b/tf2onnx/onnx_opset/tensor.py
@@ -963,6 +963,29 @@ class StridedSlice:
unqueeze_at.append(bit + ellipsis_gap)
begin_mask |= 1 << bit
end_mask |= 1 << bit
+
+ if ellipsis_mask:
+ unqueeze_at = []
+ ellipsis_gap = 0
+ num_new = 0
+ end_mask = node.get_attr("end_mask")
+ end_mask = end_mask.i if end_mask is not None else 0
+ begin_mask = node.get_attr("begin_mask")
+ begin_mask = begin_mask.i if begin_mask is not None else 0
+
+ for bit in range(32):
+ new_axis_flag = (new_axis_mask >> bit) & 1
+ ellipsis_flag = (ellipsis_mask >> bit) & 1
+ num_new += not ellipsis_flag and new_axis_flag
+
+ for bit in range(32):
+ if (ellipsis_mask >> bit) & 1:
+ ellipsis_gap = len(ctx.get_shape(input_x)) - param_rank + num_new + 1
+ elif (new_axis_mask >> bit) & 1:
+ effective_bit = bit if not ellipsis_gap else bit + ellipsis_gap - 1
+ unqueeze_at.append(effective_bit)
+ begin_mask |= 1 << bit
+ end_mask |= 1 << bit
input_x = GraphBuilder(ctx).make_unsqueeze(
{'data': input_x, 'axes': unqueeze_at})
But anyway the solution above is not properly tested and I'm not sure it doesn't affect something in a following code.