diff --git a/include/pytorch/tokenizers/hf_tokenizer.h b/include/pytorch/tokenizers/hf_tokenizer.h index 54869c7..a8976fb 100644 --- a/include/pytorch/tokenizers/hf_tokenizer.h +++ b/include/pytorch/tokenizers/hf_tokenizer.h @@ -13,6 +13,7 @@ #pragma once // Standard +#include #include // Local @@ -46,7 +47,7 @@ class HFTokenizer : public detail::BPETokenizerBase { void _decode(const std::string& input, std::string& ret) const override; - PreTokenizer::Ptr _pretokenizer; + std::optional _pretokenizer; TokenDecoder::Ptr _decoder; }; diff --git a/include/pytorch/tokenizers/pre_tokenizer.h b/include/pytorch/tokenizers/pre_tokenizer.h index 8462c9f..6605d79 100644 --- a/include/pytorch/tokenizers/pre_tokenizer.h +++ b/include/pytorch/tokenizers/pre_tokenizer.h @@ -121,7 +121,7 @@ class PreTokenizerConfig { /** * Construct the pre tokenizer instance from the member data */ - PreTokenizer::Ptr create() const; + std::optional create() const; /** * Populate from a json config file diff --git a/src/hf_tokenizer.cpp b/src/hf_tokenizer.cpp index fa62264..a3e0f16 100644 --- a/src/hf_tokenizer.cpp +++ b/src/hf_tokenizer.cpp @@ -65,10 +65,11 @@ Error HFTokenizer::load(const std::string& path) { try { std::vector> special_token_pairs; const auto& special_tokens = parsed_json.at("added_tokens"); - auto special_token_map = TK_UNWRAP(detail::buildTokenMap( - special_tokens, - [](const auto& it) -> std::string { return it.at("content"); }, - [](const auto& it) -> std::uint64_t { return it.at("id"); })); + auto special_token_map = TK_UNWRAP( + detail::buildTokenMap( + special_tokens, + [](const auto& it) -> std::string { return it.at("content"); }, + [](const auto& it) -> std::uint64_t { return it.at("id"); })); // Create special token regex to help later with encoding. special_token_regex_ = @@ -107,9 +108,12 @@ Error HFTokenizer::load(const std::string& path) { // Set up the pre-tokenizer try { std::cout << "Setting up pretokenizer..." << std::endl; - _pretokenizer = PreTokenizerConfig() - .parse_json(parsed_json.at("pre_tokenizer")) - .create(); + auto pretokenizer = PreTokenizerConfig() + .parse_json(parsed_json.at("pre_tokenizer")) + .create(); + if (pretokenizer) { + _pretokenizer = *pretokenizer; + } std::cout << "Pretokenizer set up" << std::endl; } catch (const json::out_of_range& e) { fprintf(stderr, "Could not parse pre_tokenizer: %s\n", e.what()); @@ -249,17 +253,25 @@ Error HFTokenizer::_encode( const std::string& input, std::vector& ret, uint64_t& last_piece_token_len) const { - for (const auto& piece : _pretokenizer->pre_tokenize(input)) { + auto encode_piece = [&](const std::string& piece) { const auto result = token_map_->tryGetInteger(piece); if (result) { last_piece_token_len = 1; ret.push_back(*result); - continue; + } else { + auto tokens = TK_UNWRAP(byte_pair_encode_(piece, *token_map_)); + last_piece_token_len = tokens.size(); + ret.insert(ret.end(), tokens.begin(), tokens.end()); } - auto tokens = TK_UNWRAP(byte_pair_encode_(piece, *token_map_)); + }; - last_piece_token_len = tokens.size(); - ret.insert(ret.end(), tokens.begin(), tokens.end()); + if (_pretokenizer) { + for (const auto& piece : (*_pretokenizer)->pre_tokenize(input)) { + encode_piece(piece); + } + } else { + // If no pretokenizer, treat the entire input as a single piece + encode_piece(input); } return Error::Ok; } diff --git a/src/pre_tokenizer.cpp b/src/pre_tokenizer.cpp index 26a66c6..5965414 100644 --- a/src/pre_tokenizer.cpp +++ b/src/pre_tokenizer.cpp @@ -28,14 +28,13 @@ namespace tokenizers { PreTokenizerConfig::PreTokenizerConfig(std::string type) : type(std::move(type)) {} -PreTokenizer::Ptr PreTokenizerConfig::create() const { +std::optional PreTokenizerConfig::create() const { // NOTE: These types must line up with the type strings found in the // tokenizers library // https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/mod.rs#L73 if (type == "Split") { if (!pattern) { - throw std::runtime_error( - "Missing pattern for PreTokenizer of type Split"); + return std::nullopt; // Return nullopt if no pattern is provided } return PreTokenizer::Ptr(new RegexPreTokenizer(*pattern)); } @@ -68,7 +67,14 @@ PreTokenizer::Ptr PreTokenizerConfig::create() const { pretokenizers->begin(), pretokenizers->end(), std::back_inserter(pretoks), - [](const PreTokenizerConfig& cfg) { return cfg.create(); }); + [](const PreTokenizerConfig& cfg) { + auto result = cfg.create(); + if (!result) { + throw std::runtime_error( + "Failed to create pretokenizer in sequence"); + } + return *result; + }); return PreTokenizer::Ptr(new SequencePreTokenizer(pretoks)); } throw std::runtime_error("Unsupported PreTokenizer type: " + type); diff --git a/test/test_pre_tokenizer.cpp b/test/test_pre_tokenizer.cpp index d6e2736..64e9351 100644 --- a/test/test_pre_tokenizer.cpp +++ b/test/test_pre_tokenizer.cpp @@ -122,32 +122,36 @@ class PreTokenizerConfigTest : public ::testing::Test {}; TEST_F(PreTokenizerConfigTest, AllTypesSuccess) { // Regex - PreTokenizerConfig("Split").set_pattern(R"(o)").create(); + EXPECT_TRUE(PreTokenizerConfig("Split").set_pattern(R"(o)").create()); // Digits - PreTokenizerConfig("Digits").create(); - PreTokenizerConfig("Digits").set_individual_digits(true).create(); - PreTokenizerConfig("Digits").set_individual_digits(false).create(); + EXPECT_TRUE(PreTokenizerConfig("Digits").create()); + EXPECT_TRUE( + PreTokenizerConfig("Digits").set_individual_digits(true).create()); + EXPECT_TRUE( + PreTokenizerConfig("Digits").set_individual_digits(false).create()); // ByteLevel - PreTokenizerConfig("ByteLevel").create(); - PreTokenizerConfig("ByteLevel").set_pattern(R"(o)").create(); - PreTokenizerConfig("ByteLevel").set_add_prefix_space(true).create(); - PreTokenizerConfig("ByteLevel") - .set_add_prefix_space(false) - .set_pattern(R"(o)") - .create(); + EXPECT_TRUE(PreTokenizerConfig("ByteLevel").create()); + EXPECT_TRUE(PreTokenizerConfig("ByteLevel").set_pattern(R"(o)").create()); + EXPECT_TRUE( + PreTokenizerConfig("ByteLevel").set_add_prefix_space(true).create()); + EXPECT_TRUE(PreTokenizerConfig("ByteLevel") + .set_add_prefix_space(false) + .set_pattern(R"(o)") + .create()); // Sequence - PreTokenizerConfig("Sequence") - .set_pretokenizers( - {PreTokenizerConfig("Digits"), PreTokenizerConfig("ByteLevel")}) - .create(); + EXPECT_TRUE( + PreTokenizerConfig("Sequence") + .set_pretokenizers( + {PreTokenizerConfig("Digits"), PreTokenizerConfig("ByteLevel")}) + .create()); } TEST_F(PreTokenizerConfigTest, AllTypesFailureCases) { // Regex - EXPECT_THROW(PreTokenizerConfig("Split").create(), std::runtime_error); + EXPECT_FALSE(PreTokenizerConfig("Split").create()); // Sequence EXPECT_THROW(PreTokenizerConfig("Sequence").create(), std::runtime_error); @@ -167,20 +171,21 @@ TEST_F(PreTokenizerConfigTest, AllTypesFailureCases) { TEST_F(PreTokenizerConfigTest, ParseJson) { PreTokenizerConfig config; const auto ptok = config - .parse_json(json{ - {"type", "Sequence"}, - {"pretokenizers", - json{ - json{ - {"type", "Digits"}, - {"individual_digits", true}, - }, + .parse_json( + json{ + {"type", "Sequence"}, + {"pretokenizers", json{ - {"type", "ByteLevel"}, - {"add_prefix_space", false}, - }, - }}, - }) + json{ + {"type", "Digits"}, + {"individual_digits", true}, + }, + json{ + {"type", "ByteLevel"}, + {"add_prefix_space", false}, + }, + }}, + }) .create(); assert_split_match( *ptok, @@ -203,9 +208,10 @@ TEST_F(PreTokenizerConfigTest, ParseJson) { TEST_F(PreTokenizerConfigTest, ParseJsonOptionalKey) { PreTokenizerConfig config; const auto ptok = config - .parse_json(json{ - {"type", "Digits"}, - }) + .parse_json( + json{ + {"type", "Digits"}, + }) .create(); assert_split_match( *ptok, @@ -217,12 +223,13 @@ TEST_F(PreTokenizerConfigTest, Split) { PreTokenizerConfig config; const auto ptok = config - .parse_json(json{ - {"type", "Split"}, - {"pattern", - {{"Regex", - R"((?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+)"}}}, - }) + .parse_json( + json{ + {"type", "Split"}, + {"pattern", + {{"Regex", + R"((?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+)"}}}, + }) .create(); assert_split_match(*ptok, "Hello World", {"Hello", " World"}); }