diff --git a/html5ever/src/tokenizer/mod.rs b/html5ever/src/tokenizer/mod.rs index 471e2e7c..ad749b1d 100644 --- a/html5ever/src/tokenizer/mod.rs +++ b/html5ever/src/tokenizer/mod.rs @@ -706,11 +706,11 @@ impl Tokenizer { states::Data => loop { let set = small_char_set!('\r' '\0' '&' '<' '\n'); - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + #[cfg(any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64"))] let set_result = if !(self.opts.exact_errors || self.reconsume.get() || self.ignore_lf.get()) - && is_x86_feature_detected!("sse2") + && Self::is_supported_simd_feature_detected() { let front_buffer = input.peek_front_chunk_mut(); let Some(mut front_buffer) = front_buffer else { @@ -729,8 +729,8 @@ impl Tokenizer { self.pop_except_from(input, set) } else { // SAFETY: - // This CPU is guaranteed to support SSE2 due to the is_x86_feature_detected check above - let result = unsafe { self.data_state_sse2_fast_path(&mut front_buffer) }; + // This CPU is guaranteed to support SIMD due to the is_supported_simd_feature_detected check above + let result = unsafe { self.data_state_simd_fast_path(&mut front_buffer) }; if front_buffer.is_empty() { drop(front_buffer); @@ -743,7 +743,11 @@ impl Tokenizer { self.pop_except_from(input, set) }; - #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))] + #[cfg(not(any( + target_arch = "x86", + target_arch = "x86_64", + target_arch = "aarch64" + )))] let set_result = self.pop_except_from(input, set); let Some(set_result) = set_result else { @@ -1885,18 +1889,90 @@ impl Tokenizer { } } - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - #[target_feature(enable = "sse2")] + /// Checks for supported SIMD feature, which is now either SSE2 for x86/x86_64 or NEON for aarch64. + fn is_supported_simd_feature_detected() -> bool { + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + { + is_x86_feature_detected!("sse2") + } + + #[cfg(target_arch = "aarch64")] + { + std::arch::is_aarch64_feature_detected!("neon") + } + + #[cfg(not(any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64")))] + false + } + + #[cfg(any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64"))] /// Implements the [data state] with SIMD instructions. + /// Calls SSE2- or NEON-specific function for chunks and processes any remaining bytes. /// /// The algorithm implemented is the naive SIMD approach described [here]. /// /// ### SAFETY: - /// Calling this function on a CPU that does not support SSE2 causes undefined behaviour. + /// Calling this function on a CPU that supports neither SSE2 nor NEON causes undefined behaviour. /// /// [data state]: https://html.spec.whatwg.org/#data-state /// [here]: https://lemire.me/blog/2024/06/08/scan-html-faster-with-simd-instructions-chrome-edition/ - unsafe fn data_state_sse2_fast_path(&self, input: &mut StrTendril) -> Option { + unsafe fn data_state_simd_fast_path(&self, input: &mut StrTendril) -> Option { + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + let (mut i, mut n_newlines) = self.data_state_sse2_fast_path(input); + + #[cfg(target_arch = "aarch64")] + let (mut i, mut n_newlines) = self.data_state_neon_fast_path(input); + + // Process any remaining bytes (less than STRIDE) + while let Some(c) = input.as_bytes().get(i) { + if matches!(*c, b'<' | b'&' | b'\r' | b'\0') { + break; + } + if *c == b'\n' { + n_newlines += 1; + } + + i += 1; + } + + let set_result = if i == 0 { + let first_char = input.pop_front_char().unwrap(); + debug_assert!(matches!(first_char, '<' | '&' | '\r' | '\0')); + + // FIXME: Passing a bogus input queue is only relevant when c is \n, which can never happen in this case. + // Still, it would be nice to not have to do that. + // The same is true for the unwrap call. + let preprocessed_char = self + .get_preprocessed_char(first_char, &BufferQueue::default()) + .unwrap(); + SetResult::FromSet(preprocessed_char) + } else { + debug_assert!( + input.len() >= i, + "Trying to remove {:?} bytes from a tendril that is only {:?} bytes long", + i, + input.len() + ); + let consumed_chunk = input.unsafe_subtendril(0, i as u32); + input.unsafe_pop_front(i as u32); + SetResult::NotFromSet(consumed_chunk) + }; + + self.current_line.set(self.current_line.get() + n_newlines); + + Some(set_result) + } + + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + #[target_feature(enable = "sse2")] + /// Implements the [data state] with SSE2 instructions for x86/x86_64. + /// Returns a pair of the number of bytes processed and the number of newlines found. + /// + /// ### SAFETY: + /// Calling this function on a CPU that does not support NEON causes undefined behaviour. + /// + /// [data state]: https://html.spec.whatwg.org/#data-state + unsafe fn data_state_sse2_fast_path(&self, input: &mut StrTendril) -> (usize, u64) { #[cfg(target_arch = "x86")] use std::arch::x86::{ __m128i, _mm_cmpeq_epi8, _mm_loadu_si128, _mm_movemask_epi8, _mm_or_si128, @@ -1960,44 +2036,78 @@ impl Tokenizer { i += STRIDE; } - // Process any remaining bytes (less than STRIDE) - while let Some(c) = raw_bytes.get(i) { - if matches!(*c, b'<' | b'&' | b'\r' | b'\0') { - break; - } - if *c == b'\n' { - n_newlines += 1; - } + (i, n_newlines) + } - i += 1; - } + #[cfg(target_arch = "aarch64")] + #[target_feature(enable = "neon")] + /// Implements the [data state] with NEON SIMD instructions for AArch64. + /// Returns a pair of the number of bytes processed and the number of newlines found. + /// + /// ### SAFETY: + /// Calling this function on a CPU that does not support NEON causes undefined behaviour. + /// + /// [data state]: https://html.spec.whatwg.org/#data-state + unsafe fn data_state_neon_fast_path(&self, input: &mut StrTendril) -> (usize, u64) { + use std::arch::aarch64::{vceqq_u8, vdupq_n_u8, vld1q_u8, vmaxvq_u8, vorrq_u8}; - let set_result = if i == 0 { - let first_char = input.pop_front_char().unwrap(); - debug_assert!(matches!(first_char, '<' | '&' | '\r' | '\0')); + debug_assert!(!input.is_empty()); - // FIXME: Passing a bogus input queue is only relevant when c is \n, which can never happen in this case. - // Still, it would be nice to not have to do that. - // The same is true for the unwrap call. - let preprocessed_char = self - .get_preprocessed_char(first_char, &BufferQueue::default()) - .unwrap(); - SetResult::FromSet(preprocessed_char) - } else { - debug_assert!( - input.len() >= i, - "Trying to remove {:?} bytes from a tendril that is only {:?} bytes long", - i, - input.len() - ); - let consumed_chunk = input.unsafe_subtendril(0, i as u32); - input.unsafe_pop_front(i as u32); - SetResult::NotFromSet(consumed_chunk) - }; + let quote_mask = vdupq_n_u8(b'<'); + let escape_mask = vdupq_n_u8(b'&'); + let carriage_return_mask = vdupq_n_u8(b'\r'); + let zero_mask = vdupq_n_u8(b'\0'); + let newline_mask = vdupq_n_u8(b'\n'); - self.current_line.set(self.current_line.get() + n_newlines); + let raw_bytes: &[u8] = input.as_bytes(); + let start = raw_bytes.as_ptr(); - Some(set_result) + const STRIDE: usize = 16; + let mut i = 0; + let mut n_newlines = 0; + while i + STRIDE <= raw_bytes.len() { + // Load a 16 byte chunk from the input + let data = vld1q_u8(start.add(i)); + + // Compare the chunk against each mask + let quotes = vceqq_u8(data, quote_mask); + let escapes = vceqq_u8(data, escape_mask); + let carriage_returns = vceqq_u8(data, carriage_return_mask); + let zeros = vceqq_u8(data, zero_mask); + let newlines = vceqq_u8(data, newline_mask); + + // Combine all test results and create a bitmask from them. + // Each bit in the mask will be 1 if the character at the bit position is in the set and 0 otherwise. + let test_result = + vorrq_u8(vorrq_u8(quotes, zeros), vorrq_u8(escapes, carriage_returns)); + let bitmask = vmaxvq_u8(test_result); + let newline_mask = vmaxvq_u8(newlines); + if bitmask != 0 { + // We have reached one of the characters that cause the state machine to transition + let chunk_bytes = std::slice::from_raw_parts(start.add(i), STRIDE); + let position = chunk_bytes + .iter() + .position(|&b| matches!(b, b'<' | b'&' | b'\r' | b'\0')) + .unwrap(); + + n_newlines += chunk_bytes[..position] + .iter() + .filter(|&&b| b == b'\n') + .count() as u64; + + i += position; + break; + } else { + if newline_mask != 0 { + let chunk_bytes = std::slice::from_raw_parts(start.add(i), STRIDE); + n_newlines += chunk_bytes.iter().filter(|&&b| b == b'\n').count() as u64; + } + } + + i += STRIDE; + } + + (i, n_newlines) } }