Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ extend_skip_glob = [
"tensorrt_llm/top_model_mixin.py",
"tests/unittest/_torch/modeling/test_modeling_mistral.py",
"tests/unittest/_torch/modeling/test_modeling_pixtral.py",
"tests/unittest/_torch/models/checkpoints/hf/test_weight_loader.py",
]

[tool.yapf]
Expand Down Expand Up @@ -63,6 +64,7 @@ ignore_patterns = [
"tensorrt_llm/top_model_mixin.py",
"tests/unittest/_torch/modeling/test_modeling_mistral.py",
"tests/unittest/_torch/modeling/test_modeling_pixtral.py",
"tests/unittest/_torch/models/checkpoints/hf/test_weight_loader.py",
]

[tool.codespell]
Expand Down Expand Up @@ -97,6 +99,7 @@ exclude = [
"tensorrt_llm/top_model_mixin.py",
"tests/unittest/_torch/modeling/test_modeling_mistral.py",
"tests/unittest/_torch/modeling/test_modeling_pixtral.py",
"tests/unittest/_torch/models/checkpoints/hf/test_weight_loader.py",
]


Expand Down Expand Up @@ -140,6 +143,7 @@ include = [
"tensorrt_llm/top_model_mixin.py",
"tests/unittest/_torch/modeling/test_modeling_mistral.py",
"tests/unittest/_torch/modeling/test_modeling_pixtral.py",
"tests/unittest/_torch/models/checkpoints/hf/test_weight_loader.py",
]
exclude = [
"**3rdparty/**",
Expand Down
8 changes: 8 additions & 0 deletions tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@ class HfWeightLoader(BaseWeightLoader):

def load_weights(self, checkpoint_dir: str) -> dict[str, Any]:
weight_files = glob.glob(f"{checkpoint_dir}/*.safetensors")
# Some model checkpoint directories contain not only the sharded safetensors, but one
# consolidated tensor. In the presence of both, we favor the former, as there really is no need
# to prefetch the (usually) ridiculously large consolidated tensor into memory in such a case.
filtered_weight_files = [
x for x in weight_files if "consolidated" not in os.path.split(x)[1]
]
if len(filtered_weight_files) > 0:
weight_files = filtered_weight_files
if weight_files:
# Prefetch the weight files to CPU memory if the size is less than 90% of the available memory.
# This is a heuristic to avoid prefetching files that are too large and causing file cache thrashing.
Expand Down
3 changes: 3 additions & 0 deletions tests/integration/test_lists/test-db/l0_a10.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ l0_a10:
# ------------- PyTorch tests ---------------
- unittest/_torch/modeling/test_modeling_mistral.py
- unittest/_torch/modeling/test_modeling_pixtral.py
# NOTE: this is a CPU-only test, but we do not have a dedicated job for this (and therefore no
# test list either).
- unittest/_torch/models/checkpoints/hf/test_weight_loader.py
- disaggregated/test_disaggregated.py::test_disaggregated_single_gpu_with_mpirun[TinyLlama-1.1B-Chat-v1.0]
- disaggregated/test_disaggregated.py::test_disaggregated_single_gpu_with_mpirun_trt_backend[TinyLlama-1.1B-Chat-v1.0]
- disaggregated/test_disaggregated.py::test_disaggregated_cuda_graph[TinyLlama-1.1B-Chat-v1.0]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from unittest import mock

import pytest

from tensorrt_llm._torch.models.checkpoints import HfWeightLoader


class MyError(Exception):
pass


@pytest.mark.parametrize(
"dir_name, safetensor_filenames, expected_safetensor_filenames",
[
(
"foo",
[
"model-00001-of-00002.safetensors",
"model-000002-of-00002.safetensors",
"consolidated.safetensors",
],
["model-00001-of-00002.safetensors", "model-000002-of-00002.safetensors"],
),
(
"foo",
[
*(f"model-0000{i}-of-00010.safetensors" for i in range(1, 11)),
"foo-consolidated.safetensors",
],
[f"model-0000{i}-of-00010.safetensors" for i in range(1, 11)],
),
# If there is only a consolidated safetensor, that one should still be used.
(
"foo",
["consolidated.safetensors"],
["consolidated.safetensors"],
),
# If the directory contains "consolidated" in its name, but its contents are sharded tensors.
(
"consolidated-model",
[
"model-00001-of-00002.safetensors",
"model-000002-of-00002.safetensors",
"consolidated.safetensors",
],
["model-00001-of-00002.safetensors", "model-000002-of-00002.safetensors"],
),
],
)
def test_load_weights_ignores_consolidated_ckpt_when_sharded_ckpt_exists(
tmp_path,
dir_name: str,
safetensor_filenames: list[str],
expected_safetensor_filenames: list[str],
):
checkpoint_dir = tmp_path / dir_name
checkpoint_dir.mkdir()
for filename in safetensor_filenames:
(checkpoint_dir / filename).touch()
expected_safetensor_filenames = set(
str(checkpoint_dir / filename) for filename in expected_safetensor_filenames
)

loader = HfWeightLoader()
with (
mock.patch.object(
loader, "_load_weights_in_parallel", side_effect=MyError
) as load_weights_in_parallel,
mock.patch.object(loader, "prefetch_files") as prefetch_files,
pytest.raises(MyError),
):
loader.load_weights(checkpoint_dir=str(checkpoint_dir))

prefetch_files.assert_called_once()
prefetched_files = prefetch_files.call_args[0][0]
assert set(prefetched_files) == expected_safetensor_filenames

load_weights_in_parallel.assert_called_once()
loaded_weight_files = load_weights_in_parallel.call_args[0][0]
assert set(loaded_weight_files) == expected_safetensor_filenames