3939
4040#define INTERP_MIN (a , b ) ((a) < (b) ? (a) : (b))
4141
42- void linear_coeffs (int w , int outw , int * xofs , float * alpha )
42+ static void linear_coeffs (int w , int outw , int * xofs , float * alpha , int align_corner )
4343{
4444 double scale = (double )w / outw ;
45-
45+ if (align_corner )
46+ {
47+ scale = (double )(w - 1 ) / (outw - 1 );
48+ }
4649 for (int dx = 0 ; dx < outw ; dx ++ )
4750 {
48- float fx = (float )((dx )* scale );
51+ float fx = (float )((dx + 0.5 ) * scale - 0.5 );
52+ if (align_corner )
53+ {
54+ fx = (float )(dx * scale );
55+ }
56+
4957 int sx = floor (fx );
5058 fx -= sx ;
5159
@@ -193,7 +201,7 @@ int ref_interp_fp32(struct tensor* input_tensor, struct tensor* output_tensor, s
193201 }
194202 }
195203 }
196- else if (param -> resize_type == 2 )
204+ else if (param -> resize_type == 2 || param -> resize_type == 4 )
197205 {
198206 float * input = (float * )input_tensor -> data ;
199207 float * output = (float * )output_tensor -> data ;
@@ -222,8 +230,9 @@ int ref_interp_fp32(struct tensor* input_tensor, struct tensor* output_tensor, s
222230 float * alpha = (float * )(buf + param -> output_width + param -> output_height ); //new float[ow * 2];
223231 float * beta = (float * )(buf + param -> output_width + param -> output_height + param -> output_width * 2 ); //new float[oh * 2];
224232
225- linear_coeffs (in_w , out_w , xofs , alpha );
226- linear_coeffs (in_h , out_h , yofs , beta );
233+ int align_corner = param -> resize_type == 2 ? 0 : 1 ;
234+ linear_coeffs (in_w , out_w , xofs , alpha , align_corner );
235+ linear_coeffs (in_h , out_h , yofs , beta , align_corner );
227236
228237 for (int q = 0 ; q < channel ; ++ q )
229238 {
@@ -290,7 +299,7 @@ int ref_interp_int8(struct tensor* input_tensor, struct tensor* output_tensor, s
290299 }
291300 }
292301 }
293- else if (param -> resize_type == 2 )
302+ else if (param -> resize_type == 2 || param -> resize_type == 4 )
294303 {
295304 int batch = input_tensor -> dims [0 ];
296305 int channel = input_tensor -> dims [1 ];
@@ -316,8 +325,9 @@ int ref_interp_int8(struct tensor* input_tensor, struct tensor* output_tensor, s
316325 float * alpha = (float * )(buf + param -> output_width + param -> output_height ); //new float[ow * 2];
317326 float * beta = (float * )(buf + param -> output_width + param -> output_height + param -> output_width * 2 ); //new float[oh * 2];
318327
319- linear_coeffs (in_w , out_w , xofs , alpha );
320- linear_coeffs (in_h , out_h , yofs , beta );
328+ int align_corner = param -> resize_type == 2 ? 0 : 1 ;
329+ linear_coeffs (in_w , out_w , xofs , alpha , align_corner );
330+ linear_coeffs (in_h , out_h , yofs , beta , align_corner );
321331
322332 for (int q = 0 ; q < channel ; ++ q )
323333 {
@@ -398,7 +408,7 @@ int ref_interp_uint8(struct tensor* input_tensor, struct tensor* output_tensor,
398408 }
399409 }
400410 }
401- else if (param -> resize_type == 2 )
411+ else if (param -> resize_type == 2 || param -> resize_type == 4 )
402412 {
403413 int batch = input_tensor -> dims [0 ];
404414 int channel = input_tensor -> dims [1 ];
@@ -424,8 +434,9 @@ int ref_interp_uint8(struct tensor* input_tensor, struct tensor* output_tensor,
424434 float * alpha = (float * )(buf + param -> output_width + param -> output_height ); //new float[ow * 2];
425435 float * beta = (float * )(buf + param -> output_width + param -> output_height + param -> output_width * 2 ); //new float[oh * 2];
426436
427- linear_coeffs (in_w , out_w , xofs , alpha );
428- linear_coeffs (in_h , out_h , yofs , beta );
437+ int align_corner = param -> resize_type == 2 ? 0 : 1 ;
438+ linear_coeffs (in_w , out_w , xofs , alpha , align_corner );
439+ linear_coeffs (in_h , out_h , yofs , beta , align_corner );
429440
430441 for (int q = 0 ; q < channel ; ++ q )
431442 {
0 commit comments