[v0.5.10] Fix mask_utils for transformers >=5.0#925
[v0.5.10] Fix mask_utils for transformers >=5.0#925yueming-yuan merged 2 commits intobump-sglang-v0.5.10from
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a helper function _apply_chat_template_ids in miles/utils/mask_utils.py to ensure that apply_chat_template consistently returns a list of token IDs, addressing changes in transformers >= 5.0. The existing logic in get_system_message_length, gen_multi_turn_loss_mask_qwen, and gen_multi_turn_loss_mask_qwen3 has been updated to use this new helper. Feedback suggests simplifying the helper function by using the return_dict=False parameter or utilizing an existing utility function.
| def _apply_chat_template_ids(tokenizer, messages, **kwargs) -> list[int]: | ||
| """Wrapper that always returns list[int] from apply_chat_template(tokenize=True). | ||
|
|
||
| transformers >=5.0 returns BatchEncoding instead of list[int].""" | ||
| result = tokenizer.apply_chat_template(messages, tokenize=True, **kwargs) | ||
| if isinstance(result, list): | ||
| return result | ||
| return result["input_ids"] |
There was a problem hiding this comment.
The implementation of _apply_chat_template_ids can be simplified by using the return_dict=False parameter. This is the standard way in the transformers library to ensure that apply_chat_template returns a list of token IDs instead of a BatchEncoding object, making the code more idiomatic and robust across different versions. Alternatively, consider using the existing apply_chat_template utility from miles.utils.chat_template_utils.template which already handles this logic and provides additional normalization for tools and messages.
| def _apply_chat_template_ids(tokenizer, messages, **kwargs) -> list[int]: | |
| """Wrapper that always returns list[int] from apply_chat_template(tokenize=True). | |
| transformers >=5.0 returns BatchEncoding instead of list[int].""" | |
| result = tokenizer.apply_chat_template(messages, tokenize=True, **kwargs) | |
| if isinstance(result, list): | |
| return result | |
| return result["input_ids"] | |
| def _apply_chat_template_ids(tokenizer, messages, **kwargs) -> list[int]: | |
| """Wrapper that always returns list[int] from apply_chat_template(tokenize=True). | |
| transformers >=5.0 returns BatchEncoding instead of list[int] by default.""" | |
| return tokenizer.apply_chat_template( | |
| messages, tokenize=True, return_dict=False, **kwargs | |
| ) |
Summary
transformers5.x changedapply_chat_template(tokenize=True)to returnBatchEncodinginstead oflist[int]mask_utils.pyused direct slicing on the result, which broke (dict haslen()=2for its keys)_apply_chat_template_ids()wrapper that normalizes the return typeTest plan
test_loss_mask_qwen3_simplepassestest_loss_mask_qwen3_toolspasses