diff --git a/captum/module/binary_concrete_stochastic_gates.py b/captum/module/binary_concrete_stochastic_gates.py index b45ca58e55..d6fa318d87 100644 --- a/captum/module/binary_concrete_stochastic_gates.py +++ b/captum/module/binary_concrete_stochastic_gates.py @@ -60,6 +60,7 @@ def __init__( lower_bound: float = -0.1, upper_bound: float = 1.1, eps: float = 1e-8, + reg_reduction: str = "sum", ): """ Args: @@ -93,8 +94,18 @@ def __init__( eps (float): term to improve numerical stability in binary concerete sampling Default: 1e-8 + + reg_reduction (str, optional): the reduction to apply to + the regularization: 'none'|'mean'|'sum'. 'none': no reduction will be + applied and it will be the same as the return of get_active_probs, + 'mean': the sum of the gates non-zero probabilities will be divided by + the number of gates, 'sum': the gates non-zero probabilities will + be summed. + Default: 'sum' """ - super().__init__(n_gates, mask=mask, reg_weight=reg_weight) + super().__init__( + n_gates, mask=mask, reg_weight=reg_weight, reg_reduction=reg_reduction + ) # avoid changing the tensor's variable name # when the module is used after compilation, diff --git a/captum/module/gaussian_stochastic_gates.py b/captum/module/gaussian_stochastic_gates.py index ebaa692c32..b10f837dc1 100644 --- a/captum/module/gaussian_stochastic_gates.py +++ b/captum/module/gaussian_stochastic_gates.py @@ -38,6 +38,7 @@ def __init__( mask: Optional[Tensor] = None, reg_weight: Optional[float] = 1.0, std: Optional[float] = 0.5, + reg_reduction: str = "sum", ): """ Args: @@ -58,8 +59,17 @@ def __init__( std (Optional[float]): standard deviation that will be fixed throughout. Default: 0.5 (by paper reference) + reg_reduction (str, optional): the reduction to apply to + the regularization: 'none'|'mean'|'sum'. 'none': no reduction will be + applied and it will be the same as the return of get_active_probs, + 'mean': the sum of the gates non-zero probabilities will be divided by + the number of gates, 'sum': the gates non-zero probabilities will + be summed. + Default: 'sum' """ - super().__init__(n_gates, mask=mask, reg_weight=reg_weight) + super().__init__( + n_gates, mask=mask, reg_weight=reg_weight, reg_reduction=reg_reduction + ) mu = torch.empty(n_gates) nn.init.normal_(mu, mean=0.5, std=0.01) diff --git a/captum/module/stochastic_gates_base.py b/captum/module/stochastic_gates_base.py index c10d32d596..75eebb2d65 100644 --- a/captum/module/stochastic_gates_base.py +++ b/captum/module/stochastic_gates_base.py @@ -29,7 +29,11 @@ class StochasticGatesBase(Module, ABC): """ def __init__( - self, n_gates: int, mask: Optional[Tensor] = None, reg_weight: float = 1.0 + self, + n_gates: int, + mask: Optional[Tensor] = None, + reg_weight: float = 1.0, + reg_reduction: str = "sum", ): """ Args: @@ -46,6 +50,14 @@ def __init__( reg_weight (Optional[float]): rescaling weight for L0 regularization term. Default: 1.0 + + reg_reduction (str, optional): the reduction to apply to + the regularization: 'none'|'mean'|'sum'. 'none': no reduction will be + applied and it will be the same as the return of get_active_probs, + 'mean': the sum of the gates non-zero probabilities will be divided by + the number of gates, 'sum': the gates non-zero probabilities will + be summed. + Default: 'sum' """ super().__init__() @@ -57,6 +69,12 @@ def __init__( " should correspond to a gate" ) + valid_reg_reduction = ["none", "mean", "sum"] + assert ( + reg_reduction in valid_reg_reduction + ), f"reg_reduction must be one of [none, mean, sum], received: {reg_reduction}" + self.reg_reduction = reg_reduction + self.n_gates = n_gates self.register_buffer( "mask", mask.detach().clone() if mask is not None else None @@ -106,7 +124,14 @@ def forward(self, input_tensor: Tensor) -> Tuple[Tensor, Tensor]: gated_input = input_tensor * gate_values prob_density = self._get_gate_active_probs() - l0_reg = self.reg_weight * prob_density.mean() + if self.reg_reduction == "sum": + l0_reg = prob_density.sum() + elif self.reg_reduction == "mean": + l0_reg = prob_density.mean() + else: + l0_reg = prob_density + + l0_reg *= self.reg_weight return gated_input, l0_reg diff --git a/tests/module/test_binary_concrete_stochastic_gates.py b/tests/module/test_binary_concrete_stochastic_gates.py index c910370350..25efbb26ad 100644 --- a/tests/module/test_binary_concrete_stochastic_gates.py +++ b/tests/module/test_binary_concrete_stochastic_gates.py @@ -32,7 +32,7 @@ def test_bcstg_1d_input(self) -> None: ).to(self.testing_device) gated_input, reg = bcstg(input_tensor) - expected_reg = 0.8316 + expected_reg = 2.4947 if self.testing_device == "cpu": expected_gated_input = [[0.0000, 0.0212, 0.1892], [0.1839, 0.3753, 0.4937]] @@ -42,6 +42,30 @@ def test_bcstg_1d_input(self) -> None: assertTensorAlmostEqual(self, gated_input, expected_gated_input, mode="max") assertTensorAlmostEqual(self, reg, expected_reg) + def test_bcstg_1d_input_with_reg_reduction(self) -> None: + + dim = 3 + mean_bcstg = BinaryConcreteStochasticGates(dim, reg_reduction="mean").to( + self.testing_device + ) + none_bcstg = BinaryConcreteStochasticGates(dim, reg_reduction="none").to( + self.testing_device + ) + input_tensor = torch.tensor( + [ + [0.0, 0.1, 0.2], + [0.3, 0.4, 0.5], + ] + ).to(self.testing_device) + + mean_gated_input, mean_reg = mean_bcstg(input_tensor) + none_gated_input, none_reg = none_bcstg(input_tensor) + expected_mean_reg = 0.8316 + expected_none_reg = torch.tensor([0.8321, 0.8310, 0.8325]) + + assertTensorAlmostEqual(self, mean_reg, expected_mean_reg) + assertTensorAlmostEqual(self, none_reg, expected_none_reg) + def test_bcstg_1d_input_with_n_gates_error(self) -> None: dim = 3 @@ -85,7 +109,7 @@ def test_bcstg_1d_input_with_mask(self) -> None: ).to(self.testing_device) gated_input, reg = bcstg(input_tensor) - expected_reg = 0.8321 + expected_reg = 1.6643 if self.testing_device == "cpu": expected_gated_input = [[0.0000, 0.0000, 0.1679], [0.0000, 0.0000, 0.2223]] @@ -118,7 +142,7 @@ def test_bcstg_2d_input(self) -> None: gated_input, reg = bcstg(input_tensor) - expected_reg = 0.8317 + expected_reg = 4.9903 if self.testing_device == "cpu": expected_gated_input = [ [[0.0000, 0.0990], [0.0261, 0.2431], [0.0551, 0.3863]], @@ -179,7 +203,7 @@ def test_bcstg_2d_input_with_mask(self) -> None: ).to(self.testing_device) gated_input, reg = bcstg(input_tensor) - expected_reg = 0.8316 + expected_reg = 2.4947 if self.testing_device == "cpu": expected_gated_input = [ diff --git a/tests/module/test_gaussian_stochastic_gates.py b/tests/module/test_gaussian_stochastic_gates.py index 06baaa8947..03df56c51f 100644 --- a/tests/module/test_gaussian_stochastic_gates.py +++ b/tests/module/test_gaussian_stochastic_gates.py @@ -25,6 +25,7 @@ def test_gstg_1d_input(self) -> None: dim = 3 gstg = GaussianStochasticGates(dim).to(self.testing_device) + input_tensor = torch.tensor( [ [0.0, 0.1, 0.2], @@ -33,7 +34,7 @@ def test_gstg_1d_input(self) -> None: ).to(self.testing_device) gated_input, reg = gstg(input_tensor) - expected_reg = 0.8404 + expected_reg = 2.5213 if self.testing_device == "cpu": expected_gated_input = [[0.0000, 0.0198, 0.1483], [0.1848, 0.3402, 0.1782]] @@ -43,6 +44,30 @@ def test_gstg_1d_input(self) -> None: assertTensorAlmostEqual(self, gated_input, expected_gated_input, mode="max") assertTensorAlmostEqual(self, reg, expected_reg) + def test_gstg_1d_input_with_reg_reduction(self) -> None: + dim = 3 + mean_gstg = GaussianStochasticGates(dim, reg_reduction="mean").to( + self.testing_device + ) + none_gstg = GaussianStochasticGates(dim, reg_reduction="none").to( + self.testing_device + ) + + input_tensor = torch.tensor( + [ + [0.0, 0.1, 0.2], + [0.3, 0.4, 0.5], + ] + ).to(self.testing_device) + + _, mean_reg = mean_gstg(input_tensor) + _, none_reg = none_gstg(input_tensor) + expected_mean_reg = 0.8404 + expected_none_reg = torch.tensor([0.8424, 0.8384, 0.8438]) + + assertTensorAlmostEqual(self, mean_reg, expected_mean_reg) + assertTensorAlmostEqual(self, none_reg, expected_none_reg) + def test_gstg_1d_input_with_n_gates_error(self) -> None: dim = 3 @@ -65,7 +90,7 @@ def test_gstg_1d_input_with_mask(self) -> None: ).to(self.testing_device) gated_input, reg = gstg(input_tensor) - expected_reg = 0.8424 + expected_reg = 1.6849 if self.testing_device == "cpu": expected_gated_input = [[0.0000, 0.0000, 0.1225], [0.0583, 0.0777, 0.3779]] @@ -111,7 +136,7 @@ def test_gstg_2d_input(self) -> None: ).to(self.testing_device) gated_input, reg = gstg(input_tensor) - expected_reg = 0.8410 + expected_reg = 5.0458 if self.testing_device == "cpu": expected_gated_input = [ @@ -173,7 +198,7 @@ def test_gstg_2d_input_with_mask(self) -> None: ).to(self.testing_device) gated_input, reg = gstg(input_tensor) - expected_reg = 0.8404 + expected_reg = 2.5213 if self.testing_device == "cpu": expected_gated_input = [