Skip to content

Support different reg_reduction in Captum STG #1090

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion captum/module/binary_concrete_stochastic_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(
lower_bound: float = -0.1,
upper_bound: float = 1.1,
eps: float = 1e-8,
reg_reduction: str = "sum",
):
"""
Args:
Expand Down Expand Up @@ -93,8 +94,18 @@ def __init__(
eps (float): term to improve numerical stability in binary concerete
sampling
Default: 1e-8

reg_reduction (str, optional): the reduction to apply to
the regularization: 'none'|'mean'|'sum'. 'none': no reduction will be
applied and it will be the same as the return of get_active_probs,
'mean': the sum of the gates non-zero probabilities will be divided by
the number of gates, 'sum': the gates non-zero probabilities will
be summed.
Default: 'sum'
"""
super().__init__(n_gates, mask=mask, reg_weight=reg_weight)
super().__init__(
n_gates, mask=mask, reg_weight=reg_weight, reg_reduction=reg_reduction
)

# avoid changing the tensor's variable name
# when the module is used after compilation,
Expand Down
12 changes: 11 additions & 1 deletion captum/module/gaussian_stochastic_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(
mask: Optional[Tensor] = None,
reg_weight: Optional[float] = 1.0,
std: Optional[float] = 0.5,
reg_reduction: str = "sum",
):
"""
Args:
Expand All @@ -58,8 +59,17 @@ def __init__(
std (Optional[float]): standard deviation that will be fixed throughout.
Default: 0.5 (by paper reference)

reg_reduction (str, optional): the reduction to apply to
the regularization: 'none'|'mean'|'sum'. 'none': no reduction will be
applied and it will be the same as the return of get_active_probs,
'mean': the sum of the gates non-zero probabilities will be divided by
the number of gates, 'sum': the gates non-zero probabilities will
be summed.
Default: 'sum'
"""
super().__init__(n_gates, mask=mask, reg_weight=reg_weight)
super().__init__(
n_gates, mask=mask, reg_weight=reg_weight, reg_reduction=reg_reduction
)

