Skip to content

Commit 1f41cbd

Browse files
add resize_type for align_corner (OAID#1062)
* add resize_type for align_corner
1 parent 12bfb67 commit 1f41cbd

File tree

3 files changed

+50
-19
lines changed

3 files changed

+50
-19
lines changed

source/device/cpu/op/interp/cortex-a/interp_kernel_arm.c

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,21 @@
3232

3333
#define MIN(a, b) ((a) < (b) ? (a) : (b))
3434

35-
static void linear_coeffs(int w, int outw, int* xofs, float* alpha)
35+
static void linear_coeffs(int w, int outw, int* xofs, float* alpha, int align_corner)
3636
{
3737
double scale = (double)w / outw;
38-
38+
if (align_corner)
39+
{
40+
scale = (double)(w - 1) / (outw - 1);
41+
}
3942
for (int dx = 0; dx < outw; dx++)
4043
{
41-
float fx = (float)((dx)*scale);
44+
float fx = (float)((dx + 0.5) * scale - 0.5);
45+
if (align_corner)
46+
{
47+
fx = (float)(dx * scale);
48+
}
49+
4250
int sx = floor(fx);
4351
fx -= sx;
4452

@@ -498,7 +506,7 @@ int interp_run(struct tensor* output_tensor, struct tensor* input_tensor, struct
498506
}
499507
}
500508
}
501-
else if (resize_type == 2) // bilinear
509+
else if (resize_type == 2 || resize_type == 4) // bilinear
502510
{
503511
int* buf = (int*)sys_malloc((out_w + out_h + out_w * 2 + out_h * 2) * sizeof(int));
504512

@@ -508,8 +516,9 @@ int interp_run(struct tensor* output_tensor, struct tensor* input_tensor, struct
508516
float* alpha = (float*)(buf + out_w + out_h); // new float[ow * 2];
509517
float* beta = (float*)(buf + out_w + out_h + out_w * 2); // new float[oh * 2];
510518

511-
linear_coeffs(in_w, out_w, xofs, alpha);
512-
linear_coeffs(in_h, out_h, yofs, beta);
519+
int align_corner = interp_param->resize_type == 2 ? 0 : 1;
520+
linear_coeffs(in_w, out_w, xofs, alpha, align_corner);
521+
linear_coeffs(in_h, out_h, yofs, beta, align_corner);
513522

514523
#pragma omp parallel for num_threads(num_thread)
515524
for (int q = 0; q < in_c; ++q)

source/device/cpu/op/interp/interp_ref.c

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,21 @@
3939

4040
#define INTERP_MIN(a, b) ((a) < (b) ? (a) : (b))
4141

42-
void linear_coeffs(int w, int outw, int* xofs, float* alpha)
42+
static void linear_coeffs(int w, int outw, int* xofs, float* alpha, int align_corner)
4343
{
4444
double scale = (double)w / outw;
45-
45+
if (align_corner)
46+
{
47+
scale = (double)(w - 1) / (outw - 1);
48+
}
4649
for (int dx = 0; dx < outw; dx++)
4750
{
48-
float fx = (float)((dx)*scale);
51+
float fx = (float)((dx + 0.5) * scale - 0.5);
52+
if (align_corner)
53+
{
54+
fx = (float)(dx * scale);
55+
}
56+
4957
int sx = floor(fx);
5058
fx -= sx;
5159

@@ -193,7 +201,7 @@ int ref_interp_fp32(struct tensor* input_tensor, struct tensor* output_tensor, s
193201
}
194202
}
195203
}
196-
else if (param->resize_type == 2)
204+
else if (param->resize_type == 2 || param->resize_type == 4)
197205
{
198206
float* input = (float*)input_tensor->data;
199207
float* output = (float*)output_tensor->data;
@@ -222,8 +230,9 @@ int ref_interp_fp32(struct tensor* input_tensor, struct tensor* output_tensor, s
222230
float* alpha = (float*)(buf + param->output_width + param->output_height); //new float[ow * 2];
223231
float* beta = (float*)(buf + param->output_width + param->output_height + param->output_width * 2); //new float[oh * 2];
224232

225-
linear_coeffs(in_w, out_w, xofs, alpha);
226-
linear_coeffs(in_h, out_h, yofs, beta);
233+
int align_corner = param->resize_type == 2 ? 0 : 1;
234+
linear_coeffs(in_w, out_w, xofs, alpha, align_corner);
235+
linear_coeffs(in_h, out_h, yofs, beta, align_corner);
227236

228237
for (int q = 0; q < channel; ++q)
229238
{
@@ -290,7 +299,7 @@ int ref_interp_int8(struct tensor* input_tensor, struct tensor* output_tensor, s
290299
}
291300
}
292301
}
293-
else if (param->resize_type == 2)
302+
else if (param->resize_type == 2 || param->resize_type == 4)
294303
{
295304
int batch = input_tensor->dims[0];
296305
int channel = input_tensor->dims[1];
@@ -316,8 +325,9 @@ int ref_interp_int8(struct tensor* input_tensor, struct tensor* output_tensor, s
316325
float* alpha = (float*)(buf + param->output_width + param->output_height); //new float[ow * 2];
317326
float* beta = (float*)(buf + param->output_width + param->output_height + param->output_width * 2); //new float[oh * 2];
318327

319-
linear_coeffs(in_w, out_w, xofs, alpha);
320-
linear_coeffs(in_h, out_h, yofs, beta);
328+
int align_corner = param->resize_type == 2 ? 0 : 1;
329+
linear_coeffs(in_w, out_w, xofs, alpha, align_corner);
330+
linear_coeffs(in_h, out_h, yofs, beta, align_corner);
321331

322332
for (int q = 0; q < channel; ++q)
323333
{
@@ -398,7 +408,7 @@ int ref_interp_uint8(struct tensor* input_tensor, struct tensor* output_tensor,
398408
}
399409
}
400410
}
401-
else if (param->resize_type == 2)
411+
else if (param->resize_type == 2 || param->resize_type == 4)
402412
{
403413
int batch = input_tensor->dims[0];
404414
int channel = input_tensor->dims[1];
@@ -424,8 +434,9 @@ int ref_interp_uint8(struct tensor* input_tensor, struct tensor* output_tensor,
424434
float* alpha = (float*)(buf + param->output_width + param->output_height); //new float[ow * 2];
425435
float* beta = (float*)(buf + param->output_width + param->output_height + param->output_width * 2); //new float[oh * 2];
426436

427-
linear_coeffs(in_w, out_w, xofs, alpha);
428-
linear_coeffs(in_h, out_h, yofs, beta);
437+
int align_corner = param->resize_type == 2 ? 0 : 1;
438+
linear_coeffs(in_w, out_w, xofs, alpha, align_corner);
439+
linear_coeffs(in_h, out_h, yofs, beta, align_corner);
429440

430441
for (int q = 0; q < channel; ++q)
431442
{

tools/convert_tool/onnx/onnx2tengine.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2136,6 +2136,17 @@ int load_resize(ir_graph_t* graph, ir_node_t* node, const onnx::NodeProto& onnx_
21362136
interp_param->height_scale = 0;
21372137
interp_param->width_scale = 0;
21382138

2139+
int align_corner = 0;
2140+
for (int k = 0; k < onnx_node.attribute_size(); k++)
2141+
{
2142+
const onnx::AttributeProto& attr = onnx_node.attribute(k);
2143+
if (attr.name() == "coordinate_transformation_mode")
2144+
{
2145+
if (attr.s() == "align_corners")
2146+
align_corner = 1;
2147+
}
2148+
}
2149+
21392150
if (onnx_node.input_size() == 1)
21402151
{
21412152
for (int k = 0; k < onnx_node.attribute_size(); k++)
@@ -2198,7 +2209,7 @@ int load_resize(ir_graph_t* graph, ir_node_t* node, const onnx::NodeProto& onnx_
21982209
}
21992210
else if (mode == "bilinear" || mode == "linear")
22002211
{
2201-
interp_param->resize_type = 2;
2212+
interp_param->resize_type = align_corner == 0 ? 2 : 4;
22022213
}
22032214

22042215
return 0;

0 commit comments

Comments
 (0)