Skip to content

Commit e7eec5c

Browse files
Add support for wasm dot-product instruction (#5861)
* Add support for wasm dot-product instruction
1 parent f2143bf commit e7eec5c

File tree

2 files changed

+13
-6
lines changed

2 files changed

+13
-6
lines changed

src/CodeGen_WebAssembly.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ const WasmIntrinsic intrinsic_defs[] = {
102102
{"saturating_narrow_i16x16_to_u8x16", UInt(8, 16), "saturating_narrow", {Int(16, 16)}, Target::WasmSimd128},
103103
{"saturating_narrow_i32x8_to_i16x8", Int(16, 8), "saturating_narrow", {Int(32, 8)}, Target::WasmSimd128},
104104
{"saturating_narrow_i32x8_to_u16x8", UInt(16, 8), "saturating_narrow", {Int(32, 8)}, Target::WasmSimd128},
105+
106+
{"llvm.wasm.dot", Int(32, 4), "dot_product", {Int(16, 8), Int(16, 8)}, Target::WasmSimd128},
105107
#endif
106108
};
107109
// clang-format on
@@ -185,6 +187,8 @@ void CodeGen_WebAssembly::codegen_vector_reduce(const VectorReduce *op, const Ex
185187
{VectorReduce::Add, 2, i32(wild_i16x_), "pairwise_widening_add", Target::WasmSimd128},
186188
{VectorReduce::Add, 2, u32(wild_u16x_), "pairwise_widening_add", Target::WasmSimd128},
187189
{VectorReduce::Add, 2, i32(wild_u16x_), "pairwise_widening_add", Target::WasmSimd128},
190+
191+
{VectorReduce::Add, 2, i32(widening_mul(wild_i16x_, wild_i16x_)), "dot_product", Target::WasmSimd128},
188192
};
189193
// clang-format on
190194

test/correctness/simd_op_check.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1753,12 +1753,15 @@ class SimdOpCheck : public SimdOpCheckTest {
17531753
check("i32x4.mul", 4 * w, i32_1 * i32_2);
17541754
check("i64x2.mul", 2 * w, i64_1 * i64_2);
17551755

1756-
// Integer dot product (16 -> 32)
1757-
// TODO(https://github.com/halide/Halide/issues/5130): NOT BEING GENERATED AT TRUNK
1758-
// {
1759-
// RDom r(0, 4);
1760-
// check("i32x4.dot_i16x8_s", 2 * w, sum(i32(in_i16(x * 4 + r)) * in_i16(x * 4 + r + 32)));
1761-
// }
1756+
if (Halide::Internal::get_llvm_version() >= 130) {
1757+
// Integer dot product (16 -> 32)
1758+
for (int f : {2, 4, 8}) {
1759+
RDom r(0, f);
1760+
for (int v : {1, 2, 4}) {
1761+
check("i32x4.dot_i16x8_s", w * v, sum(i32(in_i16(f * x + r)) * in_i16(f * x + r + 32)));
1762+
}
1763+
}
1764+
}
17621765

17631766
// Integer negation
17641767
check("i8x16.neg", 16 * w, -i8_1);

0 commit comments

Comments
 (0)