diff --git a/Cargo.toml b/Cargo.toml index e02d3b35..712f2e8a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,6 +44,7 @@ thiserror = "2.0.2" tracing = { version = "0.1" } url = "2.2" walkdir = { version = "2", optional = true } +arc-swap = "1.7.1" # Cloud storage support base64 = { version = "0.22", default-features = false, features = ["std"], optional = true } @@ -60,7 +61,7 @@ rustls-pemfile = { version = "2.0", default-features = false, features = ["std"] serde = { version = "1.0", default-features = false, features = ["derive"], optional = true } serde_json = { version = "1.0", default-features = false, features = ["std"], optional = true } serde_urlencoded = { version = "0.7", optional = true } -tokio = { version = "1.29.0", features = ["sync", "macros", "rt", "time", "io-util"] } +tokio = { version = "1.29.0", features = ["sync", "macros", "rt", "rt-multi-thread", "time", "io-util"] } [target.'cfg(target_family="unix")'.dev-dependencies] nix = { version = "0.30.0", features = ["fs"] } @@ -105,4 +106,9 @@ features = ["js"] [[test]] name = "get_range_file" path = "tests/get_range_file.rs" -required-features = ["fs"] \ No newline at end of file +required-features = ["fs"] + +[[bench]] +name = "cache_benchmark" +harness = false +required-features = [] diff --git a/benches/cache_benchmark.rs b/benches/cache_benchmark.rs new file mode 100644 index 00000000..2a393aa5 --- /dev/null +++ b/benches/cache_benchmark.rs @@ -0,0 +1,412 @@ +use arc_swap::ArcSwapOption; +use rand::Rng; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio::sync::{Mutex, RwLock}; + +#[derive(Debug, Clone)] +#[allow(dead_code)] +struct Credentials { + token: String, + expiry: Instant, +} + +// Arc-swap based cache implementation +#[derive(Debug)] +struct ArcSwapCache { + cache: ArcSwapOption, +} + +impl ArcSwapCache { + fn new() -> Self { + Self { + cache: ArcSwapOption::new(None), + } + } + + fn get(&self) -> Option { + self.cache.load_full().as_ref().map(|c| (**c).clone()) + } + + fn update(&self, creds: Credentials) { + self.cache.store(Some(Arc::new(creds))); + } +} + +// RwLock based cache implementation +#[derive(Debug)] +struct RwLockCache { + cache: RwLock>, +} + +impl RwLockCache { + fn new() -> Self { + Self { + cache: RwLock::new(None), + } + } + + async fn get(&self) -> Option { + self.cache.read().await.clone() + } + + async fn update(&self, creds: Credentials) { + *self.cache.write().await = Some(creds); + } +} + +// Mutex based cache implementation (original problematic approach) +#[derive(Debug)] +struct MutexCache { + cache: Mutex>, +} + +impl MutexCache { + fn new() -> Self { + Self { + cache: Mutex::new(None), + } + } + + async fn get(&self) -> Option { + self.cache.lock().await.clone() + } + + async fn update(&self, creds: Credentials) { + *self.cache.lock().await = Some(creds); + } +} + +fn generate_delay(jitter_us: (u64, u64)) -> u64 { + let mut rng = rand::rng(); + rng.random_range(jitter_us.0..=jitter_us.1) +} + +async fn simulate_request(jitter_us: (u64, u64)) { + let delay = generate_delay(jitter_us); + tokio::time::sleep(Duration::from_micros(delay)).await; +} + +async fn run_arc_swap_benchmark( + num_requests: usize, + concurrent_tasks: usize, + jitter_us: (u64, u64), +) -> (Duration, Vec) { + let cache = Arc::new(ArcSwapCache::new()); + + // Initialize cache with a very long TTL (no expiry during test = uncontended) + cache.update(Credentials { + token: "initial_token".to_string(), + expiry: Instant::now() + Duration::from_secs(86400), // 24 hours + }); + + // No background updater - simulating uncontended case where TTL never expires + + // Run benchmark + let start = Instant::now(); + let mut latencies = Vec::with_capacity(num_requests * concurrent_tasks); + let requests_per_task = num_requests; // Each task does this many requests + + let mut tasks = Vec::new(); + for _ in 0..concurrent_tasks { + let cache = cache.clone(); + tasks.push(tokio::spawn(async move { + let mut task_latencies = Vec::with_capacity(requests_per_task); + for _ in 0..requests_per_task { + let req_start = Instant::now(); + + // Get credentials (lock-free read, no TTL check/refresh needed) + let _creds = cache.get(); + + // Simulate HTTP request work + simulate_request(jitter_us).await; + + task_latencies.push(req_start.elapsed()); + } + task_latencies + })); + } + + // Collect all latencies + for task in tasks { + latencies.extend(task.await.unwrap()); + } + + let total_duration = start.elapsed(); + + (total_duration, latencies) +} + +async fn run_rwlock_benchmark( + num_requests: usize, + concurrent_tasks: usize, + jitter_us: (u64, u64), +) -> (Duration, Vec) { + let cache = Arc::new(RwLockCache::new()); + + // Initialize cache with a very long TTL (no expiry during test = uncontended) + cache + .update(Credentials { + token: "initial_token".to_string(), + expiry: Instant::now() + Duration::from_secs(86400), // 24 hours + }) + .await; + + // No background updater - simulating uncontended case where TTL never expires + + // Run benchmark + let start = Instant::now(); + let mut latencies = Vec::with_capacity(num_requests * concurrent_tasks); + let requests_per_task = num_requests; // Each task does this many requests + + let mut tasks = Vec::new(); + for _ in 0..concurrent_tasks { + let cache = cache.clone(); + tasks.push(tokio::spawn(async move { + let mut task_latencies = Vec::with_capacity(requests_per_task); + for _ in 0..requests_per_task { + let req_start = Instant::now(); + + // Get credentials (requires read lock acquisition even though no writes) + let _creds = cache.get().await; + + // Simulate HTTP request work + simulate_request(jitter_us).await; + + task_latencies.push(req_start.elapsed()); + } + task_latencies + })); + } + + // Collect all latencies + for task in tasks { + latencies.extend(task.await.unwrap()); + } + + let total_duration = start.elapsed(); + + (total_duration, latencies) +} + +async fn run_mutex_benchmark( + num_requests: usize, + concurrent_tasks: usize, + jitter_us: (u64, u64), +) -> (Duration, Vec) { + let cache = Arc::new(MutexCache::new()); + + // Initialize cache with a very long TTL (no expiry during test = uncontended) + cache + .update(Credentials { + token: "initial_token".to_string(), + expiry: Instant::now() + Duration::from_secs(86400), // 24 hours + }) + .await; + + // No background updater - simulating uncontended case where TTL never expires + + // Run benchmark + let start = Instant::now(); + let mut latencies = Vec::with_capacity(num_requests * concurrent_tasks); + let requests_per_task = num_requests; // Each task does this many requests + + let mut tasks = Vec::new(); + for _ in 0..concurrent_tasks { + let cache = cache.clone(); + tasks.push(tokio::spawn(async move { + let mut task_latencies = Vec::with_capacity(requests_per_task); + for _ in 0..requests_per_task { + let req_start = Instant::now(); + + // Get credentials (requires exclusive mutex lock even for reads) + let _creds = cache.get().await; + + // Simulate HTTP request work + simulate_request(jitter_us).await; + + task_latencies.push(req_start.elapsed()); + } + task_latencies + })); + } + + // Collect all latencies + for task in tasks { + latencies.extend(task.await.unwrap()); + } + + let total_duration = start.elapsed(); + + (total_duration, latencies) +} + +fn calculate_percentiles(mut latencies: Vec) -> (Duration, Duration, Duration, Duration) { + if latencies.is_empty() { + eprintln!("Warning: No latencies collected!"); + return ( + Duration::ZERO, + Duration::ZERO, + Duration::ZERO, + Duration::ZERO, + ); + } + + latencies.sort(); + let len = latencies.len(); + + let p50 = latencies[len / 2]; + let p95 = latencies[len * 95 / 100]; + let p99 = latencies[len * 99 / 100]; + let p999 = latencies[len * 999 / 1000]; + + (p50, p95, p99, p999) +} + +#[derive(Debug)] +struct BenchmarkResult { + concurrent_tasks: usize, + mutex_p50: Duration, + mutex_p99: Duration, + arc_p50: Duration, + arc_p99: Duration, + rwlock_p50: Duration, + rwlock_p99: Duration, +} + +async fn run_benchmark_for_concurrency( + num_requests: usize, + concurrent_tasks: usize, + jitter_us: (u64, u64), +) -> BenchmarkResult { + println!( + "\n=== Testing with {} concurrent readers ===", + concurrent_tasks + ); + + // Mutex + println!(" Running Mutex..."); + let (_, mutex_latencies) = run_mutex_benchmark(num_requests, concurrent_tasks, jitter_us).await; + let (mutex_p50, _, mutex_p99, _) = calculate_percentiles(mutex_latencies); + + // Arc-swap + println!(" Running Arc-swap..."); + let (_, arc_latencies) = + run_arc_swap_benchmark(num_requests, concurrent_tasks, jitter_us).await; + let (arc_p50, _, arc_p99, _) = calculate_percentiles(arc_latencies); + + // RwLock + println!(" Running RwLock..."); + let (_, rwlock_latencies) = + run_rwlock_benchmark(num_requests, concurrent_tasks, jitter_us).await; + let (rwlock_p50, _, rwlock_p99, _) = calculate_percentiles(rwlock_latencies); + + BenchmarkResult { + concurrent_tasks, + mutex_p50, + mutex_p99, + arc_p50, + arc_p99, + rwlock_p50, + rwlock_p99, + } +} + +#[tokio::main(flavor = "multi_thread")] +async fn main() { + println!("=== Credentials Cache Benchmark ==="); + println!("Simulated HTTP request workload with varying concurrency"); + println!("(Uncontended: TTL never expires, no writes during benchmark)\n"); + + let num_requests = 200; // Requests per task + let jitter_us = (1000, 2000); // 4-12ms jitter for simulated HTTP requests (8ms ± 4ms) + let concurrency_levels = vec![100, 500, 5_000, 5_000, 10_000, 25_000]; // Number of concurrent tasks + + println!("Configuration:"); + println!(" Requests per task: {}", num_requests); + println!( + " Request jitter: {}-{}ms (simulated HTTP work)", + jitter_us.0, jitter_us.1 + ); + println!(" Concurrency levels: {:?}", concurrency_levels); + println!(" Runtime: Multi-threaded"); + println!(" Scenario: All reads, no credential refreshes (uncontended)\n"); + + // Warmup + println!("Running warmup..."); + let _ = run_mutex_benchmark(100, 10, jitter_us).await; + let _ = run_arc_swap_benchmark(100, 10, jitter_us).await; + let _ = run_rwlock_benchmark(100, 10, jitter_us).await; + + // Run benchmarks for each concurrency level + let mut results = Vec::new(); + for &concurrent_tasks in &concurrency_levels { + let result = run_benchmark_for_concurrency(num_requests, concurrent_tasks, jitter_us).await; + results.push(result); + } + + // Print results table + println!("\n\n=== Results Summary ===\n"); + + println!("Median Latency (p50):"); + println!( + "{:<12} │ {:>10} │ {:>10} │ {:>10} │ {:>12} │ {:>14}", + "Concurrency", "Mutex", "Arc-swap", "RwLock", "Arc vs Mutex", "RwLock vs Mutex" + ); + println!("{}", "─".repeat(92)); + + for result in &results { + let arc_improvement = + (result.mutex_p50.as_secs_f64() / result.arc_p50.as_secs_f64() - 1.0) * 100.0; + let rwlock_improvement = + (result.mutex_p50.as_secs_f64() / result.rwlock_p50.as_secs_f64() - 1.0) * 100.0; + + let concurrency_label = if result.concurrent_tasks >= 1000 { + format!("{}k", result.concurrent_tasks / 1000) + } else { + result.concurrent_tasks.to_string() + }; + println!( + "{:<12} │ {:>8.2}ms │ {:>8.2}ms │ {:>8.2}ms │ {:>11.1}% │ {:>13.1}%", + concurrency_label, + result.mutex_p50.as_secs_f64() * 1000.0, + result.arc_p50.as_secs_f64() * 1000.0, + result.rwlock_p50.as_secs_f64() * 1000.0, + arc_improvement, + rwlock_improvement + ); + } + + println!("\n\nTail Latency (p99):"); + println!( + "{:<12} │ {:>10} │ {:>10} │ {:>10} │ {:>12} │ {:>14}", + "Concurrency", "Mutex", "Arc-swap", "RwLock", "Arc vs Mutex", "RwLock vs Mutex" + ); + println!("{}", "─".repeat(92)); + + for result in &results { + let arc_improvement = + (result.mutex_p99.as_secs_f64() / result.arc_p99.as_secs_f64() - 1.0) * 100.0; + let rwlock_improvement = + (result.mutex_p99.as_secs_f64() / result.rwlock_p99.as_secs_f64() - 1.0) * 100.0; + + let concurrency_label = if result.concurrent_tasks >= 1000 { + format!("{}k", result.concurrent_tasks / 1000) + } else { + result.concurrent_tasks.to_string() + }; + println!( + "{:<12} │ {:>8.2}ms │ {:>8.2}ms │ {:>8.2}ms │ {:>11.1}% │ {:>13.1}%", + concurrency_label, + result.mutex_p99.as_secs_f64() * 1000.0, + result.arc_p99.as_secs_f64() * 1000.0, + result.rwlock_p99.as_secs_f64() * 1000.0, + arc_improvement, + rwlock_improvement + ); + } + + println!("\n\nKey Findings:"); + println!(" - Positive % = improvement over Mutex (faster)"); + println!(" - Negative % = regression vs Mutex (slower)"); +} diff --git a/src/client/token.rs b/src/client/token.rs index 81ffc110..bfd5aeba 100644 --- a/src/client/token.rs +++ b/src/client/token.rs @@ -15,7 +15,9 @@ // specific language governing permissions and limitations // under the License. +use arc_swap::ArcSwapOption; use std::future::Future; +use std::sync::Arc; use std::time::{Duration, Instant}; use tokio::sync::Mutex; @@ -33,7 +35,8 @@ pub(crate) struct TemporaryToken { /// [`TemporaryToken`] based on its expiry #[derive(Debug)] pub(crate) struct TokenCache { - cache: Mutex, Instant)>>, + cache: ArcSwapOption>, + refresh_lock: Mutex<()>, min_ttl: Duration, fetch_backoff: Duration, } @@ -41,7 +44,8 @@ pub(crate) struct TokenCache { impl Default for TokenCache { fn default() -> Self { Self { - cache: Default::default(), + cache: ArcSwapOption::new(None), + refresh_lock: Default::default(), min_ttl: Duration::from_secs(300), // How long to wait before re-attempting a token fetch after receiving one that // is still within the min-ttl @@ -62,32 +66,58 @@ impl TokenCache { F: FnOnce() -> Fut + Send, Fut: Future, E>> + Send, { - let now = Instant::now(); - let mut locked = self.cache.lock().await; + if let Some(token) = self.try_get_cached() { + return Ok(token); + } + + // Only one fetch at a time + let _refresh = self.refresh_lock.lock().await; + + // Re-check after acquiring lock in case another task refreshed already + if let Some(token) = self.try_get_cached() { + return Ok(token); + } + + let fetched = f().await?; + let token_clone = fetched.token.clone(); + let entry = Arc::new(CacheEntry { + token: fetched.token, + expiry: fetched.expiry, + fetched_at: Instant::now(), + }); + self.cache.store(Some(entry)); + Ok(token_clone) + } - if let Some((cached, fetched_at)) = locked.as_ref() { - match cached.expiry { + fn try_get_cached(&self) -> Option { + let now = Instant::now(); + if let Some(entry) = self.cache.load_full() { + match entry.expiry { Some(ttl) => { - if ttl.checked_duration_since(now).unwrap_or_default() > self.min_ttl || - // if we've recently attempted to fetch this token and it's not actually - // expired, we'll wait to re-fetch it and return the cached one - (fetched_at.elapsed() < self.fetch_backoff && ttl.checked_duration_since(now).is_some()) + let remaining = ttl.checked_duration_since(now).unwrap_or_default(); + if remaining > self.min_ttl + || (entry.fetched_at.elapsed() < self.fetch_backoff + && ttl.checked_duration_since(now).is_some()) { - return Ok(cached.token.clone()); + return Some(entry.token.clone()); } } - None => return Ok(cached.token.clone()), + None => { + return Some(entry.token.clone()); + } } } - - let cached = f().await?; - let token = cached.token.clone(); - *locked = Some((cached, Instant::now())); - - Ok(token) + None } } +#[derive(Debug)] +struct CacheEntry { + token: T, + expiry: Option, + fetched_at: Instant, +} + #[cfg(test)] mod test { use crate::client::token::{TemporaryToken, TokenCache};