Skip to content

Commit 98f803c

Browse files
authored
Allow offload_modules to handle single module and container offloading cases (#51)
* fix: handle single module in offload_modules Fix TypeError when passing a single module (e.g., CrossLayerTranscoder) to offload_modules instead of a list. Now properly handles single modules, lists, and PyTorch container types (ModuleList, ModuleDict, Sequential). * minor aesthetic change, slight simplification of logic
1 parent 744956f commit 98f803c

File tree

2 files changed

+228
-4
lines changed

2 files changed

+228
-4
lines changed

circuit_tracer/utils/disk_offload.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import tempfile
44
from typing import Literal
55

6+
from torch import nn
67
from safetensors.torch import load_file, save_file
78

89
_offload_files = set()
@@ -13,7 +14,8 @@
1314
@atexit.register
1415
def cleanup_offload_files():
1516
for f in _offload_files:
16-
os.remove(f)
17+
if os.path.exists(f):
18+
os.remove(f)
1719

1820

1921
def cleanup_all_offload_files():
@@ -35,7 +37,8 @@ def disk_offload_module(module):
3537
module.to(device="meta")
3638

3739
def reload_handle(device=None):
38-
module.load_state_dict(load_file(f.name, device=(device or str(org_device))), assign=True)
40+
target_device = str(device or org_device)
41+
module.load_state_dict(load_file(f.name, device=target_device), assign=True)
3942
os.remove(f.name)
4043
_offload_files.remove(f.name)
4144

@@ -52,6 +55,26 @@ def reload_handle():
5255
return reload_handle
5356

5457

55-
def offload_modules(modules, offload_type: Literal["cpu", "disk"]):
58+
def offload_modules(
59+
modules: list | nn.Module | nn.ModuleList | nn.ModuleDict | nn.Sequential,
60+
offload_type: Literal["cpu", "disk"],
61+
) -> list:
62+
"""Offload one or more modules to CPU or disk.
63+
64+
Args:
65+
modules: A single module, list of modules, or PyTorch module container
66+
(ModuleList, ModuleDict, Sequential)
67+
offload_type: Type of offload - "cpu" or "disk"
68+
69+
Returns:
70+
List of reload handles, one per module
71+
"""
5672
offload_fn = disk_offload_module if offload_type == "disk" else cpu_offload_module
57-
return [offload_fn(module) for module in modules]
73+
74+
if isinstance(modules, nn.ModuleDict):
75+
mods = modules.values()
76+
elif isinstance(modules, (list, nn.ModuleList, nn.Sequential)):
77+
mods = modules
78+
else:
79+
mods = [modules]
80+
return [offload_fn(module) for module in mods]

tests/utils/test_disk_offload.py

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
"""Tests for disk_offload module functions."""
2+
3+
import pytest
4+
import torch
5+
6+
from circuit_tracer.transcoder.cross_layer_transcoder import CrossLayerTranscoder
7+
from circuit_tracer.transcoder.single_layer_transcoder import SingleLayerTranscoder
8+
from circuit_tracer.utils.disk_offload import (
9+
cleanup_all_offload_files,
10+
cpu_offload_module,
11+
disk_offload_module,
12+
offload_modules,
13+
)
14+
15+
16+
@pytest.fixture
17+
def clt_module():
18+
"""Create a small CLT."""
19+
return CrossLayerTranscoder(
20+
n_layers=2,
21+
d_transcoder=16,
22+
d_model=8,
23+
lazy_decoder=False,
24+
lazy_encoder=False,
25+
device=torch.device("cpu"),
26+
)
27+
28+
29+
@pytest.fixture
30+
def plt_module():
31+
"""Create a small PLT."""
32+
return SingleLayerTranscoder(
33+
d_model=8,
34+
d_transcoder=16,
35+
activation_function=torch.nn.functional.relu,
36+
layer_idx=0,
37+
lazy_decoder=False,
38+
lazy_encoder=False,
39+
device=torch.device("cpu"),
40+
)
41+
42+
43+
@pytest.mark.parametrize("module_fixture", ["clt_module", "plt_module"])
44+
@pytest.mark.parametrize("explicit_device", [True, False])
45+
def test_disk_offload_module(module_fixture, explicit_device, request):
46+
"""Test disk offload with CLT and PLT architectures."""
47+
module = request.getfixturevalue(module_fixture)
48+
49+
# Store original state
50+
orig_param = next(module.parameters()).data.clone()
51+
orig_device = next(module.parameters()).device
52+
53+
# Offload to disk
54+
reload_handle = disk_offload_module(module)
55+
56+
# Verify module is on meta device
57+
assert next(module.parameters()).device.type == "meta"
58+
59+
# Reload with or without explicit device
60+
if explicit_device:
61+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
62+
reload_handle(device=device)
63+
# Should be on the explicitly requested device
64+
assert next(module.parameters()).device.type == device.type
65+
assert torch.allclose(next(module.parameters()).data, orig_param.to(device))
66+
else:
67+
reload_handle()
68+
# Should be restored to original device
69+
assert next(module.parameters()).device.type == orig_device.type
70+
assert torch.allclose(next(module.parameters()).data, orig_param)
71+
72+
73+
@pytest.mark.parametrize("module_fixture", ["clt_module", "plt_module"])
74+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
75+
def test_cpu_offload_module_cuda(module_fixture, request):
76+
"""Test CPU offload with CLT and PLT on CUDA."""
77+
module = request.getfixturevalue(module_fixture)
78+
79+
# Move to CUDA
80+
module.to("cuda")
81+
orig_param = next(module.parameters()).data.clone()
82+
83+
# Offload to CPU
84+
reload_handle = cpu_offload_module(module)
85+
assert next(module.parameters()).device.type == "cpu"
86+
87+
# Reload to CUDA
88+
reload_handle()
89+
assert next(module.parameters()).device.type == "cuda"
90+
assert torch.allclose(next(module.parameters()).data, orig_param.to("cuda"))
91+
92+
93+
def test_cpu_offload_module_cpu(clt_module):
94+
"""Test CPU offload when already on CPU."""
95+
orig_device = next(clt_module.parameters()).device
96+
97+
reload_handle = cpu_offload_module(clt_module)
98+
assert next(clt_module.parameters()).device.type == "cpu"
99+
100+
reload_handle()
101+
assert next(clt_module.parameters()).device == orig_device
102+
103+
104+
@pytest.mark.parametrize(
105+
"modules_factory,expected_count",
106+
[
107+
# Single module
108+
(
109+
lambda: CrossLayerTranscoder(
110+
n_layers=2, d_transcoder=16, d_model=8, lazy_decoder=False, lazy_encoder=False
111+
),
112+
1,
113+
),
114+
# List of CLTs
115+
(
116+
lambda: [
117+
CrossLayerTranscoder(
118+
n_layers=2, d_transcoder=16, d_model=8, lazy_decoder=False, lazy_encoder=False
119+
),
120+
CrossLayerTranscoder(
121+
n_layers=2, d_transcoder=16, d_model=8, lazy_decoder=False, lazy_encoder=False
122+
),
123+
],
124+
2,
125+
),
126+
# ModuleDict with CLTs
127+
(
128+
lambda: torch.nn.ModuleDict(
129+
{
130+
"clt1": CrossLayerTranscoder(
131+
n_layers=2,
132+
d_transcoder=16,
133+
d_model=8,
134+
lazy_decoder=False,
135+
lazy_encoder=False,
136+
),
137+
"clt2": CrossLayerTranscoder(
138+
n_layers=2,
139+
d_transcoder=16,
140+
d_model=8,
141+
lazy_decoder=False,
142+
lazy_encoder=False,
143+
),
144+
}
145+
),
146+
2,
147+
),
148+
],
149+
ids=["single_clt", "list_clt", "moduledict_clt"],
150+
)
151+
@pytest.mark.parametrize("offload_type", ["cpu", "disk"])
152+
def test_offload_modules(modules_factory, expected_count, offload_type):
153+
"""Test offload_modules with various container types using CLT architecture."""
154+
modules = modules_factory()
155+
expected_device = "cpu" if offload_type == "cpu" else "meta"
156+
157+
handles = offload_modules(modules, offload_type=offload_type)
158+
159+
# Verify handles
160+
assert isinstance(handles, list)
161+
assert len(handles) == expected_count
162+
for handle in handles:
163+
assert callable(handle)
164+
165+
# Verify modules are offloaded
166+
if isinstance(modules, torch.nn.Module) and not isinstance(
167+
modules, (torch.nn.ModuleList, torch.nn.ModuleDict, torch.nn.Sequential)
168+
):
169+
assert next(modules.parameters()).device.type == expected_device
170+
else:
171+
module_iter = modules.values() if isinstance(modules, torch.nn.ModuleDict) else modules
172+
for module in module_iter:
173+
assert next(module.parameters()).device.type == expected_device
174+
175+
# Cleanup disk offloads
176+
if offload_type == "disk":
177+
for handle in handles:
178+
handle()
179+
180+
181+
def test_cleanup_offload_files(clt_module):
182+
"""Test cleanup removes offload files."""
183+
# Create some offload files
184+
modules = [clt_module]
185+
offload_modules(modules, offload_type="disk")
186+
187+
# Cleanup
188+
n_removed = cleanup_all_offload_files()
189+
190+
# Should have removed files
191+
assert n_removed >= 1
192+
193+
194+
def test_cleanup_when_no_files():
195+
"""Test cleanup when no offload files exist."""
196+
# First cleanup any existing files
197+
cleanup_all_offload_files()
198+
199+
# Second cleanup should find nothing
200+
n_removed = cleanup_all_offload_files()
201+
assert n_removed == 0

0 commit comments

Comments
 (0)