Skip to content

Commit 5708e10

Browse files
qant-umfatcat-z
andauthored
Update Resize (opset 11) layer to support scales option when dims are defined (#2137)
Signed-off-by: Quentin Muller <[email protected]> Co-authored-by: Jay Zhang <[email protected]>
1 parent ec01956 commit 5708e10

File tree

1 file changed

+30
-13
lines changed

1 file changed

+30
-13
lines changed

tf2onnx/onnx_opset/nn.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1391,20 +1391,37 @@ def version_11(cls, ctx, node, **kwargs):
13911391
else:
13921392
mode = "nearest"
13931393
roi = ctx.make_const(utils.make_name("roi"), np.array([]).astype(np.float32))
1394-
const_zero = ctx.make_const(utils.make_name("const_zero"), np.array([0]).astype(np.int64))
1395-
const_two = ctx.make_const(utils.make_name("const_two"), np.array([2]).astype(np.int64))
1396-
const_empty_float = ctx.make_const(utils.make_name("const_empty_float"), np.array([]).astype(np.float32))
13971394
input_nchw = ctx.make_node("Transpose", [node.input[0]], {"perm": constants.NHWC_TO_NCHW})
1398-
shape_input = ctx.make_node("Shape", [input_nchw.output[0]])
1399-
sliced_shape = ctx.make_node("Slice", [shape_input.output[0], const_zero.output[0], const_two.output[0]])
1400-
size_int64 = ctx.make_node("Cast", [node.input[1]], attr={"to": onnx_pb.TensorProto.INT64})
1401-
concat_shape = ctx.make_node("Concat", [sliced_shape.output[0], size_int64.output[0]], {'axis': 0})
1402-
resize_inputs = [
1403-
input_nchw.output[0],
1404-
roi.output[0],
1405-
const_empty_float.output[0],
1406-
concat_shape.output[0]
1407-
]
1395+
shape = ctx.get_shape(node.input[0])
1396+
if shape and shape[2] != -1 and shape[1] != -1 and node.inputs[1].is_const():
1397+
target_shape = node.inputs[1].get_tensor_value()
1398+
n, h, w, c = shape
1399+
nh, nw = target_shape
1400+
if "sizes" in node.attr:
1401+
sizes_val = np.array([1.0, 1.0, nh, nw]).astype(np.int64)
1402+
resize_params = ctx.make_const(utils.make_name("sizes"), sizes_val, raw=False)
1403+
else: # scales
1404+
scale_val = np.array([1.0, 1.0, float(nh) / h, float(nw) / w]).astype(np.float32)
1405+
resize_params = ctx.make_const(utils.make_name("scales"), scale_val, raw=False)
1406+
resize_inputs = [
1407+
input_nchw.output[0],
1408+
roi.output[0],
1409+
resize_params.output[0]
1410+
]
1411+
else:
1412+
const_zero = ctx.make_const(utils.make_name("const_zero"), np.array([0]).astype(np.int64))
1413+
const_two = ctx.make_const(utils.make_name("const_two"), np.array([2]).astype(np.int64))
1414+
const_empty_float = ctx.make_const(utils.make_name("const_empty_float"), np.array([]).astype(np.float32))
1415+
shape_input = ctx.make_node("Shape", [input_nchw.output[0]])
1416+
sliced_shape = ctx.make_node("Slice", [shape_input.output[0], const_zero.output[0], const_two.output[0]])
1417+
size_int64 = ctx.make_node("Cast", [node.input[1]], attr={"to": onnx_pb.TensorProto.INT64})
1418+
concat_shape = ctx.make_node("Concat", [sliced_shape.output[0], size_int64.output[0]], {'axis': 0})
1419+
resize_inputs = [
1420+
input_nchw.output[0],
1421+
roi.output[0],
1422+
const_empty_float.output[0],
1423+
concat_shape.output[0]
1424+
]
14081425
transformation_mode = "asymmetric"
14091426
nearest_mode = "floor"
14101427
if "align_corners" in node.attr and node.attr["align_corners"].i:

0 commit comments

Comments
 (0)