Skip to content

Commit b174518

Browse files
authored
Fixed group attribute in convolution op.
Also minor change reading shape dimensions so it works for different dimensionalities.
1 parent 2c1db54 commit b174518

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

tf2onnx/onnx_opset/nn.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -377,11 +377,13 @@ def any_version(cls, opset, ctx, node, **kwargs):
377377
data_format = str(node.attr["data_format"].s, encoding="utf8")
378378
shape_dim = -1
379379
if data_format == "NHWC":
380-
shape_dim = ctx.get_shape(node.input[0])[3]
380+
shape_dim = ctx.get_shape(node.input[0])[-1]
381381
elif data_format == "NCHW":
382382
shape_dim = ctx.get_shape(node.input[0])[1]
383383
if shape_dim != -1:
384-
groups = int(shape_dim / ctx.get_shape(node.input[1])[2])
384+
filter_in_channels = ctx.get_shape(node.input[1])[-2]
385+
if filter_in_channels != -1:
386+
groups = shape_dim // filter_in_channels
385387

386388
node.set_attr("group", groups)
387389

0 commit comments

Comments
 (0)