88import time
99
1010from functools import partial
11- from typing import Any , Dict , List , Optional , Tuple , Union
11+ from typing import Any , Dict , List , Optional , Union
1212from warnings import warn
1313
1414import torch
2323from torch .distributed ._tensor import DTensor
2424from torch .distributed .tensor .parallel import parallelize_module
2525from torch .optim import Optimizer
26- from torch .utils .data import DataLoader , DistributedSampler
26+ from torchdata .stateful_dataloader import StatefulDataLoader
27+ from torchdata .stateful_dataloader .sampler import StatefulDistributedSampler
2728from torchtune import config , modules , training , utils
2829from torchtune .config ._utils import _get_component_from_path
2930from torchtune .data import padded_collate_packed
@@ -347,7 +348,7 @@ def setup(self, cfg: DictConfig) -> None:
347348 # sampler and dataloader depend on the tokenizer and loss_fn and should be
348349 # setup after both of these are initialized
349350 collate_name = cfg .get ("collate_fn" , "torchtune.data.padded_collate_sft" )
350- self ._sampler , self . _dataloader = self ._setup_data (
351+ self ._dataloader = self ._setup_data (
351352 cfg_dataset = cfg .dataset ,
352353 shuffle = cfg .shuffle ,
353354 batch_size = cfg .batch_size ,
@@ -686,11 +687,12 @@ def _setup_data(
686687 shuffle : bool ,
687688 batch_size : int ,
688689 collate_fn : str ,
689- ) -> Tuple [DistributedSampler , DataLoader ]:
690+ dataloader_state_dict : Optional [Dict [str , Any ]] = None ,
691+ ) -> StatefulDataLoader :
690692 """
691- All data related setup happens here. Currently this recipe only supports the
692- DistributedSamplers with Map -style Datasets which fit into memory. Other samplers ,
693- iterable datasets and streaming datasets are not supported .
693+ All data related setup happens here. This recipe currently supports only
694+ map -style datasets. If a state_dict is provided (meaning we are resuming a training run) ,
695+ it is loaded into the dataloader .
694696 """
695697 if isinstance (cfg_dataset , ListConfig ):
696698 datasets = [
@@ -708,15 +710,13 @@ def _setup_data(
708710 raise RuntimeError ("left_pad_sequence collator is only for inference." )
709711 collate_fn = _get_component_from_path (collate_fn )
710712
711- sampler = DistributedSampler (
712- ds , num_replicas = self .dp_size , rank = self .dp_rank , shuffle = shuffle , seed = 0
713+ sampler = StatefulDistributedSampler (
714+ ds , num_replicas = self .dp_size , rank = self .dp_rank , shuffle = shuffle
713715 )
714- dataloader = DataLoader (
716+ dataloader = StatefulDataLoader (
715717 dataset = ds ,
716718 batch_size = batch_size ,
717719 sampler = sampler ,
718- # dropping last avoids shape issues with compile + flex attention
719- drop_last = True ,
720720 collate_fn = (
721721 partial (
722722 collate_fn ,
@@ -726,11 +726,15 @@ def _setup_data(
726726 if not packed
727727 else padded_collate_packed
728728 ),
729+ # dropping last avoids shape issues with compile + flex attention
730+ drop_last = True ,
729731 )
730-
731- utils .log_rank_zero (log , "Dataset and Sampler are initialized." )
732-
733- return sampler , dataloader
732+ if dataloader_state_dict is not None :
733+ dataloader .load_state_dict (dataloader_state_dict )
734+ # B/c we currently only save at epoch boundaries, if we cut the previous epoch short
735+ # we need to force the dataloader to finish the last iteration before it's actually used
736+ list (dataloader )
737+ return dataloader
734738
735739 def train (self ) -> None :
736740 """
@@ -754,19 +758,9 @@ def train(self) -> None:
754758 self ._profiler .start ()
755759 # self.epochs_run should be non-zero when we're resuming from a checkpoint
756760 for curr_epoch in range (self .epochs_run , self .total_epochs ):
757- # Update the sampler to ensure data is correctly shuffled across epochs
758- # in case shuffle is True
759- self ._sampler .set_epoch (curr_epoch )
760-
761761 pbar = tqdm (total = self ._steps_per_epoch , disable = not self ._is_rank_zero )
762+ self ._dataloader .sampler .set_epoch (curr_epoch )
762763 for idx , batch in enumerate (self ._dataloader ):
763- if (
764- self .max_steps_per_epoch is not None
765- and (idx // self ._gradient_accumulation_steps )
766- == self .max_steps_per_epoch
767- ):
768- break
769-
770764 # Start tracking CUDA memory for active steps for just the first epoch
771765 if (
772766 self ._is_rank_zero
@@ -908,6 +902,11 @@ def train(self) -> None:
908902 # will include multiple forward / backward passes if gradient accumulation > 1
909903 self ._profiler .step ()
910904
905+ if (
906+ (idx + 1 ) // self ._gradient_accumulation_steps
907+ ) == self .max_steps_per_epoch :
908+ break
909+
911910 self .epochs_run += 1
912911 self ._checkpoint_client .save_checkpoint (
913912 model = self ._model ,
@@ -921,6 +920,7 @@ def train(self) -> None:
921920 epochs_run = self .epochs_run ,
922921 total_epochs = self .total_epochs ,
923922 max_steps_per_epoch = self .max_steps_per_epoch ,
923+ dataloader_state_dict = self ._dataloader .state_dict (),
924924 ),
925925 epoch = curr_epoch ,
926926 )
0 commit comments