@@ -3,10 +3,6 @@ use anyhow::anyhow;
33use anyhow:: Result ;
44use fancy_regex:: Regex ;
55use rustc_hash:: FxHashMap as HashMap ;
6- use std:: collections:: HashSet ;
7-
8- // used to handle errors in the core below
9- impl std:: error:: Error for DecodeKeyError { }
106
117/// Rust API
128impl CoreBPE {
@@ -21,20 +17,23 @@ impl CoreBPE {
2117 special_tokens_encoder : HashMap < String , Rank > ,
2218 pattern : & str ,
2319 ) -> Result < Self > {
24- let regex = Regex :: new ( pattern) . map_err ( |e| anyhow ! ( e . to_string ( ) ) ) ?;
20+ let regex = Regex :: new ( pattern) ?;
2521
2622 let special_regex = {
27- let _parts = special_tokens_encoder
23+ let parts = special_tokens_encoder
2824 . keys ( )
2925 . map ( |s| fancy_regex:: escape ( s) )
3026 . collect :: < Vec < _ > > ( ) ;
31- Regex :: new ( & _parts . join ( "|" ) ) . map_err ( |e| anyhow ! ( e . to_string ( ) ) ) ?
27+ Regex :: new ( & parts . join ( "|" ) ) ?
3228 } ;
3329
3430 let decoder: HashMap < Rank , Vec < u8 > > =
3531 encoder. iter ( ) . map ( |( k, v) | ( * v, k. clone ( ) ) ) . collect ( ) ;
3632
37- assert ! ( encoder. len( ) == decoder. len( ) ) ;
33+ assert ! (
34+ encoder. len( ) == decoder. len( ) ,
35+ "Encoder and decoder must be of equal length; maybe you had duplicate token indices in your encoder?"
36+ ) ;
3837
3938 let special_tokens_decoder: HashMap < Rank , Vec < u8 > > = special_tokens_encoder
4039 . iter ( )
@@ -45,7 +44,7 @@ impl CoreBPE {
4544 let mut sorted_token_bytes: Vec < Vec < u8 > > = encoder. keys ( ) . cloned ( ) . collect ( ) ;
4645 sorted_token_bytes. sort ( ) ;
4746
48- Ok ( CoreBPE {
47+ Ok ( Self {
4948 encoder,
5049 special_tokens_encoder,
5150 decoder,
@@ -58,23 +57,6 @@ impl CoreBPE {
5857 } )
5958 }
6059
61- pub fn encode_ordinary ( & self , text : & str ) -> Vec < Rank > {
62- self . _encode_ordinary_native ( text)
63- }
64-
65- pub fn encode ( & self , text : & str , allowed_special : HashSet < & str > ) -> Vec < Rank > {
66- self . _encode_native ( text, & allowed_special) . 0
67- }
68-
69- pub fn encode_with_special_tokens ( & self , text : & str ) -> Vec < Rank > {
70- let allowed_special = self
71- . special_tokens_encoder
72- . keys ( )
73- . map ( |s| s. as_str ( ) )
74- . collect ( ) ;
75- self . _encode_native ( text, & allowed_special) . 0
76- }
77-
7860 // ====================
7961 // Decoding
8062 // ====================
@@ -83,7 +65,7 @@ impl CoreBPE {
8365 ///
8466 /// If unicode validation is not wanted, see _decode_native.
8567 pub fn decode ( & self , tokens : Vec < Rank > ) -> Result < String > {
86- match String :: from_utf8 ( self . _decode_native ( & tokens) ?) {
68+ match String :: from_utf8 ( self . decode_bytes ( & tokens) ?) {
8769 Ok ( text) => Ok ( text) ,
8870 Err ( e) => Err ( anyhow ! ( "Unable to decode into a valid UTF-8 string: {}" , e) ) ,
8971 }
0 commit comments