-
Notifications
You must be signed in to change notification settings - Fork 541
Description
Motivation:
Captum currently provides KernelShap
and ShapleyValueSampling
, but lacks a
fast hierarchical Shapley estimator.
PartitionSHAP (Lundberg & Erion 2021) reduces the number of model
evaluations from O(M · d) to O(M · log d) by
recursively partitioning the feature set. Users working with large‐parameter
LLMs and computer-vision models will benefit from 10-100× faster attributions
without sacrificing accuracy.
Proposed API (mirrors Captum style):
from captum.attr._core.perturbation_attribution import PerturbationAttribution
from typing import Callable, Any, Optional
class PartitionShap(PerturbationAttribution):
def __init__(
self,
forward_func: Callable,
baselines: Optional[Any] = None,
max_evals: int = 512,
perturbations_per_eval: int = 1,
partition_agg: str = "mean", # mean | sum | max
cluster_method: str = "greedy" # future-proof for alt heuristics
): ...
Returns → tensor of Shapley values with same shape as inputs.
Design outline:
-
Hierarchy builder – greedy divisive split on absolute gradients
- (torch.topk for speed, stop at single-feature leaves).
-
Conditional expectation – Monte-Carlo permutations per internal node,
- batched using perturbations_per_eval.
-
Reuse Captum helpers
- _tensorize_baselines, _construct_perturbations
- infer_target_device_dtype for GPU / dtype safety.
-
Early-exit guard – stop once eval_count ≥ max_evals, emit UserWarning.
Acceptance criteria / test plan
-
Correctness:
PartitionSHAP ≈ KernelSHAP on a 5-feature linear model
(torch.testing.assert_close(..., atol=1e-2, rtol=1e-2)). -
Speed:
On bert-base-uncased CLS token, n_samples=512 runs ≥20× faster than
KernelSHAP n_samples=2048 (CPU benchmark). -
Coverage ≥99 % for new file (pytest --cov=...).
-
Docs:
docs/source/algorithms.rst section
Tutorial notebook: GPT-2 toxicity example (<3 min CPU runtime).
Task checklist:
- Core implementation (captum/attr/_core/partition_shap.py)
- Unit tests (tests/attr/test_partition_shap.py)
- Benchmarks script (optional, skipped on CI)
- API & math docs
- Example notebook
- Changelog entry
References:
G. Lundberg & G. Erion, “Partition SHAP: Explaining complex models via
recursive feature partitioning,” 2021.
(arXiv 2105.14814)
Assignees:
@sisird864 — claiming this feature implementation. Maintainers: please add the appropriate labels (triaged, feature, help-wanted) and let me know of any design concerns before I begin.
Thanks! Looking forward to contributing.
— @sisird864