Skip to content

Commit c5dd0d8

Browse files
feat: Add HNSW vector storage (#18)
* Updated storage with hnswlib, unittests, benchmarking against nano vector storage, and a simple GraphRAG example * Added kwargs for vector storage cls to pass on hyperparameters for better speed/recall tradeoffs * Removed redundant passing of in --------- Co-authored-by: terence-gpt <numberchiffre@users.noreply.github.com>
1 parent 3ddcce7 commit c5dd0d8

File tree

6 files changed

+441
-4
lines changed

6 files changed

+441
-4
lines changed
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import asyncio
2+
import time
3+
import numpy as np
4+
from tqdm import tqdm
5+
from nano_graphrag import GraphRAG
6+
from nano_graphrag._storage import NanoVectorDBStorage, HNSWVectorStorage
7+
from nano_graphrag._utils import wrap_embedding_func_with_attrs
8+
9+
10+
WORKING_DIR = "./nano_graphrag_cache_benchmark_hnsw_vs_nano_vector_storage"
11+
DATA_LEN = 100_000
12+
FAKE_DIM = 1024
13+
BATCH_SIZE = 100000
14+
15+
16+
@wrap_embedding_func_with_attrs(embedding_dim=FAKE_DIM, max_token_size=8192)
17+
async def sample_embedding(texts: list[str]) -> np.ndarray:
18+
return np.float32(np.random.rand(len(texts), FAKE_DIM))
19+
20+
21+
def generate_test_data():
22+
return {str(i): {"content": f"Test content {i}"} for i in range(DATA_LEN)}
23+
24+
25+
async def benchmark_storage(storage_class, name):
26+
rag = GraphRAG(working_dir=WORKING_DIR, embedding_func=sample_embedding)
27+
storage = storage_class(
28+
namespace=f"benchmark_{name}",
29+
global_config=rag.__dict__,
30+
embedding_func=sample_embedding,
31+
meta_fields={"content"},
32+
)
33+
34+
test_data = generate_test_data()
35+
36+
print(f"Benchmarking {name}...")
37+
with tqdm(total=DATA_LEN, desc=f"{name} Benchmark") as pbar:
38+
start_time = time.time()
39+
for i in range(0, len(test_data), BATCH_SIZE):
40+
batch = {k: test_data[k] for k in list(test_data.keys())[i:i+BATCH_SIZE]}
41+
await storage.upsert(batch)
42+
pbar.update(min(BATCH_SIZE, DATA_LEN - i))
43+
44+
insert_time = time.time() - start_time
45+
46+
save_start_time = time.time()
47+
await storage.index_done_callback()
48+
save_time = time.time() - save_start_time
49+
pbar.update(1)
50+
51+
query_vector = np.random.rand(FAKE_DIM)
52+
query_times = []
53+
for _ in range(100):
54+
query_start = time.time()
55+
await storage.query(query_vector, top_k=10)
56+
query_times.append(time.time() - query_start)
57+
pbar.update(1)
58+
59+
avg_query_time = sum(query_times) / len(query_times)
60+
61+
print(f"{name} - Insert: {insert_time:.2f}s, Save: {save_time:.2f}s, Avg Query: {avg_query_time:.4f}s")
62+
return insert_time, save_time, avg_query_time
63+
64+
65+
async def run_benchmarks():
66+
print("Running NanoVectorDB benchmark...")
67+
nano_insert_time, nano_save_time, nano_query_time = await benchmark_storage(NanoVectorDBStorage, "nano")
68+
69+
print("\nRunning HNSWVectorStorage benchmark...")
70+
hnsw_insert_time, hnsw_save_time, hnsw_query_time = await benchmark_storage(HNSWVectorStorage, "hnsw")
71+
72+
print("\nBenchmark Results:")
73+
print(f"NanoVectorDB - Insert: {nano_insert_time:.2f}s, Save: {nano_save_time:.2f}s, Avg Query: {nano_query_time:.4f}s")
74+
print(f"HNSWVectorStorage - Insert: {hnsw_insert_time:.2f}s, Save: {hnsw_save_time:.2f}s, Avg Query: {hnsw_query_time:.4f}s")
75+
76+
77+
if __name__ == "__main__":
78+
asyncio.run(run_benchmarks())

examples/using_hnsw_as_vectorDB.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import os
2+
from nano_graphrag import GraphRAG, QueryParam
3+
from nano_graphrag._llm import gpt_4o_mini_complete
4+
from nano_graphrag._storage import HNSWVectorStorage
5+
6+
7+
WORKING_DIR = "./nano_graphrag_cache_using_hnsw_as_vectorDB"
8+
9+
10+
def remove_if_exist(file):
11+
if os.path.exists(file):
12+
os.remove(file)
13+
14+
15+
def insert():
16+
from time import time
17+
18+
with open("./tests/mock_data.txt", encoding="utf-8-sig") as f:
19+
FAKE_TEXT = f.read()
20+
21+
remove_if_exist(f"{WORKING_DIR}/vdb_entities.json")
22+
remove_if_exist(f"{WORKING_DIR}/kv_store_full_docs.json")
23+
remove_if_exist(f"{WORKING_DIR}/kv_store_text_chunks.json")
24+
remove_if_exist(f"{WORKING_DIR}/kv_store_community_reports.json")
25+
remove_if_exist(f"{WORKING_DIR}/graph_chunk_entity_relation.graphml")
26+
rag = GraphRAG(
27+
working_dir=WORKING_DIR,
28+
enable_llm_cache=True,
29+
vector_db_storage_cls=HNSWVectorStorage,
30+
vector_db_storage_cls_kwargs={"max_elements": 1000000, "ef_search": 100, "M": 8},
31+
best_model_max_async=1,
32+
cheap_model_max_async=1,
33+
best_model_func=gpt_4o_mini_complete,
34+
cheap_model_func=gpt_4o_mini_complete,
35+
)
36+
start = time()
37+
rag.insert(FAKE_TEXT)
38+
print("indexing time:", time() - start)
39+
40+
41+
def query():
42+
rag = GraphRAG(
43+
working_dir=WORKING_DIR,
44+
enable_llm_cache=True,
45+
vector_db_storage_cls=HNSWVectorStorage,
46+
vector_db_storage_cls_kwargs={"max_elements": 1000000, "ef_search": 100, "M": 8},
47+
best_model_max_async=1,
48+
cheap_model_max_async=1,
49+
best_model_func=gpt_4o_mini_complete,
50+
cheap_model_func=gpt_4o_mini_complete,
51+
)
52+
print(
53+
rag.query(
54+
"What are the top themes in this story?", param=QueryParam(mode="global")
55+
)
56+
)
57+
58+
59+
if __name__ == "__main__":
60+
insert()
61+
query()

nano_graphrag/_storage.py

Lines changed: 107 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
import json
44
import os
55
from collections import defaultdict
6-
from dataclasses import dataclass
6+
from dataclasses import dataclass, field
77
from typing import Any, Union, cast
8-
8+
import pickle
9+
import hnswlib
910
import networkx as nx
1011
import numpy as np
1112
from nano_vectordb import NanoVectorDB
@@ -115,6 +116,110 @@ async def index_done_callback(self):
115116
self._client.save()
116117

117118

119+
@dataclass
120+
class HNSWVectorStorage(BaseVectorStorage):
121+
ef_construction: int = 100
122+
M: int = 16
123+
max_elements: int = 1000000
124+
ef_search: int = 50
125+
num_threads: int = -1
126+
_index: Any = field(init=False)
127+
_metadata: dict[str, dict] = field(default_factory=dict)
128+
_current_elements: int = 0
129+
130+
def __post_init__(self):
131+
self._index_file_name = os.path.join(
132+
self.global_config["working_dir"], f"{self.namespace}_hnsw.index"
133+
)
134+
self._metadata_file_name = os.path.join(
135+
self.global_config["working_dir"], f"{self.namespace}_hnsw_metadata.pkl"
136+
)
137+
self._max_batch_size = self.global_config.get("embedding_batch_num", 100)
138+
139+
hnsw_params = self.global_config.get("vector_db_storage_cls_kwargs", {})
140+
self.ef_construction = hnsw_params.get("ef_construction", self.ef_construction)
141+
self.M = hnsw_params.get("M", self.M)
142+
self.max_elements = hnsw_params.get("max_elements", self.max_elements)
143+
self.ef_search = hnsw_params.get("ef_search", self.ef_search)
144+
self.num_threads = hnsw_params.get("num_threads", self.num_threads)
145+
146+
if os.path.exists(self._index_file_name) and os.path.exists(self._metadata_file_name):
147+
self._index = hnswlib.Index(space='cosine', dim=self.embedding_func.embedding_dim)
148+
self._index.load_index(self._index_file_name, max_elements=self.max_elements)
149+
with open(self._metadata_file_name, 'rb') as f:
150+
self._metadata, self._current_elements = pickle.load(f)
151+
logger.info(f"Loaded existing index for {self.namespace} with {self._current_elements} elements")
152+
else:
153+
self._index = hnswlib.Index(space='cosine', dim=self.embedding_func.embedding_dim)
154+
self._index.init_index(
155+
max_elements=self.max_elements,
156+
ef_construction=self.ef_construction,
157+
M=self.M
158+
)
159+
self._index.set_ef(self.ef_search)
160+
logger.info(f"Created new index for {self.namespace}")
161+
162+
async def upsert(self, data: dict[str, dict]):
163+
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
164+
if not data:
165+
raise ValueError("Attempting to insert empty data to vector DB")
166+
167+
if self._current_elements + len(data) > self.max_elements:
168+
raise ValueError(f"Cannot insert {len(data)} elements. Current: {self._current_elements}, Max: {self.max_elements}")
169+
170+
contents = [v["content"] for v in data.values()]
171+
batches = [
172+
contents[i : i + self._max_batch_size]
173+
for i in range(0, len(contents), self._max_batch_size)
174+
]
175+
embeddings_list = await asyncio.gather(
176+
*[self.embedding_func(batch) for batch in batches]
177+
)
178+
embeddings = np.concatenate(embeddings_list)
179+
180+
ids = []
181+
for id, item in data.items():
182+
metadata = {k: v for k, v in item.items() if k in self.meta_fields}
183+
metadata['id'] = id
184+
self._metadata[id] = metadata
185+
ids.append(int(id) if id.isdigit() else hash(id))
186+
187+
ids = np.array(ids)
188+
self._index.add_items(data=embeddings, ids=ids, num_threads=self.num_threads)
189+
self._current_elements += len(data)
190+
191+
async def query(self, query: str, top_k: int = 5) -> list[dict]:
192+
if len(self._metadata) == 0:
193+
return []
194+
195+
if top_k >= self.ef_search:
196+
raise ValueError(f"top_k must be greater than or equal to ef_search, got {top_k} and {self.ef_search}")
197+
198+
query_vector = await self.embedding_func([query])
199+
labels, distances = self._index.knn_query(
200+
data=query_vector,
201+
k=min(top_k, len(self._metadata)),
202+
num_threads=self.num_threads
203+
)
204+
205+
results = []
206+
for label, distance in zip(labels[0], distances[0]):
207+
id_str = str(label)
208+
if id_str in self._metadata:
209+
metadata = self._metadata[id_str]
210+
results.append({
211+
**metadata,
212+
"distance": distance,
213+
"similarity": 1 - distance
214+
})
215+
return results
216+
217+
async def index_done_callback(self):
218+
self._index.save_index(self._index_file_name)
219+
with open(self._metadata_file_name, 'wb') as f:
220+
pickle.dump((self._metadata, self._current_elements), f)
221+
222+
118223
@dataclass
119224
class NetworkXStorage(BaseGraphStorage):
120225
@staticmethod

nano_graphrag/graphrag.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ class GraphRAG:
104104
# storage
105105
key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage
106106
vector_db_storage_cls: Type[BaseVectorStorage] = NanoVectorDBStorage
107+
vector_db_storage_cls_kwargs: dict = field(default_factory=dict)
107108
graph_storage_cls: Type[BaseGraphStorage] = NetworkXStorage
108109
enable_llm_cache: bool = True
109110

@@ -150,7 +151,7 @@ def __post_init__(self):
150151
namespace="entities",
151152
global_config=asdict(self),
152153
embedding_func=self.embedding_func,
153-
meta_fields={"entity_name"},
154+
meta_fields={"entity_name"}
154155
)
155156
if self.enable_local
156157
else None

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@ openai
22
tiktoken
33
networkx
44
graspologic
5-
nano-vectordb
5+
nano-vectordb
6+
hnswlib

0 commit comments

Comments
 (0)