@@ -600,9 +600,9 @@ def host_kv_cache_pool_mapping(self) -> Optional[torch.Tensor]:
600600
601601 def __post_init__ (self ) -> None :
602602 super ().__post_init__ ()
603- self .__post_init_with_buffers (self .cuda_graph_buffers )
603+ self ._post_init_with_buffers (self .cuda_graph_buffers )
604604
605- def __post_init_with_buffers (self , buffers ) -> None :
605+ def _post_init_with_buffers (self , buffers ) -> None :
606606
607607 # Set a default value, as max_num_sequences is not always set.
608608 if self .max_num_sequences is None :
@@ -624,8 +624,6 @@ def get_empty(tensor_shape: list[int], dtype: torch.dtype,
624624 Args:
625625 tensor_shape: The required shape.
626626 dtype: The required dtype.
627- buffers: A dictionary mapping cache names to lists of buffer tensors.
628- Can be `None` or empty.
629627 cache_name: The key for the specific list of buffers to search in.
630628
631629 Returns:
@@ -652,20 +650,26 @@ def get_empty(tensor_shape: list[int], dtype: torch.dtype,
652650
653651 def get_empty_like (like_tensor : torch .Tensor ,
654652 cache_name : str ) -> torch .Tensor :
655- return get_empty (like_tensor .shape ,
656- cache_name = cache_name ,
657- dtype = like_tensor .dtype )
653+ return get_empty (
654+ like_tensor .shape ,
655+ cache_name = cache_name ,
656+ dtype = like_tensor .dtype ,
657+ )
658658
659- self .prompt_lens_cuda = get_empty ((self .max_num_sequences , ),
660- cache_name = "prompt_lens_cuda" ,
661- dtype = torch .int )
659+ self .prompt_lens_cuda = get_empty (
660+ (self .max_num_sequences , ),
661+ cache_name = "prompt_lens_cuda" ,
662+ dtype = torch .int ,
663+ )
662664 self .prompt_lens_cpu = torch .empty_like (
663665 self .prompt_lens_cuda ,
664666 device = 'cpu' ,
665667 pin_memory = True ,
666668 )
667- self .kv_lens_cuda = get_empty_like (self .prompt_lens_cuda ,
668- cache_name = "kv_lens_cuda" )
669+ self .kv_lens_cuda = get_empty_like (
670+ self .prompt_lens_cuda ,
671+ cache_name = "kv_lens_cuda" ,
672+ )
669673 self .kv_lens = torch .empty_like (self .kv_lens_cuda ,
670674 device = 'cpu' ,
671675 pin_memory = True )
@@ -685,7 +689,8 @@ def get_empty_like(like_tensor: torch.Tensor,
685689 self .kv_cache_manager .max_blocks_per_seq
686690 ],
687691 cache_name = "kv_cache_block_offsets" ,
688- dtype = torch .int32 )
692+ dtype = torch .int32 ,
693+ )
689694 self .host_kv_cache_block_offsets = torch .empty_like (
690695 self .kv_cache_block_offsets ,
691696 device = 'cpu' ,
@@ -700,20 +705,23 @@ def get_empty_like(like_tensor: torch.Tensor,
700705 self .kv_cache_manager .max_blocks_per_seq
701706 ],
702707 cache_name = "block_ids_per_seq" ,
703- dtype = torch .int32 )
708+ dtype = torch .int32 ,
709+ )
704710 self .kv_block_ids_per_seq = get_empty (
705711 [
706712 self .kv_cache_manager .max_batch_size ,
707713 self .kv_cache_manager .max_blocks_per_seq
708714 ],
709715 cache_name = "kv_block_ids_per_seq" ,
710- dtype = torch .int32 )
716+ dtype = torch .int32 ,
717+ )
711718 if self .enable_paged_context_mla :
712719 # for kv cache reuse/chunked context in MLA
713720 self .ctx_cached_token_indptr = get_empty (
714721 (self .max_num_requests + 1 , ),
715722 cache_name = "ctx_cached_token_indptr" ,
716- dtype = torch .int64 )
723+ dtype = torch .int64 ,
724+ )
717725 self .host_ctx_cached_token_indptr = torch .zeros_like (
718726 self .ctx_cached_token_indptr ,
719727 device = 'cpu' ,
@@ -722,16 +730,19 @@ def get_empty_like(like_tensor: torch.Tensor,
722730 self .ctx_uncached_token_indptr = get_empty (
723731 (self .max_num_requests + 1 , ),
724732 cache_name = "ctx_uncached_token_indptr" ,
725- dtype = torch .int64 )
733+ dtype = torch .int64 ,
734+ )
726735 self .host_ctx_uncached_token_indptr = torch .zeros_like (
727736 self .ctx_uncached_token_indptr ,
728737 device = 'cpu' ,
729738 pin_memory = True ,
730739 )
731740 # context full seqlens include cached tokens and uncached tokens
732- self .ctx_kv_indptr = get_empty ((self .max_num_requests + 1 , ),
733- cache_name = "ctx_kv_indptr" ,
734- dtype = torch .int64 )
741+ self .ctx_kv_indptr = get_empty (
742+ (self .max_num_requests + 1 , ),
743+ cache_name = "ctx_kv_indptr" ,
744+ dtype = torch .int64 ,
745+ )
735746 self .host_ctx_kv_indptr = torch .zeros_like (
736747 self .ctx_kv_indptr ,
737748 device = 'cpu' ,
0 commit comments