Skip to content

Commit 4880cf6

Browse files
committed
Fixed group attribute in convolution op. Also minor change reading shape dimensions so it works for different dimensionalities.
Signed-off-by: Javier Dehesa <[email protected]>
1 parent 434b4a7 commit 4880cf6

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

tf2onnx/onnx_opset/nn.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -377,11 +377,13 @@ def any_version(cls, opset, ctx, node, **kwargs):
377377
data_format = str(node.attr["data_format"].s, encoding="utf8")
378378
shape_dim = -1
379379
if data_format == "NHWC":
380-
shape_dim = ctx.get_shape(node.input[0])[3]
380+
shape_dim = ctx.get_shape(node.input[0])[-1]
381381
elif data_format == "NCHW":
382382
shape_dim = ctx.get_shape(node.input[0])[1]
383383
if shape_dim != -1:
384-
groups = int(shape_dim / ctx.get_shape(node.input[1])[2])
384+
filter_in_channels = ctx.get_shape(node.input[1])[-2]
385+
if filter_in_channels != -1:
386+
groups = shape_dim // filter_in_channels
385387

386388
node.set_attr("group", groups)
387389

0 commit comments

Comments
 (0)