|
| 1 | +/** |
| 2 | + * @brief Hardware-accelerated UTF-8 text processing utilities. |
| 3 | + * @file utf8.h |
| 4 | + * @author Ash Vardanian |
| 5 | +
|
| 6 | + */ |
| 7 | +#ifndef STRINGZILLA_UTF8_H_ |
| 8 | +#define STRINGZILLA_UTF8_H_ |
| 9 | + |
| 10 | +#include "types.h" |
| 11 | + |
| 12 | +#ifdef __cplusplus |
| 13 | +extern "C" { |
| 14 | +#endif |
| 15 | + |
| 16 | +SZ_PUBLIC sz_cptr_t sz_utf8_unpack_chunk_ice( // |
| 17 | + sz_cptr_t text, sz_size_t length, // |
| 18 | + sz_rune_t *runes, sz_size_t runes_capacity, // |
| 19 | + sz_size_t *runes_unpacked) { |
| 20 | + |
| 21 | + // Process up to the minimum of: available bytes, output capacity * 4, or optimal chunk size (64) |
| 22 | + sz_size_t chunk_size = sz_min_of_three(length, runes_capacity * 4, 64); |
| 23 | + sz_u512_vec_t text_vec, runes_vec; |
| 24 | + __mmask64 load_mask = sz_u64_mask_until_(chunk_size); |
| 25 | + text_vec.zmm = _mm512_maskz_loadu_epi8(load_mask, (sz_u8_t const *)text); |
| 26 | + |
| 27 | + // Check, how many of the next characters are single byte (ASCII) codepoints |
| 28 | + // ASCII bytes have bit 7 clear (0x00-0x7F), non-ASCII have bit 7 set (0x80-0xFF) |
| 29 | + __mmask64 non_ascii_mask = _mm512_movepi8_mask(text_vec.zmm); |
| 30 | + // Find first non-ASCII byte or end of loaded data |
| 31 | + sz_size_t ascii_prefix_length = sz_u64_ctz(non_ascii_mask | ~load_mask); |
| 32 | + |
| 33 | + if (ascii_prefix_length) { |
| 34 | + // Unpack the last 16 bytes of text into the next 16 runes. |
| 35 | + // Even if we have more than 16 ASCII characters, we don't want to overcomplicate control flow here. |
| 36 | + sz_size_t runes_to_place = sz_min_of_three(ascii_prefix_length, 16, runes_capacity); |
| 37 | + runes_vec.zmm = _mm512_cvtepu8_epi32(_mm512_castsi512_si128(text_vec.zmm)); |
| 38 | + _mm512_mask_storeu_epi32(runes, sz_u16_mask_until_(runes_to_place), runes_vec.zmm); |
| 39 | + *runes_unpacked = runes_to_place; |
| 40 | + return text + runes_to_place; |
| 41 | + } |
| 42 | + |
| 43 | + // Check for the number of 2-byte characters |
| 44 | + // 2-byte UTF-8: [lead, cont] where lead=110xxxxx (0xC0-0xDF), cont=10xxxxxx (0x80-0xBF) |
| 45 | + // In 16-bit little-endian: 0xCCLL where LL=lead, CC=cont |
| 46 | + // Mask: 0xC0E0 (cont & 0xC0, lead & 0xE0), Pattern: 0x80C0 (cont=0x80, lead=0xC0) |
| 47 | + __mmask32 non_two_byte_mask = |
| 48 | + _mm512_cmpneq_epi16_mask(_mm512_and_si512(text_vec.zmm, _mm512_set1_epi16(0xC0E0)), _mm512_set1_epi16(0x80C0)); |
| 49 | + sz_size_t two_byte_prefix_length = sz_u64_ctz(non_two_byte_mask); |
| 50 | + if (two_byte_prefix_length) { |
| 51 | + // Unpack the last 32 bytes of text into the next 32 runes. |
| 52 | + // Even if we have more than 32 two-byte characters, we don't want to overcomplicate control flow here. |
| 53 | + sz_size_t runes_to_place = sz_min_of_three(two_byte_prefix_length, 32, runes_capacity); |
| 54 | + runes_vec.zmm = _mm512_cvtepu16_epi32(_mm512_castsi512_si256(text_vec.zmm)); |
| 55 | + // Decode 2-byte UTF-8: ((lead & 0x1F) << 6) | (cont & 0x3F) |
| 56 | + // After cvtepu16_epi32: value = 0x0000CCLL where LL=lead (bits 7-0), CC=cont (bits 15-8) |
| 57 | + runes_vec.zmm = _mm512_or_si512( // |
| 58 | + _mm512_slli_epi32(_mm512_and_si512(runes_vec.zmm, _mm512_set1_epi32(0x1FU)), 6), // (lead & 0x1F) << 6 |
| 59 | + _mm512_and_si512(_mm512_srli_epi32(runes_vec.zmm, 8), _mm512_set1_epi32(0x3FU))); // (cont & 0x3F) |
| 60 | + _mm512_mask_storeu_epi32(runes, sz_u32_mask_until_(runes_to_place), runes_vec.zmm); |
| 61 | + *runes_unpacked = runes_to_place; |
| 62 | + return text + runes_to_place * 2; |
| 63 | + } |
| 64 | + |
| 65 | + // Check for the number of 3-byte characters - in this case we can't easily cast to 16-bit integers |
| 66 | + // and check for equality, but we can pre-define the masks and values we expect at each byte position. |
| 67 | + // For 3-byte UTF-8 sequences, we check if bytes match the pattern: 1110xxxx 10xxxxxx 10xxxxxx |
| 68 | + // We need to check every 3rd byte starting from position 0. |
| 69 | + sz_u512_vec_t three_byte_mask_vec, three_byte_pattern_vec; |
| 70 | + three_byte_mask_vec.zmm = _mm512_set1_epi32(0x00C0C0F0); // Mask: [F0, C0, C0, 00] per 4-byte slot |
| 71 | + three_byte_pattern_vec.zmm = _mm512_set1_epi32(0x008080E0); // Pattern: [E0, 80, 80, 00] per 4-byte slot |
| 72 | + |
| 73 | + // Create permutation indices to gather 3-byte sequences into 4-byte slots |
| 74 | + // Input: [b0 b1 b2] [b3 b4 b5] [b6 b7 b8] ... (up to 16 triplets from 48 bytes) |
| 75 | + // Output: [b0 b1 b2 XX] [b3 b4 b5 XX] [b6 b7 b8 XX] ... (16 slots, 4th byte zeroed) |
| 76 | + sz_u512_vec_t permute_indices; |
| 77 | + permute_indices.zmm = _mm512_setr_epi32( |
| 78 | + // Triplets 0-3: [0,1,2,_] [3,4,5,_] [6,7,8,_] [9,10,11,_] |
| 79 | + 0x40020100, 0x40050403, 0x40080706, 0x400B0A09, |
| 80 | + // Triplets 4-7: [12,13,14,_] [15,16,17,_] [18,19,20,_] [21,22,23,_] |
| 81 | + 0x400E0D0C, 0x40111010, 0x40141312, 0x40171615, |
| 82 | + // Triplets 8-11: [24,25,26,_] [27,28,29,_] [30,31,32,_] [33,34,35,_] |
| 83 | + 0x401A1918, 0x401D1C1B, 0x40201F1E, 0x40232221, |
| 84 | + // Triplets 12-15: [36,37,38,_] [39,40,41,_] [42,43,44,_] [45,46,47,_] |
| 85 | + 0x40262524, 0x40292827, 0x402C2B2A, 0x402F2E2D); |
| 86 | + |
| 87 | + // Permute to gather triplets into slots |
| 88 | + sz_u512_vec_t gathered_triplets; |
| 89 | + gathered_triplets.zmm = _mm512_permutexvar_epi8(permute_indices.zmm, text_vec.zmm); |
| 90 | + |
| 91 | + // Check if gathered bytes match 3-byte UTF-8 pattern |
| 92 | + sz_u512_vec_t masked_triplets; |
| 93 | + masked_triplets.zmm = _mm512_and_si512(gathered_triplets.zmm, three_byte_mask_vec.zmm); |
| 94 | + __mmask16 three_byte_match_mask = _mm512_cmpeq_epi32_mask(masked_triplets.zmm, three_byte_pattern_vec.zmm); |
| 95 | + sz_size_t three_byte_prefix_length = sz_u64_ctz(~three_byte_match_mask); |
| 96 | + |
| 97 | + if (three_byte_prefix_length) { |
| 98 | + // Unpack up to 16 three-byte characters (48 bytes of input). |
| 99 | + sz_size_t runes_to_place = sz_min_of_three(three_byte_prefix_length, 16, runes_capacity); |
| 100 | + // Decode: ((b0 & 0x0F) << 12) | ((b1 & 0x3F) << 6) | (b2 & 0x3F) |
| 101 | + // gathered_triplets has: [b0, b1, b2, XX] in each 32-bit slot (little-endian: 0xXXb2b1b0) |
| 102 | + // Extract: b0 from bits 7-0, b1 from bits 15-8, b2 from bits 23-16 |
| 103 | + runes_vec.zmm = _mm512_or_si512( |
| 104 | + _mm512_or_si512( |
| 105 | + // (b0 & 0x0F) << 12 |
| 106 | + _mm512_slli_epi32(_mm512_and_si512(gathered_triplets.zmm, _mm512_set1_epi32(0x0FU)), 12), |
| 107 | + // (b1 & 0x3F) << 6 |
| 108 | + _mm512_slli_epi32( |
| 109 | + _mm512_and_si512(_mm512_srli_epi32(gathered_triplets.zmm, 8), _mm512_set1_epi32(0x3FU)), 6)), |
| 110 | + _mm512_and_si512(_mm512_srli_epi32(gathered_triplets.zmm, 16), _mm512_set1_epi32(0x3FU))); // (b2 & 0x3F) |
| 111 | + _mm512_mask_storeu_epi32(runes, sz_u16_mask_until_(runes_to_place), runes_vec.zmm); |
| 112 | + *runes_unpacked = runes_to_place; |
| 113 | + return text + runes_to_place * 3; |
| 114 | + } |
| 115 | + |
| 116 | + // Check for the number of 4-byte characters |
| 117 | + // For 4-byte UTF-8 sequences: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx |
| 118 | + // With a homogeneous 4-byte prefix, we have perfect 4-byte alignment (up to 16 sequences in 64 bytes) |
| 119 | + sz_u512_vec_t four_byte_mask_vec, four_byte_pattern_vec; |
| 120 | + four_byte_mask_vec.zmm = _mm512_set1_epi32(0xC0C0C0F8); // Mask: [F8, C0, C0, C0] per 4-byte slot |
| 121 | + four_byte_pattern_vec.zmm = _mm512_set1_epi32(0x808080F0); // Pattern: [F0, 80, 80, 80] per 4-byte slot |
| 122 | + |
| 123 | + // Mask and check for 4-byte pattern in each 32-bit slot |
| 124 | + sz_u512_vec_t masked_quads; |
| 125 | + masked_quads.zmm = _mm512_and_si512(text_vec.zmm, four_byte_mask_vec.zmm); |
| 126 | + __mmask16 four_byte_match_mask = _mm512_cmpeq_epi32_mask(masked_quads.zmm, four_byte_pattern_vec.zmm); |
| 127 | + sz_size_t four_byte_prefix_length = sz_u64_ctz(~four_byte_match_mask); |
| 128 | + |
| 129 | + if (four_byte_prefix_length) { |
| 130 | + // Unpack up to 16 four-byte characters (64 bytes of input). |
| 131 | + sz_size_t runes_to_place = sz_min_of_three(four_byte_prefix_length, 16, runes_capacity); |
| 132 | + // Decode: ((b0 & 0x07) << 18) | ((b1 & 0x3F) << 12) | ((b2 & 0x3F) << 6) | (b3 & 0x3F) |
| 133 | + runes_vec.zmm = _mm512_or_si512( |
| 134 | + _mm512_or_si512( |
| 135 | + // (b0 & 0x07) << 18 |
| 136 | + _mm512_slli_epi32(_mm512_and_si512(text_vec.zmm, _mm512_set1_epi32(0x07U)), 18), |
| 137 | + // (b1 & 0x3F) << 12 |
| 138 | + _mm512_slli_epi32(_mm512_and_si512(_mm512_srli_epi32(text_vec.zmm, 8), _mm512_set1_epi32(0x3FU)), 12)), |
| 139 | + _mm512_or_si512( |
| 140 | + // (b2 & 0x3F) << 6 |
| 141 | + _mm512_slli_epi32(_mm512_and_si512(_mm512_srli_epi32(text_vec.zmm, 16), _mm512_set1_epi32(0x3FU)), 6), |
| 142 | + // (b3 & 0x3F) |
| 143 | + _mm512_and_si512(_mm512_srli_epi32(text_vec.zmm, 24), _mm512_set1_epi32(0x3FU)))); |
| 144 | + _mm512_mask_storeu_epi32(runes, sz_u16_mask_until_(runes_to_place), runes_vec.zmm); |
| 145 | + *runes_unpacked = runes_to_place; |
| 146 | + return text + runes_to_place * 4; |
| 147 | + } |
| 148 | + |
| 149 | + // Seems like broken unicoode? |
| 150 | + *runes_unpacked = 0; |
| 151 | + return text; |
| 152 | +} |
| 153 | + |
| 154 | +#if defined(__clang__) |
| 155 | +#pragma clang attribute pop |
| 156 | +#elif defined(__GNUC__) |
| 157 | +#pragma GCC pop_options |
| 158 | +#endif |
| 159 | +#endif // SZ_USE_ICE |
| 160 | +#pragma endregion // Ice Lake Implementation |
| 161 | + |
| 162 | +#pragma region Haswell Implementation |
| 163 | +#if SZ_USE_HASWELL |
| 164 | +#if defined(__clang__) |
| 165 | +#pragma clang attribute push(__attribute__((target("avx2,bmi,bmi2,popcnt"))), apply_to = function) |
| 166 | +#elif defined(__GNUC__) |
| 167 | +#pragma GCC push_options |
| 168 | +#pragma GCC target("avx2,bmi,bmi2,popcnt") |
| 169 | +#endif |
| 170 | + |
| 171 | +SZ_PUBLIC sz_cptr_t sz_utf8_unpack_chunk_haswell( // |
| 172 | + sz_cptr_t text, sz_size_t length, // |
| 173 | + sz_rune_t *runes, sz_size_t runes_capacity, // |
| 174 | + sz_size_t *runes_unpacked) { |
| 175 | + // Fallback to serial implementation for now |
| 176 | + // A future optimization could use AVX2 for decoding |
| 177 | + return sz_utf8_unpack_chunk_serial(text, length, runes, runes_capacity, runes_unpacked); |
| 178 | +} |
| 179 | + |
| 180 | +#if defined(__clang__) |
| 181 | +#pragma clang attribute pop |
| 182 | +#elif defined(__GNUC__) |
| 183 | +#pragma GCC pop_options |
| 184 | +#endif |
| 185 | +#endif // SZ_USE_HASWELL |
| 186 | +#pragma endregion // Haswell Implementation |
| 187 | + |
| 188 | +#pragma region NEON Implementation |
| 189 | +#if SZ_USE_NEON |
| 190 | +#if defined(__clang__) |
| 191 | +#pragma clang attribute push(__attribute__((target("+simd"))), apply_to = function) |
| 192 | +#elif defined(__GNUC__) |
| 193 | +#pragma GCC push_options |
| 194 | +#pragma GCC target("+simd") |
| 195 | +#endif |
| 196 | + |
| 197 | +SZ_PUBLIC sz_cptr_t sz_utf8_unpack_chunk_neon( // |
| 198 | + sz_cptr_t text, sz_size_t length, // |
| 199 | + sz_rune_t *runes, sz_size_t runes_capacity, // |
| 200 | + sz_size_t *runes_unpacked) { |
| 201 | + // TODO: Implement a fast NEON version once we come up with an AVX-512 design. |
| 202 | + return sz_utf8_unpack_chunk_serial(text, length, runes, runes_capacity, runes_unpacked); |
| 203 | +} |
| 204 | + |
| 205 | +#if defined(__clang__) |
| 206 | +#pragma clang attribute pop |
| 207 | +#elif defined(__GNUC__) |
| 208 | +#pragma GCC pop_options |
| 209 | +#endif |
| 210 | +#endif // SZ_USE_NEON |
| 211 | + |
| 212 | +#pragma endregion // NEON Implementation |
| 213 | + |
| 214 | +#ifdef __cplusplus |
| 215 | +} |
| 216 | +#endif |
| 217 | + |
| 218 | +#endif // STRINGZILLA_UTF8_H_ |
0 commit comments