77import sys
88import time
99from functools import partial
10- from typing import Any , Dict , Optional , Tuple , Union
10+ from typing import Any , Dict , Optional , Union
1111from warnings import warn
1212
1313import torch
1414from omegaconf import DictConfig , ListConfig
1515
1616from torch import nn
1717from torch .optim import Optimizer
18- from torch . utils . data import DataLoader , DistributedSampler
18+ from torchdata . stateful_dataloader import StatefulDataLoader
1919
2020from torchtune import config , modules , training , utils
2121from torchtune .config ._utils import _get_component_from_path
@@ -302,11 +302,16 @@ def setup(self, cfg: DictConfig) -> None:
302302 # sampler and dataloader depend on the tokenizer and loss_fn and should be
303303 # setup after both of these are initialized
304304 collate_name = cfg .get ("collate_fn" , "torchtune.data.padded_collate_sft" )
305- self ._sampler , self . _dataloader = self ._setup_data (
305+ self ._dataloader = self ._setup_data (
306306 cfg_dataset = cfg .dataset ,
307307 shuffle = cfg .shuffle ,
308308 batch_size = cfg .batch_size ,
309309 collate_fn = collate_name ,
310+ dataloader_state_dict = (
311+ ckpt_dict [training .DATALOADER_KEY ]
312+ if self ._resume_from_checkpoint
313+ else None
314+ ),
310315 )
311316
312317 # Finally update the recipe state which can only be correctly set after all of the
@@ -548,11 +553,12 @@ def _setup_data(
548553 shuffle : bool ,
549554 batch_size : int ,
550555 collate_fn : str ,
551- ) -> Tuple [DistributedSampler , DataLoader ]:
556+ dataloader_state_dict : Optional [Dict [str , Any ]] = None ,
557+ ) -> StatefulDataLoader :
552558 """
553- All data related setup happens here. Currently this recipe only supports the
554- DistributedSamplers with Map -style Datasets which fit into memory. Other samplers ,
555- iterable datasets and streaming datasets are not supported .
559+ All data related setup happens here. This recipe currently supports only
560+ map -style datasets. If a state_dict is provided (meaning we are resuming a training run) ,
561+ it is loaded into the dataloader .
556562 """
557563 if isinstance (cfg_dataset , ListConfig ):
558564 datasets = [
@@ -570,19 +576,10 @@ def _setup_data(
570576 raise RuntimeError ("left_pad_sequence collator is only for inference." )
571577 collate_fn = _get_component_from_path (collate_fn )
572578
573- sampler = DistributedSampler (
574- ds ,
575- num_replicas = 1 ,
576- rank = 0 ,
577- shuffle = shuffle ,
578- seed = 0 ,
579- )
580- dataloader = DataLoader (
579+ dataloader = StatefulDataLoader (
581580 dataset = ds ,
582581 batch_size = batch_size ,
583- sampler = sampler ,
584- # dropping last avoids shape issues with compile + flex attention
585- drop_last = True ,
582+ shuffle = shuffle ,
586583 collate_fn = (
587584 partial (
588585 collate_fn ,
@@ -592,11 +589,12 @@ def _setup_data(
592589 if not packed
593590 else padded_collate_packed
594591 ),
592+ # dropping last avoids shape issues with compile + flex attention
593+ drop_last = True ,
595594 )
596-
597- log .info ("Dataset and Sampler are initialized." )
598-
599- return sampler , dataloader
595+ if dataloader_state_dict is not None :
596+ dataloader .load_state_dict (dataloader_state_dict )
597+ return dataloader
600598
601599 def save_checkpoint (self , epoch : int ) -> None :
602600 """
@@ -606,12 +604,16 @@ def save_checkpoint(self, epoch: int) -> None:
606604 ckpt_dict = {training .MODEL_KEY : self ._model .state_dict ()}
607605 # if training is in-progress, checkpoint the optimizer state as well
608606 if epoch + 1 < self .total_epochs :
607+ dataloader_sd = self ._dataloader .state_dict ()
608+ # Hardcode _iterator_finished to True to avoid issues with resuming from a checkpoint
609+ dataloader_sd ["_iterator_finished" ] = True
609610 ckpt_dict .update (
610611 {
611612 training .SEED_KEY : self .seed ,
612613 training .EPOCHS_KEY : self .epochs_run ,
613614 training .TOTAL_EPOCHS_KEY : self .total_epochs ,
614615 training .MAX_STEPS_KEY : self .max_steps_per_epoch ,
616+ training .DATALOADER_KEY : dataloader_sd ,
615617 }
616618 )
617619 if not self ._optimizer_in_bwd :
@@ -669,19 +671,8 @@ def train(self) -> None:
669671 self ._profiler .start ()
670672 # self.epochs_run should be non-zero when we're resuming from a checkpoint
671673 for curr_epoch in range (self .epochs_run , self .total_epochs ):
672- # Update the sampler to ensure data is correctly shuffled across epochs
673- # in case shuffle is True
674- self ._sampler .set_epoch (curr_epoch )
675-
676674 pbar = tqdm (total = self ._steps_per_epoch )
677675 for idx , batch in enumerate (self ._dataloader ):
678- if (
679- self .max_steps_per_epoch is not None
680- and (idx // self ._gradient_accumulation_steps )
681- == self .max_steps_per_epoch
682- ):
683- break
684-
685676 # Start tracking CUDA memory for active steps for just the first epoch
686677 if (
687678 curr_epoch == 0
@@ -777,6 +768,12 @@ def train(self) -> None:
777768 # if the schedule cycle doesn't align with gradient accumulation.
778769 self ._profiler .step ()
779770
771+ # Check if we should stop training for this epoch
772+ if (
773+ (idx + 1 ) // self ._gradient_accumulation_steps
774+ ) == self .max_steps_per_epoch :
775+ break
776+
780777 self .epochs_run += 1
781778 self .save_checkpoint (epoch = curr_epoch )
782779
0 commit comments