Skip to content

Commit 94c0eca

Browse files
authored
Use dot products for sums. (#5954)
1 parent 5a0d1e5 commit 94c0eca

File tree

2 files changed

+20
-2
lines changed

2 files changed

+20
-2
lines changed

src/CodeGen_ARM.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,12 +1143,18 @@ void CodeGen_ARM::codegen_vector_reduce(const VectorReduce *op, const Expr &init
11431143
Expr pattern;
11441144
const char *intrin;
11451145
Target::Feature required_feature;
1146+
std::vector<int> extra_operands;
11461147
};
11471148
// clang-format off
11481149
static const Pattern patterns[] = {
11491150
{VectorReduce::Add, 4, i32(widening_mul(wild_i8x_, wild_i8x_)), "dot_product", Target::ARMDotProd},
11501151
{VectorReduce::Add, 4, i32(widening_mul(wild_u8x_, wild_u8x_)), "dot_product", Target::ARMDotProd},
11511152
{VectorReduce::Add, 4, u32(widening_mul(wild_u8x_, wild_u8x_)), "dot_product", Target::ARMDotProd},
1153+
// A sum is the same as a dot product with a vector of ones, and this appears to
1154+
// be a bit faster.
1155+
{VectorReduce::Add, 4, i32(wild_i8x_), "dot_product", Target::ARMDotProd, {1}},
1156+
{VectorReduce::Add, 4, i32(wild_u8x_), "dot_product", Target::ARMDotProd, {1}},
1157+
{VectorReduce::Add, 4, u32(wild_u8x_), "dot_product", Target::ARMDotProd, {1}},
11521158
};
11531159
// clang-format on
11541160

@@ -1162,13 +1168,17 @@ void CodeGen_ARM::codegen_vector_reduce(const VectorReduce *op, const Expr &init
11621168
continue;
11631169
}
11641170
if (expr_match(p.pattern, op->value, matches)) {
1165-
if (factor != 4) {
1166-
Expr equiv = VectorReduce::make(op->op, op->value, op->value.type().lanes() / 4);
1171+
if (factor != p.factor) {
1172+
Expr equiv = VectorReduce::make(op->op, op->value, op->value.type().lanes() / p.factor);
11671173
equiv = VectorReduce::make(op->op, equiv, op->type.lanes());
11681174
codegen_vector_reduce(equiv.as<VectorReduce>(), init);
11691175
return;
11701176
}
11711177

1178+
for (int i : p.extra_operands) {
1179+
matches.push_back(make_const(matches[0].type(), i));
1180+
}
1181+
11721182
Expr i = init;
11731183
if (!i.defined()) {
11741184
i = make_zero(op->type);

test/correctness/simd_op_check.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1110,6 +1110,14 @@ class SimdOpCheck : public SimdOpCheckTest {
11101110
for (int v : {2, 4}) {
11111111
check("udot", v, sum(u32(in_u8(f * x + r)) * in_u8(f * x + r + 32)));
11121112
check("sdot", v, sum(i32(in_i8(f * x + r)) * in_i8(f * x + r + 32)));
1113+
if (f == 4) {
1114+
// This doesn't generate for higher reduction factors because the
1115+
// intermediate is 16-bit instead of 32-bit. It seems like it would
1116+
// be slower to fix this (because the intermediate sum would be
1117+
// 32-bit instead of 16-bit).
1118+
check("udot", v, sum(u32(in_u8(f * x + r))));
1119+
check("sdot", v, sum(i32(in_i8(f * x + r))));
1120+
}
11131121
}
11141122
}
11151123
}

0 commit comments

Comments
 (0)