How to switch dataloader every n training steps #12415
-
Hi everyone, This is what I would like to achieve: I found this solution on the old forum but this only switches the dataset after each epoch. Here is my current attempt at switching it every n batches: class SimpleModule(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = ...
self.batch_size = ...
self.change_every_n_batch = 20
def train_dataloader(self):
self.current_dataset = (self.global_step // self.change_every_n_batch) % 2
if self.current_dataset == 0:
dataset = Dataset1()
elif self.current_dataset == 1:
dataset = Dataset2()
dataloader = DataLoader(dataset, batch_size=self.batch_size)
return dataloader
def on_train_batch_end(self, outputs, batch, batch_idx):
new_dataset = (self.global_step // self.change_every_n_batch) % 2
if new_dataset != self.current_dataset:
self.trainer.reset_train_dataloader(self)
Any idea what could be going wrong? Or do you have a solution for what I want to achieve? Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
hey @matprst ! you can set:
and inside def train_dataloader(self):
if self.some_flag:
dataset = Dataset1()
else:
dataset = Dataset2()
self.some_flag = not self.some_flag
return DataLoader(dataset, batch_size=self.batch_size) |
Beta Was this translation helpful? Give feedback.
-
Works like a charm, and much cleaner than what I thought! Thanks for the reply! I realise now that since I am using iterable datasets (they are large and don't fit into memory), the reloading restarts the iterable from the beginning rather than continuing where it stopped (or at least returning a random batch). This is another problem with the dataset, so I will consider the question answered. |
Beta Was this translation helpful? Give feedback.
hey @matprst !
you can set:
limit_train_batches=n
. This will ensure that every training epoch will progress for only n batchesreload_dataloaders_every_n_epochs=1
. this will ensure that train dataloader is reloaded after every epoch.and inside
train_dataloader
, flip the dataloader on each reload. something like: