Skip to content

Commit b13aef4

Browse files
committed
Improve: Faster LUT on Ice Lake and Zen4+
This design is cleaner, but I'm not seeing any gains on AMD Zen5. Closes #240
1 parent d8aac4a commit b13aef4

File tree

1 file changed

+25
-64
lines changed

1 file changed

+25
-64
lines changed

include/stringzilla/memory.h

Lines changed: 25 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1015,93 +1015,54 @@ SZ_PUBLIC void sz_lookup_ice(sz_ptr_t target, sz_size_t length, sz_cptr_t source
10151015
__mmask64 head_mask = sz_u64_mask_until_(head_length);
10161016
__mmask64 tail_mask = sz_u64_mask_until_(tail_length);
10171017

1018-
// We need to pull the lookup table into 4x ZMM registers.
1019-
// We can use `vpermi2b` instruction to perform the look in two ZMM registers with `_mm512_permutex2var_epi8`
1020-
// intrinsics, but it has a 6-cycle latency on Sapphire Rapids and requires AVX512-VBMI. Assuming we need to
1021-
// operate on 4 registers, it might be cleaner to use 2x separate `_mm512_permutexvar_epi8` calls.
1022-
// Combining the results with 2x `_mm512_test_epi8_mask` and 3x blends afterwards.
1023-
//
1024-
// - 4x `_mm512_permutexvar_epi8` maps to "VPERMB (ZMM, ZMM, ZMM)":
1025-
// - On Ice Lake: 3 cycles latency, ports: 1*p5
1026-
// - On Genoa: 6 cycles latency, ports: 1*FP12
1027-
// - 3x `_mm512_mask_blend_epi8` maps to "VPBLENDMB_Z (ZMM, K, ZMM, ZMM)":
1028-
// - On Ice Lake: 3 cycles latency, ports: 1*p05
1029-
// - On Genoa: 1 cycle latency, ports: 1*FP0123
1030-
// - 2x `_mm512_test_epi8_mask` maps to "VPTESTMB (K, ZMM, ZMM)":
1031-
// - On Ice Lake: 3 cycles latency, ports: 1*p5
1032-
// - On Genoa: 4 cycles latency, ports: 1*FP01
1018+
// We use VPERMI2B (`_mm512_permutex2var_epi8`) to perform 256-entry lookups efficiently.
1019+
// VPERMI2B uses bit 6 of each index to select between two 64-byte tables, allowing us to
1020+
// cover 128 entries per instruction (2 instructions for all 256 entries).
10331021
//
1022+
// For the high-bit (bit 7) selection, we use VPMOVB2M (`_mm512_movepi8_mask`) which extracts
1023+
// the sign bit of each byte directly to a mask register. This goes to port 0 on Intel,
1024+
// avoiding the port 5 bottleneck that VPTESTMB would cause.
10341025
sz_u512_vec_t lut_0_to_63_vec, lut_64_to_127_vec, lut_128_to_191_vec, lut_192_to_255_vec;
10351026
lut_0_to_63_vec.zmm = _mm512_loadu_si512((lut));
10361027
lut_64_to_127_vec.zmm = _mm512_loadu_si512((lut + 64));
10371028
lut_128_to_191_vec.zmm = _mm512_loadu_si512((lut + 128));
10381029
lut_192_to_255_vec.zmm = _mm512_loadu_si512((lut + 192));
10391030

1040-
sz_u512_vec_t first_bit_vec, second_bit_vec;
1041-
first_bit_vec.zmm = _mm512_set1_epi8((char)0x80);
1042-
second_bit_vec.zmm = _mm512_set1_epi8((char)0x40);
1043-
1044-
__mmask64 first_bit_mask, second_bit_mask;
1045-
sz_u512_vec_t source_vec;
1046-
// If the top bit is set in each word of `source_vec`, than we use `lookup_128_to_191_vec` or
1047-
// `lookup_192_to_255_vec`. If the second bit is set, we use `lookup_64_to_127_vec` or `lookup_192_to_255_vec`.
1048-
sz_u512_vec_t lookup_0_to_63_vec, lookup_64_to_127_vec, lookup_128_to_191_vec, lookup_192_to_255_vec;
1049-
sz_u512_vec_t blended_0_to_127_vec, blended_128_to_255_vec, blended_0_to_255_vec;
1031+
__mmask64 high_bit_mask;
1032+
sz_u512_vec_t source_vec, low_half_vec, high_half_vec, result_vec;
10501033

10511034
// Handling the head.
10521035
if (head_length) {
10531036
source_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, source);
1054-
lookup_0_to_63_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_0_to_63_vec.zmm);
1055-
lookup_64_to_127_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_64_to_127_vec.zmm);
1056-
lookup_128_to_191_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_128_to_191_vec.zmm);
1057-
lookup_192_to_255_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_192_to_255_vec.zmm);
1058-
first_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, first_bit_vec.zmm);
1059-
second_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, second_bit_vec.zmm);
1060-
blended_0_to_127_vec.zmm =
1061-
_mm512_mask_blend_epi8(second_bit_mask, lookup_0_to_63_vec.zmm, lookup_64_to_127_vec.zmm);
1062-
blended_128_to_255_vec.zmm =
1063-
_mm512_mask_blend_epi8(second_bit_mask, lookup_128_to_191_vec.zmm, lookup_192_to_255_vec.zmm);
1064-
blended_0_to_255_vec.zmm =
1065-
_mm512_mask_blend_epi8(first_bit_mask, blended_0_to_127_vec.zmm, blended_128_to_255_vec.zmm);
1066-
_mm512_mask_storeu_epi8(target, head_mask, blended_0_to_255_vec.zmm);
1037+
// VPERMI2B: bit 6 selects between the two tables, bits 0-5 index within each
1038+
low_half_vec.zmm = _mm512_permutex2var_epi8(lut_0_to_63_vec.zmm, source_vec.zmm, lut_64_to_127_vec.zmm);
1039+
high_half_vec.zmm = _mm512_permutex2var_epi8(lut_128_to_191_vec.zmm, source_vec.zmm, lut_192_to_255_vec.zmm);
1040+
// VPMOVB2M: extract bit 7 (sign bit) of each byte directly to mask - uses port 0, not port 5
1041+
high_bit_mask = _mm512_movepi8_mask(source_vec.zmm);
1042+
result_vec.zmm = _mm512_mask_blend_epi8(high_bit_mask, low_half_vec.zmm, high_half_vec.zmm);
1043+
_mm512_mask_storeu_epi8(target, head_mask, result_vec.zmm);
10671044
source += head_length, target += head_length, length -= head_length;
10681045
}
10691046