mu = torch.empty(n_gates)
nn.init.normal_(mu, mean=0.5, std=0.01)
Expand Down
29 changes: 27 additions & 2 deletions captum/module/stochastic_gates_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ class StochasticGatesBase(Module, ABC):
"""

def __init__(
self, n_gates: int, mask: Optional[Tensor] = None, reg_weight: float = 1.0
self,
n_gates: int,
mask: Optional[Tensor] = None,
reg_weight: float = 1.0,
reg_reduction: str = "sum",
):
"""
Args:
Expand All @@ -46,6 +50,14 @@ def __init__(

reg_weight (Optional[float]): rescaling weight for L0 regularization term.
Default: 1.0

reg_reduction (str, optional): the reduction to apply to
the regularization: 'none'|'mean'|'sum'. 'none': no reduction will be
applied and it will be the same as the return of get_active_probs,
'mean': the sum of the gates non-zero probabilities will be divided by
the number of gates, 'sum': the gates non-zero probabilities will
be summed.
Default: 'sum'
"""
super().__init__()

Expand All @@ -57,6 +69,12 @@ def __init__(
" should correspond to a gate"
)

valid_reg_reduction = ["none", "mean", "sum"]
assert (
reg_reduction in valid_reg_reduction
), f"reg_reduction must be one of [none, mean, sum], received: {reg_reduction}"
self.reg_reduction = reg_reduction

self.n_gates = n_gates
self.register_buffer(
"mask", mask.detach().clone() if mask is not None else None
Expand Down Expand Up @@ -106,7 +124,14 @@ def forward(self, input_tensor: Tensor) -> Tuple[Tensor, Tensor]:
gated_input = input_tensor * gate_values

prob_density = self._get_gate_active_probs()
l0_reg = self.reg_weight * prob_density.mean()
if self.reg_reduction == "sum":
l0_reg = prob_density.sum()
elif self.reg_reduction == "mean":
l0_reg = prob_density.mean()
else:
l0_reg = prob_density

l0_reg *= self.reg_weight

return gated_input, l0_reg

Expand Down
32 changes: 28 additions & 4 deletions tests/module/test_binary_concrete_stochastic_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_bcstg_1d_input(self) -> None:
).to(self.testing_device)

gated_input, reg = bcstg(input_tensor)
expected_reg = 0.8316
expected_reg = 2.4947

if self.testing_device == "cpu":
expected_gated_input = [[0.0000, 0.0212, 0.1892], [0.1839, 0.3753, 0.4937]]
Expand All @@ -42,6 +42,30 @@ def test_bcstg_1d_input(self) -> None:
assertTensorAlmostEqual(self, gated_input, expected_gated_input, mode="max")
assertTensorAlmostEqual(self, reg, expected_reg)

def test_bcstg_1d_input_with_reg_reduction(self) -> None:

dim = 3
mean_bcstg = BinaryConcreteStochasticGates(dim, reg_reduction="mean").to(
self.testing_device
)
none_bcstg = BinaryConcreteStochasticGates(dim, reg_reduction="none").to(
self.testing_device
)
input_tensor = torch.tensor(
[
[0.0, 0.1, 0.2],
[0.3, 0.4, 0.5],
]
).to(self.testing_device)

mean_gated_input, mean_reg = mean_bcstg(input_tensor)
none_gated_input, none_reg = none_bcstg(input_tensor)
expected_mean_reg = 0.8316
expected_none_reg = torch.tensor([0.8321, 0.8310, 0.8325])

assertTensorAlmostEqual(self, mean_reg, expected_mean_reg)
assertTensorAlmostEqual(self, none_reg, expected_none_reg)

def test_bcstg_1d_input_with_n_gates_error(self) -> None:

dim = 3
Expand Down Expand Up @@ -85,7 +109,7 @@ def test_bcstg_1d_input_with_mask(self) -> None:
).to(self.testing_device)

gated_input, reg = bcstg(input_tensor)
expected_reg = 0.8321
expected_reg = 1.6643

if self.testing_device == "cpu":
expected_gated_input = [[0.0000, 0.0000, 0.1679], [0.0000, 0.0000, 0.2223]]
Expand Down Expand Up @@ -118,7 +142,7 @@ def test_bcstg_2d_input(self) -> None:

gated_input, reg = bcstg(input_tensor)

expected_reg = 0.8317
expected_reg = 4.9903
if self.testing_device == "cpu":
expected_gated_input = [
[[0.0000, 0.0990], [0.0261, 0.2431], [0.0551, 0.3863]],
Expand Down Expand Up @@ -179,7 +203,7 @@ def test_bcstg_2d_input_with_mask(self) -> None:
).to(self.testing_device)

gated_input, reg = bcstg(input_tensor)
expected_reg = 0.8316
expected_reg = 2.4947

if self.testing_device == "cpu":
expected_gated_input = [
Expand Down
33 changes: 29 additions & 4 deletions tests/module/test_gaussian_stochastic_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def test_gstg_1d_input(self) -> None:

dim = 3
gstg = GaussianStochasticGates(dim).to(self.testing_device)

input_tensor = torch.tensor(
[
[0.0, 0.1, 0.2],
Expand All @@ -33,7 +34,7 @@ def test_gstg_1d_input(self) -> None:
).to(self.testing_device)

gated_input, reg = gstg(input_tensor)
expected_reg = 0.8404
expected_reg = 2.5213

if self.testing_device == "cpu":
expected_gated_input = [[0.0000, 0.0198, 0.1483], [0.1848, 0.3402, 0.1782]]
Expand All @@ -43,6 +44,30 @@ def test_gstg_1d_input(self) -> None:
assertTensorAlmostEqual(self, gated_input, expected_gated_input, mode="max")
assertTensorAlmostEqual(self, reg, expected_reg)

def test_gstg_1d_input_with_reg_reduction(self) -> None:
dim = 3
mean_gstg = GaussianStochasticGates(dim, reg_reduction="mean").to(
self.testing_device
)
none_gstg = GaussianStochasticGates(dim, reg_reduction="none").to(
self.testing_device
)

input_tensor = torch.tensor(
[
[0.0, 0.1, 0.2],
[0.3, 0.4, 0.5],
]
).to(self.testing_device)

_, mean_reg = mean_gstg(input_tensor)
_, none_reg = none_gstg(input_tensor)
expected_mean_reg = 0.8404
expected_none_reg = torch.tensor([0.8424, 0.8384, 0.8438])

assertTensorAlmostEqual(self, mean_reg, expected_mean_reg)
assertTensorAlmostEqual(self, none_reg, expected_none_reg)

def test_gstg_1d_input_with_n_gates_error(self) -> None:

dim = 3
Expand All @@ -65,7 +90,7 @@ def test_gstg_1d_input_with_mask(self) -> None:
).to(self.testing_device)

gated_input, reg = gstg(input_tensor)
expected_reg = 0.8424
expected_reg = 1.6849

if self.testing_device == "cpu":
expected_gated_input = [[0.0000, 0.0000, 0.1225], [0.0583, 0.0777, 0.3779]]
Expand Down Expand Up @@ -111,7 +136,7 @@ def test_gstg_2d_input(self) -> None:
).to(self.testing_device)

gated_input, reg = gstg(input_tensor)
expected_reg = 0.8410
expected_reg = 5.0458

if self.testing_device == "cpu":
expected_gated_input = [
Expand Down Expand Up @@ -173,7 +198,7 @@ def test_gstg_2d_input_with_mask(self) -> None:
).to(self.testing_device)

gated_input, reg = gstg(input_tensor)
expected_reg = 0.8404
expected_reg = 2.5213

if self.testing_device == "cpu":
expected_gated_input = [
Expand Down