Description
Describe the bug
I wrapped the single operator AdjustContrastv2
in tensorflow 2.11 as a model and saved it as a frozen pb model file. At the same time I transformed the tf model by tf2onnx to get the onnx model, and the two have inconsistent results with large errors for the same input parameters.
Urgency
System information
- OS Platform and Distribution (e.g., Linux Ubuntu 18.04*): Ubuntu 20.04.5 LTS
- TensorFlow Version: 2.11.0.dev
- Python version: 3.8.10
- ONNX version (if applicable, e.g. 1.11*): 1.12.0
- ONNXRuntime version (if applicable, e.g. 1.11*): 1.12.1
To Reproduce
The link shows the two models before(tf.raw_ops.AdjustContrastv2_frozen_graph.pb
) and after(tf.raw_ops.AdjustContrastv2_model.onnx
) the conversion, and the two input parameters of the model(images.npy
and contrast_factor.npy
). The tf_save_model
dir is the model saved using tf.saved_model.save
.
The following code is running two models separately, feeding them the same input, but with inconsistent results.
import numpy as np
import tensorflow as tf
import onnxruntime as rt
images = np.load("images.npy")
contrast_factor = np.load("contrast_factor.npy")
onnx_model_path = "tf.raw_ops.AdjustContrastv2_model.onnx"
tf_model_path = "tf.raw_ops.AdjustContrastv2_frozen_graph.pb"
class OnnxModel():
def __init__(self, onnx_path):
self.onnx_session = rt.InferenceSession(onnx_path)
self.input_name = self.get_input_name(self.onnx_session)
self.output_name = self.get_output_name(self.onnx_session)
def get_output_name(self, onnx_session):
output_name = []
for node in onnx_session.get_outputs():
output_name.append(node.name)
return output_name
def get_input_name(self, onnx_session):
input_name = []
for node in onnx_session.get_inputs():
input_name.append(node.name)
return input_name
def get_input_feed(self, input_name, image_numpy):
i = 0
input_feed = {}
for name in input_name:
input_feed[name] = image_numpy[i]
i += 1
return input_feed
def forward(self, numpy_list):
input_feed = self.get_input_feed(self.input_name, numpy_list)
output = self.onnx_session.run(self.output_name, input_feed=input_feed)
return output
def onnx_model_test(model_path, test_args):
model = OnnxModel(model_path)
return model.forward(test_args)[0]
def wrap_frozen_graph(graph_def, inputs, outputs, print_graph=False):
def _imports_graph_def():
tf.compat.v1.import_graph_def(graph_def, name="")
wrapped_import = tf.compat.v1.wrap_function(_imports_graph_def, [])
import_graph = wrapped_import.graph
return wrapped_import.prune(
tf.nest.map_structure(import_graph.as_graph_element, inputs),
tf.nest.map_structure(import_graph.as_graph_element, outputs))
def tf_model_test(model_path, test_args):
with tf.io.gfile.GFile(model_path, "rb") as f:
graph_def = tf.compat.v1.GraphDef()
loaded = graph_def.ParseFromString(f.read())
# Wrap frozen graph to ConcreteFunctions
frozen_func = wrap_frozen_graph(graph_def=graph_def,
inputs=["x:0", "x_1:0"],
outputs=["PartitionedCall/AdjustContrastv2:0"],
print_graph=True)
print("-" * 50)
print("Frozen model inputs: ")
print(frozen_func.inputs)
print("Frozen model outputs: ")
print(frozen_func.outputs)
predictions = frozen_func(x=tf.convert_to_tensor(test_args[0]), x_1=tf.convert_to_tensor(test_args[1]))
return predictions[0]
res1 = onnx_model_test(onnx_model_path, (images, contrast_factor))
res2 = tf_model_test(tf_model_path, (images, contrast_factor))
np.testing.assert_allclose(res1, res2.numpy(), rtol=1e-4, atol=1e-4)
Here are the results:
--------------------------------------------------
Frozen model inputs:
[<tf.Tensor 'x:0' shape=(3, 3, 3, 2) dtype=float32>, <tf.Tensor 'x_1:0' shape=() dtype=float32>]
Frozen model outputs:
[<tf.Tensor 'PartitionedCall/AdjustContrastv2:0' shape=(3, 3, 3, 2) dtype=float32>]
Traceback (most recent call last):
File "debug-cross-framework.py", line 81, in <module>
np.testing.assert_allclose(res1, res2.numpy(), rtol=1e-4, atol=1e-4)
File "/lib/python3.8/site-packages/numpy/testing/_private/utils.py", line 1527, in assert_allclose
assert_array_compare(compare, actual, desired, err_msg=str(err_msg),
File "/lib/python3.8/site-packages/numpy/testing/_private/utils.py", line 844, in assert_array_compare
raise AssertionError(msg)
AssertionError:
Not equal to tolerance rtol=0.0001, atol=0.0001
Mismatched elements: 54 / 54 (100%)
Max absolute difference: 0.3644538
Max relative difference: 0.28935865
x: array([[[[1.789271, 2.70442 ],
[3.019392, 1.692135],
[2.755761, 3.188043]],...
y: array([[[[1.91538 , 2.339967],
[3.145502, 1.327681],
[2.88187 , 2.823589]],...
The commands for model conversion is:
python -m tf2onnx.convert --saved-model {saved_tf_model_dir} --output {saved_onnx_model_path} --opset 17
The conversion log is:
2023-01-08 23:20:23.909106: I tensorflow/core/grappler/clusters/single_machine.cc:358] Starting new session
2023-01-08 23:20:23,919 - INFO - Using tensorflow=2.11.0.dev20220905, onnx=1.12.0, tf2onnx=1.12.1/b6d590
2023-01-08 23:20:23,920 - INFO - Using opset <onnx, 17>
2023-01-08 23:20:23,921 - INFO - Computed 0 values for constant folding
2023-01-08 23:20:23,925 - INFO - Optimizing ONNX model
2023-01-08 23:20:23,937 - INFO - After optimization: Identity -2 (2->0)
2023-01-08 23:20:23,938 - INFO -
2023-01-08 23:20:23,938 - INFO - Successfully converted TensorFlow model onnx_test/tf_model/ to ONNX
2023-01-08 23:20:23,938 - INFO - Model inputs: ['args_0', 'args_1']
2023-01-08 23:20:23,938 - INFO - Model outputs: ['output_0']
2023-01-08 23:20:23,938 - INFO - ONNX model is saved at xxxx/tf.raw_ops.AdjustContrastv2_model.onnx
Screenshots
Tensorflow tf.raw_ops.AdjustContrastv2_frozen_graph.pb
:
ONNX tf.raw_ops.AdjustContrastv2_model.onnx
:
Additional context