Skip to content

Commit 9900d35

Browse files
committed
add test code for splitter & reformat chunking methods
1 parent 2915526 commit 9900d35

File tree

5 files changed

+233
-79
lines changed

5 files changed

+233
-79
lines changed

main.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import os
2+
import sys
3+
4+
sys.path.append("..")
5+
import logging
6+
import numpy as np
7+
from nano_graphrag import GraphRAG, QueryParam
8+
from nano_graphrag._utils import wrap_embedding_func_with_attrs
9+
from sentence_transformers import SentenceTransformer
10+
from nano_graphrag._op import chunking_by_seperators
11+
12+
logging.basicConfig(level=logging.WARNING)
13+
logging.getLogger("nano-graphrag").setLevel(logging.INFO)
14+
15+
WORKING_DIR = "/mnt/rangehow/nano-graphrag/neu_cache"
16+
17+
18+
19+
20+
21+
22+
from openai import AsyncOpenAI
23+
from nano_graphrag.base import BaseKVStorage
24+
from nano_graphrag._utils import compute_args_hash
25+
# CUSTOM LLM
26+
MODEL="default"
27+
async def custom_model_if_cache(
28+
prompt, system_prompt=None, history_messages=[], **kwargs
29+
) -> str:
30+
openai_async_client = AsyncOpenAI(
31+
api_key="EMPTY", base_url="http://152.136.16.221:8203/v1"
32+
)
33+
messages = []
34+
if system_prompt:
35+
messages.append({"role": "system", "content": system_prompt})
36+
37+
# Get the cached response if having-------------------
38+
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
39+
messages.extend(history_messages)
40+
messages.append({"role": "user", "content": prompt})
41+
if hashing_kv is not None:
42+
args_hash = compute_args_hash(MODEL, messages)
43+
if_cache_return = await hashing_kv.get_by_id(args_hash)
44+
if if_cache_return is not None:
45+
return if_cache_return["return"]
46+
# -----------------------------------------------------
47+
48+
response = await openai_async_client.chat.completions.create(
49+
model=MODEL, messages=messages, temperature=0,**kwargs,
50+
timeout=10e6,
51+
)
52+
53+
# Cache the response if having-------------------
54+
if hashing_kv is not None:
55+
await hashing_kv.upsert(
56+
{args_hash: {"return": response.choices[0].message.content, "model": MODEL}}
57+
)
58+
await hashing_kv.index_done_callback()
59+
# -----------------------------------------------------
60+
return response.choices[0].message.content
61+
62+
63+
64+
65+
66+
67+
68+
69+
# CUSTOM EMBEDDING
70+
71+
EMBED_MODEL = SentenceTransformer(
72+
"/mnt/rangehow/models/Conan-embedding-v1", cache_folder=WORKING_DIR, device="cpu"
73+
)
74+
75+
76+
# We're using Sentence Transformers to generate embeddings for the BGE model
77+
@wrap_embedding_func_with_attrs(
78+
embedding_dim=EMBED_MODEL.get_sentence_embedding_dimension(),
79+
max_token_size=EMBED_MODEL.max_seq_length,
80+
)
81+
async def local_embedding(texts: list[str]) -> np.ndarray:
82+
return EMBED_MODEL.encode(texts, normalize_embeddings=True)
83+
84+
85+
rag = GraphRAG(
86+
working_dir=WORKING_DIR,
87+
embedding_func=local_embedding,
88+
enable_llm_cache=True,
89+
best_model_func=custom_model_if_cache,
90+
cheap_model_func=custom_model_if_cache,
91+
chunk_func=chunking_by_seperators,
92+
best_model_max_async=1024,
93+
cheap_model_max_async=1024,
94+
entity_extract_max_gleaning=0,
95+
)
96+
97+
documents=[]
98+
input_directory="/mnt/rangehow/neuspider/document/markdown_saved"
99+
filenames = [f for f in os.listdir(input_directory) if os.path.isfile(os.path.join(input_directory, f))]
100+
for filename in filenames:
101+
with open(os.path.join(input_directory,filename), encoding="utf-8") as f:
102+
string=f.read()
103+
if len(string)<50:
104+
continue
105+
documents.append(string)
106+
107+
print(len(documents))
108+
109+
rag.insert(documents)
110+
print(rag.query("东北大学谁最牛逼?", param=QueryParam(mode="global")))

