Skip to content

Commit 4fadcaa

Browse files
committed
enable SIMD, memory pool
Signed-off-by: eric-epsilla <eric@epsilla.com>
1 parent 7b6c107 commit 4fadcaa

7 files changed

Lines changed: 163 additions & 7 deletions

File tree

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0.3.22
1+
0.3.23

engine/CMakeLists.txt

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ list(APPEND LIB_FILES ${UTILS_FILES})
6868
file(GLOB UTILS_FILES "db/execution/*.cpp")
6969
list(APPEND LIB_FILES ${UTILS_FILES})
7070
file(GLOB UTILS_FILES "db/index/*.cpp")
71+
# 排除有问题的 distances.cpp,使用 distance_simd.cpp
72+
list(REMOVE_ITEM UTILS_FILES "${CMAKE_CURRENT_SOURCE_DIR}/db/index/distances.cpp")
7173
list(APPEND LIB_FILES ${UTILS_FILES})
7274
file(GLOB UTILS_FILES "db/index/nsg/*.cpp")
7375
list(APPEND LIB_FILES ${UTILS_FILES})
@@ -311,3 +313,27 @@ if(CLANG_TIDY_PATH)
311313
DEPENDS vector_db_test
312314
)
313315
endif()
316+
317+
# ============================================================
318+
# Phase 1 优化: SIMD 距离计算
319+
# ============================================================
320+
include(CheckCXXCompilerFlag)
321+
322+
# 检测 AVX2 支持
323+
check_cxx_compiler_flag("-mavx2" COMPILER_SUPPORTS_AVX2)
324+
if(COMPILER_SUPPORTS_AVX2)
325+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx2 -mfma")
326+
add_definitions(-DUSE_AVX2)
327+
message(STATUS "✓ AVX2 SIMD support enabled")
328+
endif()
329+
330+
# 检测 ARM NEON 支持
331+
if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64|arm64|ARM64")
332+
add_definitions(-DUSE_NEON)
333+
message(STATUS "✓ ARM NEON SIMD support enabled")
334+
endif()
335+
336+
# 使用 SIMD 优化的距离计算
337+
add_definitions(-DUSE_SIMD_OPTIMIZED)
338+
message(STATUS "✓ SIMD optimized distance calculation enabled")
339+

engine/db/execution/vec_search_executor.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "db/execution/vec_search_executor.hpp"
22
#include "utils/safe_memory_ops.hpp"
3+
#include "utils/memory_pool.hpp"
34

45
#include <omp.h>
56

