Skip to content

Commit 8a68879

Browse files
committed
device count
1 parent 4a1560d commit 8a68879

2 files changed

Lines changed: 27 additions & 15 deletions

File tree

src/transformers/modeling_utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4304,7 +4304,26 @@ def from_pretrained(
43044304

43054305
# Prepare the full device map
43064306
if device_map is not None:
4307+
import psutil
4308+
from transformers.testing_utils import _cpu_memory_patch_info
4309+
_vm = psutil.virtual_memory()
4310+
_patch_info = (
4311+
f", real_total={_cpu_memory_patch_info['real_total_gb']:.2f} GB"
4312+
f", real_available={_cpu_memory_patch_info['real_available_gb']:.2f} GB"
4313+
f", capped_to={_cpu_memory_patch_info['limit_bytes'] / 1024**3:.2f} GB"
4314+
if _cpu_memory_patch_info else ""
4315+
)
4316+
logger.warning(
4317+
f"[device_map] Before _get_device_map — "
4318+
f"device_map={device_map!r}, "
4319+
f"max_memory={max_memory!r}, "
4320+
f"CPU RAM total={_vm.total / 1024**3:.2f} GB (as seen by psutil){_patch_info}"
4321+
)
43074322
device_map = _get_device_map(model, device_map, max_memory, hf_quantizer)
4323+
logger.warning(
4324+
f"[device_map] After _get_device_map — "
4325+
f"device_map={device_map!r}"
4326+
)
43084327

43094328
# Finalize model weight initialization
43104329
load_config = LoadStateDictConfig(

src/transformers/testing_utils.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3430,6 +3430,10 @@ def patched(*args, **kwargs):
34303430
torch.compile = patched
34313431

34323432

3433+
_cpu_memory_patch_info: dict = {}
3434+
"""Stores pre-patch memory stats set by patch_psutil_cpu_memory, for later logging during test execution."""
3435+
3436+
34333437
def patch_psutil_cpu_memory(limit_bytes: int):
34343438
"""
34353439
Patch `psutil.virtual_memory` to cap the reported CPU memory to `limit_bytes`.
@@ -3441,15 +3445,11 @@ def patch_psutil_cpu_memory(limit_bytes: int):
34413445
"""
34423446
import psutil
34433447

3444-
mem_before = psutil.virtual_memory()
3445-
logger.warning(
3446-
f"[patch_psutil_cpu_memory] Before patching — "
3447-
f"total={mem_before.total / 1024**3:.2f} GB, "
3448-
f"available={mem_before.available / 1024**3:.2f} GB, "
3449-
f"limit_bytes={limit_bytes / 1024**3:.2f} GB"
3450-
)
3451-
34523448
_original_virtual_memory = psutil.virtual_memory
3449+
_real_mem = _original_virtual_memory()
3450+
_cpu_memory_patch_info["limit_bytes"] = limit_bytes
3451+
_cpu_memory_patch_info["real_total_gb"] = _real_mem.total / 1024**3
3452+
_cpu_memory_patch_info["real_available_gb"] = _real_mem.available / 1024**3
34533453

34543454
def _capped_virtual_memory():
34553455
mem = _original_virtual_memory()
@@ -3461,13 +3461,6 @@ def _capped_virtual_memory():
34613461

34623462
psutil.virtual_memory = _capped_virtual_memory
34633463

3464-
mem_after = psutil.virtual_memory()
3465-
logger.warning(
3466-
f"[patch_psutil_cpu_memory] After patching — "
3467-
f"total={mem_after.total / 1024**3:.2f} GB, "
3468-
f"available={mem_after.available / 1024**3:.2f} GB"
3469-
)
3470-
34713464

34723465
def _get_test_info():
34733466
"""

0 commit comments

Comments
 (0)