Skip to content

How to switch dataloader every n training steps #12415

Discussion options

You must be logged in to vote

hey @matprst !

you can set:

  • limit_train_batches=n. This will ensure that every training epoch will progress for only n batches
  • reload_dataloaders_every_n_epochs=1. this will ensure that train dataloader is reloaded after every epoch.

and inside train_dataloader, flip the dataloader on each reload. something like:

def train_dataloader(self):
    if self.some_flag:
        dataset = Dataset1()
    else:
        dataset = Dataset2()

    self.some_flag = not self.some_flag

    return DataLoader(dataset, batch_size=self.batch_size)

Replies: 2 comments 1 reply

Comment options

You must be logged in to vote
1 reply
@wenhaoli-xmu
Comment options

Answer selected by matprst
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment