-
Notifications
You must be signed in to change notification settings - Fork 536
Description
Hi, first of all, thanks for providing and working on such a neat library!
I think I found a "bug" in the common.py
file, which is used for most attributions. The problem arises when a Cuda device is used and a non-tensor target is given, for example, a simple list.
I don't know if you officially support CUDA or not, but as I couldn't find any hints indicating otherwise, I've used the Saliency function without any problems on a CUDA device, while I was using a tensor as a target. After some refactoring, I changed the target to a simple list and the error occurred. I traced it down to the following line:
return torch.gather(output, 1, torch.tensor(target).reshape(len(output), 1))
As it can be seen a tensor is created but no device information is used. A simple fix would be to look where the output tensor lives:
device = "cuda" if output.is_cuda else "cpu"
return torch.gather(output, 1, torch.tensor(target).reshape(len(output), 1).to(device))
# EDIT alternative:
device = output.device
return torch.gather(output, 1, torch.tensor(target, device=device).reshape(len(output), 1))
(I don't know if this could cause problems if the data lies on a different GPU, as I have no experience with multiple GPUs)
The tensor version works, as it can be moved before on the user side and only reshapes the used one:
return torch.gather(output, 1, target.reshape(len(output), 1))
I've included a minimal example highlighting the problem:
# minimal_example.py
import torch
import torchvision
import numpy as np
from captum.attr import Saliency
# device = "cpu" works!
device = "cuda"
model = torchvision.models.alexnet(pretrained=False)
model.to(device)
model.eval()
X = torch.rand(3, 3, 224, 224, dtype=torch.float)
X = X.to(device)
y_pred = model(X)
saliency = Saliency(model)
X.requires_grad = True
grads = saliency.attribute(X, target=[0, 1])
# working derivations:
# grads = saliency.attribute(X, target=torch.tensor([0, 1, 2]).to(device))
grads.shape
# The reason is the gathering command WITHOUT regarding the current device.