@@ -9,7 +9,7 @@ def __init__(
99 chunk_overlap : int = 200 ,
1010 length_function : callable = len ,
1111 ):
12- self ._separators = separators or [[ 10 ], [ 13 , 10 ]] # 默认使用换行符作为分隔符
12+ self ._separators = separators or []
1313 self ._keep_separator = keep_separator
1414 self ._chunk_size = chunk_size
1515 self ._chunk_overlap = chunk_overlap
@@ -27,18 +27,12 @@ def _split_tokens_with_separators(self, tokens: List[int]) -> List[List[int]]:
2727 separator_found = False
2828 for separator in self ._separators :
2929 if tokens [i :i + len (separator )] == separator :
30+ if self ._keep_separator in [True , "end" ]:
31+ current_split .extend (separator )
3032 if current_split :
31- if self ._keep_separator == "end" :
32- current_split .extend (separator )
33- splits .append (current_split )
34- current_split = []
35- elif self ._keep_separator == "start" :
36- splits .append (current_split )
37- current_split = separator [:]
38- else :
39- splits .append (current_split )
40- current_split = []
41- elif self ._keep_separator :
33+ splits .append (current_split )
34+ current_split = []
35+ if self ._keep_separator == "start" :
4236 current_split .extend (separator )
4337 i += len (separator )
4438 separator_found = True
@@ -51,71 +45,48 @@ def _split_tokens_with_separators(self, tokens: List[int]) -> List[List[int]]:
5145 return [s for s in splits if s ]
5246
5347 def _merge_splits (self , splits : List [List [int ]]) -> List [List [int ]]:
48+ if not splits :
49+ return []
50+
5451 merged_splits = []
55- current_split = []
56- current_length = 0
57- separator = [] if self ._keep_separator is False else self ._separators [- 1 ]
52+ current_chunk = []
5853
5954 for split in splits :
60- if self . _length_function ( current_split ) + self . _length_function ( split ) <= self . _chunk_size :
61- if current_split and separator :
62- current_split . extend ( separator )
63- current_split .extend (split )
55+ if not current_chunk :
56+ current_chunk = split
57+ elif self . _length_function ( current_chunk ) + self . _length_function ( split ) <= self . _chunk_size :
58+ current_chunk .extend (split )
6459 else :
65- if current_split :
66- merged_splits . append ( current_split )
67- current_split = split
68- if self . _length_function ( current_split ) >= self . _chunk_size :
69- merged_splits .append (current_split )
70- current_split = []
71- if current_split :
72- merged_splits . append ( current_split )
60+ merged_splits . append ( current_chunk )
61+ current_chunk = split
62+
63+ if current_chunk :
64+ merged_splits .append (current_chunk )
65+
66+ if len ( merged_splits ) == 1 and self . _length_function ( merged_splits [ 0 ]) > self . _chunk_size :
67+ return self . _split_chunk ( merged_splits [ 0 ] )
7368
7469 if self ._chunk_overlap > 0 :
7570 return self ._enforce_overlap (merged_splits )
71+
7672 return merged_splits
7773
74+ def _split_chunk (self , chunk : List [int ]) -> List [List [int ]]:
75+ result = []
76+ for i in range (0 , len (chunk ), self ._chunk_size - self ._chunk_overlap ):
77+ result .append (chunk [i :i + self ._chunk_size ])
78+ return result
79+
7880 def _enforce_overlap (self , chunks : List [List [int ]]) -> List [List [int ]]:
79- new_chunks = []
81+ result = []
8082 for i , chunk in enumerate (chunks ):
8183 if i == 0 :
82- new_chunks .append (chunk )
84+ result .append (chunk )
8385 else :
84- overlap_tokens = chunks [i - 1 ][- self ._chunk_overlap :]
85- new_chunk = overlap_tokens + chunk
86+ overlap = chunks [i - 1 ][- self ._chunk_overlap :]
87+ new_chunk = overlap + chunk
8688 if self ._length_function (new_chunk ) > self ._chunk_size :
87- new_chunk = new_chunk [- self ._chunk_size :]
88- new_chunks .append (new_chunk )
89- return new_chunks
90-
91- # EXAMPLE USAGE
92- if __name__ == "__main__" :
93- import tiktoken
94- tokenizer = tiktoken .encoding_for_model ("gpt-4" )
95-
96- def tokenize (text : str ) -> List [int ]:
97- return tokenizer .encode (text )
98-
99- def detokenize (tokens : List [int ]) -> str :
100- return tokenizer .decode (tokens )
101-
102- # 创建splitter实例
103- splitter = SeparatorSplitter (
104- separators = [tokenize ('\n ' ), tokenize ('.' )], # 使用换行符和句号作为分隔符
105- chunk_size = 5 ,
106- chunk_overlap = 0 ,
107- keep_separator = "end"
108- )
109-
110- # 示例文本
111- text = "This is a sample text. It contains multiple sentences.\n Some sentences are short. Others are longer."
112- tokens = tokenize (text )
113-
114- # 分割tokens
115- split_tokens = splitter .split_tokens (tokens )
89+ new_chunk = new_chunk [:self ._chunk_size ]
90+ result .append (new_chunk )
91+ return result
11692
117- print ("Split tokens:" )
118- for i , token_chunk in enumerate (split_tokens ):
119- print (f"Chunk { i + 1 } :" )
120- print (detokenize (token_chunk ))
121- print ("---" )
0 commit comments