diff --git a/README.md b/README.md index 363ea59..04d6db5 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,9 @@ + + Documentation +
@@ -39,7 +42,7 @@ build orthogonal layers, with a focus on convolutional layers . We noticed that significant role in the final performance : a more efficient implementation allows larger networks and more training steps within the same compute budget. So our implementation differs from original papers in order to -be faster, to consume less memory or be more flexible. +be faster, to consume less memory or be more flexible. Feel free to read the [documentation](https://thib-s.github.io/orthogonium/)! # 📃 What is included in this library ? diff --git a/docs/api/activations.md b/docs/api/activations.md new file mode 100644 index 0000000..e58bd7a --- /dev/null +++ b/docs/api/activations.md @@ -0,0 +1,5 @@ +::: orthogonium.layers.custom_activations + rendering: + show_root_toc_entry: True + selection: + inherited_members: True diff --git a/docs/api/losses.md b/docs/api/losses.md new file mode 100644 index 0000000..1b714be --- /dev/null +++ b/docs/api/losses.md @@ -0,0 +1,5 @@ +::: orthogonium.losses + rendering: + show_root_toc_entry: True + selection: + inherited_members: True diff --git a/mkdocs.yml b/mkdocs.yml index 4545550..6fdaf6e 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -7,6 +7,8 @@ nav: - convolutions: api/conv.md - linear layers: api/linear.md - reparametrizers: api/reparametrizers.md + - activations: api/activations.md + - losses: api/losses.md # - layers.conv.AOC module: api/aoc.md # - layers.conv.adaptiveSOC module: api/adaptiveSOC.md # - layers.conv.SLL module: api/sll.md diff --git a/orthogonium/layers/conv/AOC/fast_block_ortho_conv.py b/orthogonium/layers/conv/AOC/fast_block_ortho_conv.py index 277ba5b..10f3a83 100644 --- a/orthogonium/layers/conv/AOC/fast_block_ortho_conv.py +++ b/orthogonium/layers/conv/AOC/fast_block_ortho_conv.py @@ -156,7 +156,6 @@ def __init__( out_channels, kernel_size, groups, - contiguous_optimization=False, ): """This module is used to generate orthogonal kernels for the BCOP layer. It takes as input a matrix PQ of shape (groups, 2*kernel_size, c, c//2) and returns a kernel @@ -167,9 +166,6 @@ def __init__( out_channels (int): number of output channels kernel_size (int): size of the kernel groups (int): number of groups - contiguous_optimization (bool, optional): if True, the kernel will have twice the - number of channels. This is used to increase expressiveness, but at the price - of orthogonality (not Lipschitzness). Defaults to False. """ super(BCOPTrivializer, self).__init__() self.kernel_size = kernel_size @@ -178,9 +174,6 @@ def __init__( self.in_channels = in_channels self.min_channels = min(in_channels, out_channels) self.max_channels = max(in_channels, out_channels) - if contiguous_optimization: - self.max_channels *= 2 - self.contiguous_optimization = contiguous_optimization self.transpose = out_channels < in_channels self.num_kernels = 2 * kernel_size @@ -249,12 +242,6 @@ def forward(self, PQ): res = c11 for i in range(c22.shape[0]): # c22.shape[0] == 1 if k-1 is a power of two res = fast_matrix_conv(res, c22[i], self.groups) - # if contiguous optimization is enabled, we constructed a conv with twice the number - # of channels, we need to remove the extra channels - if self.contiguous_optimization: - res = res[ - : self.max_channels // 2, : self.min_channels // self.groups, :, : - ] # since it is less expensive to compute the transposed kernel when co < ci # we transpose the kernel if needed if self.transpose: @@ -288,7 +275,6 @@ def attach_bcop_weight( num_kernels = ( 2 * kernel_size ) # the number of projectors needed to create the kernel - contiguous_optimization = ortho_params.contiguous_optimization # register projectors matrices layer.register_parameter( weight_name, @@ -296,16 +282,8 @@ def attach_bcop_weight( torch.Tensor( groups, num_kernels, - ( - 2 * max_channels // groups - if contiguous_optimization - else max_channels // groups - ), - ( - max_channels // groups - if contiguous_optimization - else max_channels // (groups * 2) - ), + (max_channels // groups), + (max_channels // (groups * 2)), ), requires_grad=True, ), @@ -343,7 +321,6 @@ def attach_bcop_weight( out_channels, kernel_size, groups, - contiguous_optimization=contiguous_optimization, ), unsafe=True, ) diff --git a/orthogonium/layers/conv/AOC/ortho_conv.py b/orthogonium/layers/conv/AOC/ortho_conv.py index 5b4eb21..a997039 100644 --- a/orthogonium/layers/conv/AOC/ortho_conv.py +++ b/orthogonium/layers/conv/AOC/ortho_conv.py @@ -27,32 +27,34 @@ def AdaptiveOrthoConv2d( """ Factory function to create an orthogonal convolutional layer, selecting the appropriate class based on kernel size and stride. - **Key Features:** - - Enforces orthogonality, preserving gradient norms. - - Supports native striding, dilation, grouped convolutions, and flexible padding. - - **Behavior:** - - When kernel_size == stride, the layer is an `RKOConv2d`. - - When stride == 1, the layer is a `FastBlockConv2d`. - - Otherwise, the layer is a `BcopRkoConv2d`. - - **Arguments:** - - `in_channels` (int): Number of input channels. - - `out_channels` (int): Number of output channels. - - `kernel_size` (_size_2_t): Size of the convolution kernel. - - `stride` (_size_2_t, optional): Stride of the convolution. Default is 1. - - `padding` (str or _size_2_t, optional): Padding mode or size. Default is "same". - - `dilation` (_size_2_t, optional): Dilation rate. Default is 1. - - `groups` (int, optional): Number of blocked connections from input to output channels. Default is 1. - - `bias` (bool, optional): Whether to include a learnable bias. Default is True. - - `padding_mode` (str, optional): Padding mode. Default is "circular". - - `ortho_params` (OrthoParams, optional): Parameters to control orthogonality. Default is `OrthoParams()`. - - **Returns:** - - A configured instance of `nn.Conv2d` (one of `RKOConv2d`, `FastBlockConv2d`, or `BcopRkoConv2d`). - - **Raises:** - - `ValueError`: If kernel_size < stride, as orthogonality cannot be enforced. + Key Features: + ------------- + - Enforces orthogonality, preserving gradient norms. + - Supports native striding, dilation, grouped convolutions, and flexible padding. + + Behavior: + ------------- + - When kernel_size == stride, the layer is an `RKOConv2d`. + - When stride == 1, the layer is a `FastBlockConv2d`. + - Otherwise, the layer is a `BcopRkoConv2d`. + + Arguments: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + kernel_size (_size_2_t): Size of the convolution kernel. + stride (_size_2_t, optional): Stride of the convolution. Default is 1. + padding (str or _size_2_t, optional): Padding mode or size. Default is "same". + dilation (_size_2_t, optional): Dilation rate. Default is 1. + groups (int, optional): Number of blocked connections from input to output channels. Default is 1. + bias (bool, optional): Whether to include a learnable bias. Default is True. + padding_mode (str, optional): Padding mode. Default is "circular". + ortho_params (OrthoParams, optional): Parameters to control orthogonality. Default is `OrthoParams()`. + + Returns: + A configured instance of `nn.Conv2d` (one of `RKOConv2d`, `FastBlockConv2d`, or `BcopRkoConv2d`). + + Raises: + `ValueError`: If kernel_size < stride, as orthogonality cannot be enforced. """ if kernel_size < stride: @@ -95,30 +97,32 @@ def AdaptiveOrthoConvTranspose2d( """ Factory function to create an orthogonal convolutional transpose layer, adapting based on kernel size and stride. - **Key Features:** - - Ensures orthogonality in transpose convolutions for stable gradient propagation. - - Supports dilation, grouped operations, and efficient kernel construction. - - **Behavior:** - - When kernel_size == stride, the layer is an `RkoConvTranspose2d`. - - When stride == 1, the layer is a `FastBlockConvTranspose2D`. - - Otherwise, the layer is a `BcopRkoConvTranspose2d`. - - **Arguments:** - - `in_channels` (int): Number of input channels. - - `out_channels` (int): Number of output channels. - - `kernel_size` (_size_2_t): Size of the convolution kernel. - - `stride` (_size_2_t, optional): Stride of the transpose convolution. Default is 1. - - `padding` (_size_2_t, optional): Padding size. Default is 0. - - `output_padding` (_size_2_t, optional): Additional size for output. Default is 0. - - `groups` (int, optional): Number of groups. Default is 1. - - `bias` (bool, optional): Whether to include a learnable bias. Default is True. - - `dilation` (_size_2_t, optional): Dilation rate. Default is 1. - - `padding_mode` (str, optional): Padding mode. Default is "zeros". - - `ortho_params` (OrthoParams, optional): Parameters to control orthogonality. Default is `OrthoParams()`. - - **Returns:** - - A configured instance of `nn.ConvTranspose2d` (one of `RkoConvTranspose2d`, `FastBlockConvTranspose2D`, or `BcopRkoConvTranspose2d`). + Key Features: + ------------- + - Ensures orthogonality in transpose convolutions for stable gradient propagation. + - Supports dilation, grouped operations, and efficient kernel construction. + + Behavior: + --------- + - When kernel_size == stride, the layer is an `RkoConvTranspose2d`. + - When stride == 1, the layer is a `FastBlockConvTranspose2D`. + - Otherwise, the layer is a `BcopRkoConvTranspose2d`. + + Arguments: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + kernel_size (_size_2_t): Size of the convolution kernel. + stride (_size_2_t, optional): Stride of the transpose convolution. Default is 1. + padding (_size_2_t, optional): Padding size. Default is 0. + output_padding (_size_2_t, optional): Additional size for output. Default is 0. + groups (int, optional): Number of groups. Default is 1. + bias (bool, optional): Whether to include a learnable bias. Default is True. + dilation (_size_2_t, optional): Dilation rate. Default is 1. + padding_mode (str, optional): Padding mode. Default is "zeros". + ortho_params (OrthoParams, optional): Parameters to control orthogonality. Default is `OrthoParams()`. + + Returns: + A configured instance of `nn.ConvTranspose2d` (one of `RkoConvTranspose2d`, `FastBlockConvTranspose2D`, or `BcopRkoConvTranspose2d`). **Raises:** - `ValueError`: If kernel_size < stride, as orthogonality cannot be enforced. diff --git a/orthogonium/layers/conv/adaptiveSOC/fast_skew_ortho_conv.py b/orthogonium/layers/conv/adaptiveSOC/fast_skew_ortho_conv.py index 8ba1041..26685be 100644 --- a/orthogonium/layers/conv/adaptiveSOC/fast_skew_ortho_conv.py +++ b/orthogonium/layers/conv/adaptiveSOC/fast_skew_ortho_conv.py @@ -138,14 +138,17 @@ def attach_soc_weight( weight_name (str): name of the weight kernel_shape (tuple): shape of the kernel (out_channels, in_channels/groups, kernel_size, kernel_size) groups (int): number of groups - bjorck_params (BjorckParams, optional): parameters of the Bjorck orthogonalization. Defaults to BjorckParams(). + exp_params (ExpParams): parameters for the exponential algorithm. Returns: torch.Tensor: a handle to the attached weight """ out_channels, in_channels, kernel_size, k2 = kernel_shape in_channels *= groups # compute the real number of input channels - assert kernel_size == k2, "only square kernels are supported for the moment" + assert ( + kernel_size == k2 + ), "only square kernels are supported (to compute skew symmetric kernels)" + assert kernel_size % 2 == 1, "kernel size must be odd" max_channels = max(in_channels, out_channels) layer.register_parameter( weight_name, @@ -238,8 +241,6 @@ def __init__( raise ValueError( "kernel size must be smaller than stride. The set of orthonal convolutions is empty in this setting." ) - if (in_channels % groups != 0) and (out_channels % groups != 0): - ) self.padding = padding self.stride = stride self.kernel_size = kernel_size @@ -252,11 +253,6 @@ def __init__( groups, exp_params=exp_params, ) - if bias: - self.bias = nn.Parameter(torch.Tensor(out_channels)) - nn.init.zeros_(self.bias) - else: - self.register_parameter("bias", None) def singular_values(self): """Compute the singular values of the convolutional layer using the FFT+SVD method. @@ -341,8 +337,6 @@ def __init__( raise ValueError( "kernel size must be smaller than stride. The set of orthonal convolutions is empty in this setting." ) - if (in_channels % groups != 0) and (out_channels % groups != 0): - ) if ((self.max_channels // groups) < 2) and (kernel_size != stride): raise ValueError("inner conv must have at least 2 channels") if out_channels * (stride**2) < in_channels: @@ -367,12 +361,6 @@ def __init__( exp_params=exp_params, ) - if bias: - self.bias = nn.Parameter(torch.Tensor(out_channels)) - nn.init.zeros_(self.bias) - else: - self.register_parameter("bias", None) - def singular_values(self): if self.padding_mode != "circular": print( @@ -387,8 +375,8 @@ def singular_values(self): self.groups, self.in_channels // self.groups, self.out_channels // self.groups, - self.kernel_size, - self.kernel_size, + self.weight.shape[-2], + self.weight.shape[-1], ) .numpy(), self._input_shape, diff --git a/orthogonium/layers/conv/adaptiveSOC/ortho_conv.py b/orthogonium/layers/conv/adaptiveSOC/ortho_conv.py index 92e471a..d00eb10 100644 --- a/orthogonium/layers/conv/adaptiveSOC/ortho_conv.py +++ b/orthogonium/layers/conv/adaptiveSOC/ortho_conv.py @@ -40,7 +40,7 @@ def AdaptiveSOCConv2d( ) if kernel_size == stride: convclass = RKOConv2d - elif (stride == 1) or (in_channels >= out_channels): + elif stride == 1: convclass = FastSOC else: convclass = SOCRkoConv2d diff --git a/orthogonium/layers/conv/adaptiveSOC/soc_x_rko_conv.py b/orthogonium/layers/conv/adaptiveSOC/soc_x_rko_conv.py index d2cdb34..0ac5894 100644 --- a/orthogonium/layers/conv/adaptiveSOC/soc_x_rko_conv.py +++ b/orthogonium/layers/conv/adaptiveSOC/soc_x_rko_conv.py @@ -67,8 +67,6 @@ def __init__( raise ValueError( "kernel size must be smaller than stride. The set of orthonal convolutions is empty in this setting." ) - if (in_channels % groups != 0) and (out_channels % groups != 0): - ) if ((self.max_channels // groups) < 2) and (kernel_size != stride): raise ValueError("inner conv must have at least 2 channels") self.padding = padding @@ -77,14 +75,26 @@ def __init__( self.groups = groups self.intermediate_channels = max(in_channels, out_channels // stride**2) del self.weight + int_kernel_size = kernel_size - (stride - 1) + if int_kernel_size % 2 == 0: + if int_kernel_size <= 2: + int_kernel_size += 1 + else: + int_kernel_size -= 1 + # warn user that kernel size changed + warnings.warn( + f"kernel size changed from {kernel_size} to {int_kernel_size} " + f"as even kernel size is not supported for SOC.", + RuntimeWarning, + ) attach_soc_weight( self, "weight_1", ( self.intermediate_channels, in_channels // groups, - kernel_size - (stride - 1), - kernel_size - (stride - 1), + int_kernel_size, + int_kernel_size, ), groups, exp_params=exp_params, @@ -98,12 +108,6 @@ def __init__( ortho_params=ortho_params, ) - if bias: - self.bias = nn.Parameter(torch.Tensor(out_channels)) - nn.init.zeros_(self.bias) - else: - self.register_parameter("bias", None) - @property def weight(self): if self.training: @@ -160,7 +164,9 @@ def singular_values(self): ) sv_min = sv_min * svs_2.min() sv_max = sv_max * svs_2.max() - stable_rank = 0.5 * stable_rank + 0.5 * ((np.mean(svs_2) ** 2) / (svs_2.max() ** 2)) + stable_rank = 0.5 * stable_rank + 0.5 * ( + (np.mean(svs_2) ** 2) / (svs_2.max() ** 2) + ) return sv_min, sv_max, stable_rank def forward(self, X): @@ -218,8 +224,6 @@ def __init__( raise ValueError( "kernel size must be smaller than stride. The set of orthonal convolutions is empty in this setting." ) - if (in_channels % groups != 0) and (out_channels % groups != 0): - ) if ((self.max_channels // groups) < 2) and (kernel_size != stride): raise ValueError("inner conv must have at least 2 channels") self.padding = padding @@ -239,14 +243,26 @@ def __init__( # RuntimeWarning, # ) del self.weight + int_kernel_size = kernel_size - (stride - 1) + if int_kernel_size % 2 == 0: + if int_kernel_size <= 2: + int_kernel_size += 1 + else: + int_kernel_size -= 1 + # warn user that kernel size changed + warnings.warn( + f"kernel size changed from {kernel_size} to {int_kernel_size} " + f"as even kernel size is not supported for SOC.", + RuntimeWarning, + ) attach_soc_weight( self, "weight_1", ( self.intermediate_channels, out_channels // groups, - kernel_size - (stride - 1), - kernel_size - (stride - 1), + int_kernel_size, + int_kernel_size, ), groups, exp_params=exp_params, @@ -260,12 +276,6 @@ def __init__( ortho_params=ortho_params, ) - if bias: - self.bias = nn.Parameter(torch.Tensor(out_channels)) - nn.init.zeros_(self.bias) - else: - self.register_parameter("bias", None) - def singular_values(self): if self.padding_mode != "circular": print( @@ -280,8 +290,8 @@ def singular_values(self): self.groups, self.intermediate_channels // self.groups, self.out_channels // self.groups, - self.kernel_size, - self.kernel_size, + self.weight_1.shape[-2], + self.weight_1.shape[-1], ) .numpy(), self._input_shape, @@ -299,7 +309,9 @@ def singular_values(self): ) sv_min = sv_min * svs_2.min() sv_max = sv_max * svs_2.max() - stable_rank = 0.5 * stable_rank + 0.5 * ((np.mean(svs_2) ** 2) / (svs_2.max() ** 2)) + stable_rank = 0.5 * stable_rank + 0.5 * ( + (np.mean(svs_2) ** 2) / (svs_2.max() ** 2) + ) return sv_min, sv_max, stable_rank @property diff --git a/orthogonium/layers/custom_activations.py b/orthogonium/layers/custom_activations.py index 0204e1a..b2e19f8 100644 --- a/orthogonium/layers/custom_activations.py +++ b/orthogonium/layers/custom_activations.py @@ -10,6 +10,14 @@ class Abs(nn.Module): def __init__(self): + """ + Initializes an instance of the Abs class. + + This method is automatically called when a new object of the Abs class + is instantiated. It calls the initializer of its superclass to ensure + proper initialization of inherited class functionality, setting up + the required base structures or attributes. + """ super(Abs, self).__init__() def forward(self, z): @@ -18,6 +26,15 @@ def forward(self, z): class MaxMin(nn.Module): def __init__(self, axis=1): + """ + This class implements the MaxMin activation function. Which is a + pairwise activation function that returns the maximum and minimum (ordered) + of each pair of elements in the input tensor. + + Parameters + axis : int, default=1 the axis along which to apply the activation function. + + """ self.axis = axis super(MaxMin, self).__init__() @@ -29,6 +46,22 @@ def forward(self, z): class HouseHolder(nn.Module): def __init__(self, channels, axis=1): + """ + A activation that applies a parameterized transformation via Householder + reflection technique. It is initialized with the number of input channels, which must + be even, and an axis that determines the dimension along which operations are applied. + This is a corrected version of the original implementation from Singla et al. (2019), + which features a 1/sqrt(2) scaling factor to be 1-Lipschitz. + + Attributes: + theta (torch.nn.Parameter): Learnable parameter that determines the transformation + applied via Householder reflection. + axis (int): Dimension along which the operation is performed. + + Args: + channels (int): Total number of input channels. Must be an even number. + axis (int): Dimension along which the transformation is applied. Default is 1. + """ super(HouseHolder, self).__init__() assert (channels % 2) == 0 eff_channels = channels // 2 @@ -54,6 +87,38 @@ def forward(self, z): class HouseHolder_Order_2(nn.Module): def __init__(self, channels, axis=1): + """ + Represents a layer or module that performs operations using Householder + transformations of order 2, parameterized by angles corresponding to + each group of channels. This is a corrected version of the original + implementation from Singla et al. (2019), which features a 1/sqrt(2) + scaling factor to be 1-Lipschitz. + + Attributes: + num_groups (int): The number of groups, which is half the number + of channels provided as input. + + axis (int): The axis along which the computation is performed. + + theta0 (torch.nn.Parameter): A tensor parameter of shape `(num_groups,)` + representing the first set of angles (in radians) used in the + parameterization. + + theta1 (torch.nn.Parameter): A tensor parameter of shape `(num_groups,)` + representing the second set of angles (in radians) used in the + parameterization. + + theta2 (torch.nn.Parameter): A tensor parameter of shape `(num_groups,)` + representing the third set of angles (in radians) used in the + parameterization. + + Args: + channels (int): The total number of input channels. Must be an even + number, as it will be split into groups. + + axis (int, optional): Specifies the axis for computations. Defaults + to 1. + """ super(HouseHolder_Order_2, self).__init__() assert (channels % 2) == 0 self.num_groups = channels // 2 diff --git a/orthogonium/layers/linear/ortho_linear.py b/orthogonium/layers/linear/ortho_linear.py index 14c48d6..4ba192e 100644 --- a/orthogonium/layers/linear/ortho_linear.py +++ b/orthogonium/layers/linear/ortho_linear.py @@ -15,6 +15,36 @@ def __init__( bias: bool = True, ortho_params: OrthoParams = OrthoParams(), ): + """ + Initializes an orthogonal linear layer with customizable orthogonalization parameters. + + Attributes: + in_features : int + Number of input features. + out_features : int + Number of output features. + bias : bool + Whether to include a bias term in the layer. Default is True. + ortho_params : OrthoParams + Parameters for orthogonalization and spectral normalization. Default is the + default instance of OrthoParams. + + Parameters: + in_features : int + The size of each input sample. + out_features : int + The size of each output sample. + bias : bool + Indicates if the layer should include a learnable bias parameter. + ortho_params : OrthoParams + An object containing orthogonalization and normalization configurations. + + Notes + ----- + The layer is initialized with orthogonal weights using `torch.nn.init.orthogonal_`. + Weight parameters are further parametrized for both spectral normalization and + orthogonal constraints using the provided `OrthoParams` object. + """ super(OrthoLinear, self).__init__(in_features, out_features, bias=bias) torch.nn.init.orthogonal_(self.weight) parametrize.register_parametrization( @@ -42,7 +72,25 @@ def __init__( *args, **kwargs, ): - """LInear layer where each output unit is normalized to have Frobenius norm 1""" + """ + A custom PyTorch Linear layer that ensures weights are normalized to unit norm along a specified dimension. + + This class extends the torch.nn.Linear module and modifies the weight + matrix to maintain orthogonal initialization and unit norm + normalization during training. In this specific case, each output can be viewed as the result of a 1-Lipschitz + function. This means that the whole function in more than 1-Lipschitz but that each output taken independently + is 1-Lipschitz. + + Attributes: + weight: The learnable weight tensor with orthogonal initialization + and enforced unit norm parametrization. + + Args: + *args: Variable length positional arguments passed to the base + Linear class. + **kwargs: Variable length keyword arguments passed to the base + Linear class. + """ super(UnitNormLinear, self).__init__(*args, **kwargs) torch.nn.init.orthogonal_(self.weight) parametrize.register_parametrization( diff --git a/orthogonium/losses.py b/orthogonium/losses.py index d1ebce6..a9ff82b 100644 --- a/orthogonium/losses.py +++ b/orthogonium/losses.py @@ -8,6 +8,26 @@ def check_last_linear_layer_type(model): + """ + Determines the type of the last linear layer in a given model. + + This function inspects the architecture of the model and identifies the last + linear layer of specific types (nn.Linear, OrthoLinear, UnitNormLinear). It + then returns a string indicating the type of the last linear layer based on + its class. This allows to determine the parameter to use for computing the + VRA of a model's output. + + Args: + model: The model containing layers to be inspected. + + Returns: + str: A string indicating the type of the last linear layer. + The possible values are: + - "global" if the layer is of type OrthoLinear. + - "classwise" if the layer is of type UnitNormLinear. + - "unknown" if the layer is of any other type or if no + linear layer is found. + """ # Find the last linear layer in the model last_linear_layer = None layers = list(model.children()) @@ -102,6 +122,24 @@ def VRA( class LossXent(nn.Module): def __init__(self, n_classes, offset=2.12132, temperature=0.25): + """ + A custom loss function class for cross-entropy calculation. + + This class initializes a cross-entropy loss criterion along with additional + parameters, such as an offset and a temperature factor, to allow a finer control over + the accuracy/robustness tradeoff during training. + + Attributes: + criterion (nn.CrossEntropyLoss): The PyTorch cross-entropy loss criterion. + n_classes (int): The number of classes present in the dataset. + offset (float): An offset value for customizing the loss computation. + temperature (float): A temperature factor for scaling logits during loss calculation. + + Parameters: + n_classes (int): The number of classes in the dataset. + offset (float, optional): The offset value for loss computation. Default is 2.12132. + temperature (float, optional): The temperature scaling factor. Default is 0.25. + """ super(LossXent, self).__init__() self.criterion = nn.CrossEntropyLoss() self.n_classes = n_classes @@ -118,6 +156,15 @@ def __call__(self, outputs, labels): class CosineLoss(nn.Module): def __init__(self): + """ + A class that implements the Cosine Loss for measuring the cosine similarity + between predictions and targets. Designed for use in scenarios involving + angle-based loss calculations or similarity measurements. + + Attributes: + None + + """ super(CosineLoss, self).__init__() def forward(self, yp, yt): diff --git a/orthogonium/model_factory/models_factory.py b/orthogonium/model_factory/models_factory.py index f03a36f..5fd991f 100644 --- a/orthogonium/model_factory/models_factory.py +++ b/orthogonium/model_factory/models_factory.py @@ -188,7 +188,6 @@ def forward(self, x): # eps=1e-6, # bjorck_iters=6, # beta=0.5, -# contiguous_optimization=False, # ), # ), # act=ClassParam(MaxMin), diff --git a/orthogonium/reparametrizers.py b/orthogonium/reparametrizers.py index 29278ff..5cd4228 100644 --- a/orthogonium/reparametrizers.py +++ b/orthogonium/reparametrizers.py @@ -308,7 +308,9 @@ def __init__(self, weight_shape, niters=7): exponential, it produces an orthogonal matrix. This approach is particularly useful in contexts where smooth transitions between matrices are required. - Non-square matrices + Non-square matrices are padded to the largest dimension to ensure that the matrix can + be converted to a skew-symmetric matrix. The resulting matrix is cropped to the original + dimension. Args: weight_shape (tuple): The shape of the weight matrix. @@ -411,8 +413,6 @@ class OrthoParams: configured to use BatchedBjorckOrthogonalization with specific parameters. This callable can be provided either as a `functool.partial` or as a `orthogonium.ClassParam`. It will recieve the shape of the weight tensor as its argument. - contiguous_optimization (bool): Determines whether to perform - optimization ensuring contiguous operations. Default is False. """ # spectral_normalizer: Callable[Tuple[int, ...], nn.Module] = BatchedIdentity @@ -428,7 +428,6 @@ class OrthoParams: # BatchedCholeskyOrthogonalization, # BatchedQROrthogonalization, ) - contiguous_optimization: bool = False DEFAULT_ORTHO_PARAMS = OrthoParams() @@ -437,33 +436,27 @@ class OrthoParams: orthogonalizer=ClassParam( BatchedBjorckOrthogonalization, beta=0.5, niters=12, pass_through=True ), - contiguous_optimization=False, ) DEFAULT_TEST_ORTHO_PARAMS = OrthoParams( spectral_normalizer=ClassParam(BatchedPowerIteration, power_it_niter=4, eps=1e-4), # type: ignore orthogonalizer=ClassParam(BatchedBjorckOrthogonalization, beta=0.5, niters=25), # orthogonalizer=ClassParam(BatchedQROrthogonalization), # orthogonalizer=ClassParam(BatchedExponentialOrthogonalization, niters=12), # type: ignore - contiguous_optimization=False, ) EXP_ORTHO_PARAMS = OrthoParams( spectral_normalizer=ClassParam(BatchedPowerIteration, power_it_niter=3, eps=1e-6), # type: ignore orthogonalizer=ClassParam(BatchedExponentialOrthogonalization, niters=12), # type: ignore - contiguous_optimization=False, ) QR_ORTHO_PARAMS = OrthoParams( spectral_normalizer=ClassParam(BatchedPowerIteration, power_it_niter=3, eps=1e-3), # type: ignore orthogonalizer=ClassParam(BatchedQROrthogonalization), # type: ignore - contiguous_optimization=False, ) CHOLESKY_ORTHO_PARAMS = OrthoParams( spectral_normalizer=BatchedIdentity, # type: ignore orthogonalizer=ClassParam(BatchedCholeskyOrthogonalization), # type: ignore - contiguous_optimization=False, ) CHOLESKY_STABLE_ORTHO_PARAMS = OrthoParams( spectral_normalizer=BatchedIdentity, orthogonalizer=ClassParam(BatchedCholeskyOrthogonalization, stable=True), - contiguous_optimization=False, ) diff --git a/tests/test_orthogonality_conv_soc.py b/tests/test_orthogonality_conv_soc.py new file mode 100644 index 0000000..d0d553e --- /dev/null +++ b/tests/test_orthogonality_conv_soc.py @@ -0,0 +1,605 @@ +import numpy as np +import pytest +import torch +from orthogonium.layers.conv.adaptiveSOC import ( + AdaptiveSOCConv2d, + AdaptiveSOCConvTranspose2d, +) +from orthogonium.layers.conv.adaptiveSOC.soc_x_rko_conv import SOCRkoConv2d +from orthogonium.layers.conv.adaptiveSOC.fast_skew_ortho_conv import FastSOC + + +device = "cpu" # torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def _compute_sv_impulse_response_layer(layer, img_shape): + with torch.no_grad(): + layer = layer.to(device) + inputs = ( + torch.eye(img_shape[0] * img_shape[1] * img_shape[2]) + .view( + img_shape[0] * img_shape[1] * img_shape[2], + img_shape[0], + img_shape[1], + img_shape[2], + ) + .to(device) + ) + outputs = layer(inputs) + try: + svs = torch.linalg.svdvals(outputs.view(outputs.shape[0], -1)) + svs = svs.cpu() + return svs.min(), svs.max(), svs.mean() / svs.max() + except np.linalg.LinAlgError: + print("SVD failed returning only largest singular value") + return torch.norm(outputs.view(outputs.shape[0], -1), p=2).max(), 0, 0 + + +def check_orthogonal_layer( + orthoconv, + groups, + input_channels, + kernel_size, + output_channels, + expected_kernel_shape, + tol=5e-4, + sigma_min_requirement=0.95, +): + imsize = 8 + # Test backpropagation and weight update + try: + orthoconv = orthoconv.to(device) + orthoconv.train() + opt = torch.optim.SGD(orthoconv.parameters(), lr=0.001) + for i in range(25): + opt.zero_grad() + inp = torch.randn(1, input_channels, imsize, imsize).to(device) + output = orthoconv(inp) + loss = -output.mean() + loss.backward() + opt.step() + orthoconv.eval() # so i mpulse response test checks the eval mode + except Exception as e: + pytest.fail(f"Backpropagation or weight update failed with: {e}") + # # check that orthoconv.weight has the correct shape + # if orthoconv.weight.data.shape != expected_kernel_shape: + # pytest.fail( + # f"BCOP weight has incorrect shape: {orthoconv.weight.shape} vs {(output_channels, input_channels // groups, kernel_size, kernel_size)}" + # ) + # Test singular_values function + try: + sigma_min, sigma_max, stable_rank = orthoconv.singular_values() # try: + except np.linalg.LinAlgError as e: + pytest.skip(f"SVD failed with: {e}") + sigma_min_ir, sigma_max_ir, stable_rank_ir = _compute_sv_impulse_response_layer( + orthoconv, (input_channels, imsize, imsize) + ) + print(f"input_shape = {inp.shape}, output_shape = {output.shape}") + print( + f"({input_channels}->{output_channels}, g{groups}, k{kernel_size}), " + f"sigma_max:" + f" {sigma_max:.3f}/{sigma_max_ir:.3f}, " + f"sigma_min:" + f" {sigma_min:.3f}/{sigma_min_ir:.3f}, " + f"stable_rank: {stable_rank:.3f}/{stable_rank_ir:.3f}" + ) + # check that the singular values are close to 1 + assert sigma_max_ir < (1 + tol), "sigma_max is not less than 1" + # assert (sigma_min_ir < (1 + tol)) and ( + # sigma_min_ir > sigma_min_requirement + # ), "sigma_min is not close to 1" + # assert abs(stable_rank_ir - 1) < tol, "stable_rank is not close to 1" + # check that the singular values are close to the impulse response values + # assert ( + # sigma_max > sigma_max_ir - 1e-2 + # ), f"sigma_max must be greater to its IR value (1%): {sigma_max} vs {sigma_max_ir}" + assert ( + abs(sigma_max - sigma_max_ir) < tol + ), f"sigma_max is not close to its IR value: {sigma_max} vs {sigma_max_ir}" + # assert ( + # abs(sigma_min - sigma_min_ir) < tol + # ), f"sigma_min is not close to its IR value: {sigma_min} vs {sigma_min_ir}" + # assert ( + # abs(stable_rank - stable_rank_ir) < tol + # ), f"stable_rank is not close to its IR value: {stable_rank} vs {stable_rank_ir}" + + +@pytest.mark.parametrize("kernel_size", [1, 3]) +@pytest.mark.parametrize("input_channels", [8, 16]) +@pytest.mark.parametrize("output_channels", [8, 16]) +@pytest.mark.parametrize("stride", [1]) +@pytest.mark.parametrize("groups", [1, 2, 4]) +def test_standard_configs(kernel_size, input_channels, output_channels, stride, groups): + """ + test combinations of kernel size, input channels, output channels, stride and groups + """ + # Test instantiation + try: + orthoconv = AdaptiveSOCConv2d( + kernel_size=kernel_size, + in_channels=input_channels, + out_channels=output_channels, + stride=stride, + groups=groups, + bias=False, + padding=(kernel_size // 2, kernel_size // 2), + padding_mode="circular", + ) + except Exception as e: + if kernel_size < stride: + # we expect this configuration to raise a RuntimeError + # pytest.skip(f"BCOP instantiation failed with: {e}") + return + else: + pytest.fail(f"BCOP instantiation failed with: {e}") + check_orthogonal_layer( + orthoconv, + groups, + input_channels, + kernel_size, + output_channels, + ( + output_channels, + input_channels // groups, + kernel_size, + kernel_size, + ), + tol=5e-2, + sigma_min_requirement=0.0, + ) + + +# +# @pytest.mark.parametrize("kernel_size", [3]) +# @pytest.mark.parametrize("input_channels", [8, 16]) +# @pytest.mark.parametrize( +# "output_channels", [8, 16] +# ) # dilated convolutions are not supported for output_channels < input_channels +# @pytest.mark.parametrize("stride", [1]) +# @pytest.mark.parametrize("groups", [1, 2, 4]) +# def test_dilation(kernel_size, input_channels, output_channels, stride, groups): +# """ +# test combinations of kernel size, input channels, output channels, stride and groups +# """ +# # Test instantiation +# try: +# orthoconv = AdaptiveSOCConv2d( +# kernel_size=kernel_size, +# in_channels=input_channels, +# out_channels=output_channels, +# stride=stride, +# dilation=2, +# groups=groups, +# bias=False, +# padding="same", +# padding_mode="circular", +# ) +# except Exception as e: +# if kernel_size < stride: +# # we expect this configuration to raise a RuntimeError +# # pytest.skip(f"BCOP instantiation failed with: {e}") +# return +# else: +# pytest.fail(f"BCOP instantiation failed with: {e}") +# check_orthogonal_layer( +# orthoconv, +# groups, +# input_channels, +# kernel_size, +# output_channels, +# ( +# output_channels, +# input_channels // groups, +# kernel_size, +# kernel_size, +# ), +# ) +# +# +# @pytest.mark.parametrize("kernel_size", [2, 4]) +# @pytest.mark.parametrize("input_channels", [8, 16]) +# @pytest.mark.parametrize( +# "output_channels", [8, 16] +# ) # dilated+strided convolutions are not supported for output_channels < input_channels +# @pytest.mark.parametrize("stride", [2]) +# @pytest.mark.parametrize("dilation", [2, 3]) +# @pytest.mark.parametrize("groups", [1, 2, 4]) +# def test_dilation_strided( +# kernel_size, input_channels, output_channels, stride, dilation, groups +# ): +# """ +# test combinations of kernel size, input channels, output channels, stride and groups +# """ +# # Test instantiation +# try: +# orthoconv = AdaptiveSOCConv2d( +# kernel_size=kernel_size, +# in_channels=input_channels, +# out_channels=output_channels, +# stride=stride, +# dilation=dilation, +# groups=groups, +# bias=False, +# padding=( +# int(np.ceil((dilation * (kernel_size - 1) + 1 - stride) / 2)), +# int(np.ceil((dilation * (kernel_size - 1) + 1 - stride) / 2)), +# ), +# padding_mode="circular", +# ) +# except Exception as e: +# if (output_channels >= input_channels) and ( +# ((dilation % stride) == 0) and (stride > 1) +# ): +# # we expect this configuration to raise a ValueError +# # pytest.skip(f"BCOP instantiation failed with: {e}") +# return +# if (kernel_size == stride) and (((dilation % stride) == 0) and (stride > 1)): +# return +# else: +# pytest.fail(f"BCOP instantiation failed with: {e}") +# check_orthogonal_layer( +# orthoconv, +# groups, +# input_channels, +# kernel_size, +# output_channels, +# ( +# output_channels, +# input_channels // groups, +# kernel_size, +# kernel_size, +# ), +# ) + + +@pytest.mark.parametrize("kernel_size", [4]) +@pytest.mark.parametrize("input_channels", [2, 4, 16]) +@pytest.mark.parametrize("output_channels", [2, 4, 16]) +@pytest.mark.parametrize("stride", [2]) +@pytest.mark.parametrize("groups", [1]) +def test_strided(kernel_size, input_channels, output_channels, stride, groups): + """ + a more extensive testing when striding is enabled. + A larger range of cin and cout is used to track errors when cin < cout / stride**2 + ( ie you reduce spatial dimensions but you increase the channel dimensions so + that you actually increase overall dimension. + """ + # Test instantiation + try: + orthoconv = AdaptiveSOCConv2d( + kernel_size=kernel_size, + in_channels=input_channels, + out_channels=output_channels, + stride=stride, + groups=groups, + bias=False, + padding=((kernel_size - 1) // 2, (kernel_size - 1) // 2), + padding_mode="circular", + ) + except Exception as e: + if kernel_size < stride: + # we expect this configuration to raise a RuntimeError + # pytest.skip(f"BCOP instantiation failed with: {e}") + return + else: + pytest.fail(f"BCOP instantiation failed with: {e}") + check_orthogonal_layer( + orthoconv, + groups, + input_channels, + kernel_size, + output_channels, + ( + output_channels, + input_channels // groups, + kernel_size, + kernel_size, + ), + tol=5e-2, + sigma_min_requirement=0.0, + ) + + +# @pytest.mark.parametrize("kernel_size", [2, 4]) +# @pytest.mark.parametrize("input_channels", [8, 16]) +# @pytest.mark.parametrize("output_channels", [8, 16]) +# @pytest.mark.parametrize("stride", [1]) +# @pytest.mark.parametrize("groups", [1, 2, 4]) +# def test_even_kernels(kernel_size, input_channels, output_channels, stride, groups): +# """ +# test specific to even kernel size +# """ +# # Test instantiation +# try: +# orthoconv = AdaptiveSOCConv2d( +# kernel_size=kernel_size, +# in_channels=input_channels, +# out_channels=output_channels, +# stride=stride, +# groups=groups, +# bias=False, +# padding="same", +# padding_mode="circular", +# ) +# except Exception as e: +# if kernel_size < stride: +# # we expect this configuration to raise a RuntimeError +# # pytest.skip(f"BCOP instantiation failed with: {e}") +# return +# else: +# pytest.fail(f"BCOP instantiation failed with: {e}") +# check_orthogonal_layer( +# orthoconv, +# groups, +# input_channels, +# kernel_size, +# output_channels, +# ( +# output_channels, +# input_channels // groups, +# kernel_size, +# kernel_size, +# ), +# ) + + +# @pytest.mark.parametrize("kernel_size", [1, 2]) +# @pytest.mark.parametrize("input_channels", [4, 8, 32]) +# @pytest.mark.parametrize("output_channels", [4, 8, 32]) +# @pytest.mark.parametrize("groups", [1, 2]) +# def test_rko(kernel_size, input_channels, output_channels, groups): +# """ +# test case where stride == kernel size +# """ +# # Test instantiation +# try: +# rkoconv = AdaptiveSOCConv2d( +# kernel_size=kernel_size, +# in_channels=input_channels, +# out_channels=output_channels, +# stride=kernel_size, +# groups=groups, +# bias=False, +# padding=(0, 0), +# padding_mode="zeros", +# ) +# except Exception as e: +# pytest.fail(f"BCOP instantiation failed with: {e}") +# check_orthogonal_layer( +# rkoconv, +# groups, +# input_channels, +# kernel_size, +# output_channels, +# ( +# output_channels, +# input_channels // groups, +# kernel_size, +# kernel_size, +# ), +# ) + + +@pytest.mark.parametrize("kernel_size", [1, 3]) +@pytest.mark.parametrize("input_channels", [1, 2]) +@pytest.mark.parametrize("output_channels", [1, 2]) +@pytest.mark.parametrize("stride", [1]) +@pytest.mark.parametrize("groups", [1]) +def test_depthwise(kernel_size, input_channels, output_channels, stride, groups): + """ + test combinations of kernel size, input channels, output channels, stride and groups + """ + # Test instantiation + try: + orthoconv = AdaptiveSOCConv2d( + kernel_size=kernel_size, + in_channels=input_channels, + out_channels=output_channels, + stride=stride, + groups=groups, + bias=False, + padding=(kernel_size // 2, kernel_size // 2), + padding_mode="circular", + ) + except Exception as e: + if kernel_size < stride: + # we expect this configuration to raise a RuntimeError + # pytest.skip(f"BCOP instantiation failed with: {e}") + return + else: + pytest.fail(f"BCOP instantiation failed with: {e}") + check_orthogonal_layer( + orthoconv, + groups, + input_channels, + kernel_size, + output_channels, + ( + output_channels, + input_channels // groups, + kernel_size, + kernel_size, + ), + tol=5e-2, + sigma_min_requirement=0.0, + ) + + +# def test_invalid_kernel_smaller_than_stride(): +# """ +# A test to ensure that kernel_size < stride raises an expected ValueError +# """ +# with pytest.raises(ValueError, match=r"kernel size must be smaller than stride"): +# AdaptiveSOCConv2d( +# in_channels=8, +# out_channels=4, +# kernel_size=2, +# stride=3, # Invalid: kernel_size < stride +# groups=1, +# padding=0, +# ) +# with pytest.raises(ValueError, match=r"kernel size must be smaller than stride"): +# SOCRkoConv2d( +# in_channels=8, +# out_channels=4, +# kernel_size=2, +# stride=3, # Invalid: kernel_size < stride +# groups=1, +# padding=0, +# ) +# with pytest.raises(ValueError, match=r"kernel size must be smaller than stride"): +# FastSOC( +# in_channels=8, +# out_channels=4, +# kernel_size=2, +# stride=3, # Invalid: kernel_size < stride +# groups=1, +# padding=0, +# ) +# +# +# def test_invalid_dilation_with_stride(): +# """ +# A test to ensure dilation > 1 while stride > 1 raises an expected ValueError +# """ +# with pytest.raises( +# ValueError, +# match=r"dilation must be 1 when stride is not 1", +# ): +# AdaptiveSOCConv2d( +# in_channels=8, +# out_channels=16, +# kernel_size=3, +# stride=2, +# dilation=2, # Invalid: dilation > 1 while stride > 1 +# groups=1, +# padding=0, +# ) +# with pytest.raises( +# ValueError, +# match=r"dilation must be 1 when stride is not 1", +# ): +# SOCRkoConv2d( +# in_channels=8, +# out_channels=16, +# kernel_size=3, +# stride=2, +# dilation=2, # Invalid: dilation > 1 while stride > 1 +# groups=1, +# padding=0, +# ) +# with pytest.raises( +# ValueError, +# match=r"dilation must be 1 when stride is not 1", +# ): +# FastSOC( +# in_channels=8, +# out_channels=16, +# kernel_size=3, +# stride=2, +# dilation=2, # Invalid: dilation > 1 while stride > 1 +# groups=1, +# padding=0, +# ) + + +@pytest.mark.parametrize("kernel_size", [1, 3]) +@pytest.mark.parametrize("input_channels", [4, 8]) +@pytest.mark.parametrize("output_channels", [4, 8]) +@pytest.mark.parametrize("stride", [1]) +@pytest.mark.parametrize("groups", [1, 2]) +def test_convtranspose(kernel_size, input_channels, output_channels, stride, groups): + # Test instantiation + padding = (0, 0) + padding_mode = "zeros" + try: + + orthoconvtranspose = AdaptiveSOCConvTranspose2d( + kernel_size=kernel_size, + in_channels=input_channels, + out_channels=output_channels, + stride=stride, + groups=groups, + bias=False, + padding=padding, + padding_mode=padding_mode, + ) + except Exception as e: + if kernel_size < stride: + # we expect this configuration to raise a RuntimeError + # pytest.skip(f"BCOP instantiation failed with: {e}") + return + else: + pytest.fail(f"BCOP instantiation failed with: {e}") + if ( + kernel_size > 1 + and kernel_size != stride + and output_channels * (stride**2) < input_channels + ): + pytest.skip("this case is not handled yet") + check_orthogonal_layer( + orthoconvtranspose, + groups, + input_channels, + kernel_size, + output_channels, + ( + input_channels, + output_channels // groups, + kernel_size, + kernel_size, + ), + tol=5e-2, + sigma_min_requirement=0.0, + ) + + +@pytest.mark.parametrize("kernel_size", [2, 4]) +@pytest.mark.parametrize("input_channels", [4, 8]) +@pytest.mark.parametrize("output_channels", [4, 8]) +@pytest.mark.parametrize("stride", [2]) +@pytest.mark.parametrize("groups", [1, 2]) +def test_convtranspose(kernel_size, input_channels, output_channels, stride, groups): + # Test instantiation + padding = (0, 0) + padding_mode = "zeros" + try: + + orthoconvtranspose = AdaptiveSOCConvTranspose2d( + kernel_size=kernel_size, + in_channels=input_channels, + out_channels=output_channels, + stride=stride, + groups=groups, + bias=False, + padding=padding, + padding_mode=padding_mode, + ) + except Exception as e: + if kernel_size < stride: + # we expect this configuration to raise a RuntimeError + # pytest.skip(f"BCOP instantiation failed with: {e}") + return + else: + pytest.fail(f"BCOP instantiation failed with: {e}") + if ( + kernel_size > 1 + and kernel_size != stride + and output_channels * (stride**2) < input_channels + ): + pytest.skip("this case is not handled yet") + check_orthogonal_layer( + orthoconvtranspose, + groups, + input_channels, + kernel_size, + output_channels, + ( + input_channels, + output_channels // groups, + kernel_size, + kernel_size, + ), + tol=5e-2, + sigma_min_requirement=0.0, + )