We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent b35fa29 commit 206cc88Copy full SHA for 206cc88
torchtune/training/_activation_offloading.py
@@ -269,9 +269,7 @@ def wait_and_del_remaining_references() -> None:
269
else:
270
# Kick off the process to bring tensors back
271
with torch.cuda.stream(self.s1):
272
- gpu_tensor = maybe_gpu_tensor.to(
273
- device="cuda", non_blocking=True
274
- )
+ gpu_tensor = maybe_gpu_tensor.to("cuda", non_blocking=True)
275
maybe_gpu_tensor = gpu_tensor
276
277
# Tell comp stream to wait for the info to be loaded before executing
0 commit comments