Skip to content

Implement PartitionSHAP attribution method #1626

@sisird864

Description

@sisird864

Motivation:

Captum currently provides KernelShap and ShapleyValueSampling, but lacks a
fast hierarchical Shapley estimator.
PartitionSHAP (Lundberg & Erion 2021) reduces the number of model
evaluations from O(M · d) to O(M · log d) by
recursively partitioning the feature set. Users working with large‐parameter
LLMs and computer-vision models will benefit from 10-100× faster attributions
without sacrificing accuracy.


Proposed API (mirrors Captum style):

from captum.attr._core.perturbation_attribution import PerturbationAttribution
from typing import Callable, Any, Optional

class PartitionShap(PerturbationAttribution):
    def __init__(
        self,
        forward_func: Callable,
        baselines: Optional[Any] = None,
        max_evals: int = 512,
        perturbations_per_eval: int = 1,
        partition_agg: str = "mean",   # mean | sum | max
        cluster_method: str = "greedy" # future-proof for alt heuristics
    ): ...

Returns → tensor of Shapley values with same shape as inputs.

Design outline:

  1. Hierarchy builder – greedy divisive split on absolute gradients

    • (torch.topk for speed, stop at single-feature leaves).
  2. Conditional expectation – Monte-Carlo permutations per internal node,

    • batched using perturbations_per_eval.
  3. Reuse Captum helpers

    • _tensorize_baselines, _construct_perturbations
    • infer_target_device_dtype for GPU / dtype safety.
  4. Early-exit guard – stop once eval_count ≥ max_evals, emit UserWarning.

Acceptance criteria / test plan

  • Correctness:
    PartitionSHAP ≈ KernelSHAP on a 5-feature linear model
    (torch.testing.assert_close(..., atol=1e-2, rtol=1e-2)).

  • Speed:
    On bert-base-uncased CLS token, n_samples=512 runs ≥20× faster than
    KernelSHAP n_samples=2048 (CPU benchmark).

  • Coverage ≥99 % for new file (pytest --cov=...).

  • Docs:
    docs/source/algorithms.rst section
    Tutorial notebook: GPT-2 toxicity example (<3 min CPU runtime).

Task checklist:

  • Core implementation (captum/attr/_core/partition_shap.py)
  • Unit tests (tests/attr/test_partition_shap.py)
  • Benchmarks script (optional, skipped on CI)
  • API & math docs
  • Example notebook
  • Changelog entry

References:

G. Lundberg & G. Erion, “Partition SHAP: Explaining complex models via
recursive feature partitioning,” 2021.
(arXiv 2105.14814)

Assignees:

@sisird864 — claiming this feature implementation. Maintainers: please add the appropriate labels (triaged, feature, help-wanted) and let me know of any design concerns before I begin.
Thanks! Looking forward to contributing.
@sisird864

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions