Skip to content

[v0.5.10][2] Fix apply_chat_template behavior for transformers >=5.0#926

Open
yueming-yuan wants to merge 5 commits intobump-sglang-v0.5.10from
fix/mask-utils-transformers-v5-v2
Open

[v0.5.10][2] Fix apply_chat_template behavior for transformers >=5.0#926
yueming-yuan wants to merge 5 commits intobump-sglang-v0.5.10from
fix/mask-utils-transformers-v5-v2

Conversation

@yueming-yuan
Copy link
Copy Markdown
Collaborator

@yueming-yuan yueming-yuan commented Apr 6, 2026

ci-sglang-pr: sglang-miles-v0.5.10

Summary

  • transformers 5.x changed apply_chat_template(tokenize=True) to return BatchEncoding instead of list[int]
  • Added _apply_chat_template_ids() wrapper that normalizes the return type

Replaces #925 (was merged then reverted).

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a utility function _apply_chat_template_ids to handle changes in the transformers library (version 5.0+) where apply_chat_template may return a dictionary instead of a list. All existing calls to the tokenizer have been updated to use this wrapper. Feedback suggests adding type hints to the new function and explicitly setting return_dict=False to improve robustness and maintainability.

Comment on lines +9 to +16
def _apply_chat_template_ids(tokenizer, messages, **kwargs) -> list[int]:
"""Wrapper that always returns list[int] from apply_chat_template(tokenize=True).

transformers >=5.0 returns BatchEncoding instead of list[int]."""
result = tokenizer.apply_chat_template(messages, tokenize=True, **kwargs)
if isinstance(result, list):
return result
return result["input_ids"]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _apply_chat_template_ids wrapper is a good addition for compatibility with transformers 5.0. However, it can be improved by adding type hints for better maintainability and consistency with the rest of the file. Also, explicitly setting return_dict=False via kwargs.setdefault ensures that current versions of transformers return the expected list type, while the isinstance check provides a robust fallback for future versions where the default might change or the flag might be ignored.

Note: Passing tokenize in kwargs to this function will cause a TypeError because it is already explicitly passed to apply_chat_template.

Suggested change
def _apply_chat_template_ids(tokenizer, messages, **kwargs) -> list[int]:
"""Wrapper that always returns list[int] from apply_chat_template(tokenize=True).
transformers >=5.0 returns BatchEncoding instead of list[int]."""
result = tokenizer.apply_chat_template(messages, tokenize=True, **kwargs)
if isinstance(result, list):
return result
return result["input_ids"]
def _apply_chat_template_ids(tokenizer: AutoTokenizer, messages: list[dict], **kwargs) -> list[int]:
"""Wrapper that always returns list[int] from apply_chat_template(tokenize=True).
transformers >=5.0 returns BatchEncoding instead of list[int]."""
kwargs.setdefault("return_dict", False)
result = tokenizer.apply_chat_template(messages, tokenize=True, **kwargs)
if isinstance(result, list):
return result
return result["input_ids"]

@yueming-yuan yueming-yuan changed the title Fix mask_utils for transformers >=5.0 [v0.5.10] [2] Fix mask_utils for transformers >=5.0 Apr 6, 2026
@yueming-yuan yueming-yuan changed the title [v0.5.10] [2] Fix mask_utils for transformers >=5.0 [v0.5.10] [2] Fix apply_chat_template behavior for transformers >=5.0 Apr 6, 2026
@yueming-yuan yueming-yuan changed the title [v0.5.10] [2] Fix apply_chat_template behavior for transformers >=5.0 [v0.5.10][2] Fix apply_chat_template behavior for transformers >=5.0 Apr 6, 2026
@guapisolo guapisolo requested a review from yushengsu-thu as a code owner April 8, 2026 00:07
guapisolo and others added 2 commits April 8, 2026 00:08
Remove models broken by transformers v5 tokenizer unification
(DeepSeek-V3, step3, glm-4-9b-chat) and track them in a
TOOL_CALL_KNOWN_FAILURES list with root cause comments. Add new
passing models: Qwen3.5, Qwen3-Coder-Next, GLM-4.7-Flash, Kimi-K2.5,
MiniMax-M2.5, Nemotron-3-Super. Clean up debug helpers.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
transformers >=5.0 changed apply_chat_template(tokenize=True) to
return BatchEncoding instead of list[int]. Pass return_dict=False
to all 6 call sites in mask_utils.py to ensure list[int] on both
v4 and v5.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@guapisolo guapisolo force-pushed the fix/mask-utils-transformers-v5-v2 branch from 775a86c to f058d37 Compare April 8, 2026 00:08
Move Step-3.5-Flash from known failures into active tool-call test models, and clarify comments for remaining transformers v5 tokenizer/template incompatibilities.

Made-with: Cursor
@guapisolo
Copy link
Copy Markdown
Collaborator

A report generated by cc & codex and briefly reviewed by me. generally make sense.

Transformers v5 Tokenizer Compatibility Analysis

Background

transformers==5.3.0 introduced a unified tokenizer architecture that merges the old "slow" (Python) and "fast" (Rust) tokenizer backends. As of April 8, 2026, under the current test matrix this leaves 4 known failures (8 parametrized cases across num_tools in {1, 2}) that previously passed on transformers==4.57.1.

test_tokenize_tool_responses validates that tokenize_tool_responses produces the correct token delta by checking:

decode(apply_chat_template(tokenize=True)[delta]) == apply_chat_template(tokenize=False)[delta]

i.e. the decode of the token delta should equal the text delta. For a few models, that assumption effectively relies on decode(encode(text)) == text, which no longer holds under v5.

This document separates:

  • genuine tokenizer behavior changes that break the round-trip assumption
  • an API return-type change in apply_chat_template(tokenize=True) that requires a small compatibility fix

Root Cause 1: LlamaTokenizer Overwrites ByteLevel with Metaspace

Directly affected models: deepseek-ai/DeepSeek-V3, stepfun-ai/step3

What changed

In v4, AutoTokenizer.from_pretrained('deepseek-ai/DeepSeek-V3') loaded LlamaTokenizerFast, which read tokenizer.json directly and preserved the original configuration:

pre_tokenizer: ByteLevel(add_prefix_space=False)      # GPT-2 style
decoder:       ByteLevel(add_prefix_space=True)        # preserves special chars as-is

In v5, LlamaTokenizerFast became an alias for LlamaTokenizer. The unified LlamaTokenizer has a custom __init__ (transformers/models/llama/tokenization_llama.py:116-133) that rebuilds the Rust tokenizer from scratch using only vocab/merges extracted from tokenizer.json, and unconditionally hardcodes:

self._tokenizer.pre_tokenizer = Metaspace(replacement="▁", prepend_scheme=always)
self._tokenizer.decoder = Sequence([Replace("▁", " "), ByteFallback(), Fuse(), Strip()])

This happens because convert_to_native_format() (transformers/tokenization_utils_tokenizers.py:118) takes the elif branch for classes with a custom __init__, extracting only vocab/merges/post_processor but not pre_tokenizer or decoder from tokenizer.json.

Consequences

Encoding changed -- the Metaspace pre_tokenizer handles spaces differently from ByteLevel:

v4 (ByteLevel): '{"year": 2026}' → ['{"', 'year', '":', 'Ġ', '202', '6', '}']  (7 tokens)
v5 (Metaspace): '{"year": 2026}' → ['{"', 'year', '":', '202', '6', '}']        (6 tokens, space lost)

Decoding changed -- Replace("▁", " ") in the decoder replaces in special token names:

v4: decode([128810]) → '<|tool▁outputs▁begin|>'   # original character preserved
v5: decode([128810]) → '<|tool outputs begin|>'    # ▁ replaced with space

Note: the token's stored text (id_to_token) is identical in both versions (<|tool▁outputs▁begin|>). The difference is purely in the decode pipeline output.

Why this explains the failing models

  • deepseek-ai/DeepSeek-V3 loads LlamaTokenizer under v5, reproduces the ByteLevel -> Metaspace overwrite, and fails because decoded tool-output text loses formatting details such as spaces and in special token names.
  • stepfun-ai/step3 also loads LlamaTokenizer under v5. Its visible failure string differs from DeepSeek because it has a different chat template, but the symptom is the same category: tokenization/decoding no longer preserves the exact text delta expected by the test.
  • deepseek-ai/DeepSeek-V3.1 also loads LlamaTokenizer, so it remains in the same general risk family. However, its current tool-call chat template fails even earlier than the decode comparison: the template concatenates tool['function']['arguments'] as a string, while our dummy tool_calls[*].function.arguments is a dict, which raises TypeError: can only concatenate str (not "dict") to str.

Upstream references


Root Cause 2: Legacy _decode Segmentation Removed in ChatGLM4Tokenizer

Affected model: THUDM/glm-4-9b-chat

What changed

In v4, ChatGLM4Tokenizer inherited from PreTrainedTokenizer, whose _decode method had legacy special-token segmentation logic:

# v4 PreTrainedTokenizer._decode (simplified):
for token in filtered_tokens:
    if token in legacy_added_tokens:
        # flush current non-special tokens through convert_tokens_to_string()
        # insert the special token as a literal string
    else:
        current_sub_text.append(token)

This meant convert_tokens_to_string() never received str-type special tokens -- only byte tokens.

In v5, ChatGLM4Tokenizer inherits from PythonBackend, whose _decode is simplified:

# v5 PythonBackend._decode:
filtered_tokens = self.convert_ids_to_tokens(token_ids)
text = self.convert_tokens_to_string(filtered_tokens)  # ALL tokens, including special ones

The bug in GLM's custom tokenizer

GLM's convert_tokens_to_string was never designed to handle str-type tokens:

def convert_tokens_to_string(self, tokens):
    text = ""
    temp = b""
    for t in tokens:
        if isinstance(t, str):    # e.g. '<|assistant|>'
            if temp:
                text += temp.decode("utf-8", errors="replace")
            # BUG 1: does not append t (the special token text) to output
            # BUG 2: does not reset temp = b""
        elif isinstance(t, bytes):
            temp += t
    if temp:
        text += temp.decode(...)  # temp was not reset, so content is appended again
    return text

Consequence

Input tokens:  [b'<', b'|', b'tool', b'|', b'>\n', b'{"', b'year', b'":', b' ', b'202', b'6', b'}', '<|assistant|>']

v4 output: '<|tool|>\n{"year": 2026} <|assistant|>'      # correct (legacy segmentation)
v5 output: '<|tool|>\n{"year": 2026}<|tool|>\n{"year": 2026}'  # content doubled, <|assistant|> lost

Why Passing Models Still Work

Category 1: Custom __init__ rebuilds a ByteLevel-compatible tokenizer

Models: Qwen2.5-0.5B-Instruct, Qwen3-0.6B, Qwen3-4B-Instruct-2507, Qwen3-Coder-30B-A3B-Instruct, Qwen3.5-0.8B, Qwen3-Coder-Next, MiMo-7B-RL, MiniMax-M2, MiniMax-M2.5

These tokenizers define a custom __init__, so v5 reconstructs the Rust backend instead of loading the full tokenizer pipeline directly from tokenizer.json. However, unlike LlamaTokenizer, their hardcoded pipeline is still compatible with the model's original tokenizer config, so the overwrite is effectively a no-op.

  • Qwen2Tokenizer hardcodes a Sequence([Split(...), ByteLevel(...)]) pre_tokenizer plus a ByteLevel decoder.
  • GPT2Tokenizer hardcodes a ByteLevel pre_tokenizer plus a ByteLevel decoder.

In both cases, the effective encoding and decoding behavior remains aligned with tokenizer.json, so the tool-response round-trip still works.

Category 2: TokenizersBackend loaded directly -- no __init__ overwrite

Models: Mistral-7B-Instruct-v0.3, GLM-4.7-Flash, Step-3.5-Flash, NVIDIA-Nemotron-3-Super-120B-A12B-BF16

These models either use TokenizersBackend directly (no subclass __init__) or their tokenizer class doesn't override __init__. convert_to_native_format() takes the if branch that loads the Rust tokenizer directly from tokenizer.json, preserving all configuration including pre_tokenizer and decoder.

This is where zai-org/GLM-4.7-Flash belongs: unlike THUDM/glm-4-9b-chat, it does not rely on the old custom Python ChatGLM4Tokenizer decode path. It uses a direct Rust-backed tokenizer with a ByteLevel decoder, so the GLM-specific bug above does not apply.

stepfun-ai/Step-3.5-Flash now also belongs here. In the current revision it loads TokenizersBackend directly, not LlamaTokenizer, and the tool-response round-trip passes.

Category 3: PythonBackend custom tokenizers without bugs

Models: Kimi-K2-Instruct, Kimi-K2.5 (TikTokenTokenizer), internlm3-8b-instruct (InternLM3Tokenizer)

These are pure Python tokenizers that don't use the Rust decode pipeline at all. They also don't have the convert_tokens_to_string bug that GLM has.


Additional Compatibility Issue: apply_chat_template Return Type Change

Independently from the decode issues above, transformers v5 changed apply_chat_template(tokenize=True) to return BatchEncoding (a dict with input_ids and attention_mask) instead of list[int].

This broke mask_utils.py, where 6 call sites assumed the return type was list[int].

Fix: Add return_dict=False to all apply_chat_template(tokenize=True) calls. This parameter is supported in both v4 (no-op, already returns list[int]) and v5 (forces list[int] return).


Summary of Known Failures

Model Root Cause Upstream Issue
deepseek-ai/DeepSeek-V3 LlamaTokenizer overwrites ByteLevel with Metaspace #43066
deepseek-ai/DeepSeek-V3.1 Tool-call chat template expects string function.arguments; current dummy tool-call shape provides a dict Model-side template issue
stepfun-ai/step3 Same as above Same
THUDM/glm-4-9b-chat v5 removed legacy _decode segmentation, exposing custom tokenizer bug N/A (model-side bug)

Summary of Passing Models

Model Tokenizer Class Backend Decoder Why Unaffected
Qwen2.5-0.5B-Instruct Qwen2Tokenizer TokenizersBackend ByteLevel Hardcoded ByteLevel matches tokenizer.json
Qwen3-0.6B Qwen2Tokenizer TokenizersBackend ByteLevel Same
Qwen3-4B-Instruct-2507 Qwen2Tokenizer TokenizersBackend ByteLevel Same
Qwen3-Coder-30B-A3B-Instruct Qwen2Tokenizer TokenizersBackend ByteLevel Same
Qwen3.5-0.8B Qwen2Tokenizer TokenizersBackend ByteLevel Same
Qwen3-Coder-Next Qwen2Tokenizer TokenizersBackend ByteLevel Same
Mistral-7B-Instruct-v0.3 TokenizersBackend TokenizersBackend Metaspace Direct load from tokenizer.json, no overwrite
GLM-4.7-Flash TokenizersBackend TokenizersBackend ByteLevel Direct load from tokenizer.json; does not use the old ChatGLM Python decode path
Step-3.5-Flash TokenizersBackend TokenizersBackend ByteLevel-compatible Direct load; passes tool-response round-trip as of April 8, 2026
Nemotron-3-Super-120B TokenizersBackend TokenizersBackend ByteLevel Direct load from tokenizer.json, no overwrite
MiniMax-M2 GPT2Tokenizer TokenizersBackend ByteLevel Hardcoded ByteLevel matches tokenizer.json
MiniMax-M2.5 GPT2Tokenizer TokenizersBackend ByteLevel Same
internlm3-8b-instruct InternLM3Tokenizer PythonBackend N/A Pure Python tokenizer, no Rust decode, no bug
Kimi-K2-Instruct TikTokenTokenizer PythonBackend N/A Pure Python tokenizer, no Rust decode, no bug
Kimi-K2.5 TikTokenTokenizer PythonBackend N/A Same
MiMo-7B-RL Qwen2Tokenizer TokenizersBackend ByteLevel Hardcoded ByteLevel matches tokenizer.json

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants