Skip to content

Commit 4432d4b

Browse files
committed
Lime Type Fixes (#570)
Summary: Pull Request resolved: #570 This makes Lime work appropriately with int / long features; currently input only worked appropriately with float features. Reviewed By: bilalsal Differential Revision: D25693888 fbshipit-source-id: b96477f8c6805f554b324ffadbb00e971c12051f
1 parent 177bd3b commit 4432d4b

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

captum/attr/_core/lime.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -527,15 +527,15 @@ def default_from_interp_rep_transform(curr_sample, original_inputs, **kwargs):
527527
), "Must provide baselines to use default interpretable representation transfrom"
528528
feature_mask = kwargs["feature_mask"]
529529
if isinstance(feature_mask, Tensor):
530-
binary_mask = curr_sample[0][feature_mask]
530+
binary_mask = curr_sample[0][feature_mask].to(original_inputs.dtype)
531531
return binary_mask * original_inputs + (1 - binary_mask) * kwargs["baselines"]
532532
else:
533533
binary_mask = tuple(
534534
curr_sample[0][feature_mask[j]] for j in range(len(feature_mask))
535535
)
536536
return tuple(
537-
binary_mask[j] * original_inputs[j]
538-
+ (1 - binary_mask[j]) * kwargs["baselines"][j]
537+
binary_mask[j].to(original_inputs[j].dtype) * original_inputs[j]
538+
+ (1 - binary_mask[j].to(original_inputs[j].dtype)) * kwargs["baselines"][j]
539539
for j in range(len(feature_mask))
540540
)
541541

@@ -575,8 +575,8 @@ def get_exp_kernel_similarity_function(
575575
"""
576576

577577
def default_exp_kernel(original_inp, perturbed_inp, __, **kwargs):
578-
flattened_original_inp = _flatten_tensor_or_tuple(original_inp)
579-
flattened_perturbed_inp = _flatten_tensor_or_tuple(perturbed_inp)
578+
flattened_original_inp = _flatten_tensor_or_tuple(original_inp).float()
579+
flattened_perturbed_inp = _flatten_tensor_or_tuple(perturbed_inp).float()
580580
if distance_mode == "cosine":
581581
cos_sim = CosineSimilarity(dim=0)
582582
distance = 1 - cos_sim(flattened_original_inp, flattened_perturbed_inp)
@@ -599,7 +599,7 @@ def default_perturb_func(original_inp, **kwargs):
599599
device = original_inp[0].device
600600

601601
probs = torch.ones(1, kwargs["num_interp_features"]) * 0.5
602-
return torch.bernoulli(probs).to(device=device)
602+
return torch.bernoulli(probs).to(device=device).long()
603603

604604

605605
class Lime(LimeBase):
@@ -1130,7 +1130,10 @@ def _convert_output_shape(
11301130
is_inputs_tuple: bool,
11311131
) -> Union[Tensor, Tuple[Tensor, ...]]:
11321132
coefs = coefs.flatten()
1133-
attr = [torch.zeros_like(single_inp) for single_inp in formatted_inp]
1133+
attr = [
1134+
torch.zeros_like(single_inp, dtype=torch.float)
1135+
for single_inp in formatted_inp
1136+
]
11341137
for tensor_ind in range(len(formatted_inp)):
11351138
for single_feature in range(num_interp_features):
11361139
attr[tensor_ind] += (

0 commit comments

Comments
 (0)