Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 93 additions & 16 deletions nano_graphrag/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from typing import Union
from collections import Counter, defaultdict

import tiktoken

from ._utils import (
logger,
clean_str,
Expand All @@ -28,27 +30,102 @@
from .prompt import GRAPH_FIELD_SEP, PROMPTS




def chunking_by_token_size(
content: str, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o"
):
tokens = encode_string_by_tiktoken(content, model_name=tiktoken_model)
results = []
for index, start in enumerate(
range(0, len(tokens), max_token_size - overlap_token_size)
tokens_list: list[int], doc_keys,tiktoken_model, overlap_token_size=128, max_token_size=1024,
):
chunk_content = decode_tokens_by_tiktoken(
tokens[start : start + max_token_size], model_name=tiktoken_model
)
results.append(
{
"tokens": min(max_token_size, len(tokens) - start),
"content": chunk_content.strip(),
"chunk_order_index": index,
}
)

results=[]
for index,tokens in enumerate(tokens_list):
chunk_token=[]
lengths=[]
for start in range(0, len(tokens), max_token_size - overlap_token_size):

chunk_token.append(tokens[start : start + max_token_size])
lengths.append(min(max_token_size, len(tokens) - start))

# here somehow tricky, since the whole chunk tokens is list[list[list[int]]] for corpus(doc(chunk)),so it can't be decode entirely
chunk_token=tiktoken_model.decode_batch(chunk_token)
for i,chunk in enumerate(chunk_token):

results.append(
{
"tokens": lengths[i],
"content": chunk.strip(),
"chunk_order_index": i,
"full_doc_id":doc_keys[index],
}
)

return results

def chunking_by_seperators(tokens_list: list[int], doc_keys,tiktoken_model, overlap_token_size=128, max_token_size=1024 ):
from nano_graphrag._spliter import SeparatorSplitter

DEFAULT_SEPERATORS=[
# Paragraph separators
"\n\n",
"\r\n\r\n",
# Line breaks
"\n",
"\r\n",
# Sentence ending punctuation
"。", # Chinese period
".", # Full-width dot
".", # English period
"!", # Chinese exclamation mark
"!", # English exclamation mark
"?", # Chinese question mark
"?", # English question mark
# Whitespace characters
" ", # Space
"\t", # Tab
"\u3000", # Full-width space
# Special characters
"\u200b", # Zero-width space (used in some Asian languages)
]

splitter=SeparatorSplitter(separators=[tiktoken_model.encode(s) for s in DEFAULT_SEPERATORS],chunk_size=max_token_size,chunk_overlap=overlap_token_size)
results=[]
for index,tokens in enumerate(tokens_list):
chunk_token=splitter.split_tokens(tokens)
lengths=[len(c) for c in chunk_token]

# here somehow tricky, since the whole chunk tokens is list[list[list[int]]] for corpus(doc(chunk)),so it can't be decode entirely
chunk_token=tiktoken_model.decode_batch(chunk_token)
for i,chunk in enumerate(chunk_token):

results.append(
{
"tokens": lengths[i],
"content": chunk.strip(),
"chunk_order_index": i,
"full_doc_id":doc_keys[index],
}
)

return results


def get_chunks(new_docs,chunk_func=chunking_by_token_size,**chunk_func_params):
inserting_chunks = {}

new_docs_list=list(new_docs.items())
docs=[new_doc[1]["content"] for new_doc in new_docs_list]
doc_keys=[new_doc[0] for new_doc in new_docs_list]


ENCODER = tiktoken.encoding_for_model("gpt-4o")
tokens=ENCODER.encode_batch(docs,num_threads=16)
chunks=chunk_func(tokens,doc_keys=doc_keys,tiktoken_model=ENCODER,**chunk_func_params)

for chunk in chunks:
inserting_chunks.update({compute_mdhash_id(chunk["content"], prefix="chunk-"):chunk})

return inserting_chunks


async def _handle_entity_relation_summary(
entity_or_relation_name: str,
description: str,
Expand Down
94 changes: 94 additions & 0 deletions nano_graphrag/_spliter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from typing import List, Optional, Union, Literal

class SeparatorSplitter:
def __init__(
self,
separators: Optional[List[List[int]]] = None,
keep_separator: Union[bool, Literal["start", "end"]] = "end",
chunk_size: int = 4000,
chunk_overlap: int = 200,
length_function: callable = len,
):
self._separators = separators or []
self._keep_separator = keep_separator
self._chunk_size = chunk_size
self._chunk_overlap = chunk_overlap
self._length_function = length_function

def split_tokens(self, tokens: List[int]) -> List[List[int]]:
splits = self._split_tokens_with_separators(tokens)
return self._merge_splits(splits)

def _split_tokens_with_separators(self, tokens: List[int]) -> List[List[int]]:
splits = []
current_split = []
i = 0
while i < len(tokens):
separator_found = False
for separator in self._separators:
if tokens[i:i+len(separator)] == separator:
if self._keep_separator in [True, "end"]:
current_split.extend(separator)
if current_split:
splits.append(current_split)
current_split = []
if self._keep_separator == "start":
current_split.extend(separator)
i += len(separator)
separator_found = True
break
if not separator_found:
current_split.append(tokens[i])
i += 1
if current_split:
splits.append(current_split)
return [s for s in splits if s]

def _merge_splits(self, splits: List[List[int]]) -> List[List[int]]:
if not splits:
return []

merged_splits = []
current_chunk = []

for split in splits:
if not current_chunk:
current_chunk = split
elif self._length_function(current_chunk) + self._length_function(split) <= self._chunk_size:
current_chunk.extend(split)
else:
merged_splits.append(current_chunk)
current_chunk = split

if current_chunk:
merged_splits.append(current_chunk)

if len(merged_splits) == 1 and self._length_function(merged_splits[0]) > self._chunk_size:
return self._split_chunk(merged_splits[0])

if self._chunk_overlap > 0:
return self._enforce_overlap(merged_splits)

return merged_splits

def _split_chunk(self, chunk: List[int]) -> List[List[int]]:
result = []
for i in range(0, len(chunk), self._chunk_size - self._chunk_overlap):
new_chunk = chunk[i:i + self._chunk_size]
if len(new_chunk) > self._chunk_overlap: # 只有当 chunk 长度大于 overlap 时才添加
result.append(new_chunk)
return result

def _enforce_overlap(self, chunks: List[List[int]]) -> List[List[int]]:
result = []
for i, chunk in enumerate(chunks):
if i == 0:
result.append(chunk)
else:
overlap = chunks[i-1][-self._chunk_overlap:]
new_chunk = overlap + chunk
if self._length_function(new_chunk) > self._chunk_size:
new_chunk = new_chunk[:self._chunk_size]
result.append(new_chunk)
return result

25 changes: 9 additions & 16 deletions nano_graphrag/graphrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from functools import partial
from typing import Callable, Dict, List, Optional, Type, Union, cast

import tiktoken


from ._llm import (
gpt_4o_complete,
Expand All @@ -18,6 +20,7 @@
chunking_by_token_size,
extract_entities,
generate_community_report,
get_chunks,
local_query,
global_query,
naive_query,
Expand Down Expand Up @@ -65,7 +68,7 @@ class GraphRAG:
enable_naive_rag: bool = False

# text chunking
chunk_func: Callable[[str, Optional[int], Optional[int], Optional[str]], List[Dict[str, Union[str, int]]]] = chunking_by_token_size
chunk_func: Callable[[str,List[str],tiktoken.Encoding, Optional[int], Optional[int], ], List[Dict[str, Union[str, int]]]] = chunking_by_token_size
chunk_token_size: int = 1200
chunk_overlap_token_size: int = 100
tiktoken_model_name: str = "gpt-4o"
Expand Down Expand Up @@ -263,21 +266,11 @@ async def ainsert(self, string_or_strings):
logger.info(f"[New Docs] inserting {len(new_docs)} docs")

# ---------- chunking
inserting_chunks = {}
for doc_key, doc in new_docs.items():
chunks = {
compute_mdhash_id(dp["content"], prefix="chunk-"): {
**dp,
"full_doc_id": doc_key,
}
for dp in self.chunk_func(
doc["content"],
overlap_token_size=self.chunk_overlap_token_size,
max_token_size=self.chunk_token_size,
tiktoken_model=self.tiktoken_model_name,
)
}
inserting_chunks.update(chunks)

inserting_chunks = get_chunks(new_docs=new_docs,chunk_func=self.chunk_func,overlap_token_size=self.chunk_overlap_token_size,
max_token_size=self.chunk_token_size)


_add_chunk_keys = await self.text_chunks.filter_keys(
list(inserting_chunks.keys())
)
Expand Down
63 changes: 63 additions & 0 deletions tests/test_splitter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import unittest
from typing import List

from nano_graphrag._spliter import SeparatorSplitter

# Assuming the SeparatorSplitter class is already imported

class TestSeparatorSplitter(unittest.TestCase):

def setUp(self):
self.tokenize = lambda text: [ord(c) for c in text] # Simple tokenizer for testing
self.detokenize = lambda tokens: ''.join(chr(t) for t in tokens)

def test_split_with_custom_separator(self):
splitter = SeparatorSplitter(
separators=[self.tokenize('\n'), self.tokenize('.')],
chunk_size=19,
chunk_overlap=0,
keep_separator="end"
)
text = "This is a test.\nAnother test."
tokens = self.tokenize(text)
expected = [
self.tokenize("This is a test.\n"),
self.tokenize("Another test."),
]
result = splitter.split_tokens(tokens)

self.assertEqual(result, expected)

def test_chunk_size_limit(self):
splitter = SeparatorSplitter(
chunk_size=5,
chunk_overlap=0,
separators=[self.tokenize("\n")]
)
text = "1234567890"
tokens = self.tokenize(text)
expected = [
self.tokenize("12345"),
self.tokenize("67890")
]
result = splitter.split_tokens(tokens)
self.assertEqual(result, expected)

def test_chunk_overlap(self):
splitter = SeparatorSplitter(
chunk_size=5,
chunk_overlap=2,
separators=[self.tokenize("\n")]
)
text = "1234567890"
tokens = self.tokenize(text)
expected = [
self.tokenize("12345"),
self.tokenize("45678"),
self.tokenize("7890"),
]
result = splitter.split_tokens(tokens)
self.assertEqual(result, expected)

if __name__ == '__main__':
unittest.main()