diff --git a/html5ever/src/tokenizer/mod.rs b/html5ever/src/tokenizer/mod.rs
index 471e2e7c..ad749b1d 100644
--- a/html5ever/src/tokenizer/mod.rs
+++ b/html5ever/src/tokenizer/mod.rs
@@ -706,11 +706,11 @@ impl Tokenizer {
states::Data => loop {
let set = small_char_set!('\r' '\0' '&' '<' '\n');
- #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
+ #[cfg(any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64"))]
let set_result = if !(self.opts.exact_errors
|| self.reconsume.get()
|| self.ignore_lf.get())
- && is_x86_feature_detected!("sse2")
+ && Self::is_supported_simd_feature_detected()
{
let front_buffer = input.peek_front_chunk_mut();
let Some(mut front_buffer) = front_buffer else {
@@ -729,8 +729,8 @@ impl Tokenizer {
self.pop_except_from(input, set)
} else {
// SAFETY:
- // This CPU is guaranteed to support SSE2 due to the is_x86_feature_detected check above
- let result = unsafe { self.data_state_sse2_fast_path(&mut front_buffer) };
+ // This CPU is guaranteed to support SIMD due to the is_supported_simd_feature_detected check above
+ let result = unsafe { self.data_state_simd_fast_path(&mut front_buffer) };
if front_buffer.is_empty() {
drop(front_buffer);
@@ -743,7 +743,11 @@ impl Tokenizer {
self.pop_except_from(input, set)
};
- #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
+ #[cfg(not(any(
+ target_arch = "x86",
+ target_arch = "x86_64",
+ target_arch = "aarch64"
+ )))]
let set_result = self.pop_except_from(input, set);
let Some(set_result) = set_result else {
@@ -1885,18 +1889,90 @@ impl Tokenizer {
}
}
- #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
- #[target_feature(enable = "sse2")]
+ /// Checks for supported SIMD feature, which is now either SSE2 for x86/x86_64 or NEON for aarch64.
+ fn is_supported_simd_feature_detected() -> bool {
+ #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
+ {
+ is_x86_feature_detected!("sse2")
+ }
+
+ #[cfg(target_arch = "aarch64")]
+ {
+ std::arch::is_aarch64_feature_detected!("neon")
+ }
+
+ #[cfg(not(any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64")))]
+ false
+ }
+
+ #[cfg(any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64"))]
/// Implements the [data state] with SIMD instructions.
+ /// Calls SSE2- or NEON-specific function for chunks and processes any remaining bytes.
///
/// The algorithm implemented is the naive SIMD approach described [here].
///
/// ### SAFETY:
- /// Calling this function on a CPU that does not support SSE2 causes undefined behaviour.
+ /// Calling this function on a CPU that supports neither SSE2 nor NEON causes undefined behaviour.
///
/// [data state]: https://html.spec.whatwg.org/#data-state
/// [here]: https://lemire.me/blog/2024/06/08/scan-html-faster-with-simd-instructions-chrome-edition/
- unsafe fn data_state_sse2_fast_path(&self, input: &mut StrTendril) -> Option {
+ unsafe fn data_state_simd_fast_path(&self, input: &mut StrTendril) -> Option {
+ #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
+ let (mut i, mut n_newlines) = self.data_state_sse2_fast_path(input);
+
+ #[cfg(target_arch = "aarch64")]
+ let (mut i, mut n_newlines) = self.data_state_neon_fast_path(input);
+
+ // Process any remaining bytes (less than STRIDE)
+ while let Some(c) = input.as_bytes().get(i) {
+ if matches!(*c, b'<' | b'&' | b'\r' | b'\0') {
+ break;
+ }
+ if *c == b'\n' {
+ n_newlines += 1;
+ }
+
+ i += 1;
+ }
+
+ let set_result = if i == 0 {
+ let first_char = input.pop_front_char().unwrap();
+ debug_assert!(matches!(first_char, '<' | '&' | '\r' | '\0'));
+
+ // FIXME: Passing a bogus input queue is only relevant when c is \n, which can never happen in this case.
+ // Still, it would be nice to not have to do that.
+ // The same is true for the unwrap call.
+ let preprocessed_char = self
+ .get_preprocessed_char(first_char, &BufferQueue::default())
+ .unwrap();
+ SetResult::FromSet(preprocessed_char)
+ } else {
+ debug_assert!(
+ input.len() >= i,
+ "Trying to remove {:?} bytes from a tendril that is only {:?} bytes long",
+ i,
+ input.len()
+ );
+ let consumed_chunk = input.unsafe_subtendril(0, i as u32);
+ input.unsafe_pop_front(i as u32);
+ SetResult::NotFromSet(consumed_chunk)
+ };
+
+ self.current_line.set(self.current_line.get() + n_newlines);
+
+ Some(set_result)
+ }
+
+ #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
+ #[target_feature(enable = "sse2")]
+ /// Implements the [data state] with SSE2 instructions for x86/x86_64.
+ /// Returns a pair of the number of bytes processed and the number of newlines found.
+ ///
+ /// ### SAFETY:
+ /// Calling this function on a CPU that does not support NEON causes undefined behaviour.
+ ///
+ /// [data state]: https://html.spec.whatwg.org/#data-state
+ unsafe fn data_state_sse2_fast_path(&self, input: &mut StrTendril) -> (usize, u64) {
#[cfg(target_arch = "x86")]
use std::arch::x86::{
__m128i, _mm_cmpeq_epi8, _mm_loadu_si128, _mm_movemask_epi8, _mm_or_si128,
@@ -1960,44 +2036,78 @@ impl Tokenizer {
i += STRIDE;
}
- // Process any remaining bytes (less than STRIDE)
- while let Some(c) = raw_bytes.get(i) {
- if matches!(*c, b'<' | b'&' | b'\r' | b'\0') {
- break;
- }
- if *c == b'\n' {
- n_newlines += 1;
- }
+ (i, n_newlines)
+ }
- i += 1;
- }
+ #[cfg(target_arch = "aarch64")]
+ #[target_feature(enable = "neon")]
+ /// Implements the [data state] with NEON SIMD instructions for AArch64.
+ /// Returns a pair of the number of bytes processed and the number of newlines found.
+ ///
+ /// ### SAFETY:
+ /// Calling this function on a CPU that does not support NEON causes undefined behaviour.
+ ///
+ /// [data state]: https://html.spec.whatwg.org/#data-state
+ unsafe fn data_state_neon_fast_path(&self, input: &mut StrTendril) -> (usize, u64) {
+ use std::arch::aarch64::{vceqq_u8, vdupq_n_u8, vld1q_u8, vmaxvq_u8, vorrq_u8};
- let set_result = if i == 0 {
- let first_char = input.pop_front_char().unwrap();
- debug_assert!(matches!(first_char, '<' | '&' | '\r' | '\0'));
+ debug_assert!(!input.is_empty());
- // FIXME: Passing a bogus input queue is only relevant when c is \n, which can never happen in this case.
- // Still, it would be nice to not have to do that.
- // The same is true for the unwrap call.
- let preprocessed_char = self
- .get_preprocessed_char(first_char, &BufferQueue::default())
- .unwrap();
- SetResult::FromSet(preprocessed_char)
- } else {
- debug_assert!(
- input.len() >= i,
- "Trying to remove {:?} bytes from a tendril that is only {:?} bytes long",
- i,
- input.len()
- );
- let consumed_chunk = input.unsafe_subtendril(0, i as u32);
- input.unsafe_pop_front(i as u32);
- SetResult::NotFromSet(consumed_chunk)
- };
+ let quote_mask = vdupq_n_u8(b'<');
+ let escape_mask = vdupq_n_u8(b'&');
+ let carriage_return_mask = vdupq_n_u8(b'\r');
+ let zero_mask = vdupq_n_u8(b'\0');
+ let newline_mask = vdupq_n_u8(b'\n');
- self.current_line.set(self.current_line.get() + n_newlines);
+ let raw_bytes: &[u8] = input.as_bytes();
+ let start = raw_bytes.as_ptr();
- Some(set_result)
+ const STRIDE: usize = 16;
+ let mut i = 0;
+ let mut n_newlines = 0;
+ while i + STRIDE <= raw_bytes.len() {
+ // Load a 16 byte chunk from the input
+ let data = vld1q_u8(start.add(i));
+
+ // Compare the chunk against each mask
+ let quotes = vceqq_u8(data, quote_mask);
+ let escapes = vceqq_u8(data, escape_mask);
+ let carriage_returns = vceqq_u8(data, carriage_return_mask);
+ let zeros = vceqq_u8(data, zero_mask);
+ let newlines = vceqq_u8(data, newline_mask);
+
+ // Combine all test results and create a bitmask from them.
+ // Each bit in the mask will be 1 if the character at the bit position is in the set and 0 otherwise.
+ let test_result =
+ vorrq_u8(vorrq_u8(quotes, zeros), vorrq_u8(escapes, carriage_returns));
+ let bitmask = vmaxvq_u8(test_result);
+ let newline_mask = vmaxvq_u8(newlines);
+ if bitmask != 0 {
+ // We have reached one of the characters that cause the state machine to transition
+ let chunk_bytes = std::slice::from_raw_parts(start.add(i), STRIDE);
+ let position = chunk_bytes
+ .iter()
+ .position(|&b| matches!(b, b'<' | b'&' | b'\r' | b'\0'))
+ .unwrap();
+
+ n_newlines += chunk_bytes[..position]
+ .iter()
+ .filter(|&&b| b == b'\n')
+ .count() as u64;
+
+ i += position;
+ break;
+ } else {
+ if newline_mask != 0 {
+ let chunk_bytes = std::slice::from_raw_parts(start.add(i), STRIDE);
+ n_newlines += chunk_bytes.iter().filter(|&&b| b == b'\n').count() as u64;
+ }
+ }
+
+ i += STRIDE;
+ }
+
+ (i, n_newlines)
}
}