@@ -478,6 +478,82 @@ def update_quant_config(self, new_quant_config: Optional[QuantConfig]):
478478 self .has_fp8_kv_cache = self .quant_config .layer_quant_mode .has_fp8_kv_cache (
479479 )
480480
481+ def forward_impl (
482+ self ,
483+ q : torch .Tensor ,
484+ k : torch .Tensor ,
485+ v : torch .Tensor ,
486+ metadata : FlashInferAttentionMetadata ,
487+ attention_mask_type : int ,
488+ output : torch .Tensor ,
489+ attention_mask_data : Optional [torch .Tensor ] = None ,
490+ attention_window_size : Optional [int ] = None ,
491+ ) -> None :
492+ # Query
493+ q = q .view (- 1 , self .num_heads , self .head_dim )
494+
495+ # Key and Value
496+ kv_cache = metadata .kv_cache_manager .get_buffers (self .layer_idx )
497+
498+ if k is not None and v is not None :
499+ k = k .view (- 1 , self .num_kv_heads , self .head_dim )
500+ v = v .view (- 1 , self .num_kv_heads , self .head_dim )
501+
502+ if self .has_fp8_kv_cache :
503+ assert kv_cache .dtype == torch .float8_e4m3fn , (
504+ f"KV cache should have fp8 dtype, but get { kv_cache .dtype } " )
505+ k = k .to (torch .float8_e4m3fn )
506+ v = v .to (torch .float8_e4m3fn )
507+ assert k .dtype == v .dtype == kv_cache .dtype , (
508+ f"KV cache dtype { kv_cache .dtype } does not match k/v dtype { k .dtype } /{ v .dtype } "
509+ )
510+
511+ flashinfer .page .append_paged_kv_cache (
512+ append_key = k ,
513+ append_value = v ,
514+ batch_indices = metadata .batch_indices ,
515+ positions = metadata .positions ,
516+ paged_kv_cache = kv_cache ,
517+ kv_indices = metadata .paged_kv_indices ,
518+ kv_indptr = metadata .paged_kv_indptr ,
519+ kv_last_page_len = metadata .paged_kv_last_page_len ,
520+ kv_layout = metadata .kv_layout )
521+
522+ num_contexts = metadata .num_contexts
523+ num_generations = metadata .num_generations
524+ num_ctx_tokens = metadata .num_ctx_tokens
525+
526+ def prefill_forward (plan_params : PlanParams , out : torch .Tensor ):
527+ wrapper = metadata .get_prefill_wrapper (plan_params )
528+ wrapper .run (q [:num_ctx_tokens ],
529+ kv_cache ,
530+ out = out .view (- 1 , self .num_heads , self .head_dim ))
531+
532+ def decode_forward (plan_params : PlanParams , out : torch .Tensor ):
533+ wrapper = metadata .get_decode_wrapper (plan_params )
534+ wrapper .run (q [num_ctx_tokens :],
535+ kv_cache ,
536+ out = out .view (- 1 , self .num_heads , self .head_dim ))
537+
538+ # this will do nothing if the last forward pass had the same parameters
539+ plan_params = metadata .plan (self .num_heads ,
540+ self .num_kv_heads ,
541+ self .head_dim ,
542+ q_dtype = q .dtype ,
543+ kv_dtype = kv_cache .dtype ,
544+ q_scaling = self .q_scaling ,
545+ attention_window_size = attention_window_size ,
546+ attention_mask_type = attention_mask_type ,
547+ attention_mask_data = attention_mask_data )
548+
549+ if num_contexts == 0 :
550+ decode_forward (plan_params , output )
551+ elif num_generations == 0 :
552+ prefill_forward (plan_params , output )
553+ else :
554+ prefill_forward (plan_params , output [:num_ctx_tokens , :])
555+ decode_forward (plan_params , output [num_ctx_tokens :, :])
556+
481557 def forward (self ,
482558 q : torch .Tensor ,
483559 k : Optional [torch .Tensor ],
@@ -487,6 +563,7 @@ def forward(self,
487563 attention_window_size : Optional [int ] = None ,
488564 attention_mask : AttentionMask = PredefinedAttentionMask .CAUSAL ,
489565 attention_mask_data : Optional [torch .Tensor ] = None ,
566+ output : Optional [torch .Tensor ] = None ,
490567 ** kwargs ) -> torch .Tensor :
491568 if attention_mask == CustomAttentionMask .CUSTOM :
492569 assert attention_mask_data is not None , "attention_mask_data is required for custom attention mask."
@@ -502,133 +579,15 @@ def forward(self,
502579 else :
503580 raise ValueError ("Unexpected attention mask type" )
504581
505- return forward_pattern (q = q ,
506- k = k ,
507- v = v ,
508- num_heads = self .num_heads ,
509- head_dim = self .head_dim ,
510- num_kv_heads = self .num_kv_heads ,
511- layer_idx = self .layer_idx ,
512- has_fp8_kv_cache = self .has_fp8_kv_cache ,
513- attention_mask_type = attention_mask_type ,
514- q_scaling = self .q_scaling ,
515- attention_mask_data = attention_mask_data ,
516- attention_window_size = attention_window_size )
517-
518-
519- @torch .library .custom_op ("trtllm::flashinfer_forward" , mutates_args = ())
520- def forward_pattern (
521- q : torch .Tensor ,
522- k : torch .Tensor ,
523- v : torch .Tensor ,
524- num_heads : int ,
525- head_dim : int ,
526- num_kv_heads : int ,
527- layer_idx : int ,
528- has_fp8_kv_cache : bool ,
529- attention_mask_type : int ,
530- q_scaling : Optional [float ] = None ,
531- attention_mask_data : Optional [torch .Tensor ] = None ,
532- attention_window_size : Optional [int ] = None ,
533- ) -> torch .Tensor :
534- '''
535- Wrapping the flashinfer forward as a custom op is required to fix `torch.compile` graph breaks,
536- otherwise it will graph break when calling `metadata.num_contexts` since it convert tensor's sum directly to int.
537- '''
538- # torch.compile does not support custom object as arguments, so we have to use global function to get the metadata.
539- extra_attrs = get_model_extra_attrs ()
540- if extra_attrs is not None :
541- metadata_ref = extra_attrs .get ("attention_metadata" , None )
542- metadata = metadata_ref () if metadata_ref is not None else None
543- else :
544- metadata = get_global_attrs ().attention_metadata ()
545-
546- assert isinstance (
547- metadata ,
548- FlashInferAttentionMetadata ,
549- )
550-
551- # Query
552- q = q .view (- 1 , num_heads , head_dim )
553-
554- # Key and Value
555- kv_cache = metadata .kv_cache_manager .get_buffers (layer_idx )
556-
557- if k is not None and v is not None :
558- k = k .view (- 1 , num_kv_heads , head_dim )
559- v = v .view (- 1 , num_kv_heads , head_dim )
560-
561- if has_fp8_kv_cache :
562- assert kv_cache .dtype == torch .float8_e4m3fn , f"KV cache should have fp8 dtype, but get { kv_cache .dtype } "
563- k = k .to (torch .float8_e4m3fn )
564- v = v .to (torch .float8_e4m3fn )
565- assert k .dtype == v .dtype == kv_cache .dtype , f"KV cache dtype { kv_cache .dtype } does not match k/v dtype { k .dtype } /{ v .dtype } "
566-
567- flashinfer .page .append_paged_kv_cache (
568- append_key = k ,
569- append_value = v ,
570- batch_indices = metadata .batch_indices ,
571- positions = metadata .positions ,
572- paged_kv_cache = kv_cache ,
573- kv_indices = metadata .paged_kv_indices ,
574- kv_indptr = metadata .paged_kv_indptr ,
575- kv_last_page_len = metadata .paged_kv_last_page_len ,
576- kv_layout = metadata .kv_layout )
577-
578- num_contexts = metadata .num_contexts
579- num_generations = metadata .num_generations
580- num_ctx_tokens = metadata .num_ctx_tokens
581-
582- def prefill_forward (plan_params : PlanParams ):
583- wrapper = metadata .get_prefill_wrapper (plan_params )
584- output = wrapper .run (q [:num_ctx_tokens ], kv_cache )
585- output = output .view (num_ctx_tokens , - 1 )
586- return output
587-
588- def decode_forward (plan_params : PlanParams ):
589- wrapper = metadata .get_decode_wrapper (plan_params )
590- output = wrapper .run (q [num_ctx_tokens :], kv_cache )
591- output = output .view (num_generations , - 1 )
582+ if output is None :
583+ output = torch .empty_like (q )
584+
585+ self .forward_impl (q = q ,
586+ k = k ,
587+ v = v ,
588+ metadata = metadata ,
589+ attention_mask_type = attention_mask_type ,
590+ attention_mask_data = attention_mask_data ,
591+ attention_window_size = attention_window_size ,
592+ output = output )
592593 return output
593-
594- # this will do nothing if the last forward pass had the same parameters
595- plan_params = metadata .plan (num_heads ,
596- num_kv_heads ,
597- head_dim ,
598- q_dtype = q .dtype ,
599- kv_dtype = kv_cache .dtype ,
600- q_scaling = q_scaling ,
601- attention_window_size = attention_window_size ,
602- attention_mask_type = attention_mask_type ,
603- attention_mask_data = attention_mask_data )
604-
605- if num_contexts > 0 :
606- ctx_output = prefill_forward (plan_params )
607-
608- if num_generations > 0 :
609- gen_output = decode_forward (plan_params )
610-
611- if num_contexts > 0 and num_generations > 0 :
612- output = torch .cat ([ctx_output , gen_output ], dim = 0 )
613- elif num_contexts > 0 :
614- output = ctx_output
615- elif num_generations > 0 :
616- output = gen_output
617-
618- return output
619-
620-
621- @forward_pattern .register_fake
622- def _ (
623- q : torch .Tensor ,
624- k : torch .Tensor ,
625- v : torch .Tensor ,
626- num_heads : int ,
627- head_dim : int ,
628- num_kv_heads : int ,
629- layer_idx : int ,
630- has_fp8_kv_cache : bool ,
631- attention_mask_type : int ,
632- attention_mask_data : Optional [torch .Tensor ],
633- ):
634- return torch .empty_like (q )
0 commit comments