Skip to content

Commit fb09351

Browse files
committed
Fix: Special Cyrillic folding cases
1 parent b4e269e commit fb09351

File tree

3 files changed

+38
-38
lines changed

3 files changed

+38
-38
lines changed

include/stringzilla/utf8_case.h

Lines changed: 29 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -4414,9 +4414,8 @@ SZ_INTERNAL sz_cptr_t sz_utf8_case_insensitive_find_ice_latin1ab_upto16byte_( //
44144414

44154415
// Main loop - single load, 3-probe filter, verify candidates
44164416
for (; haystack_length >= 64; haystack += step, haystack_length -= step) {
4417-
__m512i const raw_chunk = _mm512_loadu_si512(haystack);
4418-
4419-
haystack_vec.zmm = sz_utf8_case_insensitive_find_ice_latin1ab_fold_zmm_(raw_chunk);
4417+
haystack_vec.zmm = _mm512_loadu_si512(haystack);
4418+
haystack_vec.zmm = sz_utf8_case_insensitive_find_ice_latin1ab_fold_zmm_(haystack_vec.zmm);
44204419

44214420
sz_u64_t first_matches = _mm512_cmpeq_epi8_mask(haystack_vec.zmm, probe_first_vec.zmm);
44224421
sz_u64_t mid_matches = _mm512_cmpeq_epi8_mask(haystack_vec.zmm, probe_mid_vec.zmm);
@@ -4454,10 +4453,10 @@ SZ_INTERNAL sz_cptr_t sz_utf8_case_insensitive_find_ice_latin1ab_upto16byte_( //
44544453
// Tail
44554454
if (haystack_length >= needle_length) {
44564455
__mmask64 const load_mask = sz_u64_mask_until_(haystack_length);
4457-
__m512i const raw_tail = _mm512_maskz_loadu_epi8(load_mask, haystack);
4456+
haystack_vec.zmm = _mm512_maskz_loadu_epi8(load_mask, haystack);
44584457

44594458
__mmask64 const tail_valid = sz_u64_mask_until_(haystack_length - needle_length + 1);
4460-
haystack_vec.zmm = sz_utf8_case_insensitive_find_ice_latin1ab_fold_zmm_(raw_tail);
4459+
haystack_vec.zmm = sz_utf8_case_insensitive_find_ice_latin1ab_fold_zmm_(haystack_vec.zmm);
44614460

44624461
sz_u64_t first_matches = _mm512_cmpeq_epi8_mask(haystack_vec.zmm, probe_first_vec.zmm);
44634462
sz_u64_t mid_matches = _mm512_cmpeq_epi8_mask(haystack_vec.zmm, probe_mid_vec.zmm);
@@ -4580,14 +4579,14 @@ SZ_INTERNAL sz_cptr_t sz_utf8_case_insensitive_find_ice_latin1ab_(sz_cptr_t hays
45804579
sz_u512_vec_t haystack_first_vec, haystack_mid_vec, haystack_last_vec;
45814580
for (; haystack_length >= needle_length + 64; haystack += 62, haystack_length -= 62) {
45824581
// Load raw haystack chunks at adjusted offsets
4583-
__m512i const raw_first = _mm512_loadu_si512(haystack + safe_offset + load_first);
4584-
__m512i const raw_mid = _mm512_loadu_si512(haystack + safe_offset + load_mid);
4585-
__m512i const raw_last = _mm512_loadu_si512(haystack + safe_offset + load_last);
4582+
haystack_first_vec.zmm = _mm512_loadu_si512(haystack + safe_offset + load_first);
4583+
haystack_mid_vec.zmm = _mm512_loadu_si512(haystack + safe_offset + load_mid);
4584+
haystack_last_vec.zmm = _mm512_loadu_si512(haystack + safe_offset + load_last);
45864585

45874586
// Fold and compare
4588-
haystack_first_vec.zmm = sz_utf8_case_insensitive_find_ice_latin1ab_fold_zmm_(raw_first);
4589-
haystack_mid_vec.zmm = sz_utf8_case_insensitive_find_ice_latin1ab_fold_zmm_(raw_mid);
4590-
haystack_last_vec.zmm = sz_utf8_case_insensitive_find_ice_latin1ab_fold_zmm_(raw_last);
4587+
haystack_first_vec.zmm = sz_utf8_case_insensitive_find_ice_latin1ab_fold_zmm_(haystack_first_vec.zmm);
4588+
haystack_mid_vec.zmm = sz_utf8_case_insensitive_find_ice_latin1ab_fold_zmm_(haystack_mid_vec.zmm);
4589+
haystack_last_vec.zmm = sz_utf8_case_insensitive_find_ice_latin1ab_fold_zmm_(haystack_last_vec.zmm);
45914590

45924591
// 3-point Raita filter: find positions where all 3 probes match
45934592
// Shift masks by prefix to align probe byte position
@@ -4641,14 +4640,13 @@ SZ_INTERNAL sz_cptr_t sz_utf8_case_insensitive_find_ice_latin1ab_(sz_cptr_t hays
46414640
__mmask64 tail_mask_last = sz_u64_mask_until_(need_last < avail_last ? need_last : avail_last);
46424641

46434642
// Load raw tail chunks
4644-
__m512i raw_first = _mm512_maskz_loadu_epi8(tail_mask_first, haystack + safe_offset + load_first);
4645-
__m512i raw_mid = _mm512_maskz_loadu_epi8(tail_mask_mid, haystack + safe_offset + load_mid);
4646-
__m512i raw_last = _mm512_maskz_loadu_epi8(tail_mask_last, haystack + safe_offset + load_last);
4643+
haystack_first_vec.zmm = _mm512_maskz_loadu_epi8(tail_mask_first, haystack + safe_offset + load_first);
4644+
haystack_mid_vec.zmm = _mm512_maskz_loadu_epi8(tail_mask_mid, haystack + safe_offset + load_mid);
4645+
haystack_last_vec.zmm = _mm512_maskz_loadu_epi8(tail_mask_last, haystack + safe_offset + load_last);
46474646

4648-
// Fold chunks
4649-
haystack_first_vec.zmm = sz_utf8_case_insensitive_find_ice_latin1ab_fold_zmm_(raw_first);
4650-
haystack_mid_vec.zmm = sz_utf8_case_insensitive_find_ice_latin1ab_fold_zmm_(raw_mid);
4651-
haystack_last_vec.zmm = sz_utf8_case_insensitive_find_ice_latin1ab_fold_zmm_(raw_last);
4647+
haystack_first_vec.zmm = sz_utf8_case_insensitive_find_ice_latin1ab_fold_zmm_(haystack_first_vec.zmm);
4648+
haystack_mid_vec.zmm = sz_utf8_case_insensitive_find_ice_latin1ab_fold_zmm_(haystack_mid_vec.zmm);
4649+
haystack_last_vec.zmm = sz_utf8_case_insensitive_find_ice_latin1ab_fold_zmm_(haystack_last_vec.zmm);
46524650

46534651
__mmask64 match_first =
46544652
_mm512_cmpeq_epi8_mask(haystack_first_vec.zmm, probe_first_vec.zmm) >> window->prefix_first;
@@ -4732,6 +4730,13 @@ SZ_INTERNAL __m512i sz_utf8_case_insensitive_find_ice_cyrillic_fold_zmm_(__m512i
47324730
__mmask64 is_cyrillic_d1_lower_mask =
47334731
_mm512_mask_cmplt_epu8_mask(is_after_d1_mask, _mm512_sub_epi8(source, x_80_zmm), x_10_zmm);
47344732

4733+
// Step 4b: Fold Cyrillic extended lowercase after D1 (second bytes 90-9F → subtract 0x10)
4734+
// This handles 'small io' ё (D1 91) → Ё (D0 81), є (D1 94) → Є (D0 84), etc.
4735+
// Range check: (byte - 0x90) < 0x10
4736+
__m512i const x_90_zmm = _mm512_set1_epi8((char)0x90);
4737+
__mmask64 is_cyrillic_d1_ext_mask =
4738+
_mm512_mask_cmplt_epu8_mask(is_after_d1_mask, _mm512_sub_epi8(source, x_90_zmm), x_10_zmm);
4739+
47354740
// Step 5: Fold ASCII uppercase (A-Z) → add 0x20
47364741
__mmask64 is_ascii_upper_mask = _mm512_cmplt_epu8_mask(_mm512_sub_epi8(source, x_41_zmm), x_1a_zmm);
47374742

@@ -4741,6 +4746,9 @@ SZ_INTERNAL __m512i sz_utf8_case_insensitive_find_ice_cyrillic_fold_zmm_(__m512i
47414746
// Step 7: Apply Cyrillic D1 lowercase folding (add 0x20)
47424747
result_zmm = _mm512_mask_add_epi8(result_zmm, is_cyrillic_d1_lower_mask, result_zmm, x_20_zmm);
47434748

4749+
// Step 7b: Apply Cyrillic D1 extended folding (subtract 0x10)
4750+
result_zmm = _mm512_mask_sub_epi8(result_zmm, is_cyrillic_d1_ext_mask, result_zmm, x_10_zmm);
4751+
47444752
// Step 8: Apply ASCII folding (add 0x20)
47454753
result_zmm = _mm512_mask_add_epi8(result_zmm, is_ascii_upper_mask, result_zmm, x_20_zmm);
47464754

@@ -5472,7 +5480,7 @@ SZ_INTERNAL sz_cptr_t sz_utf8_case_insensitive_find_ice_greek_upto16byte_( //
54725480
if (haystack_length >= needle_length) {
54735481
__mmask64 const load_mask = sz_u64_mask_until_(haystack_length);
54745482
__mmask64 const tail_valid = sz_u64_mask_until_(haystack_length - needle_length + 1);
5475-
5483+
54765484
haystack_vec.zmm = _mm512_maskz_loadu_epi8(load_mask, haystack);
54775485

54785486
// Fall-back to serial code for Polytonic Greek characters, obsolete in modern texts
@@ -5677,8 +5685,8 @@ SZ_INTERNAL sz_cptr_t sz_utf8_case_insensitive_find_ice_greek_(sz_cptr_t haystac
56775685
_mm512_mask_cmpeq_epi8_mask(tail_mask_mid, haystack_mid_vec.zmm, x_e1_zmm) |
56785686
_mm512_mask_cmpeq_epi8_mask(tail_mask_last, haystack_last_vec.zmm, x_e1_zmm);
56795687
if (has_e1_mask) {
5680-
sz_cptr_t result = sz_utf8_case_insensitive_find_chunk_(haystack, haystack_length, needle, needle_length,
5681-
matched_length);
5688+
sz_cptr_t result =
5689+
sz_utf8_case_insensitive_find_chunk_(haystack, haystack_length, needle, needle_length, matched_length);
56825690
if (result) return result;
56835691
return SZ_NULL_CHAR;
56845692
}
@@ -6665,23 +6673,6 @@ SZ_PUBLIC sz_cptr_t sz_utf8_case_insensitive_find_ice( //
66656673
return result;
66666674
}
66676675

6668-
// Ligature handling: If the needle contains ligature bytes (EF AC xx or EF AD xx range),
6669-
// we must fall back to serial because ligatures like fi (EF AC 81) expand to "fi" (2 bytes)
6670-
// during case-folding. SIMD can't match 3-byte ligature against 2-byte expansion.
6671-
// This is rare in practice but must be correct.
6672-
for (sz_size_t i = 0; i + 2 < needle_length; ++i) {
6673-
if ((sz_u8_t)needle[i] == 0xEF && ((sz_u8_t)needle[i + 1] == 0xAC || (sz_u8_t)needle[i + 1] == 0xAD)) {
6674-
return sz_utf8_case_insensitive_find_serial(haystack, haystack_length, needle, needle_length,
6675-
matched_length);
6676-
}
6677-
}
6678-
// Handle needle ending with EF or EF AC (partial ligature sequence)
6679-
if (needle_length >= 1 && (sz_u8_t)needle[needle_length - 1] == 0xEF)
6680-
return sz_utf8_case_insensitive_find_serial(haystack, haystack_length, needle, needle_length, matched_length);
6681-
if (needle_length >= 2 && (sz_u8_t)needle[needle_length - 2] == 0xEF &&
6682-
((sz_u8_t)needle[needle_length - 1] == 0xAC || (sz_u8_t)needle[needle_length - 1] == 0xAD))
6683-
return sz_utf8_case_insensitive_find_serial(haystack, haystack_length, needle, needle_length, matched_length);
6684-
66856676
// There is a way to perform case-insensitive substring search faster than case-folding both strings
66866677
// and calling a standard substring search algorithm on them. Case-folding bicameral scripts is typically
66876678
// a multi-step procedure:

scripts/test_stringzilla.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,13 @@ void test_utf8_case_fold_equivalence( //
615615
"\xD0\x9F", // П (U+041F)
616616
"\xD0\x9F\xD0\xA0\xD0\x98\xD0\x92\xD0\x95\xD0\xA2", // ПРИВЕТ
617617
"\xD0\xBF\xD1\x80\xD0\xB8\xD0\xB2\xD0\xB5\xD1\x82", // привет
618+
// Cyrillic Special
619+
"\xD0\x81", // Ё
620+
"\xD1\x91", // ё
621+
"\xD0\x84", // Є
622+
"\xD1\x94", // є
623+
"\xD0\x87", // Ї
624+
"\xD1\x97", // ї
618625
// Greek (2-byte UTF-8 starting with CE-CF)
619626
"\xCE\x91", // Α (U+0391)
620627
"\xCE\xA9", // Ω (U+03A9)

scripts/test_stringzilla.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
* @endcode
3535
*/
3636
#pragma once
37+
#include <cstdio> // `std::printf`, `std::fflush`
3738
#include <cstdlib> // `std::getenv`, `std::strtoul`
3839
#include <fstream> // `std::ifstream`
3940
#include <iostream> // `std::cout`, `std::endl`
@@ -142,6 +143,7 @@ inline void print_test_environment() noexcept {
142143
std::printf("- Test seed: %u%s\n", static_cast<unsigned>(seed), from_env ? " (from SZ_TESTS_SEED)" : "");
143144
double multiplier = get_iterations_multiplier();
144145
if (multiplier != 1.0) std::printf("- Iterations multiplier: %.2fx\n", multiplier);
146+
std::fflush(stdout); // Ensure output is visible even on crash
145147
}
146148

147149
/**

0 commit comments

Comments
 (0)