-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
bugSomething isn't workingSomething isn't workinghelp wantedOpen to be worked onOpen to be worked on
Description
Describe the bug
When I activate the DDP, the test_step function is replaced by None. No problem when I run on one GPU.
To Reproduce
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import pytorch_lightning as pl
import torch.utils.data as data
from pytorch_lightning import Trainer
import torchvision
import torchvision.transforms as transforms
num_workers = 0
classes = (
'plane', 'car', 'bird', 'cat', 'deer', 'dog',
'frog', 'horse', 'ship', 'truck'
)
n_classes = len(classes)
ddp = True
class PlModule(pl.LightningModule):
def __init__(self):
super(PlModule, self).__init__()
model = torchvision.models.squeezenet1_1(True)
model.n_classes = n_classes
final_conv = nn.Conv2d(512, n_classes, kernel_size=1)
model.classifier = nn.Sequential(
nn.Dropout(p=0.5),
final_conv,
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d((1, 1))
)
self.model = model
self.criterion = nn.CrossEntropyLoss()
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_nb):
inputs, targets = batch
outputs = self.forward(inputs)
loss = self.criterion(outputs, targets)
return {"loss": loss}
def test_step(self, batch, batch_nb):
inputs, targets = batch
outputs = self.forward(inputs)
loss = self.criterion(outputs, targets)
return {'test_loss': loss}
def test_end(self, outputs):
metric = [o["test_loss"] for o in outputs]
val_loss = np.sum(metric) / len(outputs)
tqdm_dict = {"test_loss": val_loss}
return {
'test_loss': tqdm_dict["test_loss"],
'progress_bar': tqdm_dict,
}
def configure_optimizers(self):
return optim.AdamW(self.parameters(), lr=0.001)
@pl.data_loader
def train_dataloader(self):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
trainset = torchvision.datasets.CIFAR10(
root='./data', train=True,
download=True, transform=transform)
t_sampler = None
if ddp:
t_sampler = data.distributed.DistributedSampler(trainset)
return torch.utils.data.DataLoader(
trainset,
sampler=t_sampler,
batch_size=400,
num_workers=num_workers,
shuffle=False,
pin_memory=True,
)
@pl.data_loader
def test_dataloader(self):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
testset = torchvision.datasets.CIFAR10(
root='./data', train=False,
download=True, transform=transform)
test_sampler = None
if ddp:
test_sampler = data.distributed.DistributedSampler(testset)
return torch.utils.data.DataLoader(
testset,
sampler=test_sampler,
batch_size=400,
num_workers=num_workers,
shuffle=False,
pin_memory=True,
)
if __name__ == "__main__":
distributed = {
"gpus": 2 if ddp else 1,
"distributed_backend": 'ddp' if ddp else None
}
trainer = Trainer(
logger=False,
checkpoint_callback=False,
early_stop_callback=False,
max_nb_epochs=1,
nb_sanity_val_steps=1,
**distributed
)
model = PlModule()
trainer.fit(model)
trainer.test()
Give the following error:
Traceback (most recent call last):
File "test.py", line 128, in <module>
trainer.test()
File "/home/j/miniconda3/envs/alp36/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 478, in test
self.run_evaluation(test=True)
File "/home/j/miniconda3/envs/alp36/lib/python3.6/site-packages/pytorch_lightning/trainer/evaluation_loop_mixin.py", line 88, in run_evaluation
can_run_test_step = self.is_overriden('test_step') and self.is_overriden('test_end')
File "/home/j/miniconda3/envs/alp36/lib/python3.6/site-packages/pytorch_lightning/trainer/model_hooks_mixin.py", line 16, in is_overriden
is_overriden = getattr(model, f_name).__code__ is not getattr(super_object, f_name).__code__
AttributeError: 'NoneType' object has no attribute 'test_step'
Change ddp = True by ddp = False and no error.
Version:
- pytorch-lightning: 0.5.3.2
- pytorch: 1.3.1
- torchvision: 0.4.2
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workinghelp wantedOpen to be worked onOpen to be worked on