Skip to content

Commit d15a616

Browse files
Ubospicacyx-6
andauthored
[Feature] Cover JSON schema string format (#266)
This PR covers the most string formats in JSON schema, except "idn-email", "idn-hostname", "iri", and "iri-reference". Co-authored-by: Yaxing Cai <[email protected]>
1 parent 2408ba0 commit d15a616

18 files changed

+1546
-401
lines changed

cpp/grammar_builder.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,15 @@ class GrammarBuilder {
2828
/*! \brief Default constructor. Creates a new grammar object. */
2929
GrammarBuilder() : grammar_(std::make_shared<Grammar::Impl>()) {}
3030

31+
/*! \brief Constructor. Creates a new grammar object from an existing grammar. */
32+
GrammarBuilder(const Grammar& grammar)
33+
: grammar_(std::make_shared<Grammar::Impl>(*grammar.operator->())) {
34+
for (int i = 0; i < static_cast<int>(grammar->NumRules()); ++i) {
35+
auto rule = grammar->GetRule(i);
36+
rule_name_to_id_[rule.name] = i;
37+
}
38+
}
39+
3140
/*!
3241
* \brief Get the result grammar. This function will also set the root rule to the rule with the
3342
* specified name. The rule should be already added to the grammar.

cpp/grammar_compiler.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ bool GrammarMatcherForTokenMaskCache::IsTokenPassLookaheadAssertion(
249249

250250
// Find all positions that can come to and end. Then check if the suffix from that position
251251
// can be accepted by the lookahead assertion.
252-
for (int i = static_cast<int>(can_reach_end_stack.size()); i >= 0; --i) {
252+
for (int i = static_cast<int>(can_reach_end_stack.size()) - 1; i >= 0; --i) {
253253
if (!can_reach_end_stack[i]) {
254254
continue;
255255
}

cpp/grammar_functor.cc

Lines changed: 280 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
#include <algorithm>
1111
#include <queue>
12+
#include <set>
1213
#include <unordered_set>
1314
#include <vector>
1415

@@ -284,7 +285,17 @@ class NestedRuleUnwrapper : public GrammarMutator {
284285
}
285286
};
286287

287-
class ByteStringFuser : public GrammarMutator {
288+
class StructureNormalizerImpl : public GrammarMutator {
289+
public:
290+
using GrammarMutator::Apply;
291+
using GrammarMutator::GrammarMutator;
292+
293+
Grammar Apply(const Grammar& grammar) final {
294+
return NestedRuleUnwrapper().Apply(SingleElementExprEliminator().Apply(grammar));
295+
}
296+
};
297+
298+
class ByteStringFuserImpl : public GrammarMutator {
288299
public:
289300
using GrammarMutator::Apply;
290301
using GrammarMutator::GrammarMutator;
@@ -316,6 +327,249 @@ class ByteStringFuser : public GrammarMutator {
316327
return builder_.AddSequence(new_sequence_ids);
317328
}
318329
};
330+
331+
class RuleInlinerImpl : public GrammarMutator {
332+
public:
333+
using GrammarMutator::Apply;
334+
using GrammarMutator::GrammarMutator;
335+
336+
private:
337+
int32_t VisitChoices(const RuleExpr& rule_expr) final {
338+
std::vector<int32_t> new_choice_ids;
339+
for (int i : rule_expr) {
340+
auto choice_expr = base_grammar_->GetRuleExpr(i);
341+
if (choice_expr.type == RuleExprType::kEmptyStr) {
342+
new_choice_ids.push_back(VisitExpr(i));
343+
continue;
344+
}
345+
XGRAMMAR_ICHECK(choice_expr.type == RuleExprType::kSequence);
346+
auto first_element = base_grammar_->GetRuleExpr(choice_expr[0]);
347+
if (first_element.type != RuleExprType::kRuleRef) {
348+
new_choice_ids.push_back(VisitExpr(choice_expr));
349+
continue;
350+
}
351+
auto rule_ref_id = first_element[0];
352+
if (can_rule_be_inlined_.count(rule_ref_id) == 0) {
353+
can_rule_be_inlined_[rule_ref_id] = CheckIfRuleCanBeInlined(rule_ref_id);
354+
}
355+
if (!can_rule_be_inlined_[rule_ref_id]) {
356+
new_choice_ids.push_back(VisitExpr(choice_expr));
357+
continue;
358+
}
359+
360+
// Do inlining
361+
std::vector<int32_t> other_elements;
362+
for (int i = 1; i < choice_expr.size(); ++i) {
363+
other_elements.push_back(VisitExpr(choice_expr[i]));
364+
}
365+
366+
auto ref_rule = base_grammar_->GetRule(rule_ref_id);
367+
auto ref_rule_expr = base_grammar_->GetRuleExpr(ref_rule.body_expr_id);
368+
369+
for (auto ref_choice_id : ref_rule_expr) {
370+
auto ref_choice_expr = base_grammar_->GetRuleExpr(ref_choice_id);
371+
XGRAMMAR_ICHECK(ref_choice_expr.type == RuleExprType::kSequence);
372+
std::vector<int32_t> choice_to_add;
373+
for (auto ref_element_id : ref_choice_expr) {
374+
choice_to_add.push_back(VisitExpr(ref_element_id));
375+
}
376+
choice_to_add.insert(choice_to_add.end(), other_elements.begin(), other_elements.end());
377+
new_choice_ids.push_back(builder_.AddSequence(choice_to_add));
378+
}
379+
}
380+
return builder_.AddChoices(new_choice_ids);
381+
}
382+
383+
/**
384+
* The rule should be: a sequence of choices, cannot be empty, cannot refer to other rules
385+
*/
386+
bool CheckIfRuleCanBeInlined(int32_t rule_id) {
387+
auto rule = base_grammar_->GetRule(rule_id);
388+
auto rule_expr = base_grammar_->GetRuleExpr(rule.body_expr_id);
389+
if (rule_expr.type != RuleExprType::kChoices) {
390+
return false;
391+
}
392+
if (rule_expr.size() == 0) {
393+
return false;
394+
}
395+
for (auto choice_id : rule_expr) {
396+
auto choice_expr = base_grammar_->GetRuleExpr(choice_id);
397+
if (choice_expr.type == RuleExprType::kEmptyStr) {
398+
return false;
399+
}
400+
XGRAMMAR_ICHECK(choice_expr.type == RuleExprType::kSequence);
401+
for (auto element_id : choice_expr) {
402+
auto element_expr = base_grammar_->GetRuleExpr(element_id);
403+
if (element_expr.type == RuleExprType::kRuleRef) {
404+
return false;
405+
}
406+
}
407+
}
408+
return true;
409+
}
410+
411+
std::unordered_map<int32_t, bool> can_rule_be_inlined_;
412+
};
413+
414+
/*!
415+
* \brief Analyze all referenced rules or the main rule. Return a list of all referenced rule ids.
416+
* This is useful for dead code elimination.
417+
*/
418+
class UsedRulesAnalyzer : public GrammarVisitor<std::vector<int32_t>> {
419+
public:
420+
UsedRulesAnalyzer() = default;
421+
422+
std::vector<int32_t> Apply(const Grammar& grammar) final {
423+
base_grammar_ = grammar;
424+
425+
std::set<int32_t> visited;
426+
427+
std::queue<int32_t>().swap(visit_queue_);
428+
429+
visit_queue_.push(base_grammar_->GetRootRuleId());
430+
while (!visit_queue_.empty()) {
431+
auto rule_id = visit_queue_.front();
432+
visit_queue_.pop();
433+
if (visited.count(rule_id)) {
434+
continue;
435+
}
436+
visited.insert(rule_id);
437+
auto rule = base_grammar_->GetRule(rule_id);
438+
VisitExpr(rule.body_expr_id);
439+
}
440+
441+
return std::vector<int32_t>(visited.begin(), visited.end());
442+
}
443+
444+
void VisitTagDispatch(const RuleExpr& rule_expr) {
445+
for (int i = 0; i < rule_expr.size(); i += 2) {
446+
visit_queue_.push(rule_expr[i + 1]);
447+
}
448+
}
449+
450+
void VisitRuleRef(const RuleExpr& rule_expr) { visit_queue_.push(rule_expr[0]); }
451+
452+
private:
453+
std::queue<int32_t> visit_queue_;
454+
};
455+
456+
class DeadCodeEliminatorImpl : public GrammarMutator {
457+
public:
458+
using GrammarMutator::Apply;
459+
using GrammarMutator::GrammarMutator;
460+
461+
Grammar Apply(const Grammar& grammar) final {
462+
Init(grammar);
463+
auto used_rules = UsedRulesAnalyzer().Apply(grammar);
464+
rule_id_map_.clear();
465+
for (auto rule_id : used_rules) {
466+
rule_id_map_[rule_id] = builder_.AddEmptyRule(grammar->GetRule(rule_id).name);
467+
}
468+
for (auto rule_id : used_rules) {
469+
auto rule = grammar->GetRule(rule_id);
470+
auto new_body_expr_id = VisitExpr(rule.body_expr_id);
471+
builder_.UpdateRuleBody(rule_id_map_[rule_id], new_body_expr_id);
472+
builder_.AddLookaheadAssertion(
473+
rule_id_map_[rule_id], VisitLookaheadAssertion(rule.lookahead_assertion_id)
474+
);
475+
}
476+
XGRAMMAR_CHECK(rule_id_map_.count(grammar->GetRootRuleId()) > 0);
477+
return builder_.Get(rule_id_map_[grammar->GetRootRuleId()]);
478+
}
479+
480+
int32_t VisitTagDispatch(const RuleExpr& rule_expr) final {
481+
std::vector<std::pair<int32_t, int32_t>> tag_dispatch_list;
482+
for (int i = 0; i < rule_expr.size(); i += 2) {
483+
XGRAMMAR_DCHECK(rule_id_map_.count(rule_expr[i + 1]) > 0);
484+
auto new_rule_id = rule_id_map_[rule_expr[i + 1]];
485+
tag_dispatch_list.push_back({VisitExpr(rule_expr[i]), new_rule_id});
486+
}
487+
return builder_.AddTagDispatch(tag_dispatch_list);
488+
}
489+
490+
int32_t VisitRuleRef(const RuleExpr& rule_expr) final {
491+
XGRAMMAR_DCHECK(rule_id_map_.count(rule_expr[0]) > 0);
492+
auto new_rule_id = rule_id_map_[rule_expr[0]];
493+
return builder_.AddRuleRef(new_rule_id);
494+
}
495+
496+
private:
497+
std::unordered_map<int32_t, int32_t> rule_id_map_;
498+
};
499+
500+
class LookaheadAssertionAnalyzerImpl : public GrammarMutator {
501+
public:
502+
using GrammarMutator::GrammarMutator;
503+
504+
Grammar Apply(const Grammar& grammar) final {
505+
InitWithCopy(grammar);
506+
auto root_rule = grammar->GetRootRule();
507+
auto root_rule_expr = base_grammar_->GetRuleExpr(root_rule.body_expr_id);
508+
if (root_rule_expr.type == RuleExprType::kTagDispatch) {
509+
return grammar;
510+
}
511+
for (int i = 0; i < static_cast<int>(grammar->NumRules()); ++i) {
512+
auto rule = grammar->GetRule(i);
513+
if (i == grammar->GetRootRuleId() || rule.lookahead_assertion_id != -1) {
514+
continue;
515+
}
516+
auto look_head_assertion_id = DetectLookaheadAssertion(i);
517+
if (look_head_assertion_id != -1) {
518+
builder_.AddLookaheadAssertion(i, look_head_assertion_id);
519+
}
520+
}
521+
return builder_.Get(grammar->GetRootRuleId());
522+
}
523+
524+
int32_t DetectLookaheadAssertion(int32_t rule_id) {
525+
std::vector<int32_t> found_sequence; // Element ids
526+
bool found = false;
527+
for (int i = 0; i < static_cast<int>(base_grammar_->NumRules()); ++i) {
528+
auto rule = base_grammar_->GetRule(i);
529+
auto rule_expr = base_grammar_->GetRuleExpr(rule.body_expr_id);
530+
if (rule_expr.type == RuleExprType::kTagDispatch) {
531+
for (int j = 1; j < rule_expr.size(); j += 2) {
532+
if (rule_expr[j] == rule_id) {
533+
return -1;
534+
}
535+
}
536+
continue;
537+
}
538+
XGRAMMAR_DCHECK(rule_expr.type == RuleExprType::kChoices);
539+
for (auto sequence_id : rule_expr) {
540+
auto sequence_expr = base_grammar_->GetRuleExpr(sequence_id);
541+
if (sequence_expr.type != RuleExprType::kSequence) {
542+
continue;
543+
}
544+
auto last_element = base_grammar_->GetRuleExpr(sequence_expr.end()[-1]);
545+
if (last_element.type == RuleExprType::kRuleRef && last_element[0] == rule_id &&
546+
i != rule_id) {
547+
return -1;
548+
}
549+
550+
for (int j = 0; j < sequence_expr.size() - 1; ++j) {
551+
auto element_expr = base_grammar_->GetRuleExpr(sequence_expr[j]);
552+
if (element_expr.type != RuleExprType::kRuleRef || element_expr[0] != rule_id) {
553+
continue;
554+
}
555+
if (found) {
556+
return -1;
557+
}
558+
found = true;
559+
for (int k = j + 1; k < sequence_expr.size(); ++k) {
560+
found_sequence.push_back(sequence_expr[k]);
561+
}
562+
}
563+
}
564+
}
565+
566+
if (!found) {
567+
return -1;
568+
}
569+
return builder_.AddSequence(found_sequence);
570+
}
571+
};
572+
319573
/*!
320574
* \brief A class that normalizes a grammar by applying a series of transformations.
321575
*
@@ -341,9 +595,11 @@ class GrammarNormalizerImpl : public GrammarMutator {
341595
// Return the list of all normalizers in the class. The normalizers are applied one by one.
342596
std::vector<std::unique_ptr<GrammarMutator>> GetNormalizerList() {
343597
std::vector<std::unique_ptr<GrammarMutator>> normalizer_mutators;
344-
normalizer_mutators.emplace_back(std::make_unique<SingleElementExprEliminator>());
345-
normalizer_mutators.emplace_back(std::make_unique<NestedRuleUnwrapper>());
346-
normalizer_mutators.emplace_back(std::make_unique<ByteStringFuser>());
598+
normalizer_mutators.emplace_back(std::make_unique<StructureNormalizerImpl>());
599+
normalizer_mutators.emplace_back(std::make_unique<ByteStringFuserImpl>());
600+
normalizer_mutators.emplace_back(std::make_unique<RuleInlinerImpl>());
601+
normalizer_mutators.emplace_back(std::make_unique<DeadCodeEliminatorImpl>());
602+
normalizer_mutators.emplace_back(std::make_unique<LookaheadAssertionAnalyzerImpl>());
347603
return normalizer_mutators;
348604
}
349605
};
@@ -515,8 +771,8 @@ class AllowEmptyRuleAnalyzerImpl : public GrammarVisitor<std::vector<int32_t>> {
515771
std::unordered_set<int32_t> empty_rule_id_set;
516772
FindExplicitEmptyRules(&empty_rule_id_set);
517773

518-
// Step 2: Find rules that indirectly allow empty string. Using the Bellman-Ford algorithm on
519-
// the rule reference graph.
774+
// Step 2: Find rules that indirectly allow empty string. Using the Bellman-Ford algorithm
775+
// on the rule reference graph.
520776
std::vector<std::vector<int32_t>> rule_ref_graph = RuleRefGraphFinder().Apply(grammar);
521777
FindIndirectEmptyRules(&empty_rule_id_set, rule_ref_graph);
522778

@@ -703,4 +959,22 @@ Grammar StructuralTagGrammarCreator::Apply(
703959
return StructuralTagGrammarCreatorImpl().Apply(triggers, tag_groups);
704960
}
705961

962+
Grammar RuleInliner::Apply(const Grammar& grammar) { return RuleInlinerImpl().Apply(grammar); }
963+
964+
Grammar ByteStringFuser::Apply(const Grammar& grammar) {
965+
return ByteStringFuserImpl().Apply(grammar);
966+
}
967+
968+
Grammar DeadCodeEliminator::Apply(const Grammar& grammar) {
969+
return DeadCodeEliminatorImpl().Apply(grammar);
970+
}
971+
972+
Grammar StructureNormalizer::Apply(const Grammar& grammar) {
973+
return StructureNormalizerImpl().Apply(grammar);
974+
}
975+
976+
Grammar LookaheadAssertionAnalyzer::Apply(const Grammar& grammar) {
977+
return LookaheadAssertionAnalyzerImpl().Apply(grammar);
978+
}
979+
706980
} // namespace xgrammar

0 commit comments

Comments
 (0)