We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent f226306 commit 081c81fCopy full SHA for 081c81f
captum/attr/_core/feature_ablation.py
@@ -319,7 +319,7 @@ def attribute(
319
total_attrib = [
320
# attribute w.r.t each output element
321
torch.zeros(
322
- (n_outputs, *input.shape[1:]),
+ (n_outputs,) + input.shape[1:],
323
dtype=attrib_type,
324
device=input.device,
325
)
@@ -330,7 +330,7 @@ def attribute(
330
if self.use_weights:
331
weights = [
332
333
- (n_outputs, *input.shape[1:]), device=input.device
+ (n_outputs,) + input.shape[1:], device=input.device
334
).float()
335
for input in inputs
336
]
0 commit comments