Skip to content

Commit 7b654ea

Browse files
authored
Add support for StatefulDataLoader (#2410)
1 parent cf0142b commit 7b654ea

File tree

5 files changed

+49
-36
lines changed

5 files changed

+49
-36
lines changed

recipes/full_finetune_single_device.py

Lines changed: 30 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,15 @@
77
import sys
88
import time
99
from functools import partial
10-
from typing import Any, Dict, Optional, Tuple, Union
10+
from typing import Any, Dict, Optional, Union
1111
from warnings import warn
1212

1313
import torch
1414
from omegaconf import DictConfig, ListConfig
1515

1616
from torch import nn
1717
from torch.optim import Optimizer
18-
from torch.utils.data import DataLoader, DistributedSampler
18+
from torchdata.stateful_dataloader import StatefulDataLoader
1919

2020
from torchtune import config, modules, training, utils
2121
from torchtune.config._utils import _get_component_from_path
@@ -302,11 +302,16 @@ def setup(self, cfg: DictConfig) -> None:
302302
# sampler and dataloader depend on the tokenizer and loss_fn and should be
303303
# setup after both of these are initialized
304304
collate_name = cfg.get("collate_fn", "torchtune.data.padded_collate_sft")
305-
self._sampler, self._dataloader = self._setup_data(
305+
self._dataloader = self._setup_data(
306306
cfg_dataset=cfg.dataset,
307307
shuffle=cfg.shuffle,
308308
batch_size=cfg.batch_size,
309309
collate_fn=collate_name,
310+
dataloader_state_dict=(
311+
ckpt_dict[training.DATALOADER_KEY]
312+
if self._resume_from_checkpoint
313+
else None
314+
),
310315
)
311316

312317
# Finally update the recipe state which can only be correctly set after all of the
@@ -548,11 +553,12 @@ def _setup_data(
548553
shuffle: bool,
549554
batch_size: int,
550555
collate_fn: str,
551-
) -> Tuple[DistributedSampler, DataLoader]:
556+
dataloader_state_dict: Optional[Dict[str, Any]] = None,
557+
) -> StatefulDataLoader:
552558
"""
553-
All data related setup happens here. Currently this recipe only supports the
554-
DistributedSamplers with Map-style Datasets which fit into memory. Other samplers,
555-
iterable datasets and streaming datasets are not supported.
559+
All data related setup happens here. This recipe currently supports only
560+
map-style datasets. If a state_dict is provided (meaning we are resuming a training run),
561+
it is loaded into the dataloader.
556562
"""
557563
if isinstance(cfg_dataset, ListConfig):
558564
datasets = [
@@ -570,19 +576,10 @@ def _setup_data(
570576
raise RuntimeError("left_pad_sequence collator is only for inference.")
571577
collate_fn = _get_component_from_path(collate_fn)
572578

573-
sampler = DistributedSampler(
574-
ds,
575-
num_replicas=1,
576-
rank=0,
577-
shuffle=shuffle,
578-
seed=0,
579-
)
580-
dataloader = DataLoader(
579+
dataloader = StatefulDataLoader(
581580
dataset=ds,
582581
batch_size=batch_size,
583-
sampler=sampler,
584-
# dropping last avoids shape issues with compile + flex attention
585-
drop_last=True,
582+
shuffle=shuffle,
586583
collate_fn=(
587584
partial(
588585
collate_fn,
@@ -592,11 +589,12 @@ def _setup_data(
592589
if not packed
593590
else padded_collate_packed
594591
),
592+
# dropping last avoids shape issues with compile + flex attention
593+
drop_last=True,
595594
)
596-
597-
log.info("Dataset and Sampler are initialized.")
598-
599-
return sampler, dataloader
595+
if dataloader_state_dict is not None:
596+
dataloader.load_state_dict(dataloader_state_dict)
597+
return dataloader
600598

601599
def save_checkpoint(self, epoch: int) -> None:
602600
"""
@@ -606,12 +604,16 @@ def save_checkpoint(self, epoch: int) -> None:
606604
ckpt_dict = {training.MODEL_KEY: self._model.state_dict()}
607605
# if training is in-progress, checkpoint the optimizer state as well
608606
if epoch + 1 < self.total_epochs:
607+
dataloader_sd = self._dataloader.state_dict()
608+
# Hardcode _iterator_finished to True to avoid issues with resuming from a checkpoint
609+
dataloader_sd["_iterator_finished"] = True
609610
ckpt_dict.update(
610611
{
611612
training.SEED_KEY: self.seed,
612613
training.EPOCHS_KEY: self.epochs_run,
613614
training.TOTAL_EPOCHS_KEY: self.total_epochs,
614615
training.MAX_STEPS_KEY: self.max_steps_per_epoch,
616+
training.DATALOADER_KEY: dataloader_sd,
615617
}
616618
)
617619
if not self._optimizer_in_bwd:
@@ -669,19 +671,8 @@ def train(self) -> None:
669671
self._profiler.start()
670672
# self.epochs_run should be non-zero when we're resuming from a checkpoint
671673
for curr_epoch in range(self.epochs_run, self.total_epochs):
672-
# Update the sampler to ensure data is correctly shuffled across epochs
673-
# in case shuffle is True
674-
self._sampler.set_epoch(curr_epoch)
675-
676674
pbar = tqdm(total=self._steps_per_epoch)
677675
for idx, batch in enumerate(self._dataloader):
678-
if (
679-
self.max_steps_per_epoch is not None
680-
and (idx // self._gradient_accumulation_steps)
681-
== self.max_steps_per_epoch
682-
):
683-
break
684-
685676
# Start tracking CUDA memory for active steps for just the first epoch
686677
if (
687678
curr_epoch == 0
@@ -777,6 +768,12 @@ def train(self) -> None:
777768
# if the schedule cycle doesn't align with gradient accumulation.
778769
self._profiler.step()
779770

771+
# Check if we should stop training for this epoch
772+
if (
773+
(idx + 1) // self._gradient_accumulation_steps
774+
) == self.max_steps_per_epoch:
775+
break
776+
780777
self.epochs_run += 1
781778
self.save_checkpoint(epoch=curr_epoch)
782779

tests/recipes/test_full_finetune_single_device.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ def _get_test_config_overrides(self):
5656

5757
def _fetch_expected_loss_values(self, model_type):
5858
loss_values_map = {
59-
"llama2": [10.5201, 10.5217, 10.4945, 10.5136],
60-
"llama3": [11.9839, 11.9684, 11.9596, 11.9366],
59+
"llama2": [10.5219, 10.5292, 10.5475, 10.5195],
60+
"llama3": [11.9611, 11.9432, 11.9326, 11.9807],
6161
}
6262

6363
return loss_values_map[model_type]
@@ -153,7 +153,7 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch, optimizer_in_bwd):
153153
ckpt = "llama2_hf"
154154
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
155155
ckpt_dir = ckpt_path.parent
156-
log_file = gen_log_file_name(tmpdir)
156+
first_log_file = gen_log_file_name(tmpdir, suffix="first")
157157

158158
# Config file needed for model conversion.
159159
# Create a second copy for training resume
@@ -173,6 +173,7 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch, optimizer_in_bwd):
173173
checkpointer.model_type=LLAMA2 \
174174
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
175175
tokenizer.prompt_template=null \
176+
metric_logger.filename={first_log_file} \
176177
optimizer_in_bwd={optimizer_in_bwd} \
177178
""".split()
178179

@@ -183,7 +184,15 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch, optimizer_in_bwd):
183184
with pytest.raises(SystemExit, match=""):
184185
runpy.run_path(TUNE_PATH, run_name="__main__")
185186

187+
# Sanity check that the loss values are expected for the initial run
188+
expected_loss_values = self._fetch_expected_loss_values("llama2")
189+
loss_values = get_loss_values_from_metric_logger(first_log_file)
190+
torch.testing.assert_close(
191+
loss_values, expected_loss_values, rtol=1e-4, atol=1e-4
192+
)
193+
186194
# Resume training
195+
log_file = gen_log_file_name(tmpdir, suffix="resume")
187196
epoch_folder = get_largest_iter_folder(tmpdir)
188197
epoch_folder_minus_one = f"epoch_{int(epoch_folder.split('_')[-1]) - 1}"
189198
suffix = ".safetensors"

torchtune/training/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
ADAPTER_CONFIG,
4141
ADAPTER_KEY,
4242
Checkpointer,
43+
DATALOADER_KEY,
4344
DistributedCheckpointer,
4445
EPOCHS_KEY,
4546
FormattedCheckpointFiles,
@@ -138,4 +139,5 @@
138139
"scale_grads",
139140
"get_distributed_backend",
140141
"disable_dropout",
142+
"DATALOADER_KEY",
141143
]

torchtune/training/checkpointing/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from torchtune.training.checkpointing._utils import (
1515
ADAPTER_CONFIG,
1616
ADAPTER_KEY,
17+
DATALOADER_KEY,
1718
EPOCHS_KEY,
1819
FormattedCheckpointFiles,
1920
get_largest_iter_folder,
@@ -55,4 +56,5 @@
5556
"STEPS_KEY",
5657
"TOTAL_EPOCHS_KEY",
5758
"FormattedCheckpointFiles",
59+
"DATALOADER_KEY",
5860
]

torchtune/training/checkpointing/_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@
7171
# rng state for ensuring correct training resuming in PPO
7272
RNG_KEY = "rng_state"
7373

74+
# key used for dataloader state
75+
DATALOADER_KEY = "dataloader"
76+
7477

7578
class ModelType(Enum):
7679
"""ModelType is used by the checkpointer to distinguish between different model architectures.

0 commit comments

Comments
 (0)