Skip to content

Commit 1e92f00

Browse files
committed
Fix formats and address comments
Signed-off-by: Hui Gao <[email protected]>
1 parent 8de6493 commit 1e92f00

File tree

2 files changed

+31
-21
lines changed

2 files changed

+31
-21
lines changed

tensorrt_llm/_torch/attention_backend/trtllm.py

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -600,9 +600,9 @@ def host_kv_cache_pool_mapping(self) -> Optional[torch.Tensor]:
600600

601601
def __post_init__(self) -> None:
602602
super().__post_init__()
603-
self.__post_init_with_buffers(self.cuda_graph_buffers)
603+
self._post_init_with_buffers(self.cuda_graph_buffers)
604604

605-
def __post_init_with_buffers(self, buffers) -> None:
605+
def _post_init_with_buffers(self, buffers) -> None:
606606

607607
# Set a default value, as max_num_sequences is not always set.
608608
if self.max_num_sequences is None:
@@ -624,8 +624,6 @@ def get_empty(tensor_shape: list[int], dtype: torch.dtype,
624624
Args:
625625
tensor_shape: The required shape.
626626
dtype: The required dtype.
627-
buffers: A dictionary mapping cache names to lists of buffer tensors.
628-
Can be `None` or empty.
629627
cache_name: The key for the specific list of buffers to search in.
630628
631629
Returns:
@@ -652,20 +650,26 @@ def get_empty(tensor_shape: list[int], dtype: torch.dtype,
652650

653651
def get_empty_like(like_tensor: torch.Tensor,
654652
cache_name: str) -> torch.Tensor:
655-
return get_empty(like_tensor.shape,
656-
cache_name=cache_name,
657-
dtype=like_tensor.dtype)
653+
return get_empty(
654+
like_tensor.shape,
655+
cache_name=cache_name,
656+
dtype=like_tensor.dtype,
657+
)
658658

659-
self.prompt_lens_cuda = get_empty((self.max_num_sequences, ),
660-
cache_name="prompt_lens_cuda",
661-
dtype=torch.int)
659+
self.prompt_lens_cuda = get_empty(
660+
(self.max_num_sequences, ),
661+
cache_name="prompt_lens_cuda",
662+
dtype=torch.int,
663+
)
662664
self.prompt_lens_cpu = torch.empty_like(
663665
self.prompt_lens_cuda,
664666
device='cpu',
665667
pin_memory=True,
666668
)
667-
self.kv_lens_cuda = get_empty_like(self.prompt_lens_cuda,
668-
cache_name="kv_lens_cuda")
669+
self.kv_lens_cuda = get_empty_like(
670+
self.prompt_lens_cuda,
671+
cache_name="kv_lens_cuda",
672+
)
669673
self.kv_lens = torch.empty_like(self.kv_lens_cuda,
670674
device='cpu',
671675
pin_memory=True)
@@ -685,7 +689,8 @@ def get_empty_like(like_tensor: torch.Tensor,
685689
self.kv_cache_manager.max_blocks_per_seq
686690
],
687691
cache_name="kv_cache_block_offsets",
688-
dtype=torch.int32)
692+
dtype=torch.int32,
693+
)
689694
self.host_kv_cache_block_offsets = torch.empty_like(
690695
self.kv_cache_block_offsets,
691696
device='cpu',
@@ -700,20 +705,23 @@ def get_empty_like(like_tensor: torch.Tensor,
700705
self.kv_cache_manager.max_blocks_per_seq
701706
],
702707
cache_name="block_ids_per_seq",
703-
dtype=torch.int32)
708+
dtype=torch.int32,
709+
)
704710
self.kv_block_ids_per_seq = get_empty(
705711
[
706712
self.kv_cache_manager.max_batch_size,
707713
self.kv_cache_manager.max_blocks_per_seq
708714
],
709715
cache_name="kv_block_ids_per_seq",
710-
dtype=torch.int32)
716+
dtype=torch.int32,
717+
)
711718
if self.enable_paged_context_mla:
712719
# for kv cache reuse/chunked context in MLA
713720
self.ctx_cached_token_indptr = get_empty(
714721
(self.max_num_requests + 1, ),
715722
cache_name="ctx_cached_token_indptr",
716-
dtype=torch.int64)
723+
dtype=torch.int64,
724+
)
717725
self.host_ctx_cached_token_indptr = torch.zeros_like(
718726
self.ctx_cached_token_indptr,
719727
device='cpu',
@@ -722,16 +730,19 @@ def get_empty_like(like_tensor: torch.Tensor,
722730
self.ctx_uncached_token_indptr = get_empty(
723731
(self.max_num_requests + 1, ),
724732
cache_name="ctx_uncached_token_indptr",
725-
dtype=torch.int64)
733+
dtype=torch.int64,
734+
)
726735
self.host_ctx_uncached_token_indptr = torch.zeros_like(
727736
self.ctx_uncached_token_indptr,
728737
device='cpu',
729738
pin_memory=True,
730739
)
731740
# context full seqlens include cached tokens and uncached tokens
732-
self.ctx_kv_indptr = get_empty((self.max_num_requests + 1, ),
733-
cache_name="ctx_kv_indptr",
734-
dtype=torch.int64)
741+
self.ctx_kv_indptr = get_empty(
742+
(self.max_num_requests + 1, ),
743+
cache_name="ctx_kv_indptr",
744+
dtype=torch.int64,
745+
)
735746
self.host_ctx_kv_indptr = torch.zeros_like(
736747
self.ctx_kv_indptr,
737748
device='cpu',

tests/integration/test_lists/waives.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,6 @@ full:GB200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen1.5_7b
278278
full:GB200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen2_7b_instruct-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5247837)
279279
full:GB200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen2_vl_7b_instruct-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5359696)
280280
full:GB200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen2.5_7b_chat-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5247837)
281-
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5410391)
282281
accuracy/test_llm_api.py::TestMistral_Nemo_12B_Base::test_fp8 SKIP (https://nvbugs/5413197)
283282
accuracy/test_cli_flow.py::TestLlama3_8BInstructGradient1048k::test_long_context_ppl SKIP (https://nvbugs/5413362)
284283
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_tp8] SKIP (https://nvbugs/5455140)

0 commit comments

Comments
 (0)