Skip to content

Bug in *.attribute -- New tensor in gathering ignoring used device #316

@kai-tub

Description

@kai-tub

Hi, first of all, thanks for providing and working on such a neat library!
I think I found a "bug" in the common.py file, which is used for most attributions. The problem arises when a Cuda device is used and a non-tensor target is given, for example, a simple list.

I don't know if you officially support CUDA or not, but as I couldn't find any hints indicating otherwise, I've used the Saliency function without any problems on a CUDA device, while I was using a tensor as a target. After some refactoring, I changed the target to a simple list and the error occurred. I traced it down to the following line:
return torch.gather(output, 1, torch.tensor(target).reshape(len(output), 1))

As it can be seen a tensor is created but no device information is used. A simple fix would be to look where the output tensor lives:

device = "cuda" if output.is_cuda else "cpu"
return torch.gather(output, 1, torch.tensor(target).reshape(len(output), 1).to(device))
# EDIT alternative:
device = output.device
return torch.gather(output, 1, torch.tensor(target, device=device).reshape(len(output), 1))

(I don't know if this could cause problems if the data lies on a different GPU, as I have no experience with multiple GPUs)

The tensor version works, as it can be moved before on the user side and only reshapes the used one:
return torch.gather(output, 1, target.reshape(len(output), 1))

I've included a minimal example highlighting the problem:

# minimal_example.py
import torch
import torchvision
import numpy as np
from captum.attr import Saliency

# device = "cpu" works!
device = "cuda"
model = torchvision.models.alexnet(pretrained=False)
model.to(device)
model.eval()
X = torch.rand(3, 3, 224, 224, dtype=torch.float)
X = X.to(device)
y_pred = model(X)
saliency = Saliency(model)
X.requires_grad = True
grads = saliency.attribute(X, target=[0, 1])
# working derivations:
# grads = saliency.attribute(X, target=torch.tensor([0, 1, 2]).to(device))
grads.shape
# The reason is the gathering command WITHOUT regarding the current device.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions