@@ -1143,12 +1143,18 @@ void CodeGen_ARM::codegen_vector_reduce(const VectorReduce *op, const Expr &init
1143
1143
Expr pattern;
1144
1144
const char *intrin;
1145
1145
Target::Feature required_feature;
1146
+ std::vector<int > extra_operands;
1146
1147
};
1147
1148
// clang-format off
1148
1149
static const Pattern patterns[] = {
1149
1150
{VectorReduce::Add, 4 , i32 (widening_mul (wild_i8x_, wild_i8x_)), " dot_product" , Target::ARMDotProd},
1150
1151
{VectorReduce::Add, 4 , i32 (widening_mul (wild_u8x_, wild_u8x_)), " dot_product" , Target::ARMDotProd},
1151
1152
{VectorReduce::Add, 4 , u32 (widening_mul (wild_u8x_, wild_u8x_)), " dot_product" , Target::ARMDotProd},
1153
+ // A sum is the same as a dot product with a vector of ones, and this appears to
1154
+ // be a bit faster.
1155
+ {VectorReduce::Add, 4 , i32 (wild_i8x_), " dot_product" , Target::ARMDotProd, {1 }},
1156
+ {VectorReduce::Add, 4 , i32 (wild_u8x_), " dot_product" , Target::ARMDotProd, {1 }},
1157
+ {VectorReduce::Add, 4 , u32 (wild_u8x_), " dot_product" , Target::ARMDotProd, {1 }},
1152
1158
};
1153
1159
// clang-format on
1154
1160
@@ -1162,13 +1168,17 @@ void CodeGen_ARM::codegen_vector_reduce(const VectorReduce *op, const Expr &init
1162
1168
continue ;
1163
1169
}
1164
1170
if (expr_match (p.pattern , op->value , matches)) {
1165
- if (factor != 4 ) {
1166
- Expr equiv = VectorReduce::make (op->op , op->value , op->value .type ().lanes () / 4 );
1171
+ if (factor != p. factor ) {
1172
+ Expr equiv = VectorReduce::make (op->op , op->value , op->value .type ().lanes () / p. factor );
1167
1173
equiv = VectorReduce::make (op->op , equiv, op->type .lanes ());
1168
1174
codegen_vector_reduce (equiv.as <VectorReduce>(), init);
1169
1175
return ;
1170
1176
}
1171
1177
1178
+ for (int i : p.extra_operands ) {
1179
+ matches.push_back (make_const (matches[0 ].type (), i));
1180
+ }
1181
+
1172
1182
Expr i = init;
1173
1183
if (!i.defined ()) {
1174
1184
i = make_zero (op->type );
0 commit comments