@@ -1391,20 +1391,37 @@ def version_11(cls, ctx, node, **kwargs):
1391
1391
else :
1392
1392
mode = "nearest"
1393
1393
roi = ctx .make_const (utils .make_name ("roi" ), np .array ([]).astype (np .float32 ))
1394
- const_zero = ctx .make_const (utils .make_name ("const_zero" ), np .array ([0 ]).astype (np .int64 ))
1395
- const_two = ctx .make_const (utils .make_name ("const_two" ), np .array ([2 ]).astype (np .int64 ))
1396
- const_empty_float = ctx .make_const (utils .make_name ("const_empty_float" ), np .array ([]).astype (np .float32 ))
1397
1394
input_nchw = ctx .make_node ("Transpose" , [node .input [0 ]], {"perm" : constants .NHWC_TO_NCHW })
1398
- shape_input = ctx .make_node ("Shape" , [input_nchw .output [0 ]])
1399
- sliced_shape = ctx .make_node ("Slice" , [shape_input .output [0 ], const_zero .output [0 ], const_two .output [0 ]])
1400
- size_int64 = ctx .make_node ("Cast" , [node .input [1 ]], attr = {"to" : onnx_pb .TensorProto .INT64 })
1401
- concat_shape = ctx .make_node ("Concat" , [sliced_shape .output [0 ], size_int64 .output [0 ]], {'axis' : 0 })
1402
- resize_inputs = [
1403
- input_nchw .output [0 ],
1404
- roi .output [0 ],
1405
- const_empty_float .output [0 ],
1406
- concat_shape .output [0 ]
1407
- ]
1395
+ shape = ctx .get_shape (node .input [0 ])
1396
+ if shape and shape [2 ] != - 1 and shape [1 ] != - 1 and node .inputs [1 ].is_const ():
1397
+ target_shape = node .inputs [1 ].get_tensor_value ()
1398
+ n , h , w , c = shape
1399
+ nh , nw = target_shape
1400
+ if "sizes" in node .attr :
1401
+ sizes_val = np .array ([1.0 , 1.0 , nh , nw ]).astype (np .int64 )
1402
+ resize_params = ctx .make_const (utils .make_name ("sizes" ), sizes_val , raw = False )
1403
+ else : # scales
1404
+ scale_val = np .array ([1.0 , 1.0 , float (nh ) / h , float (nw ) / w ]).astype (np .float32 )
1405
+ resize_params = ctx .make_const (utils .make_name ("scales" ), scale_val , raw = False )
1406
+ resize_inputs = [
1407
+ input_nchw .output [0 ],
1408
+ roi .output [0 ],
1409
+ resize_params .output [0 ]
1410
+ ]
1411
+ else :
1412
+ const_zero = ctx .make_const (utils .make_name ("const_zero" ), np .array ([0 ]).astype (np .int64 ))
1413
+ const_two = ctx .make_const (utils .make_name ("const_two" ), np .array ([2 ]).astype (np .int64 ))
1414
+ const_empty_float = ctx .make_const (utils .make_name ("const_empty_float" ), np .array ([]).astype (np .float32 ))
1415
+ shape_input = ctx .make_node ("Shape" , [input_nchw .output [0 ]])
1416
+ sliced_shape = ctx .make_node ("Slice" , [shape_input .output [0 ], const_zero .output [0 ], const_two .output [0 ]])
1417
+ size_int64 = ctx .make_node ("Cast" , [node .input [1 ]], attr = {"to" : onnx_pb .TensorProto .INT64 })
1418
+ concat_shape = ctx .make_node ("Concat" , [sliced_shape .output [0 ], size_int64 .output [0 ]], {'axis' : 0 })
1419
+ resize_inputs = [
1420
+ input_nchw .output [0 ],
1421
+ roi .output [0 ],
1422
+ const_empty_float .output [0 ],
1423
+ concat_shape .output [0 ]
1424
+ ]
1408
1425
transformation_mode = "asymmetric"
1409
1426
nearest_mode = "floor"
1410
1427
if "align_corners" in node .attr and node .attr ["align_corners" ].i :
0 commit comments