@@ -68,6 +69,31 @@ VecSearchExecutor::VecSearchExecutor(
6869
prefilter_enabled_(prefilter_enabled) {
6970
ann_index_ = ann_index;
7071

72+
// Initialize memory pool (singleton, only initialized once)
73+
static bool memory_pool_initialized = false;
74+
if (!memory_pool_initialized) {
75+
utils::MemoryPoolConfig config;
76+
config.initial_size = 128 * 1024 * 1024; // 128MB
77+
config.max_size = 1024 * 1024 * 1024; // 1GB
78+
config.enable_thread_cache = true;
79+
config.enable_stats = true;
80+
utils::MemoryPool::GetInstance().Initialize(config);
81+
memory_pool_initialized = true;
82+
}
83+
84+
// Pre-allocate vectors to reduce runtime allocations
85+
// These are already sized in the initializer list, but ensure capacity
86+
search_result_.reserve(L_master_);
87+
distance_.reserve(L_master_);
88+
init_ids_.reserve(L_master_);
89+
is_visited_.reserve(ann_index->record_number_);
90+
set_L_.reserve((num_threads - 1) * L_local + L_master);
91+
local_queues_sizes_.reserve(num_threads);
92+
local_queues_starts_.reserve(num_threads);
93+
if (brute_force_search_) {
94+
brute_force_queue_.reserve(ann_index->record_number_);
95+
}
96+
7197
// Log thread configuration for debugging - only in debug builds or when explicitly enabled
7298
#ifdef VECTORDB_DEBUG_BUILD
7399
static std::atomic<int> executor_count(0);

engine/db/index/distances.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <cmath>
66
#include <cstdio>
77
#include <cstring>
8+
#include <memory>
89

910

1011
#ifdef __AVX2__

engine/db/index/simd_config.hpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#pragma once
2+
3+
// SIMD 配置
4+
#if defined(USE_AVX2)
5+
#define SIMD_WIDTH 8
6+
#define SIMD_TYPE "AVX2"
7+
#elif defined(USE_NEON)
8+
#define SIMD_WIDTH 4
9+
#define SIMD_TYPE "NEON"
10+
#else
11+
#define SIMD_WIDTH 1
12+
#define SIMD_TYPE "Scalar"
13+
#endif
14+
15+
namespace vectordb {
16+
namespace simd {
17+
inline const char* GetSIMDType() {
18+
return SIMD_TYPE;
19+
}
20+
21+
inline int GetSIMDWidth() {
22+
return SIMD_WIDTH;
23+
}
24+
}
25+
}

engine/db/vector_normalizer.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
#include "db/batch_insertion_optimizer.hpp"
2+
3+
#ifdef __AVX2__
4+
#include <immintrin.h>
5+
6+
namespace vectordb {
7+
namespace engine {
8+
namespace db {
9+
10+
void VectorNormalizer::NormalizeBatchAVX2(float* vectors, size_t count, size_t dimension) {
11+
for (size_t i = 0; i < count; ++i) {
12+
float* vec = vectors + i * dimension;
13+
14+
// Compute norm using AVX2
15+
__m256 sum_vec = _mm256_setzero_ps();
16+
size_t j = 0;
17+
18+
for (; j + 7 < dimension; j += 8) {
19+
__m256 v = _mm256_loadu_ps(vec + j);
20+
sum_vec = _mm256_fmadd_ps(v, v, sum_vec);
21+
}
22+
23+
// Horizontal sum
24+
__m128 sum_high = _mm256_extractf128_ps(sum_vec, 1);
25+
__m128 sum_low = _mm256_castps256_ps128(sum_vec);
26+
__m128 sum128 = _mm_add_ps(sum_high, sum_low);
27+
sum128 = _mm_hadd_ps(sum128, sum128);
28+
sum128 = _mm_hadd_ps(sum128, sum128);
29+
30+
float sum = _mm_cvtss_f32(sum128);
31+
32+
// Handle remaining elements
33+
for (; j < dimension; ++j) {
34+
sum += vec[j] * vec[j];
35+
}
36+
37+
// Normalize using AVX2
38+
float inv_norm = 1.0f / std::sqrt(sum);
39+
__m256 inv_norm_vec = _mm256_set1_ps(inv_norm);
40+
41+
j = 0;
42+
for (; j + 7 < dimension; j += 8) {
43+
__m256 v = _mm256_loadu_ps(vec + j);
44+
v = _mm256_mul_ps(v, inv_norm_vec);
45+
_mm256_storeu_ps(vec + j, v);
46+
}
47+
48+
// Handle remaining elements
49+
for (; j < dimension; ++j) {
50+
vec[j] *= inv_norm;
51+
}
52+
}
53+
}
54+
55+
} // namespace db
56+
} // namespace engine
57+
} // namespace vectordb
58+
59+
#endif // __AVX2__

engine/utils/memory_pool.hpp

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,11 @@ struct MemoryPoolStats {
6969
std::atomic<size_t> cache_hits{0};
7070
std::atomic<size_t> cache_misses{0};
7171

72+
// Delete copy constructor and assignment operator
73+
MemoryPoolStats() = default;
74+
MemoryPoolStats(const MemoryPoolStats&) = delete;
75+
MemoryPoolStats& operator=(const MemoryPoolStats&) = delete;
76+
7277
double GetCacheHitRate() const {
7378
size_t total = cache_hits + cache_misses;
7479
return total > 0 ? static_cast<double>(cache_hits) / total : 0.0;
@@ -416,10 +421,17 @@ class MemoryPool {
416421
}
417422

418423
/**
419-
* @brief Get pool statistics
424+
* @brief Get pool statistics (returns a snapshot copy)
420425
*/
421-
MemoryPoolStats GetStats() const {
422-
return stats_;
426+
void GetStats(MemoryPoolStats& out_stats) const {
427+
out_stats.total_allocated.store(stats_.total_allocated.load());
428+
out_stats.total_freed.store(stats_.total_freed.load());
429+
out_stats.current_usage.store(stats_.current_usage.load());
430+
out_stats.peak_usage.store(stats_.peak_usage.load());
431+
out_stats.allocation_count.store(stats_.allocation_count.load());
432+
out_stats.free_count.store(stats_.free_count.load());
433+
out_stats.cache_hits.store(stats_.cache_hits.load());
434+
out_stats.cache_misses.store(stats_.cache_misses.load());
423435
}
424436

425437
/**
@@ -437,8 +449,15 @@ class MemoryPool {
437449
// Clear large allocations
438450
large_allocations_.clear();
439451

440-
// Reset stats
441-
stats_ = MemoryPoolStats();
452+
// Reset stats (manually reset each atomic)
453+
stats_.total_allocated.store(0);
454+
stats_.total_freed.store(0);
455+
stats_.current_usage.store(0);
456+
stats_.peak_usage.store(0);
457+
stats_.allocation_count.store(0);
458+
stats_.free_count.store(0);
459+
stats_.cache_hits.store(0);
460+
stats_.cache_misses.store(0);
442461

443462
// Reinitialize if needed
444463
if (initialized_) {
@@ -601,7 +620,7 @@ class MemoryPool {
601620
std::unordered_map<std::thread::id, std::unique_ptr<ThreadCache>> thread_caches_;
602621

603622
MemoryPoolStats stats_;
604-
Logger logger_;
623+
engine::Logger logger_;
605624
};
606625

607626
/**

0 commit comments

Comments
 (0)