Skip to content

Bug with TF StridedSlice op conversion #1849

@iolkhovsky

Description

@iolkhovsky

Describe the bug
Conversion of tf.strided_slice in some specific cases produces ONNX graph which doesn't work like the original one. For example I get incorrect ONNX graph when try to convert the next tensorflow code:

x = x[..., tf.newaxis, tf.newaxis]

In this particular case after tf2onnx conversion source StridedSlice operation is splitted into 2 ONNX nodes: Unsqueeze + Slice. Unsqueeze op in this case has 'axes' input equals to [1, 2], but according TF result it should have 'axes'=[2, 3].

I suppose that the problem is related to this block of code

System information

  • Ubuntu 20.04.3 LTS
  • Tensorflow 2.8.0
  • Python 3.8.10 [GCC 9.3.0] on linux

To Reproduce
The next snippet illustrates the problem:

import tensorflow as tf
import tf2onnx
import numpy as np
import onnx
import onnxruntime as ort


class OrtModel:

    def __init__(self, onnx_path, custom_ops=None):
        so = ort.SessionOptions()
        if custom_ops is not None:
            so.register_custom_ops_library(custom_ops)
        self._session = ort.InferenceSession(onnx_path, so)
        self._model = onnx.load(onnx_path)

    def predict(self, inputs, output_names=None):
        return self._session.run(output_names, inputs)

    def __call__(self, *args, **kwargs):
        return self.predict(*args, **kwargs)

    def save_onnx(self, onnx_path):
        onnx.save(self._model, onnx_path)
        

def graph_with_strided_slice(x):
    return x[..., tf.newaxis, tf.newaxis]


signature = [tf.TensorSpec(shape=[None, None], dtype=tf.float64)]
tf_func = tf.function(func=graph_with_strided_slice, input_signature=signature)

test_data = np.zeros(shape=(2, 0))
print("Test data (shape)", test_data.shape)

test_tensor = tf.convert_to_tensor(test_data)
print("Tensorflow inference result (shape):", tf_func(test_tensor).shape)

SAVED_MODEL_PATH = "model.onnx"
proto, _ = tf2onnx.convert.from_function(
    function=tf_func,
    input_signature=signature,
    output_path=SAVED_MODEL_PATH,
    opset=14
)
input_names = [x.name for x in proto.graph.input]
output_names = [x.name for x in proto.graph.output]

model = OrtModel(SAVED_MODEL_PATH)
onnx_res = model({name: val for name, val in zip(input_names, [test_data])}, output_names)
print("OnnxRuntime inference result (shape):", onnx_res[0].shape)

I get the next output after the snippet is completed:

Test data (shape) (2, 0)
Tensorflow inference result (shape): (2, 0, 1, 1)
OnnxRuntime inference result (shape): (2, 1, 1, 0)

Additional context
It seems I've managed to fix the issue with the next patch:

diff --git a/tf2onnx/onnx_opset/tensor.py b/tf2onnx/onnx_opset/tensor.py
index 4956422..d961777 100644
--- a/tf2onnx/onnx_opset/tensor.py
+++ b/tf2onnx/onnx_opset/tensor.py
@@ -963,6 +963,29 @@ class StridedSlice:
                     unqueeze_at.append(bit + ellipsis_gap)
                     begin_mask |= 1 << bit
                     end_mask |= 1 << bit
+            
+            if ellipsis_mask:
+                unqueeze_at = []
+                ellipsis_gap = 0
+                num_new = 0
+                end_mask = node.get_attr("end_mask")
+                end_mask = end_mask.i if end_mask is not None else 0
+                begin_mask = node.get_attr("begin_mask")
+                begin_mask = begin_mask.i if begin_mask is not None else 0
+
+                for bit in range(32):
+                    new_axis_flag = (new_axis_mask >> bit) & 1
+                    ellipsis_flag = (ellipsis_mask >> bit) & 1
+                    num_new += not ellipsis_flag and new_axis_flag
+
+                for bit in range(32):
+                    if (ellipsis_mask >> bit) & 1:
+                        ellipsis_gap = len(ctx.get_shape(input_x)) - param_rank + num_new + 1
+                    elif (new_axis_mask >> bit) & 1:
+                        effective_bit = bit if not ellipsis_gap else bit + ellipsis_gap - 1
+                        unqueeze_at.append(effective_bit)
+                        begin_mask |= 1 << bit
+                        end_mask |= 1 << bit
 
             input_x = GraphBuilder(ctx).make_unsqueeze(
                 {'data': input_x, 'axes': unqueeze_at})

But anyway the solution above is not properly tested and I'm not sure it doesn't affect something in a following code.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions