Skip to content

Commit 4ce490f

Browse files
committed
[interpreter] Implement SIMD extended multiply instructions
These were accepted into the proposal in WebAssembly#376. There are 12 instructions in total: - i16x8.extmul_{low,high}_i8x16_{s,u} - i32x4.extmul_{low,high}_i16x8_{s,u} - i64x2.extmul_{low,high}_i32x4_{s,u} The implementation is straightforward, widen (using existing operations), then a multiply with the wider shape. Added a test generation script that reuses some logic in the generator for arithmetic instructions. Since these instructions have different src and dst shapes, I tweaked the base class to allow for having different shapes.
1 parent 7554a37 commit 4ce490f

File tree

12 files changed

+1380
-25
lines changed

12 files changed

+1380
-25
lines changed

interpreter/exec/eval_simd.ml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,10 @@ module SimdOp (SXX : Simd.S) (Value : ValueType with type t = SXX.t) = struct
101101
| I16x8 MaxS -> SXX.I16x8.max_s
102102
| I16x8 MaxU -> SXX.I16x8.max_u
103103
| I16x8 AvgrU -> SXX.I16x8.avgr_u
104+
| I16x8 ExtMulLowS -> SXX.I16x8_convert.extmul_low_s
105+
| I16x8 ExtMulHighS -> SXX.I16x8_convert.extmul_high_s
106+
| I16x8 ExtMulLowU -> SXX.I16x8_convert.extmul_low_u
107+
| I16x8 ExtMulHighU -> SXX.I16x8_convert.extmul_high_u
104108
| I32x4 Add -> SXX.I32x4.add
105109
| I32x4 Sub -> SXX.I32x4.sub
106110
| I32x4 MinS -> SXX.I32x4.min_s
@@ -119,9 +123,17 @@ module SimdOp (SXX : Simd.S) (Value : ValueType with type t = SXX.t) = struct
119123
| I32x4 GeS -> SXX.I32x4.ge_s
120124
| I32x4 GeU -> SXX.I32x4.ge_u
121125
| I32x4 DotI16x8S -> SXX.I32x4_convert.dot_i16x8_s
126+
| I32x4 ExtMulLowS -> SXX.I32x4_convert.extmul_low_s
127+
| I32x4 ExtMulHighS -> SXX.I32x4_convert.extmul_high_s
128+
| I32x4 ExtMulLowU -> SXX.I32x4_convert.extmul_low_u
129+
| I32x4 ExtMulHighU -> SXX.I32x4_convert.extmul_high_u
122130
| I64x2 Add -> SXX.I64x2.add
123131
| I64x2 Sub -> SXX.I64x2.sub
124132
| I64x2 Mul -> SXX.I64x2.mul
133+
| I64x2 ExtMulLowS -> SXX.I64x2_convert.extmul_low_s
134+
| I64x2 ExtMulHighS -> SXX.I64x2_convert.extmul_high_s
135+
| I64x2 ExtMulLowU -> SXX.I64x2_convert.extmul_low_u
136+
| I64x2 ExtMulHighU -> SXX.I64x2_convert.extmul_high_u
125137
| F32x4 Eq -> SXX.F32x4.eq
126138
| F32x4 Ne -> SXX.F32x4.ne
127139
| F32x4 Lt -> SXX.F32x4.lt

interpreter/exec/simd.ml

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,10 @@ sig
177177
val widen_high_s : t -> t
178178
val widen_low_u : t -> t
179179
val widen_high_u : t -> t
180+
val extmul_low_s : t -> t -> t
181+
val extmul_high_s : t -> t -> t
182+
val extmul_low_u : t -> t -> t
183+
val extmul_high_u : t -> t -> t
180184
end
181185
module I32x4_convert : sig
182186
val trunc_sat_f32x4_s : t -> t
@@ -186,10 +190,20 @@ sig
186190
val widen_low_u : t -> t
187191
val widen_high_u : t -> t
188192
val dot_i16x8_s : t -> t -> t
193+
val extmul_low_s : t -> t -> t
194+
val extmul_high_s : t -> t -> t
195+
val extmul_low_u : t -> t -> t
196+
val extmul_high_u : t -> t -> t
189197
end
190198
module I64x2_convert : sig
191199
val widen_low_s : t -> t
200+
val widen_high_s : t -> t
192201
val widen_low_u : t -> t
202+
val widen_high_u : t -> t
203+
val extmul_low_s : t -> t -> t
204+
val extmul_high_s : t -> t -> t
205+
val extmul_low_u : t -> t -> t
206+
val extmul_high_u : t -> t -> t
193207
end
194208
module F32x4_convert : sig
195209
val convert_i32x4_s : t -> t
@@ -417,6 +431,10 @@ struct
417431
let widen_low_u = widen Lib.List.take 0xffl
418432
let widen_high_u = widen Lib.List.drop 0xffl
419433

434+
let extmul_low_s x y = I16x8.mul (widen_low_s x) (widen_low_s y)
435+
let extmul_high_s x y = I16x8.mul (widen_high_s x) (widen_high_s y)
436+
let extmul_low_u x y = I16x8.mul (widen_low_u x) (widen_low_u y)
437+
let extmul_high_u x y = I16x8.mul (widen_high_u x) (widen_high_u y)
420438
end
421439

422440
module I32x4_convert = struct
@@ -441,16 +459,28 @@ struct
441459
| [], [] -> []
442460
| _, _ -> assert false
443461
in Rep.of_i32x4 (dot xs ys)
462+
463+
let extmul_low_s x y = I32x4.mul (widen_low_s x) (widen_low_s y)
464+
let extmul_high_s x y = I32x4.mul (widen_high_s x) (widen_high_s y)
465+
let extmul_low_u x y = I32x4.mul (widen_low_u x) (widen_low_u y)
466+
let extmul_high_u x y = I32x4.mul (widen_high_u x) (widen_high_u y)
444467
end
445468

446469
module I64x2_convert = struct
447-
let widen mask x =
470+
let widen take_or_drop mask x =
448471
Rep.of_i64x2
449472
(List.map
450473
(fun i32 -> Int64.(logand mask (of_int32 i32)))
451-
(Lib.List.take 2 (Rep.to_i32x4 x)))
452-
let widen_low_s = widen 0xffffffffffffffffL
453-
let widen_low_u = widen 0xffffffffL
474+
(take_or_drop 2 (Rep.to_i32x4 x)))
475+
let widen_low_s = widen Lib.List.take 0xffffffffffffffffL
476+
let widen_high_s = widen Lib.List.drop 0xffffffffffffffffL
477+
let widen_low_u = widen Lib.List.take 0xffffffffL
478+
let widen_high_u = widen Lib.List.drop 0xffffffffL
479+
480+
let extmul_low_s x y = I64x2.mul (widen_low_s x) (widen_low_s y)
481+
let extmul_high_s x y = I64x2.mul (widen_high_s x) (widen_high_s y)
482+
let extmul_low_u x y = I64x2.mul (widen_low_u x) (widen_low_u y)
483+
let extmul_high_u x y = I64x2.mul (widen_high_u x) (widen_high_u y)
454484
end
455485

456486
module F32x4_convert = struct

interpreter/syntax/ast.ml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ struct
5555
| Swizzle | Shuffle of int list | NarrowS | NarrowU
5656
| AddSatS | AddSatU | SubSatS | SubSatU
5757
| DotI16x8S
58+
| ExtMulLowS | ExtMulHighS | ExtMulLowU | ExtMulHighU
5859
type funop = Abs | Neg | Sqrt
5960
| Ceil | Floor | Trunc | Nearest
6061
| ConvertI32x4S | ConvertI32x4U

interpreter/syntax/operators.ml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,10 @@ let i16x8_min_u = Binary (V128 V128Op.(I16x8 MinU))
340340
let i16x8_max_s = Binary (V128 V128Op.(I16x8 MaxS))
341341
let i16x8_max_u = Binary (V128 V128Op.(I16x8 MaxU))
342342
let i16x8_avgr_u = Binary (V128 V128Op.(I16x8 AvgrU))
343+
let i16x8_extmul_low_i8x16_s = Binary (V128 V128Op.(I16x8 ExtMulLowS))
344+
let i16x8_extmul_high_i8x16_s = Binary (V128 V128Op.(I16x8 ExtMulHighS))
345+
let i16x8_extmul_low_i8x16_u = Binary (V128 V128Op.(I16x8 ExtMulLowU))
346+
let i16x8_extmul_high_i8x16_u = Binary (V128 V128Op.(I16x8 ExtMulHighU))
343347

344348
let i32x4_splat = Convert (V128 V128Op.(I32x4 Splat))
345349
let i32x4_extract_lane imm = SimdExtract (V128Op.I32x4 (ZX, imm))
@@ -375,6 +379,10 @@ let i32x4_mul = Binary (V128 V128Op.(I32x4 Mul))
375379
let i32x4_trunc_sat_f32x4_s = Unary (V128 V128Op.(I32x4 TruncSatF32x4S))
376380
let i32x4_trunc_sat_f32x4_u = Unary (V128 V128Op.(I32x4 TruncSatF32x4U))
377381
let i32x4_dot_i16x8_s = Binary (V128 V128Op.(I32x4 DotI16x8S))
382+
let i32x4_extmul_low_i16x8_s = Binary (V128 V128Op.(I32x4 ExtMulLowS))
383+
let i32x4_extmul_high_i16x8_s = Binary (V128 V128Op.(I32x4 ExtMulHighS))
384+
let i32x4_extmul_low_i16x8_u = Binary (V128 V128Op.(I32x4 ExtMulLowU))
385+
let i32x4_extmul_high_i16x8_u = Binary (V128 V128Op.(I32x4 ExtMulHighU))
378386

379387
let i64x2_splat = Convert (V128 V128Op.(I64x2 Splat))
380388
let i64x2_extract_lane imm = SimdExtract (V128Op.I64x2 (ZX, imm))
@@ -386,6 +394,10 @@ let i64x2_mul = Binary (V128 V128Op.(I64x2 Mul))
386394
let i64x2_shl = SimdShift V128Op.(I64x2 Shl)
387395
let i64x2_shr_s = SimdShift V128Op.(I64x2 ShrS)
388396
let i64x2_shr_u = SimdShift V128Op.(I64x2 ShrU)
397+
let i64x2_extmul_low_i32x4_s = Binary (V128 V128Op.(I64x2 ExtMulLowS))
398+
let i64x2_extmul_high_i32x4_s = Binary (V128 V128Op.(I64x2 ExtMulHighS))
399+
let i64x2_extmul_low_i32x4_u = Binary (V128 V128Op.(I64x2 ExtMulLowU))
400+
let i64x2_extmul_high_i32x4_u = Binary (V128 V128Op.(I64x2 ExtMulHighU))
389401

390402
let f32x4_splat = Convert (V128 V128Op.(F32x4 Splat))
391403
let f32x4_extract_lane imm = SimdExtract (V128Op.F32x4 (ZX, imm))

interpreter/text/lexer.mll

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,19 @@ rule token = parse
578578
| "i32x4.dot_i16x8_s"
579579
{ BINARY i32x4_dot_i16x8_s }
580580

581+
| "i16x8.extmul_low_i8x16_"(sign as s)
582+
{ BINARY (ext s i16x8_extmul_low_i8x16_s i16x8_extmul_low_i8x16_u) }
583+
| "i16x8.extmul_high_i8x16_"(sign as s)
584+
{ BINARY (ext s i16x8_extmul_high_i8x16_s i16x8_extmul_high_i8x16_u) }
585+
| "i32x4.extmul_low_i16x8_"(sign as s)
586+
{ BINARY (ext s i32x4_extmul_low_i16x8_s i32x4_extmul_low_i16x8_u) }
587+
| "i32x4.extmul_high_i16x8_"(sign as s)
588+
{ BINARY (ext s i32x4_extmul_high_i16x8_s i32x4_extmul_high_i16x8_u) }
589+
| "i64x2.extmul_low_i32x4_"(sign as s)
590+
{ BINARY (ext s i64x2_extmul_low_i32x4_s i64x2_extmul_low_i32x4_u) }
591+
| "i64x2.extmul_high_i32x4_"(sign as s)
592+
{ BINARY (ext s i64x2_extmul_high_i32x4_s i64x2_extmul_high_i32x4_u) }
593+
581594
| (simd_shape as s) { SIMD_SHAPE (simd_shape s) }
582595

583596
| name as s { VAR s }

test/core/simd/meta/gen_tests.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
'simd_f64x2_pmin_pmax',
3333
'simd_i32x4_dot_i16x8',
3434
'simd_load_lane',
35+
'simd_ext_mul',
3536
)
3637

3738

test/core/simd/meta/simd_arithmetic.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,27 @@ def __str__(self):
3535
def lane(self):
3636
return self.LANE_VALUE.get(self.LANE_TYPE)
3737

38+
@property
39+
def dst_lane(self):
40+
return self.lane
41+
42+
@property
43+
def src_lane(self):
44+
# Used for arithmetic that extends the lane, e.g. i16x8 lanes, which
45+
# are extended multiply to i32x4.
46+
if hasattr(self, 'SRC_LANE_TYPE'):
47+
return self.LANE_VALUE.get(self.SRC_LANE_TYPE)
48+
else:
49+
return self.lane
50+
3851
@property
3952
def normal_unary_op_test_data(self):
40-
lane = self.lane
53+
lane = self.src_lane
4154
return [0, 1, -1, lane.max - 1, lane.min + 1, lane.min, lane.max, lane.mask]
4255

4356
@property
4457
def normal_binary_op_test_data(self):
45-
lane = self.lane
58+
lane = self.src_lane
4659
return [
4760
(0, 0),
4861
(0, 1),
@@ -170,7 +183,7 @@ def get_case_data(self):
170183
for data_group, v128_forms in self.bin_test_data:
171184
for data in data_group:
172185
case_data.append([op_name, [str(data[0]), str(data[1])],
173-
str(o.binary_op(data[0], data[1], self.lane)),
186+
str(o.binary_op(data[0], data[1], self.src_lane, self.dst_lane)),
174187
v128_forms])
175188
for data_group in self.full_bin_test_data:
176189
for data in data_group.get(op_name):
@@ -183,7 +196,7 @@ def get_case_data(self):
183196
for data_group, v128_forms in self.unary_test_data:
184197
for data in data_group:
185198
case_data.append([op_name, [str(data)],
186-
str(o.unary_op(data, self.lane)),
199+
str(o.unary_op(data, self.dst_lane)),
187200
v128_forms])
188201

189202
return case_data

test/core/simd/meta/simd_ext_mul.py

Lines changed: 56 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,74 @@
11
#!/usr/bin/env python3
22

3-
"""
4-
TODO comment
3+
""" Base class for generating extended multiply instructions. These
4+
instructions 2 inputs of the same (narrower) lane shape, multiplies
5+
corresponding lanes with extension (no overflow/wraparound), producing 1 output
6+
of a (wider) shape. These instructions can choose to work on the low or high
7+
halves of the inputs, and perform signed or unsigned multiply.
8+
9+
Subclasses need to define 3 attributes:
10+
- LANE_TYPE (this is the output shape)
11+
- SRC_LANE_TYPE (this is the input (narrower) shape)
12+
- BINARY_OPS (list of operations)
513
"""
614

715
from simd_arithmetic import SimdArithmeticCase
816

17+
918
class SimdExtMulCase(SimdArithmeticCase):
1019
UNARY_OPS = ()
1120

21+
@property
22+
def full_bin_test_data(self):
23+
return []
24+
25+
def get_combine_cases(self):
26+
return ''
27+
28+
@property
29+
def bin_test_data(self):
30+
lane_forms = [self.SRC_LANE_TYPE, self.SRC_LANE_TYPE, self.LANE_TYPE]
31+
return [(self.normal_binary_op_test_data, lane_forms)]
32+
33+
@property
34+
def hex_binary_op_test_data(self):
35+
return []
1236

13-
class SimdI16x8(SimdExtMulCase):
14-
LANE_LEN = 16
37+
def gen_test_cases(self):
38+
wast_filename = '../simd_{wide}_extmul_{narrow}.wast'.format(
39+
wide=self.LANE_TYPE, narrow=self.SRC_LANE_TYPE)
40+
with open(wast_filename, 'w') as fp:
41+
fp.write(self.get_all_cases())
42+
43+
44+
class SimdI16x8ExtMulCase(SimdExtMulCase):
1545
LANE_TYPE = 'i16x8'
46+
SRC_LANE_TYPE = 'i8x16'
1647
BINARY_OPS = ('extmul_low_i8x16_s', 'extmul_high_i8x16_s',
17-
'extmul_low_i8x16_u', 'extmul_high_i8x16_u')
48+
'extmul_low_i8x16_u', 'extmul_high_i8x16_u')
49+
50+
51+
class SimdI32x4ExtMulCase(SimdExtMulCase):
52+
LANE_TYPE = 'i32x4'
53+
SRC_LANE_TYPE = 'i16x8'
54+
BINARY_OPS = ('extmul_low_i16x8_s', 'extmul_high_i16x8_s',
55+
'extmul_low_i16x8_u', 'extmul_high_i16x8_u')
56+
57+
58+
class SimdI64x2ExtMulCase(SimdExtMulCase):
59+
LANE_TYPE = 'i64x2'
60+
SRC_LANE_TYPE = 'i32x4'
61+
BINARY_OPS = ('extmul_low_i32x4_s', 'extmul_high_i32x4_s',
62+
'extmul_low_i32x4_u', 'extmul_high_i32x4_u')
1863

1964

2065
def gen_test_cases():
21-
simd_i16x8_arith = SimdI16x8ArithmeticCase()
22-
simd_i16x8_arith.gen_test_cases()
66+
simd_i16x8_ext_mul_case = SimdI16x8ExtMulCase()
67+
simd_i16x8_ext_mul_case.gen_test_cases()
68+
simd_i32x4_ext_mul_case = SimdI32x4ExtMulCase()
69+
simd_i32x4_ext_mul_case.gen_test_cases()
70+
simd_i64x2_ext_mul_case = SimdI64x2ExtMulCase()
71+
simd_i64x2_ext_mul_case.gen_test_cases()
2372

2473

2574
if __name__ == '__main__':

test/core/simd/meta/simd_integer_op.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ class ArithmeticOp:
1515
add_sat_s, add_sat_u,
1616
sub_sat_s, sub_sat_u,
1717
min_s, min_u, max_s, max_u, avgr_u, abs
18+
ext_mul_s, ext_mul_u
1819
"""
1920
def __init__(self, op: str):
2021
self.op = op
@@ -121,7 +122,7 @@ def unary_op(self, operand, lane):
121122

122123
return str(result)
123124

124-
def binary_op(self, operand1, operand2, lane):
125+
def binary_op(self, operand1, operand2, src_lane, dst_lane=None):
125126
"""General integer arithmetic and saturating arithmetic operations
126127
with 2 operands.
127128
@@ -130,12 +131,15 @@ def binary_op(self, operand1, operand2, lane):
130131
add_sat_s, add_sat_u,
131132
sub_sat_s, sub_sat_u,
132133
min_s, min_u, max_s, max_u, avgr_u
134+
ext_mul_s, ext_mul_u (same as mul)
133135
134136
:param operand1: the operand 1, integer or literal string in hex or decimal format
135137
:param operand2: the operand 2, integer or literal string in hex or decimal format
136-
:param lane: the LaneValue instance of a lane in v128
138+
:param src_lane: the LaneValue instance of a lane in v128
137139
:return: the string of the result of <p1 self.op p2> in hex or decimal format
138140
"""
141+
if not dst_lane:
142+
dst_lane = src_lane
139143
v1 = operand1
140144
v2 = operand2
141145
base1 = base2 = 10
@@ -155,27 +159,35 @@ def binary_op(self, operand1, operand2, lane):
155159
value = v1 - v2
156160
elif self.op == 'mul':
157161
value = v1 * v2
162+
elif self.op.startswith('extmul_'):
163+
if self.op.endswith('s'):
164+
i1 = self.get_valid_value(v1, src_lane)
165+
i2 = self.get_valid_value(v2, src_lane)
166+
else:
167+
i1 = self.get_valid_value(v1, src_lane, signed=False)
168+
i2 = self.get_valid_value(v2, src_lane, signed=False)
169+
value = i1 * i2
158170
elif 'sat' in self.op:
159-
value = self._saturate(v1, v2, lane)
171+
value = self._saturate(v1, v2, src_lane)
160172
if self.op.endswith('_u'):
161173
result_signed = False
162174
elif self.op in ['min_s', 'max_s']:
163-
i1 = self.get_valid_value(v1, lane)
164-
i2 = self.get_valid_value(v2, lane)
175+
i1 = self.get_valid_value(v1, src_lane)
176+
i2 = self.get_valid_value(v2, src_lane)
165177
if self.op == 'min_s':
166178
return operand1 if i1 <= i2 else operand2
167179
else:
168180
return operand1 if i1 >= i2 else operand2
169181
elif self.op in ['min_u', 'max_u']:
170-
i1 = self.get_valid_value(v1, lane, signed=False)
171-
i2 = self.get_valid_value(v2, lane, signed=False)
182+
i1 = self.get_valid_value(v1, src_lane, signed=False)
183+
i2 = self.get_valid_value(v2, src_lane, signed=False)
172184
if self.op == 'min_u':
173185
return operand1 if i1 <= i2 else operand2
174186
else:
175187
return operand1 if i1 >= i2 else operand2
176188
elif self.op == 'avgr_u':
177-
i1 = self.get_valid_value(v1, lane, signed=False)
178-
i2 = self.get_valid_value(v2, lane, signed=False)
189+
i1 = self.get_valid_value(v1, src_lane, signed=False)
190+
i2 = self.get_valid_value(v2, src_lane, signed=False)
179191
result = (i1 + i2 + 1) // 2
180192
if base1 == 16 or base2 == 16:
181193
return hex(result)
@@ -184,5 +196,5 @@ def binary_op(self, operand1, operand2, lane):
184196
else:
185197
raise Exception('Unknown binary operation')
186198

187-
result = self.get_valid_value(value, lane, signed=result_signed)
199+
result = self.get_valid_value(value, dst_lane, signed=result_signed)
188200
return str(result)

0 commit comments

Comments
 (0)