Skip to content

Commit a6d4930

Browse files
benbrandtzurawiki
andauthored
Update to latest tiktoken + new model support (#106)
Co-authored-by: Roger Zurawicki <[email protected]>
1 parent f597108 commit a6d4930

File tree

5 files changed

+132
-302
lines changed

5 files changed

+132
-302
lines changed

.gitmodules

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
[submodule "vendor/tiktoken"]
22
path = vendor/tiktoken
33
url = https://github.com/openai/tiktoken.git
4-
ref = refs/tags/0.8.0
4+
ref = refs/tags/0.9.0

tiktoken-rs/src/patched_tiktoken.rs

Lines changed: 9 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,6 @@ use anyhow::anyhow;
33
use anyhow::Result;
44
use fancy_regex::Regex;
55
use rustc_hash::FxHashMap as HashMap;
6-
use std::collections::HashSet;
7-
8-
// used to handle errors in the core below
9-
impl std::error::Error for DecodeKeyError {}
106

117
/// Rust API
128
impl CoreBPE {
@@ -21,20 +17,23 @@ impl CoreBPE {
2117
special_tokens_encoder: HashMap<String, Rank>,
2218
pattern: &str,
2319
) -> Result<Self> {
24-
let regex = Regex::new(pattern).map_err(|e| anyhow!(e.to_string()))?;
20+
let regex = Regex::new(pattern)?;
2521

2622
let special_regex = {
27-
let _parts = special_tokens_encoder
23+
let parts = special_tokens_encoder
2824
.keys()
2925
.map(|s| fancy_regex::escape(s))
3026
.collect::<Vec<_>>();
31-
Regex::new(&_parts.join("|")).map_err(|e| anyhow!(e.to_string()))?
27+
Regex::new(&parts.join("|"))?
3228
};
3329

3430
let decoder: HashMap<Rank, Vec<u8>> =
3531
encoder.iter().map(|(k, v)| (*v, k.clone())).collect();
3632

37-
assert!(encoder.len() == decoder.len());
33+
assert!(
34+
encoder.len() == decoder.len(),
35+
"Encoder and decoder must be of equal length; maybe you had duplicate token indices in your encoder?"
36+
);
3837

3938
let special_tokens_decoder: HashMap<Rank, Vec<u8>> = special_tokens_encoder
4039
.iter()
@@ -45,7 +44,7 @@ impl CoreBPE {
4544
let mut sorted_token_bytes: Vec<Vec<u8>> = encoder.keys().cloned().collect();
4645
sorted_token_bytes.sort();
4746

48-
Ok(CoreBPE {
47+
Ok(Self {
4948
encoder,
5049
special_tokens_encoder,
5150
decoder,
@@ -58,23 +57,6 @@ impl CoreBPE {
5857
})
5958
}
6059

61-
pub fn encode_ordinary(&self, text: &str) -> Vec<Rank> {
62-
self._encode_ordinary_native(text)
63-
}
64-
65-
pub fn encode(&self, text: &str, allowed_special: HashSet<&str>) -> Vec<Rank> {
66-
self._encode_native(text, &allowed_special).0
67-
}
68-
69-
pub fn encode_with_special_tokens(&self, text: &str) -> Vec<Rank> {
70-
let allowed_special = self
71-
.special_tokens_encoder
72-
.keys()
73-
.map(|s| s.as_str())
74-
.collect();
75-
self._encode_native(text, &allowed_special).0
76-
}
77-
7860
// ====================
7961
// Decoding
8062
// ====================
@@ -83,7 +65,7 @@ impl CoreBPE {
8365
///
8466
/// If unicode validation is not wanted, see _decode_native.
8567
pub fn decode(&self, tokens: Vec<Rank>) -> Result<String> {
86-
match String::from_utf8(self._decode_native(&tokens)?) {
68+
match String::from_utf8(self.decode_bytes(&tokens)?) {
8769
Ok(text) => Ok(text),
8870
Err(e) => Err(anyhow!("Unable to decode into a valid UTF-8 string: {}", e)),
8971
}

tiktoken-rs/src/tokenizer.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,17 @@ pub enum Tokenizer {
3232
// https://github.com/openai/tiktoken/blob/63527649963def8c759b0f91f2eb69a40934e468/tiktoken/model.py#L7
3333
const MODEL_PREFIX_TO_TOKENIZER: &[(&str, Tokenizer)] = &[
3434
("o1-", Tokenizer::O200kBase),
35+
("o3-", Tokenizer::O200kBase),
36+
("o4-", Tokenizer::O200kBase),
3537
// chat
38+
("gpt-4.1-", Tokenizer::O200kBase),
3639
("chatgpt-4o-", Tokenizer::O200kBase),
3740
("gpt-4o-", Tokenizer::O200kBase),
3841
("gpt-4-", Tokenizer::Cl100kBase),
3942
("gpt-3.5-turbo-", Tokenizer::Cl100kBase),
4043
("gpt-35-turbo-", Tokenizer::Cl100kBase),
4144
// fine-tuned
45+
("ft:gpt-4o", Tokenizer::O200kBase),
4246
("ft:gpt-4", Tokenizer::Cl100kBase),
4347
("ft:gpt-3.5-turbo", Tokenizer::Cl100kBase),
4448
("ft:davinci-002", Tokenizer::Cl100kBase),
@@ -48,7 +52,11 @@ const MODEL_PREFIX_TO_TOKENIZER: &[(&str, Tokenizer)] = &[
4852
// Keep this in sync with:
4953
// https://github.com/openai/tiktoken/blob/63527649963def8c759b0f91f2eb69a40934e468/tiktoken/model.py#L22
5054
const MODEL_TO_TOKENIZER: &[(&str, Tokenizer)] = &[
55+
// reasoning
56+
("o1", Tokenizer::O200kBase),
57+
("o3", Tokenizer::O200kBase),
5158
// chat
59+
("gpt-4.1", Tokenizer::O200kBase),
5260
("chatgpt-4o-latest", Tokenizer::O200kBase),
5361
("gpt-4o", Tokenizer::O200kBase),
5462
("gpt-4", Tokenizer::Cl100kBase),

0 commit comments

Comments
 (0)