nano_graphrag/_op.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,25 @@ def chunking_by_seperators(tokens_list: list[int], doc_keys,tiktoken_model, over
107107

108108
return results
109109

110+
111+
def get_chunks(new_docs,chunk_func=chunking_by_token_size,**chunk_func_params):
112+
inserting_chunks = {}
113+
114+
new_docs_list=list(new_docs.items())
115+
docs=[new_doc[1]["content"] for new_doc in new_docs_list]
116+
doc_keys=[new_doc[0] for new_doc in new_docs_list]
117+
118+
119+
ENCODER = tiktoken.encoding_for_model("gpt-4o")
120+
tokens=ENCODER.encode_batch(docs,num_threads=16)
121+
chunks=chunk_func(tokens,doc_keys=doc_keys,tiktoken_model=ENCODER,**chunk_func_params)
122+
123+
for chunk in chunks:
124+
inserting_chunks.update({compute_mdhash_id(chunk["content"], prefix="chunk-"):chunk})
125+
126+
return inserting_chunks
127+
128+
110129
async def _handle_entity_relation_summary(
111130
entity_or_relation_name: str,
112131
description: str,

nano_graphrag/_spliter.py

Lines changed: 36 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ def __init__(
99
chunk_overlap: int = 200,
1010
length_function: callable = len,
1111
):
12-
self._separators = separators or [[10], [13, 10]] # 默认使用换行符作为分隔符
12+
self._separators = separators or []
1313
self._keep_separator = keep_separator
1414
self._chunk_size = chunk_size
1515
self._chunk_overlap = chunk_overlap
@@ -27,18 +27,12 @@ def _split_tokens_with_separators(self, tokens: List[int]) -> List[List[int]]:
2727
separator_found = False
2828
for separator in self._separators:
2929
if tokens[i:i+len(separator)] == separator:
30+
if self._keep_separator in [True, "end"]:
31+
current_split.extend(separator)
3032
if current_split:
31-
if self._keep_separator == "end":
32-
current_split.extend(separator)
33-
splits.append(current_split)
34-
current_split = []
35-
elif self._keep_separator == "start":
36-
splits.append(current_split)
37-
current_split = separator[:]
38-
else:
39-
splits.append(current_split)
40-
current_split = []
41-
elif self._keep_separator:
33+
splits.append(current_split)
34+
current_split = []
35+
if self._keep_separator == "start":
4236
current_split.extend(separator)
4337
i += len(separator)
4438
separator_found = True
@@ -51,71 +45,48 @@ def _split_tokens_with_separators(self, tokens: List[int]) -> List[List[int]]:
5145
return [s for s in splits if s]
5246

5347
def _merge_splits(self, splits: List[List[int]]) -> List[List[int]]:
48+
if not splits:
49+
return []
50+
5451
merged_splits = []
55-
current_split = []
56-
current_length = 0
57-
separator = [] if self._keep_separator is False else self._separators[-1]
52+
current_chunk = []
5853

5954
for split in splits:
60-
if self._length_function(current_split) + self._length_function(split) <= self._chunk_size:
61-
if current_split and separator:
62-
current_split.extend(separator)
63-
current_split.extend(split)
55+
if not current_chunk:
56+
current_chunk = split
57+
elif self._length_function(current_chunk) + self._length_function(split) <= self._chunk_size:
58+
current_chunk.extend(split)
6459
else:
65-
if current_split:
66-
merged_splits.append(current_split)
67-
current_split = split
68-
if self._length_function(current_split) >= self._chunk_size:
69-
merged_splits.append(current_split)
70-
current_split = []
71-
if current_split:
72-
merged_splits.append(current_split)
60+
merged_splits.append(current_chunk)
61+
current_chunk = split
62+
63+
if current_chunk:
64+
merged_splits.append(current_chunk)
65+
66+
if len(merged_splits) == 1 and self._length_function(merged_splits[0]) > self._chunk_size:
67+
return self._split_chunk(merged_splits[0])
7368

7469
if self._chunk_overlap > 0:
7570
return self._enforce_overlap(merged_splits)
71+
7672
return merged_splits
7773

74+
def _split_chunk(self, chunk: List[int]) -> List[List[int]]:
75+
result = []
76+
for i in range(0, len(chunk), self._chunk_size - self._chunk_overlap):
77+
result.append(chunk[i:i + self._chunk_size])
78+
return result
79+
7880
def _enforce_overlap(self, chunks: List[List[int]]) -> List[List[int]]:
79-
new_chunks = []
81+
result = []
8082
for i, chunk in enumerate(chunks):
8183
if i == 0:
82-
new_chunks.append(chunk)
84+
result.append(chunk)
8385
else:
84-
overlap_tokens = chunks[i-1][-self._chunk_overlap:]
85-
new_chunk = overlap_tokens + chunk
86+
overlap = chunks[i-1][-self._chunk_overlap:]
87+
new_chunk = overlap + chunk
8688
if self._length_function(new_chunk) > self._chunk_size:
87-
new_chunk = new_chunk[-self._chunk_size:]
88-
new_chunks.append(new_chunk)
89-
return new_chunks
90-
91-
# EXAMPLE USAGE
92-
if __name__ == "__main__":
93-
import tiktoken
94-
tokenizer = tiktoken.encoding_for_model("gpt-4")
95-
96-
def tokenize(text: str) -> List[int]:
97-
return tokenizer.encode(text)
98-
99-
def detokenize(tokens: List[int]) -> str:
100-
return tokenizer.decode(tokens)
101-
102-
# 创建splitter实例
103-
splitter = SeparatorSplitter(
104-
separators=[tokenize('\n'), tokenize('.')], # 使用换行符和句号作为分隔符
105-
chunk_size=5,
106-
chunk_overlap=0,
107-
keep_separator="end"
108-
)
109-
110-
# 示例文本
111-
text = "This is a sample text. It contains multiple sentences.\nSome sentences are short. Others are longer."
112-
tokens = tokenize(text)
113-
114-
# 分割tokens
115-
split_tokens = splitter.split_tokens(tokens)
89+
new_chunk = new_chunk[:self._chunk_size]
90+
result.append(new_chunk)
91+
return result
11692

117-
print("Split tokens:")
118-
for i, token_chunk in enumerate(split_tokens):
119-
print(f"Chunk {i + 1}:")
120-
print(detokenize(token_chunk))
121-
print("---")

nano_graphrag/graphrag.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
chunking_by_token_size,
2121
extract_entities,
2222
generate_community_report,
23+
get_chunks,
2324
local_query,
2425
global_query,
2526
naive_query,
@@ -265,21 +266,9 @@ async def ainsert(self, string_or_strings):
265266
logger.info(f"[New Docs] inserting {len(new_docs)} docs")
266267

267268
# ---------- chunking
268-
inserting_chunks = {}
269269

270-
271-
272-
new_docs_list=list(new_docs.items())
273-
docs=[new_doc[1]["content"] for new_doc in new_docs_list]
274-
doc_keys=[new_doc[0] for new_doc in new_docs_list]
275-
276-
277-
ENCODER = tiktoken.encoding_for_model("gpt-4o")
278-
tokens=ENCODER.encode_batch(docs,num_threads=16)
279-
chunks=self.chunk_func(tokens,overlap_token_size=self.chunk_overlap_token_size,
280-
max_token_size=self.chunk_token_size,doc_keys=doc_keys,tiktoken_model=ENCODER)
281-
for chunk in chunks:
282-
inserting_chunks.update({compute_mdhash_id(chunk["content"], prefix="chunk-"):chunk})
270+
inserting_chunks = get_chunks(new_docs=new_docs,chunk_func=self.chunk_func,overlap_token_size=self.chunk_overlap_token_size,
271+
max_token_size=self.chunk_token_size)
283272

284273

285274
_add_chunk_keys = await self.text_chunks.filter_keys(

tests/test_splitter.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import unittest
2+
from typing import List
3+
4+
from nano_graphrag._spliter import SeparatorSplitter
5+
6+
# Assuming the SeparatorSplitter class is already imported
7+
8+
class TestSeparatorSplitter(unittest.TestCase):
9+
10+
def setUp(self):
11+
self.tokenize = lambda text: [ord(c) for c in text] # Simple tokenizer for testing
12+
self.detokenize = lambda tokens: ''.join(chr(t) for t in tokens)
13+
14+
def test_split_with_custom_separator(self):
15+
splitter = SeparatorSplitter(
16+
separators=[self.tokenize('\n'), self.tokenize('.')],
17+
chunk_size=19,
18+
chunk_overlap=0,
19+
keep_separator="end"
20+
)
21+
text = "This is a test.\nAnother test."
22+
tokens = self.tokenize(text)
23+
expected = [
24+
self.tokenize("This is a test.\n"),
25+
self.tokenize("Another test."),
26+
]
27+
result = splitter.split_tokens(tokens)
28+
import pdb
29+
pdb.set_trace()
30+
self.assertEqual(result, expected)
31+
32+
def test_chunk_size_limit(self):
33+
splitter = SeparatorSplitter(
34+
chunk_size=5,
35+
chunk_overlap=0,
36+
separators=[self.tokenize("\n")]
37+
)
38+
text = "1234567890"
39+
tokens = self.tokenize(text)
40+
expected = [
41+
self.tokenize("12345"),
42+
self.tokenize("67890")
43+
]
44+
result = splitter.split_tokens(tokens)
45+
self.assertEqual(result, expected)
46+
47+
def test_chunk_overlap(self):
48+
splitter = SeparatorSplitter(
49+
chunk_size=5,
50+
chunk_overlap=2,
51+
separators=[self.tokenize("\n")]
52+
)
53+
text = "1234567890"
54+
tokens = self.tokenize(text)
55+
expected = [
56+
self.tokenize("12345"),
57+
self.tokenize("45678"),
58+
self.tokenize("7890"),
59+
self.tokenize("0")
60+
]
61+
result = splitter.split_tokens(tokens)
62+
self.assertEqual(result, expected)
63+
64+
if __name__ == '__main__':
65+
unittest.main()

0 commit comments

Comments
 (0)