Skip to content

Commit 20fed70

Browse files
committed
speed up chunking & add separator chunking
1 parent f11e9f2 commit 20fed70

File tree

3 files changed

+215
-32
lines changed

3 files changed

+215
-32
lines changed

nano_graphrag/_op.py

Lines changed: 75 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from typing import Union
55
from collections import Counter, defaultdict
66

7+
import tiktoken
8+
79
from ._utils import (
810
logger,
911
clean_str,
@@ -28,26 +30,82 @@
2830
from .prompt import GRAPH_FIELD_SEP, PROMPTS
2931

3032

33+
34+
3135
def chunking_by_token_size(
32-
content: str, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o"
33-
):
34-
tokens = encode_string_by_tiktoken(content, model_name=tiktoken_model)
35-
results = []
36-
for index, start in enumerate(
37-
range(0, len(tokens), max_token_size - overlap_token_size)
36+
tokens_list: list[int], doc_keys,tiktoken_model, overlap_token_size=128, max_token_size=1024,
3837
):
39-
chunk_content = decode_tokens_by_tiktoken(
40-
tokens[start : start + max_token_size], model_name=tiktoken_model
41-
)
42-
results.append(
43-
{
44-
"tokens": min(max_token_size, len(tokens) - start),
45-
"content": chunk_content.strip(),
46-
"chunk_order_index": index,
47-
}
48-
)
49-
return results
38+
39+
results=[]
40+
for index,tokens in enumerate(tokens_list):
41+
chunk_token=[]
42+
lengths=[]
43+
for start in range(0, len(tokens), max_token_size - overlap_token_size):
44+
45+
chunk_token.append(tokens[start : start + max_token_size])
46+
lengths.append(min(max_token_size, len(tokens) - start))
47+
48+
# here somehow tricky, since the whole chunk tokens is list[list[list[int]]] for corpus(doc(chunk)),so it can't be decode entirely
49+
chunk_token=tiktoken_model.decode_batch(chunk_token)
50+
for i,chunk in enumerate(chunk_token):
51+
52+
results.append(
53+
{
54+
"tokens": lengths[i],
55+
"content": chunk.strip(),
56+
"chunk_order_index": i,
57+
"full_doc_id":doc_keys[index],
58+
}
59+
)
60+
61+
return results
62+
63+
def chunking_by_seperators(tokens_list: list[int], doc_keys,tiktoken_model, overlap_token_size=128, max_token_size=1024 ):
64+
from nano_graphrag._spliter import SeparatorSplitter
65+
66+
DEFAULT_SEPERATORS=[
67+
# Paragraph separators
68+
"\n\n",
69+
"\r\n\r\n",
70+
# Line breaks
71+
"\n",
72+
"\r\n",
73+
# Sentence ending punctuation
74+
"。", # Chinese period
75+
".", # Full-width dot
76+
".", # English period
77+
"!", # Chinese exclamation mark
78+
"!", # English exclamation mark
79+
"?", # Chinese question mark
80+
"?", # English question mark
81+
# Whitespace characters
82+
" ", # Space
83+
"\t", # Tab
84+
"\u3000", # Full-width space
85+
# Special characters
86+
"\u200b", # Zero-width space (used in some Asian languages)
87+
]
88+
89+
splitter=SeparatorSplitter(separators=[tiktoken_model.encode(s) for s in DEFAULT_SEPERATORS],chunk_size=max_token_size,chunk_overlap=overlap_token_size)
90+
results=[]
91+
for index,tokens in enumerate(tokens_list):
92+
chunk_token=splitter.split_tokens(tokens)
93+
lengths=[len(c) for c in chunk_token]
94+
95+
# here somehow tricky, since the whole chunk tokens is list[list[list[int]]] for corpus(doc(chunk)),so it can't be decode entirely
96+
chunk_token=tiktoken_model.decode_batch(chunk_token)
97+
for i,chunk in enumerate(chunk_token):
98+
99+
results.append(
100+
{
101+
"tokens": lengths[i],
102+
"content": chunk.strip(),
103+
"chunk_order_index": i,
104+
"full_doc_id":doc_keys[index],
105+
}
106+
)
50107

108+
return results
51109

52110
async def _handle_entity_relation_summary(
53111
entity_or_relation_name: str,

nano_graphrag/_spliter.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
from typing import List, Optional, Union, Literal
2+
3+
class SeparatorSplitter:
4+
def __init__(
5+
self,
6+
separators: Optional[List[List[int]]] = None,
7+
keep_separator: Union[bool, Literal["start", "end"]] = "end",
8+
chunk_size: int = 4000,
9+
chunk_overlap: int = 200,
10+
length_function: callable = len,
11+
):
12+
self._separators = separators or [[10], [13, 10]] # 默认使用换行符作为分隔符
13+
self._keep_separator = keep_separator
14+
self._chunk_size = chunk_size
15+
self._chunk_overlap = chunk_overlap
16+
self._length_function = length_function
17+
18+
def split_tokens(self, tokens: List[int]) -> List[List[int]]:
19+
splits = self._split_tokens_with_separators(tokens)
20+
return self._merge_splits(splits)
21+
22+
def _split_tokens_with_separators(self, tokens: List[int]) -> List[List[int]]:
23+
splits = []
24+
current_split = []
25+
i = 0
26+
while i < len(tokens):
27+
separator_found = False
28+
for separator in self._separators:
29+
if tokens[i:i+len(separator)] == separator:
30+
if current_split:
31+
if self._keep_separator == "end":
32+
current_split.extend(separator)
33+
splits.append(current_split)
34+
current_split = []
35+
elif self._keep_separator == "start":
36+
splits.append(current_split)
37+
current_split = separator[:]
38+
else:
39+
splits.append(current_split)
40+
current_split = []
41+
elif self._keep_separator:
42+
current_split.extend(separator)
43+
i += len(separator)
44+
separator_found = True
45+
break
46+
if not separator_found:
47+
current_split.append(tokens[i])
48+
i += 1
49+
if current_split:
50+
splits.append(current_split)
51+
return [s for s in splits if s]
52+
53+
def _merge_splits(self, splits: List[List[int]]) -> List[List[int]]:
54+
merged_splits = []
55+
current_split = []
56+
current_length = 0
57+
separator = [] if self._keep_separator is False else self._separators[-1]
58+
59+
for split in splits:
60+
if self._length_function(current_split) + self._length_function(split) <= self._chunk_size:
61+
if current_split and separator:
62+
current_split.extend(separator)
63+
current_split.extend(split)
64+
else:
65+
if current_split:
66+
merged_splits.append(current_split)
67+
current_split = split
68+
if self._length_function(current_split) >= self._chunk_size:
69+
merged_splits.append(current_split)
70+
current_split = []
71+
if current_split:
72+
merged_splits.append(current_split)
73+
74+
if self._chunk_overlap > 0:
75+
return self._enforce_overlap(merged_splits)
76+
return merged_splits
77+
78+
def _enforce_overlap(self, chunks: List[List[int]]) -> List[List[int]]:
79+
new_chunks = []
80+
for i, chunk in enumerate(chunks):
81+
if i == 0:
82+
new_chunks.append(chunk)
83+
else:
84+
overlap_tokens = chunks[i-1][-self._chunk_overlap:]
85+
new_chunk = overlap_tokens + chunk
86+
if self._length_function(new_chunk) > self._chunk_size:
87+
new_chunk = new_chunk[-self._chunk_size:]
88+
new_chunks.append(new_chunk)
89+
return new_chunks
90+
91+
# EXAMPLE USAGE
92+
if __name__ == "__main__":
93+
import tiktoken
94+
tokenizer = tiktoken.encoding_for_model("gpt-4")
95+
96+
def tokenize(text: str) -> List[int]:
97+
return tokenizer.encode(text)
98+
99+
def detokenize(tokens: List[int]) -> str:
100+
return tokenizer.decode(tokens)
101+
102+
# 创建splitter实例
103+
splitter = SeparatorSplitter(
104+
separators=[tokenize('\n'), tokenize('.')], # 使用换行符和句号作为分隔符
105+
chunk_size=5,
106+
chunk_overlap=0,
107+
keep_separator="end"
108+
)
109+
110+
# 示例文本
111+
text = "This is a sample text. It contains multiple sentences.\nSome sentences are short. Others are longer."
112+
tokens = tokenize(text)
113+
114+
# 分割tokens
115+
split_tokens = splitter.split_tokens(tokens)
116+
117+
print("Split tokens:")
118+
for i, token_chunk in enumerate(split_tokens):
119+
print(f"Chunk {i + 1}:")
120+
print(detokenize(token_chunk))
121+
print("---")

nano_graphrag/graphrag.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from functools import partial
66
from typing import Callable, Dict, List, Optional, Type, Union, cast
77

8+
import tiktoken
9+
810

911
from ._llm import (
1012
gpt_4o_complete,
@@ -65,7 +67,7 @@ class GraphRAG:
6567
enable_naive_rag: bool = False
6668

6769
# text chunking
68-
chunk_func: Callable[[str, Optional[int], Optional[int], Optional[str]], List[Dict[str, Union[str, int]]]] = chunking_by_token_size
70+
chunk_func: Callable[[str,List[str],tiktoken.Encoding, Optional[int], Optional[int], ], List[Dict[str, Union[str, int]]]] = chunking_by_token_size
6971
chunk_token_size: int = 1200
7072
chunk_overlap_token_size: int = 100
7173
tiktoken_model_name: str = "gpt-4o"
@@ -264,20 +266,22 @@ async def ainsert(self, string_or_strings):
264266

265267
# ---------- chunking
266268
inserting_chunks = {}
267-
for doc_key, doc in new_docs.items():
268-
chunks = {
269-
compute_mdhash_id(dp["content"], prefix="chunk-"): {
270-
**dp,
271-
"full_doc_id": doc_key,
272-
}
273-
for dp in self.chunk_func(
274-
doc["content"],
275-
overlap_token_size=self.chunk_overlap_token_size,
276-
max_token_size=self.chunk_token_size,
277-
tiktoken_model=self.tiktoken_model_name,
278-
)
279-
}
280-
inserting_chunks.update(chunks)
269+
270+
271+
272+
new_docs_list=list(new_docs.items())
273+
docs=[new_doc[1]["content"] for new_doc in new_docs_list]
274+
doc_keys=[new_doc[0] for new_doc in new_docs_list]
275+
276+
277+
ENCODER = tiktoken.encoding_for_model("gpt-4o")
278+
tokens=ENCODER.encode_batch(docs,num_threads=16)
279+
chunks=self.chunk_func(tokens,overlap_token_size=self.chunk_overlap_token_size,
280+
max_token_size=self.chunk_token_size,doc_keys=doc_keys,tiktoken_model=ENCODER)
281+
for chunk in chunks:
282+
inserting_chunks.update({compute_mdhash_id(chunk["content"], prefix="chunk-"):chunk})
283+
284+
281285
_add_chunk_keys = await self.text_chunks.filter_keys(
282286
list(inserting_chunks.keys())
283287
)

0 commit comments

Comments
 (0)