Skip to content

Commit c4f0366

Browse files
committed
Add activations offloading API
1 parent df29d8a commit c4f0366

23 files changed

+412
-3
lines changed

docs/source/tutorials/memory_optimizations.rst

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,33 @@ and in most cases training can slow-down quite a bit as a result of this activat
8383
To enable activation checkpointing, use the ``enable_activation_checkpointing`` config entry or flag
8484
in any of our recipes, e.g. ``enable_activation_checkpointing=True``.
8585

86+
.. _glossary_act_off:
87+
88+
Activation Offloading
89+
---------------------
90+
91+
*What's going on here?*
92+
93+
You may have just read about activation checkpointing! Similar to checkpointing, offloading is a memory
94+
efficiency technique that allows saving GPU VRAM by temporarily moving activations to CPU and bringing
95+
them back when needed in the backward pass.
96+
97+
See `PyTorch autograd hook tutorial <https://pytorch.org/tutorials/intermediate/autograd_saved_tensors_hooks_tutorial.html#saving-tensors-to-cpu>`_
98+
for more details about how this is implemented through saved_tensors_hooks.
99+
100+
This setting is especially helpful for larger batch sizes, or longer context lengths when you're memory constrained.
101+
However, these savings in memory can come at the cost of training speed (i.e. tokens per-second), as it takes runtime
102+
and resources to move Tensors from GPU to CPU and back. The implementation in torchtune uses multiple CUDA streams
103+
in order to overlap the extra communication with the computation to hide the extra runtime. As the communication
104+
workload is variable depending on the number and size of tensors being offloaded, it is common to not offload every
105+
single activation. In fact, once can use offloading in conjunction with activations checkpointing, where all
106+
activations will either be recomputed later in the backward or brought back from the CPU.
107+
108+
*Sounds great! How do I use it?*
109+
110+
To enable activation offloading, use the ``enable_activation_offloading`` config entry or flag
111+
in our lora finetuning single device recipe, e.g. ``enable_activation_offloading=True``.
112+
86113
.. _glossary_grad_accm:
87114

88115
Gradient Accumulation

recipes/configs/code_llama2/7B_lora_single_device.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ device: cuda
7373

7474
# Memory management
7575
enable_activation_checkpointing: True
76+
enable_activation_offloading: False
7677
dtype: bf16
7778

7879
# Logging

recipes/configs/code_llama2/7B_qlora_single_device.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ device: cuda
7373

7474
# Memory management
7575
enable_activation_checkpointing: True
76+
enable_activation_offloading: False
7677
dtype: bf16
7778

7879
# Logging

recipes/configs/gemma/2B_lora_single_device.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ device: cuda
7070

7171
# Memory management
7272
enable_activation_checkpointing: True
73+
enable_activation_offloading: False
7374

7475
# Reduced precision
7576
dtype: bf16

recipes/configs/gemma/2B_qlora_single_device.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ device: cuda
7070

7171
# Memory management
7272
enable_activation_checkpointing: True
73+
enable_activation_offloading: False
7374

7475
# Reduced precision
7576
dtype: bf16

recipes/configs/gemma/7B_lora_single_device.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ device: cuda
7272

7373
# Memory management
7474
enable_activation_checkpointing: True
75+
enable_activation_offloading: False
7576

7677
# Reduced precision
7778
dtype: bf16

recipes/configs/gemma/7B_qlora_single_device.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ device: cuda
7272

7373
# Memory management
7474
enable_activation_checkpointing: True
75+
enable_activation_offloading: False
7576

7677
# Reduced precision
7778
dtype: bf16

recipes/configs/llama2/13B_qlora_single_device.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,9 @@ log_peak_memory_stats: False
8080
# Environment
8181
device: cuda
8282
dtype: bf16
83+
8384
enable_activation_checkpointing: True
85+
enable_activation_offloading: False
8486

8587
# Show case the usage of pytorch profiler
8688
# Set enabled to False as it's only needed for debugging training

recipes/configs/llama2/7B_lora_single_device.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,10 @@ log_peak_memory_stats: False
8080
# Environment
8181
device: cuda
8282
dtype: bf16
83+
84+
# Activations Memory
8385
enable_activation_checkpointing: True
86+
enable_activation_offloading: False
8487

8588
# Show case the usage of pytorch profiler
8689
# Set enabled to False as it's only needed for debugging training

recipes/configs/llama2/7B_qlora_single_device.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,10 @@ log_peak_memory_stats: False
7979
# Environment
8080
device: cuda
8181
dtype: bf16
82+
83+
# Activations Memory
8284
enable_activation_checkpointing: True
85+
enable_activation_offloading: False
8386

8487
# Show case the usage of pytorch profiler
8588
# Set enabled to False as it's only needed for debugging training

0 commit comments

Comments
 (0)