Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 18 additions & 12 deletions tensorrt_llm/bench/dataclasses/reporting.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,22 @@ def get_statistics_dict(self) -> Dict[str, Any]:
},
}

# Retrieve KV cache information.
kv_cache_config = self.kwargs.get("kv_cache_config", KvCacheConfig())
if isinstance(kv_cache_config, KvCacheConfig):
kv_cache_dtype = kv_cache_config.dtype
kv_cache_mem_percent = kv_cache_config.free_gpu_memory_fraction
elif isinstance(kv_cache_config, dict):
kv_cache_dtype = kv_cache_config.get("dtype", "auto")
kv_cache_mem_percent = kv_cache_config.get(
"free_gpu_memory_fraction")
else:
raise ValueError(
f"Invalid kv_cache_config type: {type(kv_cache_config)}.")

kv_cache_mem_percent = f"{kv_cache_mem_percent * 100.0:.2f}%" \
if kv_cache_mem_percent is not None else "None"

# Engine/Backend details
if self.rt_cfg.backend not in ('pytorch', '_autodeploy'):
config_path = self.rt_cfg.engine_dir / "config.json"
Expand Down Expand Up @@ -302,15 +318,6 @@ def get_statistics_dict(self) -> Dict[str, Any]:
model = self.rt_cfg.model_path or self.rt_cfg.model
model_config = ModelConfig.from_pretrained(model,
trust_remote_code=True)
kv_cache_config = self.kwargs.get("kv_cache_config",
KvCacheConfig())
if isinstance(kv_cache_config, KvCacheConfig):
kv_cache_dtype = kv_cache_config.dtype
elif isinstance(kv_cache_config, dict):
kv_cache_dtype = kv_cache_config.get("dtype", "auto")
else:
raise ValueError(
f"Invalid kv_cache_config type: {type(kv_cache_config)}.")

validate_and_set_kv_cache_quant(model_config, kv_cache_dtype)

Expand All @@ -336,8 +343,7 @@ def get_statistics_dict(self) -> Dict[str, Any]:
"max_batch_size": self.rt_cfg.settings_config.max_batch_size,
"max_num_tokens": self.rt_cfg.settings_config.max_num_tokens,
"scheduling_policy": self.rt_cfg.settings_config.scheduler_policy,
"kv_cache_percentage":
self.rt_cfg.settings_config.kv_cache_percent * 100.0,
"kv_cache_percentage": kv_cache_mem_percent,
"issue_rate": self.convert_rate_to_s(self.statistics.issue_rate_ns)
}

Expand Down Expand Up @@ -526,7 +532,7 @@ def report_statistics(self) -> None:
f"Max Runtime Batch Size: {world_info['max_batch_size']}\n"
f"Max Runtime Tokens: {world_info['max_num_tokens']}\n"
f"Scheduling Policy: {world_info['scheduling_policy']}\n"
f"KV Memory Percentage: {world_info['kv_cache_percentage']:.2f}%\n"
f"KV Memory Percentage: {world_info['kv_cache_percentage']}\n"
f"Issue Rate (req/sec): {world_info['issue_rate']:.4E}\n"
f"\n")

Expand Down