Skip to content

Commit 081c81f

Browse files
committed
fix mypy
1 parent f226306 commit 081c81f

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

captum/attr/_core/feature_ablation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ def attribute(
319319
total_attrib = [
320320
# attribute w.r.t each output element
321321
torch.zeros(
322-
(n_outputs, *input.shape[1:]),
322+
(n_outputs,) + input.shape[1:],
323323
dtype=attrib_type,
324324
device=input.device,
325325
)
@@ -330,7 +330,7 @@ def attribute(
330330
if self.use_weights:
331331
weights = [
332332
torch.zeros(
333-
(n_outputs, *input.shape[1:]), device=input.device
333+
(n_outputs,) + input.shape[1:], device=input.device
334334
).float()
335335
for input in inputs
336336
]

0 commit comments

Comments
 (0)