Skip to content

TCAV: cannot run interpret() in cuda #721

@pdpino

Description

@pdpino

🐛 Bug

When using a model in the GPU, running tcav.interpret(...) throws a wrong-device error.

To Reproduce

Steps to reproduce the behavior:

  1. Work-around the TCAV: cannot run compute_cavs() in cuda #719 bug by running tcav.compute_cavs() in the CPU, which will save CAV vectors to ./cav
  2. Run the following code:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from captum.concept import TCAV, Concept

DEVICE = 'cuda'

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 10, 10)
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.flatten = nn.Flatten()
        self.classifier = nn.Linear(10, 1)
    def forward(self, images):
        # images shape: batch_size, 3, height, width
        x = self.conv(images) # shape: batch_size, 10, features-height, features-width
        x = self.pool(x) # shape: batch_size, 10, 1, 1
        x = self.flatten(x) # shape: batch_size, 10
        x = self.classifier(x) # shape: batch_size, 1
        return x

class DummyDataset(Dataset):
    def __init__(self, device='cpu'):
        super().__init__()
        self.device = device
    def __getitem__(self, idx):
        image = torch.zeros(3, 256, 256)
        return image.to(self.device)
    def __len__(self):
        return 10

model = MyModel().to(DEVICE)

concept0 = Concept(0, 'concept0', DataLoader(DummyDataset(device=DEVICE), batch_size=10))
concept1 = Concept(1, 'concept1', DataLoader(DummyDataset(device=DEVICE), batch_size=10))

tcav = TCAV(model, layers='conv')

inputs = torch.rand(7, 3, 256, 256).to(DEVICE)
scores = tcav.interpret(inputs, [[concept0, concept1]])

The tcav.interpret(...) line throws: RuntimeError: Input, output and indices must be on the current device.
The full stack:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-1-6775c2780feb> in <module>
     39 
     40 inputs = torch.rand(7, 3, 256, 256).to(DEVICE)
---> 41 scores = tcav.interpret(inputs, [[concept0, concept1]])

~/software/captum/captum/concept/_core/tcav.py in interpret(self, inputs, experimental_sets, target, additional_forward_args, processes, **kwargs)
    597                     cav_subset,
    598                     classes_subset,
--> 599                     experimental_subset_sorted,
    600                 )
    601                 i += 1

~/software/captum/captum/concept/_core/tcav.py in _tcav_sub_computation(self, scores, layer, attribs, cavs, classes, experimental_sets)
    646             scores[concepts_key][layer] = {
    647                 "sign_count": torch.index_select(
--> 648                     sign_count_score[i, :], dim=0, index=new_ord
    649                 ),
    650                 "magnitude": torch.index_select(

RuntimeError: Input, output and indices must be on the current device

Expected behavior

The method interpret() should run without errors, in any of CPU or GPU (or the docs should state that only CPU is supported?)

Environment

  • Captum / PyTorch Version: captum 0.4.0, torch 1.7.1+cu110
  • OS (e.g., Linux): Ubuntu 18.04.5
  • How you installed Captum / PyTorch (conda, pip, source): source
  • Build command you used (if compiling from source): pip install -e ~/software/captum
  • Python version: 3.6
  • CUDA/cuDNN version: cu110
  • GPU models and configuration: using a GPU RTX 3090
  • Any other relevant information: I'm running captum in the master branch, latest commit is f658185

Additional context

  • I was able to hot-fix it by changing this line in the TCAV()._tcav_sub_computation() method to:
new_ord = torch.tensor([concept_ord[cls] for cls in cls_set], device=sign_count_score.device)

i.e. specifying the device param

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions