@@ -527,15 +527,15 @@ def default_from_interp_rep_transform(curr_sample, original_inputs, **kwargs):
527
527
), "Must provide baselines to use default interpretable representation transfrom"
528
528
feature_mask = kwargs ["feature_mask" ]
529
529
if isinstance (feature_mask , Tensor ):
530
- binary_mask = curr_sample [0 ][feature_mask ]
530
+ binary_mask = curr_sample [0 ][feature_mask ]. to ( original_inputs . dtype )
531
531
return binary_mask * original_inputs + (1 - binary_mask ) * kwargs ["baselines" ]
532
532
else :
533
533
binary_mask = tuple (
534
534
curr_sample [0 ][feature_mask [j ]] for j in range (len (feature_mask ))
535
535
)
536
536
return tuple (
537
- binary_mask [j ] * original_inputs [j ]
538
- + (1 - binary_mask [j ]) * kwargs ["baselines" ][j ]
537
+ binary_mask [j ]. to ( original_inputs [ j ]. dtype ) * original_inputs [j ]
538
+ + (1 - binary_mask [j ]. to ( original_inputs [ j ]. dtype ) ) * kwargs ["baselines" ][j ]
539
539
for j in range (len (feature_mask ))
540
540
)
541
541
@@ -575,8 +575,8 @@ def get_exp_kernel_similarity_function(
575
575
"""
576
576
577
577
def default_exp_kernel (original_inp , perturbed_inp , __ , ** kwargs ):
578
- flattened_original_inp = _flatten_tensor_or_tuple (original_inp )
579
- flattened_perturbed_inp = _flatten_tensor_or_tuple (perturbed_inp )
578
+ flattened_original_inp = _flatten_tensor_or_tuple (original_inp ). float ()
579
+ flattened_perturbed_inp = _flatten_tensor_or_tuple (perturbed_inp ). float ()
580
580
if distance_mode == "cosine" :
581
581
cos_sim = CosineSimilarity (dim = 0 )
582
582
distance = 1 - cos_sim (flattened_original_inp , flattened_perturbed_inp )
@@ -599,7 +599,7 @@ def default_perturb_func(original_inp, **kwargs):
599
599
device = original_inp [0 ].device
600
600
601
601
probs = torch .ones (1 , kwargs ["num_interp_features" ]) * 0.5
602
- return torch .bernoulli (probs ).to (device = device )
602
+ return torch .bernoulli (probs ).to (device = device ). long ()
603
603
604
604
605
605
class Lime (LimeBase ):
@@ -1130,7 +1130,10 @@ def _convert_output_shape(
1130
1130
is_inputs_tuple : bool ,
1131
1131
) -> Union [Tensor , Tuple [Tensor , ...]]:
1132
1132
coefs = coefs .flatten ()
1133
- attr = [torch .zeros_like (single_inp ) for single_inp in formatted_inp ]
1133
+ attr = [
1134
+ torch .zeros_like (single_inp , dtype = torch .float )
1135
+ for single_inp in formatted_inp
1136
+ ]
1134
1137
for tensor_ind in range (len (formatted_inp )):
1135
1138
for single_feature in range (num_interp_features ):
1136
1139
attr [tensor_ind ] += (
0 commit comments