Skip to content

Make pretokenizer optional in hf tokenizer #81

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion include/pytorch/tokenizers/hf_tokenizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#pragma once

// Standard
#include <optional>
#include <string>

// Local
Expand Down Expand Up @@ -46,7 +47,7 @@ class HFTokenizer : public detail::BPETokenizerBase {

void _decode(const std::string& input, std::string& ret) const override;

PreTokenizer::Ptr _pretokenizer;
std::optional<PreTokenizer::Ptr> _pretokenizer;
TokenDecoder::Ptr _decoder;
};

Expand Down
2 changes: 1 addition & 1 deletion include/pytorch/tokenizers/pre_tokenizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class PreTokenizerConfig {
/**
* Construct the pre tokenizer instance from the member data
*/
PreTokenizer::Ptr create() const;
std::optional<PreTokenizer::Ptr> create() const;

/**
* Populate from a json config file
Expand Down
36 changes: 24 additions & 12 deletions src/hf_tokenizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,11 @@ Error HFTokenizer::load(const std::string& path) {
try {
std::vector<std::pair<std::string, std::uint64_t>> special_token_pairs;
const auto& special_tokens = parsed_json.at("added_tokens");
auto special_token_map = TK_UNWRAP(detail::buildTokenMap(
special_tokens,
[](const auto& it) -> std::string { return it.at("content"); },
[](const auto& it) -> std::uint64_t { return it.at("id"); }));
auto special_token_map = TK_UNWRAP(
detail::buildTokenMap(
special_tokens,
[](const auto& it) -> std::string { return it.at("content"); },
[](const auto& it) -> std::uint64_t { return it.at("id"); }));

// Create special token regex to help later with encoding.
special_token_regex_ =
Expand Down Expand Up @@ -107,9 +108,12 @@ Error HFTokenizer::load(const std::string& path) {
// Set up the pre-tokenizer
try {
std::cout << "Setting up pretokenizer..." << std::endl;
_pretokenizer = PreTokenizerConfig()
.parse_json(parsed_json.at("pre_tokenizer"))
.create();
auto pretokenizer = PreTokenizerConfig()
.parse_json(parsed_json.at("pre_tokenizer"))
.create();
if (pretokenizer) {
_pretokenizer = *pretokenizer;
}
std::cout << "Pretokenizer set up" << std::endl;
} catch (const json::out_of_range& e) {
fprintf(stderr, "Could not parse pre_tokenizer: %s\n", e.what());
Expand Down Expand Up @@ -249,17 +253,25 @@ Error HFTokenizer::_encode(
const std::string& input,
std::vector<uint64_t>& ret,
uint64_t& last_piece_token_len) const {
for (const auto& piece : _pretokenizer->pre_tokenize(input)) {
auto encode_piece = [&](const std::string& piece) {
const auto result = token_map_->tryGetInteger(piece);
if (result) {
last_piece_token_len = 1;
ret.push_back(*result);
continue;
} else {
auto tokens = TK_UNWRAP(byte_pair_encode_(piece, *token_map_));
last_piece_token_len = tokens.size();
ret.insert(ret.end(), tokens.begin(), tokens.end());
}
auto tokens = TK_UNWRAP(byte_pair_encode_(piece, *token_map_));
};

last_piece_token_len = tokens.size();
ret.insert(ret.end(), tokens.begin(), tokens.end());
if (_pretokenizer) {
for (const auto& piece : (*_pretokenizer)->pre_tokenize(input)) {
encode_piece(piece);
}
} else {
// If no pretokenizer, treat the entire input as a single piece
encode_piece(input);
}
return Error::Ok;
}
Expand Down
14 changes: 10 additions & 4 deletions src/pre_tokenizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,13 @@ namespace tokenizers {
PreTokenizerConfig::PreTokenizerConfig(std::string type)
: type(std::move(type)) {}

PreTokenizer::Ptr PreTokenizerConfig::create() const {
std::optional<PreTokenizer::Ptr> PreTokenizerConfig::create() const {
// NOTE: These types must line up with the type strings found in the
// tokenizers library
// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/mod.rs#L73
if (type == "Split") {
if (!pattern) {
throw std::runtime_error(
"Missing pattern for PreTokenizer of type Split");
return std::nullopt; // Return nullopt if no pattern is provided
}
return PreTokenizer::Ptr(new RegexPreTokenizer(*pattern));
}
Expand Down Expand Up @@ -68,7 +67,14 @@ PreTokenizer::Ptr PreTokenizerConfig::create() const {
pretokenizers->begin(),
pretokenizers->end(),
std::back_inserter(pretoks),
[](const PreTokenizerConfig& cfg) { return cfg.create(); });
[](const PreTokenizerConfig& cfg) {
auto result = cfg.create();
if (!result) {
throw std::runtime_error(
"Failed to create pretokenizer in sequence");
}
return *result;
});
return PreTokenizer::Ptr(new SequencePreTokenizer(pretoks));
}
throw std::runtime_error("Unsupported PreTokenizer type: " + type);
Expand Down
83 changes: 45 additions & 38 deletions test/test_pre_tokenizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,32 +122,36 @@ class PreTokenizerConfigTest : public ::testing::Test {};

TEST_F(PreTokenizerConfigTest, AllTypesSuccess) {
// Regex
PreTokenizerConfig("Split").set_pattern(R"(o)").create();
EXPECT_TRUE(PreTokenizerConfig("Split").set_pattern(R"(o)").create());

// Digits
PreTokenizerConfig("Digits").create();
PreTokenizerConfig("Digits").set_individual_digits(true).create();
PreTokenizerConfig("Digits").set_individual_digits(false).create();
EXPECT_TRUE(PreTokenizerConfig("Digits").create());
EXPECT_TRUE(
PreTokenizerConfig("Digits").set_individual_digits(true).create());
EXPECT_TRUE(
PreTokenizerConfig("Digits").set_individual_digits(false).create());

// ByteLevel
PreTokenizerConfig("ByteLevel").create();
PreTokenizerConfig("ByteLevel").set_pattern(R"(o)").create();
PreTokenizerConfig("ByteLevel").set_add_prefix_space(true).create();
PreTokenizerConfig("ByteLevel")
.set_add_prefix_space(false)
.set_pattern(R"(o)")
.create();
EXPECT_TRUE(PreTokenizerConfig("ByteLevel").create());
EXPECT_TRUE(PreTokenizerConfig("ByteLevel").set_pattern(R"(o)").create());
EXPECT_TRUE(
PreTokenizerConfig("ByteLevel").set_add_prefix_space(true).create());
EXPECT_TRUE(PreTokenizerConfig("ByteLevel")
.set_add_prefix_space(false)
.set_pattern(R"(o)")
.create());

// Sequence
PreTokenizerConfig("Sequence")
.set_pretokenizers(
{PreTokenizerConfig("Digits"), PreTokenizerConfig("ByteLevel")})
.create();
EXPECT_TRUE(
PreTokenizerConfig("Sequence")
.set_pretokenizers(
{PreTokenizerConfig("Digits"), PreTokenizerConfig("ByteLevel")})
.create());
}

TEST_F(PreTokenizerConfigTest, AllTypesFailureCases) {
// Regex
EXPECT_THROW(PreTokenizerConfig("Split").create(), std::runtime_error);
EXPECT_FALSE(PreTokenizerConfig("Split").create());

// Sequence
EXPECT_THROW(PreTokenizerConfig("Sequence").create(), std::runtime_error);
Expand All @@ -167,20 +171,21 @@ TEST_F(PreTokenizerConfigTest, AllTypesFailureCases) {
TEST_F(PreTokenizerConfigTest, ParseJson) {
PreTokenizerConfig config;
const auto ptok = config
.parse_json(json{
{"type", "Sequence"},
{"pretokenizers",
json{
json{
{"type", "Digits"},
{"individual_digits", true},
},
.parse_json(
json{
{"type", "Sequence"},
{"pretokenizers",
json{
{"type", "ByteLevel"},
{"add_prefix_space", false},
},
}},
})
json{
{"type", "Digits"},
{"individual_digits", true},
},
json{
{"type", "ByteLevel"},
{"add_prefix_space", false},
},
}},
})
.create();
assert_split_match(
*ptok,
Expand All @@ -203,9 +208,10 @@ TEST_F(PreTokenizerConfigTest, ParseJson) {
TEST_F(PreTokenizerConfigTest, ParseJsonOptionalKey) {
PreTokenizerConfig config;
const auto ptok = config
.parse_json(json{
{"type", "Digits"},
})
.parse_json(
json{
{"type", "Digits"},
})
.create();
assert_split_match(
*ptok,
Expand All @@ -217,12 +223,13 @@ TEST_F(PreTokenizerConfigTest, Split) {
PreTokenizerConfig config;
const auto ptok =
config
.parse_json(json{
{"type", "Split"},
{"pattern",
{{"Regex",
R"((?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+)"}}},
})
.parse_json(
json{
{"type", "Split"},
{"pattern",
{{"Regex",
R"((?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+)"}}},
})
.create();
assert_split_match(*ptok, "Hello World", {"Hello", " World"});
}
Loading