Skip to content

Commit f89ed8f

Browse files
committed
Refactor dpo and ppo recipe by introducing disable_dropout utility function
1 parent 9920d4f commit f89ed8f

File tree

5 files changed

+29
-32
lines changed

5 files changed

+29
-32
lines changed

recipes/full_dpo_distributed.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from torchtune.data import CROSS_ENTROPY_IGNORE_IDX, padded_collate_dpo
2121
from torchtune.datasets import ConcatDataset
2222
from torchtune.recipe_interfaces import FTRecipeInterface
23-
from torchtune.training import DummyProfiler, PROFILER_KEY
23+
from torchtune.training import disable_dropout, DummyProfiler, PROFILER_KEY
2424
from torchtune.training.lr_schedulers import get_lr
2525
from torchtune.utils import get_world_size_and_rank
2626
from tqdm import tqdm
@@ -494,12 +494,7 @@ def _setup_model(
494494

495495
# disabling dropout if found - non-determinism leads to issues in e.g. comparing logprobs
496496
# between ref policy and current policy
497-
for module in model.modules():
498-
if isinstance(module, torch.nn.Dropout):
499-
warn(
500-
f"Dropout found in {module}. This is likely to cause issues during training. Disabling."
501-
)
502-
module.p = 0
497+
disable_dropout(model)
503498

504499
# synchronize before training begins
505500
torch.distributed.barrier()
@@ -581,12 +576,7 @@ def _setup_reference_model(
581576

582577
# disabling dropout if found - non-determinism leads to issues in e.g. comparing logprobs
583578
# between ref policy and current policy
584-
for module in model.modules():
585-
if isinstance(module, torch.nn.Dropout):
586-
warn(
587-
f"Dropout found in {module}. This is likely to cause issues during training. Disabling."
588-
)
589-
module.p = 0
579+
disable_dropout(model)
590580

591581
for p in model.parameters():
592582
p.requires_grad = False

recipes/ppo_full_finetune_single_device.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from torchtune.modules import local_kv_cache
2525
from torchtune.recipe_interfaces import FTRecipeInterface
2626
from torchtune.rlhf import PPOStats, Trajectory
27-
from torchtune.training import DummyProfiler, PROFILER_KEY
27+
from torchtune.training import disable_dropout, DummyProfiler, PROFILER_KEY
2828
from tqdm import tqdm
2929

3030
log = utils.get_logger("DEBUG")
@@ -568,20 +568,10 @@ def _setup_models(
568568

569569
# disabling dropout if found - non-determinism leads to issues in e.g. comparing logprobs
570570
# between ref policy and current policy
571-
for module in policy_model.modules():
572-
if isinstance(module, torch.nn.Dropout):
573-
warn(
574-
f"Dropout found in {module}. This is likely to cause issues during training. Disabling."
575-
)
576-
module.p = 0
577-
for module in value_model.modules():
578-
if isinstance(module, torch.nn.Dropout):
579-
warn(
580-
f"Dropout found in {module}. This is likely to cause issues during training. Disabling."
581-
)
582-
module.p = 0
571+
disable_dropout(policy_model)
572+
disable_dropout(value_model)
583573

584-
# disabling grad and dropout in reward and reference policy models
574+
# disabling grad in reward and reference policy models
585575
reward_model.eval()
586576
ref_policy_model.eval()
587577

tests/torchtune/training/test_model_util.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import warnings
88

99
import torch
10-
from torchtune.training.model_util import disable_dropout
10+
from torchtune.training._model_util import disable_dropout
1111

1212

1313
class TestDisableDropout:
@@ -28,17 +28,32 @@ def test_disable_dropout(self):
2828

2929
def test_disable_dropout_warning(self):
3030
"""
31-
Tests that a warning is issued when dropout layers are found and disabled.
31+
Tests that correct number warning is issued when dropout layers are found and disabled.
3232
"""
3333
model = torch.nn.Sequential(
3434
torch.nn.Linear(10, 10),
3535
torch.nn.Dropout(p=0.5),
3636
torch.nn.ReLU(),
3737
torch.nn.Dropout(p=0.3),
38+
torch.nn.Dropout(p=0.0),
3839
)
3940
with warnings.catch_warnings(record=True) as w:
4041
warnings.simplefilter("always")
4142
disable_dropout(model)
4243
assert len(w) == 2, "Expected 2 warnings for 2 dropout layers."
4344
assert issubclass(w[-1].category, UserWarning)
44-
assert "Dropout found in" in str(w[-1].message)
45+
assert "Found Dropout with" in str(w[-1].message)
46+
47+
def test_disable_dropout_no_warning(self):
48+
"""
49+
Tests that no warning is issued when there are no dropout layers.
50+
"""
51+
model = torch.nn.Sequential(
52+
torch.nn.Linear(10, 10),
53+
torch.nn.ReLU(),
54+
torch.nn.Linear(10, 10),
55+
)
56+
with warnings.catch_warnings(record=True) as w:
57+
warnings.simplefilter("always")
58+
disable_dropout(model)
59+
assert len(w) == 0, "Expected no warnings when there are no dropout layers."

torchtune/training/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
validate_no_params_on_meta_device,
2626
)
2727
from torchtune.training._grad_scaler import scale_grads
28+
from torchtune.training._model_util import disable_dropout
2829
from torchtune.training._profiler import (
2930
DEFAULT_PROFILE_DIR,
3031
DEFAULT_PROFILER_ACTIVITIES,
@@ -135,4 +136,5 @@
135136
"OffloadActivations",
136137
"FormattedCheckpointFiles",
137138
"scale_grads",
139+
"disable_dropout",
138140
]
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ def disable_dropout(model: torch.nn.Module) -> None:
1616
model (torch.nn.Module): The model in which dropout layers should be disabled.
1717
"""
1818
for module in model.modules():
19-
if isinstance(module, torch.nn.Dropout):
19+
if isinstance(module, torch.nn.Dropout) and module.p != 0:
2020
warnings.warn(
21-
f"Dropout found in {module}. This is likely to cause issues during training. Disabling."
21+
f"Found Dropout with value {module.p} in module {module}. Setting to zero."
2222
)
2323
module.p = 0

0 commit comments

Comments
 (0)