10701047
// Handling the body in 64-byte chunks aligned to cache-line boundaries with respect to `target`.
10711048
while (length >= 64) {
10721049
source_vec.zmm = _mm512_loadu_si512(source);
1073-
lookup_0_to_63_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_0_to_63_vec.zmm);
1074-
lookup_64_to_127_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_64_to_127_vec.zmm);
1075-
lookup_128_to_191_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_128_to_191_vec.zmm);
1076-
lookup_192_to_255_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_192_to_255_vec.zmm);
1077-
first_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, first_bit_vec.zmm);
1078-
second_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, second_bit_vec.zmm);
1079-
blended_0_to_127_vec.zmm =
1080-
_mm512_mask_blend_epi8(second_bit_mask, lookup_0_to_63_vec.zmm, lookup_64_to_127_vec.zmm);
1081-
blended_128_to_255_vec.zmm =
1082-
_mm512_mask_blend_epi8(second_bit_mask, lookup_128_to_191_vec.zmm, lookup_192_to_255_vec.zmm);
1083-
blended_0_to_255_vec.zmm =
1084-
_mm512_mask_blend_epi8(first_bit_mask, blended_0_to_127_vec.zmm, blended_128_to_255_vec.zmm);
1085-
_mm512_store_si512(target, blended_0_to_255_vec.zmm); //! Aligned store, our main weapon!
1050+
low_half_vec.zmm = _mm512_permutex2var_epi8(lut_0_to_63_vec.zmm, source_vec.zmm, lut_64_to_127_vec.zmm);
1051+
high_half_vec.zmm = _mm512_permutex2var_epi8(lut_128_to_191_vec.zmm, source_vec.zmm, lut_192_to_255_vec.zmm);
1052+
high_bit_mask = _mm512_movepi8_mask(source_vec.zmm);
1053+
result_vec.zmm = _mm512_mask_blend_epi8(high_bit_mask, low_half_vec.zmm, high_half_vec.zmm);
1054+
_mm512_store_si512(target, result_vec.zmm); //! Aligned store, our main weapon!
10861055
source += 64, target += 64, length -= 64;
10871056
}
10881057

10891058
// Handling the tail.
10901059
if (tail_length) {
10911060
source_vec.zmm = _mm512_maskz_loadu_epi8(tail_mask, source);
1092-
lookup_0_to_63_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_0_to_63_vec.zmm);
1093-
lookup_64_to_127_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_64_to_127_vec.zmm);
1094-
lookup_128_to_191_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_128_to_191_vec.zmm);
1095-
lookup_192_to_255_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_192_to_255_vec.zmm);
1096-
first_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, first_bit_vec.zmm);
1097-
second_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, second_bit_vec.zmm);
1098-
blended_0_to_127_vec.zmm =
1099-
_mm512_mask_blend_epi8(second_bit_mask, lookup_0_to_63_vec.zmm, lookup_64_to_127_vec.zmm);
1100-
blended_128_to_255_vec.zmm =
1101-
_mm512_mask_blend_epi8(second_bit_mask, lookup_128_to_191_vec.zmm, lookup_192_to_255_vec.zmm);
1102-
blended_0_to_255_vec.zmm =
1103-
_mm512_mask_blend_epi8(first_bit_mask, blended_0_to_127_vec.zmm, blended_128_to_255_vec.zmm);
1104-
_mm512_mask_storeu_epi8(target, tail_mask, blended_0_to_255_vec.zmm);
1061+
low_half_vec.zmm = _mm512_permutex2var_epi8(lut_0_to_63_vec.zmm, source_vec.zmm, lut_64_to_127_vec.zmm);
1062+
high_half_vec.zmm = _mm512_permutex2var_epi8(lut_128_to_191_vec.zmm, source_vec.zmm, lut_192_to_255_vec.zmm);
1063+
high_bit_mask = _mm512_movepi8_mask(source_vec.zmm);
1064+
result_vec.zmm = _mm512_mask_blend_epi8(high_bit_mask, low_half_vec.zmm, high_half_vec.zmm);
1065+
_mm512_mask_storeu_epi8(target, tail_mask, result_vec.zmm);
11051066
source += tail_length, target += tail_length, length -= tail_length;
11061067
}
11071068
}

0 commit comments

Comments
 (0)