Skip to content

Commit 3eb02a8

Browse files
HuggingFace LogitsProcessor to accept also list of compiled grammars (#275)
My use case is that I need to follow a different grammar for each row in a batch. I couldn't find a way to do this natively with current API, so started hacking and came up with the changes in the PR. I'm not sure about the broader context, so would appreciate a comment if this approach indeed makes sense. ``` class CountryPoland(ClassSchema): country: Literal["Poland"] class CountryGermany(ClassSchema): country: Literal["Germany"] tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer, vocab_size=config.vocab_size) grammar_compiler = xgr.GrammarCompiler(tokenizer_info) compiled_grammars = [ grammar_compiler.compile_json_schema(CountryPoland), grammar_compiler.compile_json_schema(CountryGermany), ] logit_processor = LogitsProcessor(compiled_grammars) logit_processor = LogitsProcessorList([logit_processor]) model.generate(**model_inputs, max_new_tokens=512, logits_processor=logit_processor) ``` [Linked issue](#276).
1 parent d15a616 commit 3eb02a8

File tree

1 file changed

+17
-7
lines changed
  • python/xgrammar/contrib

1 file changed

+17
-7
lines changed

python/xgrammar/contrib/hf.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
transformers.LogitsProcessor, which is to be fed to `model.generate()`.
44
"""
55

6-
from typing import List
6+
from typing import List, Union
77

88
import torch
99
import transformers
@@ -40,17 +40,19 @@ class LogitsProcessor(transformers.LogitsProcessor):
4040
- Note that this implementation may contain extra overhead.
4141
"""
4242

43-
def __init__(self, compiled_grammar: xgr.CompiledGrammar):
43+
def __init__(self, compiled_grammar: Union[xgr.CompiledGrammar, List[xgr.CompiledGrammar]]):
4444
"""Initialize the LogitsProcessor.
4545
4646
Parameters
4747
----------
48-
compiled_grammar : xgr.CompiledGrammar
49-
A grammar compiled according to the given grammar and the model's tokenizer_info.
48+
compiled_grammar : xgr.CompiledGrammar | List[xgr.CompiledGrammar]
49+
One or more grammars compiled according to the given grammar and the model's tokenizer_info.
5050
"""
5151
self.matchers: List[xgr.GrammarMatcher] = []
52-
self.compiled_grammar = compiled_grammar
53-
self.full_vocab_size = self.compiled_grammar.tokenizer_info.vocab_size
52+
self.compiled_grammars: List[xgr.CompiledGrammar] = (
53+
compiled_grammar if isinstance(compiled_grammar, list) else [compiled_grammar]
54+
)
55+
self.full_vocab_size = self.compiled_grammars[0].tokenizer_info.vocab_size
5456
self.token_bitmask = None
5557
self.prefilled = False
5658
self.batch_size = 0
@@ -65,8 +67,16 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
6567
# Lazily initialize GrammarMatchers and bitmask
6668
if len(self.matchers) == 0:
6769
self.batch_size = input_ids.shape[0]
70+
self.compiled_grammars = (
71+
self.compiled_grammars
72+
if len(self.compiled_grammars) > 1
73+
else self.compiled_grammars * self.batch_size
74+
)
75+
assert (
76+
len(self.compiled_grammars) == self.batch_size
77+
), "The number of compiled grammars must be equal to the batch size."
6878
self.matchers = [
69-
xgr.GrammarMatcher(self.compiled_grammar) for _ in range(self.batch_size)
79+
xgr.GrammarMatcher(self.compiled_grammars[i]) for i in range(self.batch_size)
7080
]
7181
self.token_bitmask = xgr.allocate_token_bitmask(self.batch_size, self.full_vocab_size)
7282

0 commit comments

Comments
 (0)