Skip to content

Commit 13ce7d1

Browse files
authored
feat: speed up chunking & add separator chunking (#48)
* speed up chunking & add separator chunking * add test code for splitter & reformat chunking methods * typo * fix overlap behaviour * typo * typo for type check
1 parent 70bbb67 commit 13ce7d1

File tree

4 files changed

+259
-32
lines changed

4 files changed

+259
-32
lines changed

nano_graphrag/_op.py

Lines changed: 93 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from typing import Union
55
from collections import Counter, defaultdict
66

7+
import tiktoken
8+
79
from ._utils import (
810
logger,
911
clean_str,
@@ -28,27 +30,102 @@
2830
from .prompt import GRAPH_FIELD_SEP, PROMPTS
2931

3032

33+
34+
3135
def chunking_by_token_size(
32-
content: str, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o"
33-
):
34-
tokens = encode_string_by_tiktoken(content, model_name=tiktoken_model)
35-
results = []
36-
for index, start in enumerate(
37-
range(0, len(tokens), max_token_size - overlap_token_size)
36+
tokens_list: list[list[int]], doc_keys,tiktoken_model, overlap_token_size=128, max_token_size=1024,
3837
):
39-
chunk_content = decode_tokens_by_tiktoken(
40-
tokens[start : start + max_token_size], model_name=tiktoken_model
41-
)
42-
results.append(
43-
{
44-
"tokens": min(max_token_size, len(tokens) - start),
45-
"content": chunk_content.strip(),
46-
"chunk_order_index": index,
47-
}
48-
)
38+
39+
results=[]
40+
for index,tokens in enumerate(tokens_list):
41+
chunk_token=[]
42+
lengths=[]
43+
for start in range(0, len(tokens), max_token_size - overlap_token_size):
44+
45+
chunk_token.append(tokens[start : start + max_token_size])
46+
lengths.append(min(max_token_size, len(tokens) - start))
47+
48+
# here somehow tricky, since the whole chunk tokens is list[list[list[int]]] for corpus(doc(chunk)),so it can't be decode entirely
49+
chunk_token=tiktoken_model.decode_batch(chunk_token)
50+
for i,chunk in enumerate(chunk_token):
51+
52+
results.append(
53+
{
54+
"tokens": lengths[i],
55+
"content": chunk.strip(),
56+
"chunk_order_index": i,
57+
"full_doc_id":doc_keys[index],
58+
}
59+
)
60+
61+
return results
62+
63+
def chunking_by_seperators(tokens_list: list[list[int]], doc_keys,tiktoken_model, overlap_token_size=128, max_token_size=1024 ):
64+
from nano_graphrag._spliter import SeparatorSplitter
65+
66+
DEFAULT_SEPERATORS=[
67+
# Paragraph separators
68+
"\n\n",
69+
"\r\n\r\n",
70+
# Line breaks
71+
"\n",
72+
"\r\n",
73+
# Sentence ending punctuation
74+
"。", # Chinese period
75+
".", # Full-width dot
76+
".", # English period
77+
"!", # Chinese exclamation mark
78+
"!", # English exclamation mark
79+
"?", # Chinese question mark
80+
"?", # English question mark
81+
# Whitespace characters
82+
" ", # Space
83+
"\t", # Tab
84+
"\u3000", # Full-width space
85+
# Special characters
86+
"\u200b", # Zero-width space (used in some Asian languages)
87+
]
88+
89+
splitter=SeparatorSplitter(separators=[tiktoken_model.encode(s) for s in DEFAULT_SEPERATORS],chunk_size=max_token_size,chunk_overlap=overlap_token_size)
90+
results=[]
91+
for index,tokens in enumerate(tokens_list):
92+
chunk_token=splitter.split_tokens(tokens)
93+
lengths=[len(c) for c in chunk_token]
94+
95+
# here somehow tricky, since the whole chunk tokens is list[list[list[int]]] for corpus(doc(chunk)),so it can't be decode entirely
96+
chunk_token=tiktoken_model.decode_batch(chunk_token)
97+
for i,chunk in enumerate(chunk_token):
98+
99+
results.append(
100+
{
101+
"tokens": lengths[i],
102+
"content": chunk.strip(),
103+
"chunk_order_index": i,
104+
"full_doc_id":doc_keys[index],
105+
}
106+
)
107+
49108
return results
50109

51110

111+
def get_chunks(new_docs,chunk_func=chunking_by_token_size,**chunk_func_params):
112+
inserting_chunks = {}
113+
114+
new_docs_list=list(new_docs.items())
115+
docs=[new_doc[1]["content"] for new_doc in new_docs_list]
116+
doc_keys=[new_doc[0] for new_doc in new_docs_list]
117+
118+
119+
ENCODER = tiktoken.encoding_for_model("gpt-4o")
120+
tokens=ENCODER.encode_batch(docs,num_threads=16)
121+
chunks=chunk_func(tokens,doc_keys=doc_keys,tiktoken_model=ENCODER,**chunk_func_params)
122+
123+
for chunk in chunks:
124+
inserting_chunks.update({compute_mdhash_id(chunk["content"], prefix="chunk-"):chunk})
125+
126+
return inserting_chunks
127+
128+
52129
async def _handle_entity_relation_summary(
53130
entity_or_relation_name: str,
54131
description: str,

nano_graphrag/_spliter.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
from typing import List, Optional, Union, Literal
2+
3+
class SeparatorSplitter:
4+
def __init__(
5+
self,
6+
separators: Optional[List[List[int]]] = None,
7+
keep_separator: Union[bool, Literal["start", "end"]] = "end",
8+
chunk_size: int = 4000,
9+
chunk_overlap: int = 200,
10+
length_function: callable = len,
11+
):
12+
self._separators = separators or []
13+
self._keep_separator = keep_separator
14+
self._chunk_size = chunk_size
15+
self._chunk_overlap = chunk_overlap
16+
self._length_function = length_function
17+
18+
def split_tokens(self, tokens: List[int]) -> List[List[int]]:
19+
splits = self._split_tokens_with_separators(tokens)
20+
return self._merge_splits(splits)
21+
22+
def _split_tokens_with_separators(self, tokens: List[int]) -> List[List[int]]:
23+
splits = []
24+
current_split = []
25+
i = 0
26+
while i < len(tokens):
27+
separator_found = False
28+
for separator in self._separators:
29+
if tokens[i:i+len(separator)] == separator:
30+
if self._keep_separator in [True, "end"]:
31+
current_split.extend(separator)
32+
if current_split:
33+
splits.append(current_split)
34+
current_split = []
35+
if self._keep_separator == "start":
36+
current_split.extend(separator)
37+
i += len(separator)
38+
separator_found = True
39+
break
40+
if not separator_found:
41+
current_split.append(tokens[i])
42+
i += 1
43+
if current_split:
44+
splits.append(current_split)
45+
return [s for s in splits if s]
46+
47+
def _merge_splits(self, splits: List[List[int]]) -> List[List[int]]:
48+
if not splits:
49+
return []
50+
51+
merged_splits = []
52+
current_chunk = []
53+
54+
for split in splits:
55+
if not current_chunk:
56+
current_chunk = split
57+
elif self._length_function(current_chunk) + self._length_function(split) <= self._chunk_size:
58+
current_chunk.extend(split)
59+
else:
60+
merged_splits.append(current_chunk)
61+
current_chunk = split
62+
63+
if current_chunk:
64+
merged_splits.append(current_chunk)
65+
66+
if len(merged_splits) == 1 and self._length_function(merged_splits[0]) > self._chunk_size:
67+
return self._split_chunk(merged_splits[0])
68+
69+
if self._chunk_overlap > 0:
70+
return self._enforce_overlap(merged_splits)
71+
72+
return merged_splits
73+
74+
def _split_chunk(self, chunk: List[int]) -> List[List[int]]:
75+
result = []
76+
for i in range(0, len(chunk), self._chunk_size - self._chunk_overlap):
77+
new_chunk = chunk[i:i + self._chunk_size]
78+
if len(new_chunk) > self._chunk_overlap: # 只有当 chunk 长度大于 overlap 时才添加
79+
result.append(new_chunk)
80+
return result
81+
82+
def _enforce_overlap(self, chunks: List[List[int]]) -> List[List[int]]:
83+
result = []
84+
for i, chunk in enumerate(chunks):
85+
if i == 0:
86+
result.append(chunk)
87+
else:
88+
overlap = chunks[i-1][-self._chunk_overlap:]
89+
new_chunk = overlap + chunk
90+
if self._length_function(new_chunk) > self._chunk_size:
91+
new_chunk = new_chunk[:self._chunk_size]
92+
result.append(new_chunk)
93+
return result
94+

nano_graphrag/graphrag.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from functools import partial
66
from typing import Callable, Dict, List, Optional, Type, Union, cast
77

8+
import tiktoken
9+
810

911
from ._llm import (
1012
gpt_4o_complete,
@@ -18,6 +20,7 @@
1820
chunking_by_token_size,
1921
extract_entities,
2022
generate_community_report,
23+
get_chunks,
2124
local_query,
2225
global_query,
2326
naive_query,
@@ -65,7 +68,7 @@ class GraphRAG:
6568
enable_naive_rag: bool = False
6669

6770
# text chunking
68-
chunk_func: Callable[[str, Optional[int], Optional[int], Optional[str]], List[Dict[str, Union[str, int]]]] = chunking_by_token_size
71+
chunk_func: Callable[[list[list[int]],List[str],tiktoken.Encoding, Optional[int], Optional[int], ], List[Dict[str, Union[str, int]]]] = chunking_by_token_size
6972
chunk_token_size: int = 1200
7073
chunk_overlap_token_size: int = 100
7174
tiktoken_model_name: str = "gpt-4o"
@@ -263,21 +266,11 @@ async def ainsert(self, string_or_strings):
263266
logger.info(f"[New Docs] inserting {len(new_docs)} docs")
264267

265268
# ---------- chunking
266-
inserting_chunks = {}
267-
for doc_key, doc in new_docs.items():
268-
chunks = {
269-
compute_mdhash_id(dp["content"], prefix="chunk-"): {
270-
**dp,
271-
"full_doc_id": doc_key,
272-
}
273-
for dp in self.chunk_func(
274-
doc["content"],
275-
overlap_token_size=self.chunk_overlap_token_size,
276-
max_token_size=self.chunk_token_size,
277-
tiktoken_model=self.tiktoken_model_name,
278-
)
279-
}
280-
inserting_chunks.update(chunks)
269+
270+
inserting_chunks = get_chunks(new_docs=new_docs,chunk_func=self.chunk_func,overlap_token_size=self.chunk_overlap_token_size,
271+
max_token_size=self.chunk_token_size)
272+
273+
281274
_add_chunk_keys = await self.text_chunks.filter_keys(
282275
list(inserting_chunks.keys())
283276
)

tests/test_splitter.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import unittest
2+
from typing import List
3+
4+
from nano_graphrag._spliter import SeparatorSplitter
5+
6+
# Assuming the SeparatorSplitter class is already imported
7+
8+
class TestSeparatorSplitter(unittest.TestCase):
9+
10+
def setUp(self):
11+
self.tokenize = lambda text: [ord(c) for c in text] # Simple tokenizer for testing
12+
self.detokenize = lambda tokens: ''.join(chr(t) for t in tokens)
13+
14+
def test_split_with_custom_separator(self):
15+
splitter = SeparatorSplitter(
16+
separators=[self.tokenize('\n'), self.tokenize('.')],
17+
chunk_size=19,
18+
chunk_overlap=0,
19+
keep_separator="end"
20+
)
21+
text = "This is a test.\nAnother test."
22+
tokens = self.tokenize(text)
23+
expected = [
24+
self.tokenize("This is a test.\n"),
25+
self.tokenize("Another test."),
26+
]
27+
result = splitter.split_tokens(tokens)
28+
29+
self.assertEqual(result, expected)
30+
31+
def test_chunk_size_limit(self):
32+
splitter = SeparatorSplitter(
33+
chunk_size=5,
34+
chunk_overlap=0,
35+
separators=[self.tokenize("\n")]
36+
)
37+
text = "1234567890"
38+
tokens = self.tokenize(text)
39+
expected = [
40+
self.tokenize("12345"),
41+
self.tokenize("67890")
42+
]
43+
result = splitter.split_tokens(tokens)
44+
self.assertEqual(result, expected)
45+
46+
def test_chunk_overlap(self):
47+
splitter = SeparatorSplitter(
48+
chunk_size=5,
49+
chunk_overlap=2,
50+
separators=[self.tokenize("\n")]
51+
)
52+
text = "1234567890"
53+
tokens = self.tokenize(text)
54+
expected = [
55+
self.tokenize("12345"),
56+
self.tokenize("45678"),
57+
self.tokenize("7890"),
58+
]
59+
result = splitter.split_tokens(tokens)
60+
self.assertEqual(result, expected)
61+
62+
if __name__ == '__main__':
63+
unittest.main()

0 commit comments

Comments
 (0)