Skip to content

Commit 554d90a

Browse files
javidcffatcat-z
andauthored
Fixed group attribute in convolution op (#2090)
* Fixed group attribute in convolution op. Also minor change reading shape dimensions so it works for different dimensionalities. * Check convolution kernel_shape values are valid Skip adding the optional attribute to Conv nodes if any shape value is negative. * Fixed convolution kernel dimension checks Signed-off-by: Javier Dehesa <[email protected]> Signed-off-by: Javier Dehesa <[email protected]> Co-authored-by: Jay Zhang <[email protected]>
1 parent dd373a1 commit 554d90a

File tree

2 files changed

+23
-5
lines changed

2 files changed

+23
-5
lines changed

tests/test_backend.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6138,6 +6138,20 @@ def func(x):
61386138
x_val = make_xval([2, 3])
61396139
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
61406140

6141+
@check_opset_min_version(11, "Pad")
6142+
def test_conv_unknown_kernel_channels(self):
6143+
x_shape = [2, 10, 3]
6144+
x_val = make_xval(x_shape)
6145+
kernel_shape = [4, 3, 5]
6146+
kernel_val = make_xval(kernel_shape)
6147+
pad_val = np.array([[0, 0], [0, 0], [0, 0]], np.int64)
6148+
def func(x, kernel, pad):
6149+
# Make kernel dimensions unknown
6150+
kernel = tf.pad(kernel, pad)
6151+
conv = tf.nn.conv1d(x, kernel, stride=[1], padding='VALID')
6152+
return tf.identity(conv, name='output')
6153+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: kernel_val, _INPUT2: pad_val})
6154+
61416155
@check_tf_min_version("2.3.0")
61426156
@check_opset_min_version(16, "ScatterND")
61436157
@skip_tfjs("not supported in tfjs")

tf2onnx/onnx_opset/nn.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -311,8 +311,9 @@ def conv_kernel_shape(ctx, node, input_idx, spatial=2):
311311
# Get spatial part.
312312
kernel_shape = kernel_shape[:spatial]
313313

314-
# Set new value and return it.
315-
node.set_attr("kernel_shape", kernel_shape)
314+
# Set attribute value only if all dimensions are known.
315+
if all(d > 0 for d in kernel_shape):
316+
node.set_attr("kernel_shape", kernel_shape)
316317

317318
return kernel_shape
318319

@@ -379,11 +380,13 @@ def any_version(cls, opset, ctx, node, **kwargs):
379380
data_format = str(node.attr["data_format"].s, encoding="utf8")
380381
shape_dim = -1
381382
if data_format == "NHWC":
382-
shape_dim = ctx.get_shape(node.input[0])[3]
383+
shape_dim = ctx.get_shape(node.input[0])[-1]
383384
elif data_format == "NCHW":
384385
shape_dim = ctx.get_shape(node.input[0])[1]
385386
if shape_dim != -1:
386-
groups = int(shape_dim / ctx.get_shape(node.input[1])[2])
387+
filter_in_channels = ctx.get_shape(node.input[1])[-2]
388+
if filter_in_channels != -1:
389+
groups = shape_dim // filter_in_channels
387390

388391
node.set_attr("group", groups)
389392

@@ -649,7 +652,8 @@ def version_1(cls, ctx, node, **kwargs):
649652
raise ValueError("input channel must be positive")
650653
k_output_channels = k_input_channels * k_channel_multiplier
651654

652-
node.set_attr("kernel_shape", [k_h, k_w])
655+
if k_h > 0 and k_w > 0:
656+
node.set_attr("kernel_shape", [k_h, k_w])
653657
strides = conv_dims_attr(node, "strides")
654658
dilations = conv_dims_attr(node, "dilations")
655659
node.set_attr("group", k_input_channels)

0 commit comments

Comments
 (0)