1+ from typing import List , Optional , Union , Literal
2+
3+ class SeparatorSplitter :
4+ def __init__ (
5+ self ,
6+ separators : Optional [List [List [int ]]] = None ,
7+ keep_separator : Union [bool , Literal ["start" , "end" ]] = "end" ,
8+ chunk_size : int = 4000 ,
9+ chunk_overlap : int = 200 ,
10+ length_function : callable = len ,
11+ ):
12+ self ._separators = separators or [[10 ], [13 , 10 ]] # 默认使用换行符作为分隔符
13+ self ._keep_separator = keep_separator
14+ self ._chunk_size = chunk_size
15+ self ._chunk_overlap = chunk_overlap
16+ self ._length_function = length_function
17+
18+ def split_tokens (self , tokens : List [int ]) -> List [List [int ]]:
19+ splits = self ._split_tokens_with_separators (tokens )
20+ return self ._merge_splits (splits )
21+
22+ def _split_tokens_with_separators (self , tokens : List [int ]) -> List [List [int ]]:
23+ splits = []
24+ current_split = []
25+ i = 0
26+ while i < len (tokens ):
27+ separator_found = False
28+ for separator in self ._separators :
29+ if tokens [i :i + len (separator )] == separator :
30+ if current_split :
31+ if self ._keep_separator == "end" :
32+ current_split .extend (separator )
33+ splits .append (current_split )
34+ current_split = []
35+ elif self ._keep_separator == "start" :
36+ splits .append (current_split )
37+ current_split = separator [:]
38+ else :
39+ splits .append (current_split )
40+ current_split = []
41+ elif self ._keep_separator :
42+ current_split .extend (separator )
43+ i += len (separator )
44+ separator_found = True
45+ break
46+ if not separator_found :
47+ current_split .append (tokens [i ])
48+ i += 1
49+ if current_split :
50+ splits .append (current_split )
51+ return [s for s in splits if s ]
52+
53+ def _merge_splits (self , splits : List [List [int ]]) -> List [List [int ]]:
54+ merged_splits = []
55+ current_split = []
56+ current_length = 0
57+ separator = [] if self ._keep_separator is False else self ._separators [- 1 ]
58+
59+ for split in splits :
60+ if self ._length_function (current_split ) + self ._length_function (split ) <= self ._chunk_size :
61+ if current_split and separator :
62+ current_split .extend (separator )
63+ current_split .extend (split )
64+ else :
65+ if current_split :
66+ merged_splits .append (current_split )
67+ current_split = split
68+ if self ._length_function (current_split ) >= self ._chunk_size :
69+ merged_splits .append (current_split )
70+ current_split = []
71+ if current_split :
72+ merged_splits .append (current_split )
73+
74+ if self ._chunk_overlap > 0 :
75+ return self ._enforce_overlap (merged_splits )
76+ return merged_splits
77+
78+ def _enforce_overlap (self , chunks : List [List [int ]]) -> List [List [int ]]:
79+ new_chunks = []
80+ for i , chunk in enumerate (chunks ):
81+ if i == 0 :
82+ new_chunks .append (chunk )
83+ else :
84+ overlap_tokens = chunks [i - 1 ][- self ._chunk_overlap :]
85+ new_chunk = overlap_tokens + chunk
86+ if self ._length_function (new_chunk ) > self ._chunk_size :
87+ new_chunk = new_chunk [- self ._chunk_size :]
88+ new_chunks .append (new_chunk )
89+ return new_chunks
90+
91+ # EXAMPLE USAGE
92+ if __name__ == "__main__" :
93+ import tiktoken
94+ tokenizer = tiktoken .encoding_for_model ("gpt-4" )
95+
96+ def tokenize (text : str ) -> List [int ]:
97+ return tokenizer .encode (text )
98+
99+ def detokenize (tokens : List [int ]) -> str :
100+ return tokenizer .decode (tokens )
101+
102+ # 创建splitter实例
103+ splitter = SeparatorSplitter (
104+ separators = [tokenize ('\n ' ), tokenize ('.' )], # 使用换行符和句号作为分隔符
105+ chunk_size = 5 ,
106+ chunk_overlap = 0 ,
107+ keep_separator = "end"
108+ )
109+
110+ # 示例文本
111+ text = "This is a sample text. It contains multiple sentences.\n Some sentences are short. Others are longer."
112+ tokens = tokenize (text )
113+
114+ # 分割tokens
115+ split_tokens = splitter .split_tokens (tokens )
116+
117+ print ("Split tokens:" )
118+ for i , token_chunk in enumerate (split_tokens ):
119+ print (f"Chunk { i + 1 } :" )
120+ print (detokenize (token_chunk ))
121+ print ("---" )
0 commit comments