9
9
10
10
#include < algorithm>
11
11
#include < queue>
12
+ #include < set>
12
13
#include < unordered_set>
13
14
#include < vector>
14
15
@@ -284,7 +285,17 @@ class NestedRuleUnwrapper : public GrammarMutator {
284
285
}
285
286
};
286
287
287
- class ByteStringFuser : public GrammarMutator {
288
+ class StructureNormalizerImpl : public GrammarMutator {
289
+ public:
290
+ using GrammarMutator::Apply;
291
+ using GrammarMutator::GrammarMutator;
292
+
293
+ Grammar Apply (const Grammar& grammar) final {
294
+ return NestedRuleUnwrapper ().Apply (SingleElementExprEliminator ().Apply (grammar));
295
+ }
296
+ };
297
+
298
+ class ByteStringFuserImpl : public GrammarMutator {
288
299
public:
289
300
using GrammarMutator::Apply;
290
301
using GrammarMutator::GrammarMutator;
@@ -316,6 +327,249 @@ class ByteStringFuser : public GrammarMutator {
316
327
return builder_.AddSequence (new_sequence_ids);
317
328
}
318
329
};
330
+
331
+ class RuleInlinerImpl : public GrammarMutator {
332
+ public:
333
+ using GrammarMutator::Apply;
334
+ using GrammarMutator::GrammarMutator;
335
+
336
+ private:
337
+ int32_t VisitChoices (const RuleExpr& rule_expr) final {
338
+ std::vector<int32_t > new_choice_ids;
339
+ for (int i : rule_expr) {
340
+ auto choice_expr = base_grammar_->GetRuleExpr (i);
341
+ if (choice_expr.type == RuleExprType::kEmptyStr ) {
342
+ new_choice_ids.push_back (VisitExpr (i));
343
+ continue ;
344
+ }
345
+ XGRAMMAR_ICHECK (choice_expr.type == RuleExprType::kSequence );
346
+ auto first_element = base_grammar_->GetRuleExpr (choice_expr[0 ]);
347
+ if (first_element.type != RuleExprType::kRuleRef ) {
348
+ new_choice_ids.push_back (VisitExpr (choice_expr));
349
+ continue ;
350
+ }
351
+ auto rule_ref_id = first_element[0 ];
352
+ if (can_rule_be_inlined_.count (rule_ref_id) == 0 ) {
353
+ can_rule_be_inlined_[rule_ref_id] = CheckIfRuleCanBeInlined (rule_ref_id);
354
+ }
355
+ if (!can_rule_be_inlined_[rule_ref_id]) {
356
+ new_choice_ids.push_back (VisitExpr (choice_expr));
357
+ continue ;
358
+ }
359
+
360
+ // Do inlining
361
+ std::vector<int32_t > other_elements;
362
+ for (int i = 1 ; i < choice_expr.size (); ++i) {
363
+ other_elements.push_back (VisitExpr (choice_expr[i]));
364
+ }
365
+
366
+ auto ref_rule = base_grammar_->GetRule (rule_ref_id);
367
+ auto ref_rule_expr = base_grammar_->GetRuleExpr (ref_rule.body_expr_id );
368
+
369
+ for (auto ref_choice_id : ref_rule_expr) {
370
+ auto ref_choice_expr = base_grammar_->GetRuleExpr (ref_choice_id);
371
+ XGRAMMAR_ICHECK (ref_choice_expr.type == RuleExprType::kSequence );
372
+ std::vector<int32_t > choice_to_add;
373
+ for (auto ref_element_id : ref_choice_expr) {
374
+ choice_to_add.push_back (VisitExpr (ref_element_id));
375
+ }
376
+ choice_to_add.insert (choice_to_add.end (), other_elements.begin (), other_elements.end ());
377
+ new_choice_ids.push_back (builder_.AddSequence (choice_to_add));
378
+ }
379
+ }
380
+ return builder_.AddChoices (new_choice_ids);
381
+ }
382
+
383
+ /* *
384
+ * The rule should be: a sequence of choices, cannot be empty, cannot refer to other rules
385
+ */
386
+ bool CheckIfRuleCanBeInlined (int32_t rule_id) {
387
+ auto rule = base_grammar_->GetRule (rule_id);
388
+ auto rule_expr = base_grammar_->GetRuleExpr (rule.body_expr_id );
389
+ if (rule_expr.type != RuleExprType::kChoices ) {
390
+ return false ;
391
+ }
392
+ if (rule_expr.size () == 0 ) {
393
+ return false ;
394
+ }
395
+ for (auto choice_id : rule_expr) {
396
+ auto choice_expr = base_grammar_->GetRuleExpr (choice_id);
397
+ if (choice_expr.type == RuleExprType::kEmptyStr ) {
398
+ return false ;
399
+ }
400
+ XGRAMMAR_ICHECK (choice_expr.type == RuleExprType::kSequence );
401
+ for (auto element_id : choice_expr) {
402
+ auto element_expr = base_grammar_->GetRuleExpr (element_id);
403
+ if (element_expr.type == RuleExprType::kRuleRef ) {
404
+ return false ;
405
+ }
406
+ }
407
+ }
408
+ return true ;
409
+ }
410
+
411
+ std::unordered_map<int32_t , bool > can_rule_be_inlined_;
412
+ };
413
+
414
+ /* !
415
+ * \brief Analyze all referenced rules or the main rule. Return a list of all referenced rule ids.
416
+ * This is useful for dead code elimination.
417
+ */
418
+ class UsedRulesAnalyzer : public GrammarVisitor <std::vector<int32_t >> {
419
+ public:
420
+ UsedRulesAnalyzer () = default ;
421
+
422
+ std::vector<int32_t > Apply (const Grammar& grammar) final {
423
+ base_grammar_ = grammar;
424
+
425
+ std::set<int32_t > visited;
426
+
427
+ std::queue<int32_t >().swap (visit_queue_);
428
+
429
+ visit_queue_.push (base_grammar_->GetRootRuleId ());
430
+ while (!visit_queue_.empty ()) {
431
+ auto rule_id = visit_queue_.front ();
432
+ visit_queue_.pop ();
433
+ if (visited.count (rule_id)) {
434
+ continue ;
435
+ }
436
+ visited.insert (rule_id);
437
+ auto rule = base_grammar_->GetRule (rule_id);
438
+ VisitExpr (rule.body_expr_id );
439
+ }
440
+
441
+ return std::vector<int32_t >(visited.begin (), visited.end ());
442
+ }
443
+
444
+ void VisitTagDispatch (const RuleExpr& rule_expr) {
445
+ for (int i = 0 ; i < rule_expr.size (); i += 2 ) {
446
+ visit_queue_.push (rule_expr[i + 1 ]);
447
+ }
448
+ }
449
+
450
+ void VisitRuleRef (const RuleExpr& rule_expr) { visit_queue_.push (rule_expr[0 ]); }
451
+
452
+ private:
453
+ std::queue<int32_t > visit_queue_;
454
+ };
455
+
456
+ class DeadCodeEliminatorImpl : public GrammarMutator {
457
+ public:
458
+ using GrammarMutator::Apply;
459
+ using GrammarMutator::GrammarMutator;
460
+
461
+ Grammar Apply (const Grammar& grammar) final {
462
+ Init (grammar);
463
+ auto used_rules = UsedRulesAnalyzer ().Apply (grammar);
464
+ rule_id_map_.clear ();
465
+ for (auto rule_id : used_rules) {
466
+ rule_id_map_[rule_id] = builder_.AddEmptyRule (grammar->GetRule (rule_id).name );
467
+ }
468
+ for (auto rule_id : used_rules) {
469
+ auto rule = grammar->GetRule (rule_id);
470
+ auto new_body_expr_id = VisitExpr (rule.body_expr_id );
471
+ builder_.UpdateRuleBody (rule_id_map_[rule_id], new_body_expr_id);
472
+ builder_.AddLookaheadAssertion (
473
+ rule_id_map_[rule_id], VisitLookaheadAssertion (rule.lookahead_assertion_id )
474
+ );
475
+ }
476
+ XGRAMMAR_CHECK (rule_id_map_.count (grammar->GetRootRuleId ()) > 0 );
477
+ return builder_.Get (rule_id_map_[grammar->GetRootRuleId ()]);
478
+ }
479
+
480
+ int32_t VisitTagDispatch (const RuleExpr& rule_expr) final {
481
+ std::vector<std::pair<int32_t , int32_t >> tag_dispatch_list;
482
+ for (int i = 0 ; i < rule_expr.size (); i += 2 ) {
483
+ XGRAMMAR_DCHECK (rule_id_map_.count (rule_expr[i + 1 ]) > 0 );
484
+ auto new_rule_id = rule_id_map_[rule_expr[i + 1 ]];
485
+ tag_dispatch_list.push_back ({VisitExpr (rule_expr[i]), new_rule_id});
486
+ }
487
+ return builder_.AddTagDispatch (tag_dispatch_list);
488
+ }
489
+
490
+ int32_t VisitRuleRef (const RuleExpr& rule_expr) final {
491
+ XGRAMMAR_DCHECK (rule_id_map_.count (rule_expr[0 ]) > 0 );
492
+ auto new_rule_id = rule_id_map_[rule_expr[0 ]];
493
+ return builder_.AddRuleRef (new_rule_id);
494
+ }
495
+
496
+ private:
497
+ std::unordered_map<int32_t , int32_t > rule_id_map_;
498
+ };
499
+
500
+ class LookaheadAssertionAnalyzerImpl : public GrammarMutator {
501
+ public:
502
+ using GrammarMutator::GrammarMutator;
503
+
504
+ Grammar Apply (const Grammar& grammar) final {
505
+ InitWithCopy (grammar);
506
+ auto root_rule = grammar->GetRootRule ();
507
+ auto root_rule_expr = base_grammar_->GetRuleExpr (root_rule.body_expr_id );
508
+ if (root_rule_expr.type == RuleExprType::kTagDispatch ) {
509
+ return grammar;
510
+ }
511
+ for (int i = 0 ; i < static_cast <int >(grammar->NumRules ()); ++i) {
512
+ auto rule = grammar->GetRule (i);
513
+ if (i == grammar->GetRootRuleId () || rule.lookahead_assertion_id != -1 ) {
514
+ continue ;
515
+ }
516
+ auto look_head_assertion_id = DetectLookaheadAssertion (i);
517
+ if (look_head_assertion_id != -1 ) {
518
+ builder_.AddLookaheadAssertion (i, look_head_assertion_id);
519
+ }
520
+ }
521
+ return builder_.Get (grammar->GetRootRuleId ());
522
+ }
523
+
524
+ int32_t DetectLookaheadAssertion (int32_t rule_id) {
525
+ std::vector<int32_t > found_sequence; // Element ids
526
+ bool found = false ;
527
+ for (int i = 0 ; i < static_cast <int >(base_grammar_->NumRules ()); ++i) {
528
+ auto rule = base_grammar_->GetRule (i);
529
+ auto rule_expr = base_grammar_->GetRuleExpr (rule.body_expr_id );
530
+ if (rule_expr.type == RuleExprType::kTagDispatch ) {
531
+ for (int j = 1 ; j < rule_expr.size (); j += 2 ) {
532
+ if (rule_expr[j] == rule_id) {
533
+ return -1 ;
534
+ }
535
+ }
536
+ continue ;
537
+ }
538
+ XGRAMMAR_DCHECK (rule_expr.type == RuleExprType::kChoices );
539
+ for (auto sequence_id : rule_expr) {
540
+ auto sequence_expr = base_grammar_->GetRuleExpr (sequence_id);
541
+ if (sequence_expr.type != RuleExprType::kSequence ) {
542
+ continue ;
543
+ }
544
+ auto last_element = base_grammar_->GetRuleExpr (sequence_expr.end ()[-1 ]);
545
+ if (last_element.type == RuleExprType::kRuleRef && last_element[0 ] == rule_id &&
546
+ i != rule_id) {
547
+ return -1 ;
548
+ }
549
+
550
+ for (int j = 0 ; j < sequence_expr.size () - 1 ; ++j) {
551
+ auto element_expr = base_grammar_->GetRuleExpr (sequence_expr[j]);
552
+ if (element_expr.type != RuleExprType::kRuleRef || element_expr[0 ] != rule_id) {
553
+ continue ;
554
+ }
555
+ if (found) {
556
+ return -1 ;
557
+ }
558
+ found = true ;
559
+ for (int k = j + 1 ; k < sequence_expr.size (); ++k) {
560
+ found_sequence.push_back (sequence_expr[k]);
561
+ }
562
+ }
563
+ }
564
+ }
565
+
566
+ if (!found) {
567
+ return -1 ;
568
+ }
569
+ return builder_.AddSequence (found_sequence);
570
+ }
571
+ };
572
+
319
573
/* !
320
574
* \brief A class that normalizes a grammar by applying a series of transformations.
321
575
*
@@ -341,9 +595,11 @@ class GrammarNormalizerImpl : public GrammarMutator {
341
595
// Return the list of all normalizers in the class. The normalizers are applied one by one.
342
596
std::vector<std::unique_ptr<GrammarMutator>> GetNormalizerList () {
343
597
std::vector<std::unique_ptr<GrammarMutator>> normalizer_mutators;
344
- normalizer_mutators.emplace_back (std::make_unique<SingleElementExprEliminator>());
345
- normalizer_mutators.emplace_back (std::make_unique<NestedRuleUnwrapper>());
346
- normalizer_mutators.emplace_back (std::make_unique<ByteStringFuser>());
598
+ normalizer_mutators.emplace_back (std::make_unique<StructureNormalizerImpl>());
599
+ normalizer_mutators.emplace_back (std::make_unique<ByteStringFuserImpl>());
600
+ normalizer_mutators.emplace_back (std::make_unique<RuleInlinerImpl>());
601
+ normalizer_mutators.emplace_back (std::make_unique<DeadCodeEliminatorImpl>());
602
+ normalizer_mutators.emplace_back (std::make_unique<LookaheadAssertionAnalyzerImpl>());
347
603
return normalizer_mutators;
348
604
}
349
605
};
@@ -515,8 +771,8 @@ class AllowEmptyRuleAnalyzerImpl : public GrammarVisitor<std::vector<int32_t>> {
515
771
std::unordered_set<int32_t > empty_rule_id_set;
516
772
FindExplicitEmptyRules (&empty_rule_id_set);
517
773
518
- // Step 2: Find rules that indirectly allow empty string. Using the Bellman-Ford algorithm on
519
- // the rule reference graph.
774
+ // Step 2: Find rules that indirectly allow empty string. Using the Bellman-Ford algorithm
775
+ // on the rule reference graph.
520
776
std::vector<std::vector<int32_t >> rule_ref_graph = RuleRefGraphFinder ().Apply (grammar);
521
777
FindIndirectEmptyRules (&empty_rule_id_set, rule_ref_graph);
522
778
@@ -703,4 +959,22 @@ Grammar StructuralTagGrammarCreator::Apply(
703
959
return StructuralTagGrammarCreatorImpl ().Apply (triggers, tag_groups);
704
960
}
705
961
962
+ Grammar RuleInliner::Apply (const Grammar& grammar) { return RuleInlinerImpl ().Apply (grammar); }
963
+
964
+ Grammar ByteStringFuser::Apply (const Grammar& grammar) {
965
+ return ByteStringFuserImpl ().Apply (grammar);
966
+ }
967
+
968
+ Grammar DeadCodeEliminator::Apply (const Grammar& grammar) {
969
+ return DeadCodeEliminatorImpl ().Apply (grammar);
970
+ }
971
+
972
+ Grammar StructureNormalizer::Apply (const Grammar& grammar) {
973
+ return StructureNormalizerImpl ().Apply (grammar);
974
+ }
975
+
976
+ Grammar LookaheadAssertionAnalyzer::Apply (const Grammar& grammar) {
977
+ return LookaheadAssertionAnalyzerImpl ().Apply (grammar);
978
+ }
979
+
706
980
} // namespace xgrammar
0 commit comments