Skip to content

Commit ce29107

Browse files
sdua-nvfatcat-z
andauthored
Fix the axis of inserted QDQ for ConvTranspose (#2134)
The quantization axis of QDQ nodes that are being inserted before the kernel weights of all Conv nodes is currently 0. This is incorrect; ConvTranspose requires axis=1. Signed-off-by: Sirej Dua <[email protected]> Co-authored-by: Jay Zhang <[email protected]>
1 parent f7d49c7 commit ce29107

File tree

4 files changed

+59
-9
lines changed

4 files changed

+59
-9
lines changed

tests/backend_test_base.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,10 @@ def run_tflite(self, tflite_path, feed_dict):
284284
# tflite sometimes converts from tf but produces an invalid model
285285
return None, None
286286

287-
def assert_shapes_correct(self, graph, allow_missing=False, run_checker=True):
287+
def assert_shapes_correct(self, graph, allow_missing=False, run_checker=True, check_shape=True):
288+
if not check_shape:
289+
return None
290+
288291
model_proto = graph.make_model("test")
289292

290293
if run_checker and not any(graph.get_shape(out) is None for out in graph.outputs + graph.input_names):
@@ -402,8 +405,10 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
402405
i = output_names_with_port.index(output_name)
403406
actual[i] = np.transpose(actual[i], constants.NCHW_TO_NHWC)
404407

405-
self.assert_results_equal(expected, actual, rtol, atol, mtol, check_value, check_shape, check_dtype)
406-
self.assert_shapes_correct(g, self.config.allow_missing_shapes, not self.config.skip_onnx_checker)
408+
self.assert_results_equal(expected, actual, rtol, atol, mtol, check_value, check_shape,
409+
check_dtype)
410+
self.assert_shapes_correct(g, self.config.allow_missing_shapes, not self.config.skip_onnx_checker,
411+
check_shape)
407412

408413
if graph_validator:
409414
self.assertTrue(graph_validator(g))
@@ -441,7 +446,8 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
441446
onnx_tfl_res[i] = np.transpose(onnx_tfl_res[i], constants.NCHW_TO_NHWC)
442447

443448
self.assert_results_equal(tfl_res, onnx_tfl_res, rtol, atol, mtol, check_value, check_shape, check_dtype)
444-
self.assert_shapes_correct(g, self.config.allow_missing_shapes, not self.config.skip_onnx_checker)
449+
self.assert_shapes_correct(g, self.config.allow_missing_shapes, not self.config.skip_onnx_checker,
450+
check_shape)
445451

446452
if graph_validator:
447453
self.assertTrue(graph_validator(g))
@@ -475,7 +481,8 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
475481

476482
self.assert_results_equal(tfjs_res, onnx_tfjs_res, rtol, atol, mtol, check_value, check_shape,
477483
check_dtype=False)
478-
self.assert_shapes_correct(g, self.config.allow_missing_shapes, not self.config.skip_onnx_checker)
484+
self.assert_shapes_correct(g, self.config.allow_missing_shapes, not self.config.skip_onnx_checker,
485+
check_shape)
479486

480487
if graph_validator:
481488
self.assertTrue(graph_validator(g))

tests/common.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
"check_op_count",
5252
"check_gru_count",
5353
"check_lstm_count",
54+
"check_quantization_axis",
5455
"timeout",
5556
]
5657

@@ -471,6 +472,8 @@ def check_lstm_count(graph, expected_count):
471472
def check_gru_count(graph, expected_count):
472473
return check_op_count(graph, "GRU", expected_count)
473474

475+
def check_quantization_axis(graph, op_type, expected_axis):
476+
return np.all(np.array([n.get_attr_int("axis") for n in group_nodes_by_type(graph)[op_type]]) == expected_axis)
474477

475478
_MAX_MS_OPSET_VERSION = 1
476479

tests/test_backend.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,44 @@ def func(x, output_shape_placeholder):
607607
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: output_shape},
608608
rtol=1e-05, process_args=process_args)
609609

610+
@check_opset_min_version(10, "quantize_and_dequantize")
611+
def test_conv2d_quantization_axis(self):
612+
x_shape = [1, 1, 5, 5]
613+
kernel_shape = _KERNEL3x3
614+
strides = [1, 1, 1, 1]
615+
x_val = make_xval(x_shape).transpose(NCHW_TO_NHWC)
616+
kernel_val = make_xval(_KERNEL3x3)
617+
618+
def func(x):
619+
f = tf.constant(kernel_val, name="kernel", dtype=tf.float32)
620+
kernel_dq = quantize_and_dequantize(f, 0, np.prod(kernel_shape))
621+
conv = tf.nn.conv2d(x, kernel_dq, strides=strides, padding="VALID")
622+
return tf.identity(conv, name=_TFOUTPUT)
623+
def graph_validator(g):
624+
return check_quantization_axis(g, "DequantizeLinear", 0)
625+
626+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-05, graph_validator=graph_validator,
627+
check_shape=False)
628+
629+
@check_opset_min_version(10, "quantize_and_dequantize")
630+
def test_conv2d_transpose_quantization_axis(self):
631+
x_shape = [2, 6, 4, 3]
632+
output_shape = [2, 13, 9, 2]
633+
kernel_shape = [3, 3, 2, 3]
634+
strides = [1, 2, 2, 1]
635+
x_val = make_xval(x_shape)
636+
kernel_val = make_xval(kernel_shape)
637+
def func(x):
638+
f = tf.constant(kernel_val, name="kernel", dtype=tf.float32)
639+
kernel_dq = quantize_and_dequantize(f, 0, np.prod(kernel_shape))
640+
conv = tf.nn.conv2d_transpose(x, kernel_dq, output_shape, strides=strides, padding="VALID")
641+
return tf.identity(conv, name=_TFOUTPUT)
642+
def graph_validator(g):
643+
return check_quantization_axis(g, "DequantizeLinear", 1)
644+
645+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-05, graph_validator=graph_validator,
646+
check_shape=False)
647+
610648
def test_depthwiseconv_0(self):
611649
x_shape = [1, 3, 4, 3]
612650
kernel_shape = [3, 3, 3, 3]

tf2onnx/onnx_opset/nn.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ def get_channels_last_permutation(spatial):
5858

5959

6060
def conv_convert_inputs(ctx, node, with_kernel=False, new_kernel_shape=None,
61-
input_indices=None, output_indices=None, spatial=2):
61+
input_indices=None, output_indices=None, spatial=2,
62+
quantization_axis=0):
6263
"""Convert input and kernel from tensorflow to onnx. This may be required to
6364
insert transpose ops for input, kernel, and output unless they are constants
6465
and we can transpose the constant.
@@ -73,6 +74,7 @@ def conv_convert_inputs(ctx, node, with_kernel=False, new_kernel_shape=None,
7374
new_kernel_shape: Pass to reshape the kernel.
7475
input_indices: Indices that define the inputs.
7576
output_indices: Indices that define the outputs.
77+
quantization_axis: Axis for the inserted QDQ nodes
7678
"""
7779

7880
if input_indices is None:
@@ -151,8 +153,8 @@ def conv_convert_inputs(ctx, node, with_kernel=False, new_kernel_shape=None,
151153
weights_node.set_tensor_value(val)
152154
need_transpose = False
153155
# Change the quantization axis for Q and DQ node accordingly
154-
kernel_node.set_attr("axis", 0) # DQ node
155-
kernel_node.inputs[0].set_attr("axis", 0) # Q node
156+
kernel_node.set_attr("axis", quantization_axis) # DQ node
157+
kernel_node.inputs[0].set_attr("axis", quantization_axis) # Q node
156158
else:
157159
val = kernel_node.get_tensor_value(as_list=False)
158160
val = np.transpose(val, permutation)
@@ -607,7 +609,7 @@ def version_1(cls, ctx, node, **kwargs):
607609
ctx.replace_input(node, node.input[0], node.input[1], 0)
608610
ctx.replace_input(node, node.input[1], t, 1)
609611

610-
conv_convert_inputs(ctx, node, with_kernel=True, spatial=spatial)
612+
conv_convert_inputs(ctx, node, with_kernel=True, spatial=spatial, quantization_axis=1)
611613

612614
@classmethod
613615
def version_11(cls, ctx, node, **kwargs):

0 commit comments

Comments
 (0)