diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 2d00cee05f0..6aea018bded 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -648,6 +648,14 @@ def release_batch(result: ScheduledRequests | None): return with contextlib.ExitStack() as stack: + + def clean_up_kv_cache(): + # Zero the KV cache; NaNs may be introduced during warmup + for layer_idx in kv_cache_manager.layer_offsets.keys(): + kv_cache_manager.get_buffers(layer_idx).zero_() + + stack.callback(clean_up_kv_cache) + if self._torch_compile_enabled: def disable_optimization(backend: Backend):