Skip to content

Commit 3a697e3

Browse files
NarineKfacebook-github-bot
authored andcommitted
Support TCAV on cuda (#725)
Summary: Addresses the issues: #721 #719 #720 Pull Request resolved: #725 Reviewed By: bilalsal Differential Revision: D30356015 Pulled By: NarineK fbshipit-source-id: 010a5263bdfc33e8c4d3f9de523d9d3ba3969f49
1 parent d201bc4 commit 3a697e3

File tree

4 files changed

+23
-9
lines changed

4 files changed

+23
-9
lines changed

captum/_utils/models/linear_model/train.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -343,9 +343,15 @@ def sklearn_train_linear_model(
343343
if hasattr(sklearn_model, "classes_")
344344
else None
345345
)
346+
347+
# extract model device
348+
device = model.device if hasattr(model, "device") else "cpu"
349+
346350
num_outputs = sklearn_model.coef_.shape[0] if sklearn_model.coef_.ndim > 1 else 1
347-
weight_values = torch.FloatTensor(sklearn_model.coef_) # type: ignore
348-
bias_values = torch.FloatTensor([sklearn_model.intercept_]) # type: ignore
351+
weight_values = torch.FloatTensor(sklearn_model.coef_).to(device) # type: ignore
352+
bias_values = torch.FloatTensor([sklearn_model.intercept_]).to( # type: ignore
353+
device # type: ignore
354+
) # type: ignore
349355
model._construct_model_params(
350356
norm_type=None,
351357
weight_values=weight_values.view(num_outputs, -1),

captum/concept/_core/concept.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
#!/usr/bin/env python3
22

3-
from typing import Callable, Iterable, Union
3+
from typing import Callable, Union
44

5+
import torch
56
from torch.nn import Module
67

78

@@ -18,15 +19,19 @@ class Concept:
1819
adjectives and words that convey happiness.
1920
"""
2021

21-
def __init__(self, id: int, name: str, data_iter: Union[None, Iterable]) -> None:
22+
def __init__(
23+
self, id: int, name: str, data_iter: Union[None, torch.utils.data.DataLoader]
24+
) -> None:
2225

2326
r"""
2427
Args:
2528
id (int): The unique identifier of the concept.
2629
name (str): A unique name of the concept.
27-
data_iter (iter): A pytorch Dataloader object. Combines a dataset
30+
data_iter (DataLoader): A pytorch DataLoader object that combines a dataset
2831
and a sampler, and provides an iterable over a given
29-
dataset. For more information, please check:
32+
dataset. Only the input batches are provided by `data_iter`.
33+
Concept ids can be used as labels if necessary.
34+
For more information, please check:
3035
https://pytorch.org/docs/stable/data.html
3136
3237
Example::

captum/concept/_core/tcav.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -640,7 +640,9 @@ def _tcav_sub_computation(
640640

641641
# sort classes / concepts in the order specified in concept_keys
642642
concept_ord = {concept.id: ci for ci, concept in enumerate(concepts)}
643-
new_ord = torch.tensor([concept_ord[cls] for cls in cls_set])
643+
new_ord = torch.tensor(
644+
[concept_ord[cls] for cls in cls_set], device=tcav_score.device
645+
)
644646

645647
# sort based on classes
646648
scores[concepts_key][layer] = {

captum/concept/_utils/classifier.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,16 +170,17 @@ def train_and_eval(
170170
inputs.append(input)
171171
labels.append(label)
172172

173+
device = "cpu" if input is None else input.device
173174
x_train, x_test, y_train, y_test = _train_test_split(
174175
torch.cat(inputs), torch.cat(labels), test_split=test_split_ratio
175176
)
176-
177+
self.lm.device = device
177178
self.lm.fit(DataLoader(TensorDataset(x_train, y_train)))
178179

179180
predict = self.lm(x_test)
180181

181182
predict = self.lm.classes()[torch.argmax(predict, dim=1)]
182-
score = predict.long() == y_test.long()
183+
score = predict.long() == y_test.long().cpu()
183184

184185
accs = score.float().mean()
185186

0 commit comments

Comments
 (0)