-
Notifications
You must be signed in to change notification settings - Fork 537
Closed
Description
🐛 Bug
When using a model in the GPU, running tcav.interpret(...)
throws a wrong-device error.
To Reproduce
Steps to reproduce the behavior:
- Work-around the TCAV: cannot run compute_cavs() in cuda #719 bug by running
tcav.compute_cavs()
in the CPU, which will save CAV vectors to./cav
- Run the following code:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from captum.concept import TCAV, Concept
DEVICE = 'cuda'
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 10, 10)
self.pool = nn.AdaptiveAvgPool2d((1, 1))
self.flatten = nn.Flatten()
self.classifier = nn.Linear(10, 1)
def forward(self, images):
# images shape: batch_size, 3, height, width
x = self.conv(images) # shape: batch_size, 10, features-height, features-width
x = self.pool(x) # shape: batch_size, 10, 1, 1
x = self.flatten(x) # shape: batch_size, 10
x = self.classifier(x) # shape: batch_size, 1
return x
class DummyDataset(Dataset):
def __init__(self, device='cpu'):
super().__init__()
self.device = device
def __getitem__(self, idx):
image = torch.zeros(3, 256, 256)
return image.to(self.device)
def __len__(self):
return 10
model = MyModel().to(DEVICE)
concept0 = Concept(0, 'concept0', DataLoader(DummyDataset(device=DEVICE), batch_size=10))
concept1 = Concept(1, 'concept1', DataLoader(DummyDataset(device=DEVICE), batch_size=10))
tcav = TCAV(model, layers='conv')
inputs = torch.rand(7, 3, 256, 256).to(DEVICE)
scores = tcav.interpret(inputs, [[concept0, concept1]])
The tcav.interpret(...)
line throws: RuntimeError: Input, output and indices must be on the current device
.
The full stack:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-1-6775c2780feb> in <module>
39
40 inputs = torch.rand(7, 3, 256, 256).to(DEVICE)
---> 41 scores = tcav.interpret(inputs, [[concept0, concept1]])
~/software/captum/captum/concept/_core/tcav.py in interpret(self, inputs, experimental_sets, target, additional_forward_args, processes, **kwargs)
597 cav_subset,
598 classes_subset,
--> 599 experimental_subset_sorted,
600 )
601 i += 1
~/software/captum/captum/concept/_core/tcav.py in _tcav_sub_computation(self, scores, layer, attribs, cavs, classes, experimental_sets)
646 scores[concepts_key][layer] = {
647 "sign_count": torch.index_select(
--> 648 sign_count_score[i, :], dim=0, index=new_ord
649 ),
650 "magnitude": torch.index_select(
RuntimeError: Input, output and indices must be on the current device
Expected behavior
The method interpret()
should run without errors, in any of CPU or GPU (or the docs should state that only CPU is supported?)
Environment
- Captum / PyTorch Version: captum 0.4.0, torch 1.7.1+cu110
- OS (e.g., Linux): Ubuntu 18.04.5
- How you installed Captum / PyTorch (
conda
,pip
, source): source - Build command you used (if compiling from source):
pip install -e ~/software/captum
- Python version: 3.6
- CUDA/cuDNN version: cu110
- GPU models and configuration: using a GPU RTX 3090
- Any other relevant information: I'm running captum in the master branch, latest commit is f658185
Additional context
- I was able to hot-fix it by changing this line in the
TCAV()._tcav_sub_computation()
method to:
new_ord = torch.tensor([concept_ord[cls] for cls in cls_set], device=sign_count_score.device)
i.e. specifying the device
param
Metadata
Metadata
Assignees
Labels
No labels