Skip to content

Commit f1178c7

Browse files
committed
Address Felipe's prev review i missed, try docs again
1 parent c4f0366 commit f1178c7

File tree

4 files changed

+17
-10
lines changed

4 files changed

+17
-10
lines changed

docs/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ ifneq ($(EXAMPLES_PATTERN),)
66
endif
77

88
# You can set these variables from the command line.
9-
SPHINXOPTS = -W -j auto $(EXAMPLES_PATTERN_OPTS)
9+
SPHINXOPTS = -W -j auto $(EXAMPLES_PATTERN_OPTS) -T -v
1010
SPHINXBUILD = sphinx-build
1111
SPHINXPROJ = torchtune
1212
SOURCEDIR = source

recipes/lora_finetune_single_device.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,12 @@
3131
validate_missing_and_unexpected_for_lora,
3232
)
3333
from torchtune.recipe_interfaces import FTRecipeInterface
34-
from torchtune.training import DummyProfiler, PROFILER_KEY
35-
from torchtune.training._activation_offloading import NoOpManager, OffloadActivations
34+
from torchtune.training import (
35+
DummyProfiler,
36+
NoOpManager,
37+
OffloadActivations,
38+
PROFILER_KEY,
39+
)
3640
from tqdm import tqdm
3741

3842
log = utils.get_logger("DEBUG")

torchtune/training/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6+
from torchtune.training._activation_offloading import NoOpManager, OffloadActivations
67
from torchtune.training._compile import compile_loss, compile_model
78
from torchtune.training._distributed import (
89
contains_fsdp,
@@ -122,4 +123,6 @@
122123
"setup_torch_profiler",
123124
"compile_loss",
124125
"compile_model",
126+
"NoOpManager",
127+
"OffloadActivations",
125128
]

torchtune/training/_activation_offloading.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -137,11 +137,11 @@ def pack_tensor(activation: torch.Tensor) -> int:
137137
if use_streams:
138138
# First, sync back and dereference previously offloaded tensors
139139
# as the offloading should be done sufficiently long ago.
140-
for k in [x for x in self.fwd_stash.keys()]:
141-
if k <= tensor_id - self.max_fwd_stash_size:
142-
_, ev = self.fwd_stash[k]
140+
for id in [k for k in self.fwd_stash.keys()]:
141+
if id <= tensor_id - self.max_fwd_stash_size:
142+
_, ev = self.fwd_stash[id]
143143
self.s0.wait_event(ev)
144-
del self.fwd_stash[k]
144+
del self.fwd_stash[id]
145145
else:
146146
break
147147

@@ -266,10 +266,10 @@ def hook(outputs, inputs):
266266
self.bwd_ev_stash[unpack_tensor_id] = event
267267

268268
# if there are still things in the fwd_stash, get rid of them as we're in bwd now
269-
for k in [x for x in self.fwd_stash.keys()]:
270-
_, ev = self.fwd_stash[k]
269+
for id in [k for k in self.fwd_stash.keys()]:
270+
_, ev = self.fwd_stash[id]
271271
self.s0.wait_event(ev)
272-
del self.fwd_stash[k]
272+
del self.fwd_stash[id]
273273

274274
# wait on prev node's events and del those
275275
for id in prev_node_ids:

0 commit comments

Comments
 (0)