From baed4c8466d6c0f3460e72f365351c1abf167d96 Mon Sep 17 00:00:00 2001 From: Thibaut Boissin Date: Tue, 21 Jan 2025 16:55:13 +0100 Subject: [PATCH 01/13] quick update of the cifar10 demo --- scripts/demo_cifar.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/scripts/demo_cifar.py b/scripts/demo_cifar.py index d0a1100..b5a7565 100644 --- a/scripts/demo_cifar.py +++ b/scripts/demo_cifar.py @@ -1,11 +1,12 @@ """ This script is a quick demonstration of the behaviour of the orthogonium library. -It trains a small CNN (4.8M params) on the CIFAR10 dataset. This script does not aim to reach state-of-the-art performance, but to provide -decent performance in a reasonable amount of time on affordable hardware (30min for the first setup on a RTX 3080). +It trains a small CNN (1.9M params) on the CIFAR10 dataset. This script does not aim to reach state-of-the-art performance, but to provide +decent performance in a reasonable amount of time on affordable hardware (20min for the first setup on a RTX 3080). The training can be adapted for 3 different settings: -- non robust training: the loss is the cross-entropy loss, and the model reaches 89% accuracy and 0% verified robust accuracy in 60 epochs. -- mildly robust training: the loss is the cross-entropy loss with a high margin, and the model reaches 80% accuracy and 32% VRA in 150 epochs. -- robust training: the loss is the cross-entropy loss with a high margin, and the model reaches 77% accuracy and 45% VRA in 150 epochs. +- non robust training: the loss is the cross-entropy loss, and the model reaches 88.5% accuracy and 0% verified robust accuracy in 30 epochs. +- mildly robust training: the loss is the cross-entropy loss with a high margin, and the model reaches 75% accuracy and 42% VRA in 150 epochs. +- robust training: the loss is the cross-entropy loss with a high margin, and the model reaches 71% accuracy and 47% VRA in 150 epochs. +you can increase the model size and number of epoch to reach performances closer to the state-of-the-art. """ import argparse @@ -58,8 +59,8 @@ "loss": ClassParam( LossXent, n_classes=10, - # sqrt(2) /0.1983 is the used factor in VRA computation - offset=(math.sqrt(2) / 0.1983) * (8 / 255), + offset=(math.sqrt(2) / 0.1983) + * (36 / 255), # aims for 36/255 verified robust accuracy temperature=0.25, ), "epochs": 150, @@ -69,7 +70,7 @@ LossXent, n_classes=10, offset=(math.sqrt(2) / 0.1983) - * (36 / 255), # aims for 36/255 verified robust accuracy + * (72 / 255), # aims for 36/255 verified robust accuracy temperature=0.25, ), "epochs": 150, @@ -179,6 +180,7 @@ def __init__(self, num_classes=10, loss=None): AdaptiveOrthoConv2d, bias=False, padding="same", + padding_mode="zeros", ), act=ClassParam(MaxMin), pool=ClassParam( From 623ed6727d8abe6c6d59ea2261e60941ce0f71a0 Mon Sep 17 00:00:00 2001 From: Thibaut Boissin Date: Tue, 21 Jan 2025 17:00:11 +0100 Subject: [PATCH 02/13] update of the GAN demo --- scripts/ortho_gan/ortho_GAN.py | 748 ++++++++++++--------------------- 1 file changed, 277 insertions(+), 471 deletions(-) diff --git a/scripts/ortho_gan/ortho_GAN.py b/scripts/ortho_gan/ortho_GAN.py index c1995fb..f133332 100644 --- a/scripts/ortho_gan/ortho_GAN.py +++ b/scripts/ortho_gan/ortho_GAN.py @@ -1,8 +1,21 @@ -from __future__ import print_function +# -*- coding: utf-8 -*- +# this is a slightly modified version of the pytorch gan tutorial +# available at https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html +# it has been modified to use the orthogonium layers instead of regular convolutions + + +# Commented out IPython magic to ensure Python compatibility. +# For tips on running notebooks in Google Colab, see +# https://pytorch.org/tutorials/beginner/colab +# %matplotlib inline -import random -import torch.backends.cudnn as cudnn +# %matplotlib inline +import argparse +import os +import random +from orthogonium.layers.custom_activations import MaxMin +import torch import torch.nn as nn import torch.nn.parallel import torch.optim as optim @@ -10,104 +23,111 @@ import torchvision.datasets as dset import torchvision.transforms as transforms import torchvision.utils as vutils -from torchinfo import summary - -from orthogonium.layers import AdaptiveOrthoConv2d -from orthogonium.layers import MaxMin -from orthogonium.layers.conv.AOC.bcop_x_rko_conv import BcopRkoConv2d -from orthogonium.layers.conv.AOC.bcop_x_rko_conv import BcopRkoConvTranspose2d - -# from orthogonium.layers import LayerCentering - -# from orthogonium.layers import OrthoLinear -# from orthogonium.layers import ScaledAvgPool2d -# from orthogonium.layers import SOC -# from orthogonium.layers import UnitNormLinear -# from orthogonium.layers.custom_activations import Abs -# from orthogonium.layers.custom_activations import HouseHolder -# from orthogonium.layers.custom_activations import HouseHolder_Order_2 +import numpy as np +import matplotlib.pyplot as plt +import matplotlib.animation as animation +from orthogonium.layers.conv.AOC import ( + AdaptiveOrthoConv2d, + AdaptiveOrthoConvTranspose2d, +) -cudnn.benchmark = True -torch.set_float32_matmul_precision("medium") +# from IPython.display import HTML -bs = 512 -# set manual seed to a constant get a consistent output -manualSeed = random.randint(1, 10000) +# Set random seed for reproducibility +manualSeed = 999 +# manualSeed = random.randint(1, 10000) # use if you want new results print("Random Seed: ", manualSeed) random.seed(manualSeed) torch.manual_seed(manualSeed) +# torch.use_deterministic_algorithms(True) # Needed for reproducible results + +# Root directory for dataset +dataroot = "./data/celeba" + +# Number of workers for dataloader +workers = 2 + +# Batch size during training +batch_size = 128 + +# Spatial size of training images. All images will be resized to this +# size using a transformer. +image_size = 64 + +# Number of channels in the training images. For color images this is 3 +nc = 3 + +# Size of z latent vector (i.e. size of generator input) +nz = 100 + +# Size of feature maps in generator +ngf = 64 + +# Size of feature maps in discriminator +ndf = 64 + +# Number of training epochs +num_epochs = 20 + +# Learning rate for optimizers +lr = 0.0001 -# loading the dataset +# Beta1 hyperparameter for Adam optimizers +beta1 = 0.5 + +# Number of GPUs available. Use 0 for CPU mode. +ngpu = 1 + +# We can use an image folder dataset the way we have it setup. +# Create the dataset dataset = dset.ImageFolder( - # dataset = dset.CIFAR10( - # root="/local_data/imagenet_cache/ILSVRC/Data/CLS-LOC/train/", - root="/datasets/shared_datasets/imagenette/imagenette2-160/train/", - # root="/mnt/deel/datasets/shared_datasets/imagenet/ILSVRC/Data/CLS-LOC/train/", - # root="./data", - # split="unlabeled", - # download=True, - # ) + root=dataroot, transform=transforms.Compose( [ - # transforms.Resize(64), - transforms.Resize(64), - transforms.RandomResizedCrop(64), - transforms.RandomHorizontalFlip(), + transforms.Resize(image_size), + transforms.CenterCrop(image_size), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), - # transforms.RandomResizedCrop(64, scale=(0.8, 1.0)), ] ), ) -nc = 3 - +# Create the dataloader dataloader = torch.utils.data.DataLoader( - dataset, - batch_size=bs, - shuffle=True, - num_workers=8, - prefetch_factor=4, - pin_memory=True, + dataset, batch_size=batch_size, shuffle=True, num_workers=workers ) -# checking the availability of cuda devices -device = "cuda" if torch.cuda.is_available() else "cpu" - -# number of gpu's available -ngpu = 1 -# input noise dimension -nz = 128 -# number of generator filters -ngf = 64 -# number of discriminator filters -ndf = 64 - -ks = 5 +# Decide which device we want to run on +device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu") + +# Plot some training images +real_batch = next(iter(dataloader)) +plt.figure(figsize=(8, 8)) +plt.axis("off") +plt.title("Training Images") +plt.imshow( + np.transpose( + vutils.make_grid( + real_batch[0].to(device)[:64], padding=2, normalize=True + ).cpu(), + (1, 2, 0), + ) +) +# plt.show() +plt.savefig("real_images.png") -# custom weights initialization called on netG and netD +# custom weights initialization called on ``netG`` and ``netD`` def weights_init(m): classname = m.__class__.__name__ - # if classname.find("Conv") != -1: - # m.weight.data.normal_(0.0, 0.02) - # elif classname.find("BatchNorm") != -1: - # m.weight.data.normal_(1.0, 0.02) - # m.bias.data.fill_(0) + if classname.find("Conv") != -1: + pass + # nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find("BatchNorm") != -1: + nn.init.normal_(m.weight.data, 1.0, 0.02) + nn.init.constant_(m.bias.data, 0) -class Residual(nn.Module): - def __init__(self, fn): - super().__init__() - self.add_module("fn", fn) - - def forward(self, x): - # split x - # x1, x2 = x.chunk(2, dim=1) - # apply function - out = self.fn(x) - # concat and return - # return torch.cat([x1, out], dim=1) - return (x + out) * 0.5 +# Generator Code class Generator(nn.Module): @@ -116,227 +136,58 @@ def __init__(self, ngpu): self.ngpu = ngpu self.main = nn.Sequential( # input is Z, going into a convolution - BcopRkoConvTranspose2d( - nz, - ngf * 16, - 4, - 2, - padding=(1, 1), - output_padding=0, - bias=False, + AdaptiveOrthoConvTranspose2d( + nz, ngf * 8, 4, 1, 0, bias=False, padding_mode="zeros" ), - nn.BatchNorm2d(ngf * 16), + MaxMin(), + nn.BatchNorm2d(ngf * 8), # nn.ReLU(True), - Residual( - nn.Sequential( - BcopRkoConvTranspose2d( - ngf * 16, - ngf * 16, - ks, - 1, - padding=(ks // 2, ks // 2), - output_padding=0, - bias=False, - ), - nn.BatchNorm2d(ngf * 16), - nn.ReLU(True), - ) - ), - # state size. (ngf*8) x 4 x 4 - BcopRkoConvTranspose2d( - ngf * 16, - ngf * 8, - 4, - 2, - padding=(1, 1), - output_padding=0, - bias=False, - ), - Residual( - nn.Sequential( - BcopRkoConvTranspose2d( - ngf * 8, - ngf * 8, - ks, - 1, - padding=(ks // 2, ks // 2), - output_padding=0, - bias=False, - ), - nn.BatchNorm2d(ngf * 8), - nn.ReLU(True), - BcopRkoConvTranspose2d( - ngf * 8, - ngf * 8, - ks, - 1, - padding=(ks // 2, ks // 2), - output_padding=0, - bias=False, - ), - nn.BatchNorm2d(ngf * 8), - nn.ReLU(True), - ) - ), - # state size. (ngf*8) x x 4 - BcopRkoConvTranspose2d( - ngf * 8, - ngf * 4, - 4, - 2, - padding=(1, 1), - output_padding=0, - bias=False, + # state size. ``(ngf*8) x 4 x 4`` + AdaptiveOrthoConvTranspose2d( + ngf * 8, ngf * 4, 4, 2, 1, bias=False, padding_mode="zeros" ), + MaxMin(), nn.BatchNorm2d(ngf * 4), # nn.ReLU(True), - Residual( - nn.Sequential( - BcopRkoConvTranspose2d( - ngf * 4, - ngf * 4, - ks, - 1, - padding=(ks // 2, ks // 2), - output_padding=0, - bias=False, - ), - nn.BatchNorm2d(ngf * 4), - nn.ReLU(True), - BcopRkoConvTranspose2d( - ngf * 4, - ngf * 4, - ks, - 1, - padding=(ks // 2, ks // 2), - output_padding=0, - bias=False, - ), - nn.BatchNorm2d(ngf * 4), - nn.ReLU(True), - ) - ), - # state size. (ngf*4) x 8 x 8 - BcopRkoConvTranspose2d( - ngf * 4, - ngf * 2, - 4, - 2, - padding=(1, 1), - output_padding=0, - bias=False, + # state size. ``(ngf*4) x 8 x 8`` + AdaptiveOrthoConvTranspose2d( + ngf * 4, ngf * 2, 4, 2, 1, bias=False, padding_mode="zeros" ), + MaxMin(), nn.BatchNorm2d(ngf * 2), # nn.ReLU(True), - Residual( - nn.Sequential( - BcopRkoConvTranspose2d( - ngf * 2, - ngf * 2, - ks, - 1, - padding=(ks // 2, ks // 2), - output_padding=0, - bias=False, - ), - nn.BatchNorm2d(ngf * 2), - nn.ReLU(True), - BcopRkoConvTranspose2d( - ngf * 2, - ngf * 2, - ks, - 1, - padding=(ks // 2, ks // 2), - output_padding=0, - bias=False, - ), - nn.BatchNorm2d(ngf * 2), - nn.ReLU(True), - ) - ), - # state size. (ngf*2) x 16 x 16 - BcopRkoConvTranspose2d( - ngf * 2, - ngf, - 4, - 2, - padding=(1, 1), - output_padding=0, - bias=False, + # state size. ``(ngf*2) x 16 x 16`` + AdaptiveOrthoConvTranspose2d( + ngf * 2, ngf, 4, 2, 1, bias=False, padding_mode="zeros" ), + MaxMin(), nn.BatchNorm2d(ngf), # nn.ReLU(True), - Residual( - nn.Sequential( - BcopRkoConvTranspose2d( - ngf, - ngf, - ks, - 1, - padding=(ks // 2, ks // 2), - output_padding=0, - bias=False, - ), - nn.BatchNorm2d(ngf), - nn.ReLU(True), - BcopRkoConvTranspose2d( - ngf, - ngf, - ks, - 1, - padding=(ks // 2, ks // 2), - output_padding=0, - bias=False, - ), - nn.BatchNorm2d(ngf), - nn.ReLU(True), - ) - ), - # state size. (ngf) x 32 x 32 - BcopRkoConvTranspose2d( - ngf, - ngf, - 4, - 2, - padding=(1, 1), - output_padding=0, - bias=False, - ), - Residual( - nn.Sequential( - BcopRkoConvTranspose2d( - ngf, - ngf, - ks, - 1, - padding=(ks // 2, ks // 2), - output_padding=0, - bias=False, - ), - nn.BatchNorm2d(ngf), - nn.ReLU(True), - ) - ), - BcopRkoConvTranspose2d( - ngf, nc, ks, 1, padding=(ks // 2, ks // 2), output_padding=0, bias=False + # state size. ``(ngf) x 32 x 32`` + AdaptiveOrthoConvTranspose2d( + ngf, nc, 4, 2, 1, bias=False, padding_mode="zeros" ), nn.Tanh(), - # state size. (nc) x 64 x 64 + # state size. ``(nc) x 64 x 64`` ) def forward(self, input): - if input.is_cuda and self.ngpu > 1: - output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) - else: - output = self.main(input) - return output + return self.main(input) +# Create the generator netG = Generator(ngpu).to(device) + +# Handle multi-GPU if desired +if (device.type == "cuda") and (ngpu > 1): + netG = nn.DataParallel(netG, list(range(ngpu))) + +# Apply the ``weights_init`` function to randomly initialize all weights +# to ``mean=0``, ``stdev=0.02``. netG.apply(weights_init) -# load weights to test the model -# netG.load_state_dict(torch.load('weights/netG_epoch_24.pth')) -summary(netG, (1024, nz, 1, 1)) + +# Print the model +print(netG) class Discriminator(nn.Module): @@ -344,239 +195,194 @@ def __init__(self, ngpu): super(Discriminator, self).__init__() self.ngpu = ngpu self.main = nn.Sequential( - # input is (nc) x 64 x 64 - BcopRkoConv2d( - 3, - ndf, - kernel_size=ks, - stride=1, - padding=(ks // 2, ks // 2), - padding_mode="circular", - bias=True, - ), - # LayerCentering(), - MaxMin(), - BcopRkoConv2d( - ndf, - ndf * 2, - kernel_size=ks, - stride=2, - padding=(ks // 2, ks // 2), - padding_mode="circular", - bias=True, - ), - # LayerCentering(), - MaxMin(), - BcopRkoConv2d( - ndf * 2, - ndf * 2, - kernel_size=ks, - stride=1, - padding=(ks // 2, ks // 2), - padding_mode="circular", - bias=True, - ), - # LayerCentering(), + # input is ``(nc) x 64 x 64`` + AdaptiveOrthoConv2d(nc, ndf, 4, 2, 1, bias=False, padding_mode="zeros"), + # nn.LeakyReLU(0.2, inplace=True), MaxMin(), - # state size. (ndf) x 32 x 32 - BcopRkoConv2d( - ndf * 2, - ndf * 4, - kernel_size=ks, - stride=2, - padding=(ks // 2, ks // 2), - padding_mode="circular", - bias=True, - ), - # LayerCentering(), - # MaxMin(), - BcopRkoConv2d( - ndf * 4, - ndf * 4, - kernel_size=ks, - stride=1, - padding=(ks // 2, ks // 2), - padding_mode="circular", - bias=False, - ), - # LayerCentering(), - MaxMin(), - # state size. (ndf*2) x 16 x 16 - BcopRkoConv2d( - ndf * 4, - ndf * 8, - kernel_size=ks, - stride=2, - padding=(ks // 2, ks // 2), - padding_mode="circular", - bias=True, - ), - # LayerCentering(), - MaxMin(), - BcopRkoConv2d( - ndf * 8, - ndf * 8, - kernel_size=ks, - stride=1, - padding=(ks // 2, ks // 2), - padding_mode="circular", - bias=True, - ), - # LayerCentering(), - MaxMin(), - # state size. (ndf*4) x 8 x 8 - BcopRkoConv2d( - ndf * 8, - ndf * 16, - kernel_size=ks, - stride=2, - padding=(ks // 2, ks // 2), - padding_mode="circular", - bias=True, + # state size. ``(ndf) x 32 x 32`` + AdaptiveOrthoConv2d( + ndf, ndf * 2, 4, 2, 1, bias=False, padding_mode="zeros" ), - # LayerCentering(), + # nn.BatchNorm2d(ndf * 2), + # nn.LeakyReLU(0.2, inplace=True), MaxMin(), - BcopRkoConv2d( - ndf * 16, - ndf * 16, - kernel_size=ks, - stride=1, - padding=(ks // 2, ks // 2), - padding_mode="circular", - bias=True, + # state size. ``(ndf*2) x 16 x 16`` + AdaptiveOrthoConv2d( + ndf * 2, ndf * 4, 4, 2, 1, bias=False, padding_mode="zeros" ), - # LayerCentering(), + # nn.BatchNorm2d(ndf * 4), + # nn.LeakyReLU(0.2, inplace=True), MaxMin(), - # state size. (ndf*8) x 4 x 4 + # state size. ``(ndf*4) x 8 x 8`` AdaptiveOrthoConv2d( - ndf * 16, - 1, - kernel_size=4, - stride=4, - padding=(0, 0), - # padding_mode="circular", - bias=False, + ndf * 4, ndf * 8, 4, 2, 1, bias=False, padding_mode="zeros" ), - nn.Flatten(), - # UnitNormLinear(4 * 4, 1), - # nn.Sigmoid(), + # nn.BatchNorm2d(ndf * 8), + # nn.LeakyReLU(0.2, inplace=True), + MaxMin(), + # state size. ``(ndf*8) x 4 x 4`` + AdaptiveOrthoConv2d(ndf * 8, 1, 4, 1, 0, bias=False, padding_mode="zeros"), + nn.Sigmoid(), ) def forward(self, input): - if input.is_cuda and self.ngpu > 1: - output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) - else: - output = self.main(input) - return output - # return output.view(-1, 1).squeeze(1) + return self.main(input) +# Create the Discriminator netD = Discriminator(ngpu).to(device) + +# Handle multi-GPU if desired +if (device.type == "cuda") and (ngpu > 1): + netD = nn.DataParallel(netD, list(range(ngpu))) + +# Apply the ``weights_init`` function to randomly initialize all weights +# like this: ``to mean=0, stdev=0.2``. netD.apply(weights_init) -# load weights to test the model -# netD.load_state_dict(torch.load('weights/netD_epoch_24.pth')) -# print(netD) -summary(netD, (1024, 3, 64, 64)) - - -# criterion = nn.BCELoss() -# use KR criterion -def criterion(output, label): - kr = torch.mean(output * label) - torch.mean(output * (1 - label)) - hinge = torch.mean(torch.nn.functional.relu(0.1 + output * (label * 2 - 1))) - return 0.5 * kr + 0.5 * hinge - # return hkr_loss(output, label, alpha=0, min_margin=0, true_values=(1, 0)) - - -# setup optimizer -optimizerD = optim.Adam(netD.parameters(), lr=0.0001, betas=(0.5, 0.999)) -# optimizerD = schedulefree.AdamWScheduleFree( -# netD.parameters(), lr=0.0005, weight_decay=0.0 -# ) -optimizerG = optim.Adam(netG.parameters(), lr=0.0001, betas=(0.5, 0.999)) -# optimizerG = schedulefree.AdamWScheduleFree( -# netG.parameters(), lr=0.0005, weight_decay=0.0 -# ) - -fixed_noise = torch.randn(128, nz, 1, 1, device=device) -real_label = 1 -fake_label = 0 - -niter = 100 -g_loss = [] -d_loss = [] -for epoch in range(niter): + +# Print the model +print(netD) + +# Initialize the ``BCELoss`` function +criterion = nn.BCELoss() + +# Create batch of latent vectors that we will use to visualize +# the progression of the generator +fixed_noise = torch.randn(64, nz, 1, 1, device=device) + +# Establish convention for real and fake labels during training +real_label = 1.0 +fake_label = 0.0 + +# Setup Adam optimizers for both G and D +optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999)) +optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999)) + +# Commented out IPython magic to ensure Python compatibility. +# Training Loop + +# Lists to keep track of progress +img_list = [] +G_losses = [] +D_losses = [] +iters = 0 + +print("Starting Training Loop...") +# For each epoch +for epoch in range(num_epochs): + # For each batch in the dataloader for i, data in enumerate(dataloader, 0): + ############################ # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) ########################### - # train with real + ## Train with all-real batch netD.zero_grad() + # Format batch real_cpu = data[0].to(device) - batch_size = real_cpu.size(0) - label = torch.full( - (batch_size,), real_label, dtype=real_cpu.dtype, device=device - ) - - output = netD( - real_cpu - ) # + torch.normal(0, 1e-3, size=real_cpu.size()).to(device)) + b_size = real_cpu.size(0) + label = torch.full((b_size,), real_label, dtype=torch.float, device=device) + # Forward pass real batch through D + output = netD(real_cpu).view(-1) + # Calculate loss on all-real batch errD_real = criterion(output, label) + # Calculate gradients for D in backward pass errD_real.backward() D_x = output.mean().item() - # train with fake - noise = torch.randn(batch_size, nz, 1, 1, device=device) + ## Train with all-fake batch + # Generate batch of latent vectors + noise = torch.randn(b_size, nz, 1, 1, device=device) + # Generate fake image batch with G fake = netG(noise) label.fill_(fake_label) - output = netD( - fake.detach() # + torch.normal(0, 1e-3, size=fake.size()).to(device) - ) + # Classify all fake batch with D + output = netD(fake.detach()).view(-1) + # Calculate D's loss on the all-fake batch errD_fake = criterion(output, label) + # Calculate the gradients for this batch, accumulated (summed) with previous gradients errD_fake.backward() D_G_z1 = output.mean().item() + # Compute error of D as sum over the fake and the real batches errD = errD_real + errD_fake + # Update D optimizerD.step() ############################ # (2) Update G network: maximize log(D(G(z))) ########################### netG.zero_grad() - - fake = netG(noise) label.fill_(real_label) # fake labels are real for generator cost - output = netD(fake) + # Since we just updated D, perform another forward pass of all-fake batch through D + output = netD(fake).view(-1) + # Calculate G's loss based on this output errG = criterion(output, label) + # Calculate gradients for G errG.backward() D_G_z2 = output.mean().item() - torch.nn.utils.clip_grad_norm_(netG.parameters(), 1.0) + # Update G optimizerG.step() - print( - "[%d/%d][%d/%d] Loss_G: %.4f Loss_D: %.4f D(x): %.4f D(G(z)): %.4f / %.4f" - % ( - epoch, - niter, - i, - len(dataloader), - errG.item(), - errD.item(), - D_x, - D_G_z1, - D_G_z2, - ) - ) - # save the output - if i % 100 == 1: - print("saving the output") - vutils.save_image(real_cpu[:128], "output/real_samples.png", normalize=True) - fake = netG(fixed_noise) - vutils.save_image( - fake.detach(), - "output/fake_samples_epoch_%03d.png" % (epoch), - normalize=True, + # Output training stats + if i % 50 == 0: + print( + f"[{epoch}/{num_epochs}][{i}/{len(dataloader)}]\t" + f"Loss_D: {errD.item():.4f}\tLoss_G: {errG.item():.4f}\t" + f"D(x): {D_x:.4f}\tD(G(z)): {D_G_z1:.4f} / {D_G_z2:.4f}" ) - # Check pointing for every epoch - # torch.save(netG.state_dict(), "weights/netG_epoch_%d.pth" % (epoch)) - # torch.save(netD.state_dict(), "weights/netD_epoch_%d.pth" % (epoch)) + # Save Losses for plotting later + G_losses.append(errG.item()) + D_losses.append(errD.item()) + + # Check how the generator is doing by saving G's output on fixed_noise + if (iters % 500 == 0) or ( + (epoch == num_epochs - 1) and (i == len(dataloader) - 1) + ): + with torch.no_grad(): + fake = netG(fixed_noise).detach().cpu() + img_list.append(vutils.make_grid(fake, padding=2, normalize=True)) + + iters += 1 + +plt.figure(figsize=(10, 5)) +plt.title("Generator and Discriminator Loss During Training") +plt.plot(G_losses, label="G") +plt.plot(D_losses, label="D") +plt.xlabel("iterations") +plt.ylabel("Loss") +plt.legend() +# plt.show() +plt.savefig("training_losses.png") + +fig = plt.figure(figsize=(8, 8)) +plt.axis("off") +ims = [[plt.imshow(np.transpose(i, (1, 2, 0)), animated=True)] for i in img_list] +ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True) + +# HTML(ani.to_jshtml()) + +# Grab a batch of real images from the dataloader +real_batch = next(iter(dataloader)) + +# Plot the real images +plt.figure(figsize=(15, 15)) +plt.subplot(1, 2, 1) +plt.axis("off") +plt.title("Real Images") +plt.imshow( + np.transpose( + vutils.make_grid( + real_batch[0].to(device)[:64], padding=5, normalize=True + ).cpu(), + (1, 2, 0), + ) +) + +# Plot the fake images from the last epoch +plt.subplot(1, 2, 2) +plt.axis("off") +plt.title("Fake Images") +plt.imshow(np.transpose(img_list[-1], (1, 2, 0))) +# plt.show() +plt.savefig("fake_images.png") From 642429cd4f99aec4b974cf437904a055399c2af2 Mon Sep 17 00:00:00 2001 From: Thibaut Boissin Date: Wed, 22 Jan 2025 18:00:56 +0100 Subject: [PATCH 03/13] moved residuals to new module --- orthogonium/layers/residual.py | 57 ++++++++++++++++ orthogonium/model_factory/models_factory.py | 73 +++------------------ scripts/imagenet.py | 4 +- 3 files changed, 68 insertions(+), 66 deletions(-) create mode 100644 orthogonium/layers/residual.py diff --git a/orthogonium/layers/residual.py b/orthogonium/layers/residual.py new file mode 100644 index 0000000..6848c6c --- /dev/null +++ b/orthogonium/layers/residual.py @@ -0,0 +1,57 @@ +import torch +from torch import nn as nn + + +class ConcatResidual(nn.Module): + def __init__(self, fn): + super().__init__() + self.add_module("fn", fn) + + def forward(self, x): + # split x + x1, x2 = x.chunk(2, dim=1) + # apply function + out = self.fn(x2) + # concat and return + return torch.cat([x1, out], dim=1) + + +class L2NormResidual(nn.Module): + def __init__(self, fn, eps=1e-6): + super().__init__() + self.eps = eps + self.add_module("fn", fn) + + def forward(self, x): + # apply function + out = self.fn(x) + # concat and return + return torch.sqrt(x**2 + out**2 + self.eps) + + +class AdditiveResidual(nn.Module): + def __init__(self, fn, init_val=1.0): + super().__init__() + self.add_module("fn", fn) + self.alpha = nn.Parameter(torch.tensor(init_val), requires_grad=True) + + def forward(self, x): + # apply function + out = self.fn(x) + # alpha = self.alpha.clamp(0, 1) + alpha = torch.sigmoid(self.alpha) # check if alpha don't grow to infinity + return alpha * x + (1 - alpha) * out + + +class PrescaledAdditiveResidual(nn.Module): + def __init__(self, fn, init_val=1.0): + super().__init__() + self.add_module("fn", fn) + self.alpha = nn.Parameter(torch.tensor(init_val), requires_grad=True) + + def forward(self, x): + # apply function + out = self.fn(x * self.alpha) + lip_cst = 1.0 + torch.abs(self.alpha) + # we divide by lip const on each branch as it is more numerically stable + return x / lip_cst + (1.0 / lip_cst) * out diff --git a/orthogonium/model_factory/models_factory.py b/orthogonium/model_factory/models_factory.py index 657ab64..8aa0c69 100644 --- a/orthogonium/model_factory/models_factory.py +++ b/orthogonium/model_factory/models_factory.py @@ -1,7 +1,7 @@ -import torch import torch.nn as nn from torch.nn import AvgPool2d +from orthogonium.layers.residual import PrescaledAdditiveResidual from orthogonium.model_factory.classparam import ClassParam from orthogonium.layers import AdaptiveOrthoConv2d from orthogonium.layers import BatchCentering2D @@ -129,61 +129,6 @@ def ResNet50Block(in_channels, out_channels, n_blocks, norm, act, stride=2): return layers -class ConcatResidual(nn.Module): - def __init__(self, fn): - super().__init__() - self.add_module("fn", fn) - - def forward(self, x): - # split x - x1, x2 = x.chunk(2, dim=1) - # apply function - out = self.fn(x2) - # concat and return - return torch.cat([x1, out], dim=1) - - -class L2NormResidual(nn.Module): - def __init__(self, fn, eps=1e-6): - super().__init__() - self.eps = eps - self.add_module("fn", fn) - - def forward(self, x): - # apply function - out = self.fn(x) - # concat and return - return torch.sqrt(x**2 + out**2 + self.eps) - - -# class Residual(nn.Module): -# def __init__(self, fn, init_val=1.0): -# super().__init__() -# self.add_module("fn", fn) -# self.alpha = nn.Parameter(torch.tensor(init_val), requires_grad=True) -# -# def forward(self, x): -# # apply function -# out = self.fn(x) -# # alpha = self.alpha.clamp(0, 1) -# alpha = torch.sigmoid(self.alpha) # check if alpha don't grow to infinity -# return alpha * x + (1 - alpha) * out - - -class Residual(nn.Module): - def __init__(self, fn, init_val=1.0): - super().__init__() - self.add_module("fn", fn) - self.alpha = nn.Parameter(torch.tensor(init_val), requires_grad=True) - - def forward(self, x): - # apply function - out = self.fn(x * self.alpha) - lip_cst = 1.0 + torch.abs(self.alpha) - # we divide by lip const on each branch as it is more numerically stable - return x / lip_cst + (1.0 / lip_cst) * out - - # def dumbNet500M( # img_size=(224, 224), # dim=1024, @@ -261,7 +206,7 @@ def AOCNetV1( embedding_dim=1024, groups=8, skip=ClassParam( - Residual, + PrescaledAdditiveResidual, init_val=1.0, ), conv=ClassParam( @@ -385,7 +330,7 @@ def resblock(in_channels, out_channels, n_blocks, conv, act, norm): padding_mode="zeros", ), skip=ClassParam( - Residual, + PrescaledAdditiveResidual, init_val=2.0, ), act=ClassParam(MaxMin), @@ -400,7 +345,7 @@ def resblock(in_channels, out_channels, n_blocks, conv, act, norm): embedding_dim=1024, groups=4, skip=ClassParam( - Residual, + PrescaledAdditiveResidual, init_val=2.0, ), conv=ClassParam( @@ -425,7 +370,7 @@ def resblock(in_channels, out_channels, n_blocks, conv, act, norm): groups=None, # None is depthwise, 1 is no groups # skip=None, skip=ClassParam( - Residual, + PrescaledAdditiveResidual, init_val=3.0, ), conv=ClassParam( @@ -450,7 +395,7 @@ def resblock(in_channels, out_channels, n_blocks, conv, act, norm): groups=None, # None is depthwise, 1 is no groups # skip=None, skip=ClassParam( - Residual, + PrescaledAdditiveResidual, init_val=3.0, ), conv=ClassParam( @@ -475,7 +420,7 @@ def resblock(in_channels, out_channels, n_blocks, conv, act, norm): groups=None, # None is depthwise, 1 is no groups # skip=None, skip=ClassParam( - Residual, + PrescaledAdditiveResidual, init_val=3.0, ), conv=ClassParam( @@ -499,7 +444,7 @@ def LipResNet( img_shape=(3, 224, 224), n_classes=1000, skip=ClassParam( - Residual, + PrescaledAdditiveResidual, init_val=3.0, ), conv=ClassParam( @@ -715,7 +660,7 @@ def PatchBasedExapandedCNN( norm=ClassParam(LayerCentering2D), ): if skip: - skipco = Residual + skipco = PrescaledAdditiveResidual else: skipco = nn.Sequential return nn.Sequential( diff --git a/scripts/imagenet.py b/scripts/imagenet.py index 8f127fa..4d7f520 100644 --- a/scripts/imagenet.py +++ b/scripts/imagenet.py @@ -29,7 +29,7 @@ from orthogonium.losses import LossXent from orthogonium.losses import VRA from orthogonium.model_factory.models_factory import AOCNetV1 -from orthogonium.model_factory.models_factory import Residual +from orthogonium.layers.residual import PrescaledAdditiveResidual torch.backends.cudnn.benchmark = True torch.set_float32_matmul_precision("medium") @@ -137,7 +137,7 @@ def __init__(self, num_classes=1000): groups=None, # None is depthwise, 1 is no groups # skip=None, skip=ClassParam( - Residual, + PrescaledAdditiveResidual, init_val=3.0, ), conv=ClassParam( From 1fd06a8afb39dd7cc28489e5f2733e2081f4b1f2 Mon Sep 17 00:00:00 2001 From: Thibaut Boissin Date: Fri, 31 Jan 2025 14:48:44 +0100 Subject: [PATCH 04/13] added group support for AOL --- orthogonium/layers/conv/AOL/aol.py | 19 +++++++------- tests/test_aol.py | 41 +++++------------------------- 2 files changed, 17 insertions(+), 43 deletions(-) diff --git a/orthogonium/layers/conv/AOL/aol.py b/orthogonium/layers/conv/AOL/aol.py index 7830950..ff31f1a 100644 --- a/orthogonium/layers/conv/AOL/aol.py +++ b/orthogonium/layers/conv/AOL/aol.py @@ -2,7 +2,11 @@ from torch import nn from torch.nn.utils import parametrize -from orthogonium.layers.conv.AOC.fast_block_ortho_conv import conv_singular_values_numpy +from orthogonium.layers.conv.AOC.fast_block_ortho_conv import ( + conv_singular_values_numpy, + transpose_kernel, + fast_matrix_conv, +) from orthogonium.layers.conv.SLL.sll_layer import safe_inv @@ -11,18 +15,15 @@ def __init__(self, nb_features, groups): super(AOLReparametrizer, self).__init__() self.nb_features = nb_features self.groups = groups - self.q = nn.Parameter(torch.randn(nb_features)) + self.q = nn.Parameter(torch.ones(nb_features, 1, 1, 1)) def forward(self, kernel): - ktk = nn.functional.conv2d( - kernel, - kernel, - groups=1, - padding=kernel.shape[-1] - 1, + ktk = fast_matrix_conv( + transpose_kernel(kernel, self.groups, flip=True), kernel, self.groups ) ktk = torch.abs(ktk) - q = torch.exp(self.q).reshape(-1, 1, 1, 1) - q_inv = torch.exp(-self.q).reshape(-1, 1, 1, 1) + q = torch.exp(self.q) + q_inv = torch.exp(-self.q) t = (q_inv * ktk * q).sum((1, 2, 3)) t = safe_inv(torch.sqrt(t)) t = t.reshape(-1, 1, 1, 1) diff --git a/tests/test_aol.py b/tests/test_aol.py index 92df57b..dca4a64 100644 --- a/tests/test_aol.py +++ b/tests/test_aol.py @@ -1,6 +1,7 @@ import pytest import torch from orthogonium.layers.conv.AOL.aol import AOLConv2D, AOLConvTranspose2D +from orthogonium.layers.conv.singular_values import get_conv_sv @pytest.mark.parametrize("convclass", [AOLConv2D, AOLConvTranspose2D]) @@ -24,10 +25,10 @@ def test_lipschitz_layers(convclass, in_channels, out_channels, kernel_size, gro x = torch.randn((4, in_channels, 8, 8), requires_grad=True) # Input # Pre-optimization Lipschitz constant (if applicable) - pre_lipschitz_constant = compute_lipschitz_constant(layer, x) - print(f"{convclass.__name__} | Before: {pre_lipschitz_constant:.6f}") + pre_lipschitz_constant = get_conv_sv(layer, n_iter=5, agg_groups=True) + print(f"{convclass.__name__} | Before: {pre_lipschitz_constant}") assert ( - pre_lipschitz_constant <= 1 + 1e-4 + pre_lipschitz_constant[0] <= 1 + 1e-4 ), "Pre-optimization Lipschitz constant violation." # Define optimizer and loss function @@ -42,36 +43,8 @@ def test_lipschitz_layers(convclass, in_channels, out_channels, kernel_size, gro optimizer.step() # Post-optimization Lipschitz constant (if applicable) - post_lipschitz_constant = compute_lipschitz_constant(layer, x) - print(f"{convclass.__name__} | After: {post_lipschitz_constant:.6f}") + post_lipschitz_constant = get_conv_sv(layer, n_iter=5, agg_groups=True) + print(f"{convclass.__name__} | After: {post_lipschitz_constant}") assert ( - post_lipschitz_constant <= 1 + 1e-4 + post_lipschitz_constant[0] <= 1 + 1e-4 ), "Post-optimization Lipschitz constant violation." - - -def compute_lipschitz_constant(layer, x): - """ - Calculate the Lipschitz constant for a given layer by computing the - maximum singular value of the Jacobian. - """ - y = layer(x) - - # Compute Jacobian by autograd - jacobian = [] - for i in range(y.numel()): - grad_output = torch.zeros_like(y) - grad_output.view(-1)[i] = 1 - gradients = torch.autograd.grad( - outputs=y, - inputs=x, - grad_outputs=grad_output, - retain_graph=True, - create_graph=True, - allow_unused=True, - )[0] - jacobian.append(gradients.view(-1).detach().cpu().numpy()) - jacobian = torch.tensor(jacobian).view(y.numel(), x.numel()) # Construct Jacobian - - # Compute singular values and return the maximum value - singular_values = torch.linalg.svdvals(jacobian).detach() - return singular_values.max().item() From 59439e428d8379a02431e1ccb15e091a9e25e2ab Mon Sep 17 00:00:00 2001 From: Thibaut Boissin Date: Mon, 3 Feb 2025 21:39:26 +0100 Subject: [PATCH 05/13] added quick demo for robust classification --- .../notebooks/demo_cifar_classification.ipynb | 455 ++++++++++++++++++ 1 file changed, 455 insertions(+) create mode 100644 docs/notebooks/demo_cifar_classification.ipynb diff --git a/docs/notebooks/demo_cifar_classification.ipynb b/docs/notebooks/demo_cifar_classification.ipynb new file mode 100644 index 0000000..94046cc --- /dev/null +++ b/docs/notebooks/demo_cifar_classification.ipynb @@ -0,0 +1,455 @@ +{ + "cells": [ + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "# Training a 1-Lipschitz constrained network on CIFAR10 with Orthogonium\n", + "\n", + "## Lipschitz-Constrained Networks and Certifiable Robustness\n", + "\n", + "**What is a Lipschitz Network?**\n", + "A *Lipschitz network* is a neural network in which each layer is constrained to be a 1-Lipschitz function. This means that small changes in the input lead to only small changes in the output, ensuring controlled sensitivity throughout the network. The overall Lipschitz constant of the network is usually estimated as the product of the Lipschitz constants of its individual layers. However, this bound is often loose and difficult to compute exactly.\n", + "\n", + "**How to Build Lipschitz Networks?**\n", + "To construct such networks:\n", + "- **Orthogonal Layers:** Use layers that enforce orthogonality constraints (e.g., Adaptive OrthoConvolutions). These layers are designed to strictly represent 1-Lipschitz functions.\n", + "- **Special Activations:** Incorporate activations like **MaxMin** which, when combined with orthogonal layers, help in obtaining a tight estimation of the network's Lipschitz constant.\n", + "- **Reparametrization Techniques:** Methods such as AOC (Adaptive OrthoConvolutions) ensure that each layer adheres to the 1-Lipschitz constraint, making the overall bound much tighter compared to a simple product of individual bounds.\n", + "\n", + "**Certifiable Robustness**\n", + "Certifiable robustness provides a guarantee on the minimal perturbation needed to alter the network's prediction, independent of any specific adversarial attack. For a 1-Lipschitz classification function \\( f \\) with \\( f(x)_l \\) representing the logit for the true class and \\( f(x)_i \\) for any other class, a robustness certificate in the \\( L_2 \\) norm is given by:\n", + "\\[\n", + "\\epsilon \\geq \\frac{f(x)_l - \\max_{i \\neq l} f(x)_i}{\\sqrt{2}}\n", + "\\]\n", + "This means that as long as the perturbation remains below \\( \\epsilon \\), the classification will not change. This certificate is:\n", + "- **Independent of Attacks:** It does not rely on any particular adversarial attack method, ensuring that the guarantee remains valid even as new attack strategies emerge.\n", + "- **Computationally Efficient:** The certificate can be computed cheaply and even integrated as a loss term during training, leading to models that are robust by design.\n", + "\n", + "**Applications and Benefits**\n", + "Lipschitz-constrained networks are not only crucial for certifiable robustness but also have broader applications:\n", + "- They are tightly linked with generative models like WGANs and concepts in optimal transport.\n", + "- They enable scalable differential privacy and help avoid singularities in models such as diffusion networks.\n", + "- They guarantee existence and uniqueness in classification tasks, making them appealing for reliable machine learning.\n", + "\n", + "In summary, by combining orthogonal layers with appropriate activations and reparametrization techniques, one can build Lipschitz networks that not only deliver competitive performance but also offer provable robustness guarantees.\n" + ], + "id": "430c3d772b2f2bb0" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "import math\n", + "import os\n", + "\n", + "import schedulefree\n", + "import torch\n", + "import torch.utils.data\n", + "import torchmetrics\n", + "from lightning.pytorch import callbacks as pl_callbacks\n", + "from lightning.pytorch import Trainer\n", + "from lightning.pytorch import LightningModule, LightningDataModule\n", + "# from lightning.pytorch.loggers import WandbLogger # Uncomment if using Wandb logging\n", + "from torch.nn import AvgPool2d\n", + "from torch.utils.data import DataLoader\n", + "from torchinfo import summary\n", + "from torchvision.datasets import CIFAR10\n", + "from torchvision.transforms import Compose, Normalize, RandAugment, RandomHorizontalFlip, RandomResizedCrop, ToTensor\n", + "\n", + "from orthogonium.model_factory.classparam import ClassParam\n", + "from orthogonium.layers.conv.AOC import AdaptiveOrthoConv2d\n", + "from orthogonium.layers.linear import OrthoLinear\n", + "from orthogonium.layers.custom_activations import MaxMin\n", + "from orthogonium.losses import LossXent, CosineLoss\n", + "from orthogonium.losses import VRA\n", + "from orthogonium.model_factory.models_factory import StagedCNN, PatchBasedExapandedCNN\n", + "\n", + "# Enable benchmark mode and set matmul precision for performance tuning\n", + "torch.backends.cudnn.benchmark = True\n", + "torch.set_float32_matmul_precision(\"medium\")\n" + ], + "id": "54c7e322929ba6a2" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "## Training Settings\n", + "\n", + "You can play with the training settings to explore different configurations and compare their performance. The settings include:\n", + "\n", + "**Training settings include:**\n", + "- **non_robust:** Cosine Similarity loss training.\n", + "- **mildly_robust:** Cross Entropy Loss includes a high margin targeting a VRA of 36/255, resulting in 42% VRA.\n", + "- **robust:** Similar to mildly robust, but with settings that push towards 72/255 verified robust accuracy, resulting in 47% VRA.\n", + "\n", + "\n", + "> Note: The aim here is to show the training flow rather than reach state-of-the-art performance.\n", + "\n", + "## Training Settings Performance\n", + "\n", + "| Setting | Epochs | Accuracy | Verified Robust Accuracy (VRA) |\n", + "|---------------|--------|----------|--------------------------------|\n", + "| **non_robust** | 30 | 88.5% | 0% |\n", + "| **mildly_robust** | 150 | 75% | 42% |\n", + "| **robust** | 150 | 71% | 47% |\n", + "\n", + "These configurations are stored in the `settings` dictionary.\n" + ], + "id": "2ee069e5694ccd2d" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "settings = {\n", + " \"non_robust\": {\n", + " \"loss\": CosineLoss,\n", + " \"epochs\": 30,\n", + " },\n", + " \"mildly_robust\": {\n", + " \"loss\": ClassParam(\n", + " LossXent,\n", + " n_classes=10,\n", + " offset=(math.sqrt(2) / 0.1983) * (36 / 255), # aiming for 36/255 verified robust accuracy\n", + " temperature=0.25,\n", + " ),\n", + " \"epochs\": 150,\n", + " },\n", + " \"robust\": {\n", + " \"loss\": ClassParam(\n", + " LossXent,\n", + " n_classes=10,\n", + " offset=(math.sqrt(2) / 0.1983) * (72 / 255), # aiming for 72/255 verified robust accuracy\n", + " temperature=0.25,\n", + " ),\n", + " \"epochs\": 150,\n", + " },\n", + "}\n" + ], + "id": "337a713cc4e13f55" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "## Data Module: CIFAR10\n", + "\n", + "We create a `LightningDataModule` to load and preprocess the CIFAR10 training and validation datasets.\n", + "\n", + "The training dataloader applies several transforms:\n", + "- Random resized cropping\n", + "- Random horizontal flip\n", + "- Normalization using precomputed mean and standard deviation\n", + "\n", + "The validation dataloader only applies tensor conversion and normalization.\n", + "\n" + ], + "id": "eec8549d5fa315b5" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "class Cifar10DataModule(LightningDataModule):\n", + " # Dataset configuration\n", + " _BATCH_SIZE = 256\n", + " _NUM_WORKERS = 8 # Number of parallel processes for data loading\n", + " _PREPROCESSING_PARAMS = {\n", + " \"img_mean\": (0.41757566, 0.26098573, 0.25888634),\n", + " \"img_std\": (0.21938758, 0.1983, 0.19342837),\n", + " \"crop_size\": 32,\n", + " \"horizontal_flip_prob\": 0.5,\n", + " \"random_resized_crop_params\": {\n", + " \"scale\": (0.5, 1.0),\n", + " \"ratio\": (3.0 / 4.0, 4.0 / 3.0),\n", + " },\n", + " }\n", + "\n", + " def train_dataloader(self):\n", + " # Define the transformations for training data\n", + " transform = Compose(\n", + " [\n", + " RandomResizedCrop(\n", + " self._PREPROCESSING_PARAMS[\"crop_size\"],\n", + " **self._PREPROCESSING_PARAMS[\"random_resized_crop_params\"],\n", + " ),\n", + " RandomHorizontalFlip(self._PREPROCESSING_PARAMS[\"horizontal_flip_prob\"]),\n", + " # Uncomment the following line to use RandAugment\n", + " # RandAugment(**self._PREPROCESSING_PARAMS[\"randaug_params\"]),\n", + " ToTensor(),\n", + " Normalize(\n", + " mean=self._PREPROCESSING_PARAMS[\"img_mean\"],\n", + " std=self._PREPROCESSING_PARAMS[\"img_std\"],\n", + " ),\n", + " ]\n", + " )\n", + "\n", + " train_dataset = CIFAR10(\n", + " root=\"./data\",\n", + " train=True,\n", + " download=True,\n", + " transform=transform,\n", + " )\n", + "\n", + " return DataLoader(\n", + " train_dataset,\n", + " batch_size=self._BATCH_SIZE,\n", + " num_workers=self._NUM_WORKERS,\n", + " prefetch_factor=2,\n", + " shuffle=True,\n", + " )\n", + "\n", + " def val_dataloader(self):\n", + " # Define the transformations for validation data\n", + " transform = Compose(\n", + " [\n", + " ToTensor(),\n", + " Normalize(\n", + " mean=self._PREPROCESSING_PARAMS[\"img_mean\"],\n", + " std=self._PREPROCESSING_PARAMS[\"img_std\"],\n", + " ),\n", + " ]\n", + " )\n", + "\n", + " val_dataset = CIFAR10(\n", + " root=\"./data\",\n", + " train=False,\n", + " download=True,\n", + " transform=transform,\n", + " )\n", + "\n", + " return DataLoader(\n", + " val_dataset,\n", + " batch_size=self._BATCH_SIZE,\n", + " num_workers=self._NUM_WORKERS,\n", + " shuffle=False,\n", + " )\n" + ], + "id": "1e578065843cfeb2" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "## Classification Model Module\n", + "\n", + "We now define a `LightningModule` that wraps our CNN model. The network is built using the `PatchBasedExapandedCNN` factory method from *orthogonium*.\n", + "\n", + "Key components include:\n", + "- The custom CNN model architecture.\n", + "- The loss function (set based on the selected training configuration).\n", + "- Training and validation steps that compute and log both accuracy and verified robust accuracy (VRA).\n", + "- The `configure_optimizers` method which sets up the Adam optimizer with schedule-free updates.\n" + ], + "id": "617cfc5c06ab2d75" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "class ClassificationLightningModule(LightningModule):\n", + " def __init__(self, num_classes=10, loss=None):\n", + " super().__init__()\n", + " self.num_classes = num_classes\n", + " self.model = PatchBasedExapandedCNN(\n", + " img_shape=(3, 32, 32),\n", + " dim=256,\n", + " depth=12,\n", + " kernel_size=3,\n", + " patch_size=2,\n", + " expand_factor=2,\n", + " groups=None,\n", + " n_classes=10,\n", + " skip=True,\n", + " conv=ClassParam(\n", + " AdaptiveOrthoConv2d,\n", + " bias=False,\n", + " padding=\"same\",\n", + " padding_mode=\"zeros\",\n", + " ),\n", + " act=ClassParam(MaxMin),\n", + " pool=ClassParam(\n", + " AdaptiveOrthoConv2d,\n", + " in_channels=256,\n", + " out_channels=256,\n", + " groups=128,\n", + " bias=False,\n", + " padding=0,\n", + " kernel_size=16,\n", + " stride=16,\n", + " ),\n", + " lin=ClassParam(OrthoLinear, bias=False),\n", + " norm=None,\n", + " )\n", + " self.criteria = loss() if loss is not None else torch.nn.CrossEntropyLoss()\n", + " self.train_acc = torchmetrics.Accuracy(task=\"multiclass\", num_classes=num_classes)\n", + " self.val_acc = torchmetrics.Accuracy(task=\"multiclass\", num_classes=num_classes)\n", + " self.train_vra = torchmetrics.MeanMetric()\n", + " self.val_vra = torchmetrics.MeanMetric()\n", + "\n", + " def forward(self, x):\n", + " return self.model(x)\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " self.model.train()\n", + " img, label = batch\n", + " y_hat = self.model(img)\n", + " loss = self.criteria(y_hat, label)\n", + " self.train_acc(y_hat, label)\n", + " self.train_vra(\n", + " VRA(\n", + " y_hat,\n", + " label,\n", + " L=1 / min(Cifar10DataModule._PREPROCESSING_PARAMS[\"img_std\"]),\n", + " eps=36 / 255,\n", + " last_layer_type=\"global\",\n", + " )\n", + " )\n", + " self.log(\"loss\", loss, on_epoch=True, on_step=True, prog_bar=True, sync_dist=True)\n", + " self.log(\"accuracy\", self.train_acc, on_epoch=True, on_step=True, prog_bar=True, sync_dist=True)\n", + " self.log(\"vra\", self.train_vra, on_epoch=True, on_step=True, prog_bar=True, sync_dist=False)\n", + " return loss\n", + "\n", + " def validation_step(self, batch, batch_idx):\n", + " self.model.eval()\n", + " img, label = batch\n", + " y_hat = self.model(img)\n", + " loss = self.criteria(y_hat, label)\n", + " self.val_acc(y_hat, label)\n", + " self.val_vra(\n", + " VRA(\n", + " y_hat,\n", + " label,\n", + " L=1 / min(Cifar10DataModule._PREPROCESSING_PARAMS[\"img_std\"]),\n", + " eps=36 / 255,\n", + " last_layer_type=\"global\",\n", + " )\n", + " )\n", + " self.log(\"val_loss\", loss, on_epoch=True, on_step=False, prog_bar=True, sync_dist=True)\n", + " self.log(\"val_accuracy\", self.val_acc, on_epoch=True, on_step=False, prog_bar=True, sync_dist=True)\n", + " self.log(\"val_vra\", self.val_vra, on_epoch=True, on_step=False, prog_bar=True, sync_dist=True)\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " # Setup the Adam optimizer with schedule-free updates.\n", + " optimizer = schedulefree.AdamWScheduleFree(self.parameters(), lr=5e-3, weight_decay=0)\n", + " optimizer.train()\n", + " self.hparams[\"lr\"] = optimizer.param_groups[0][\"lr\"]\n", + " return optimizer\n" + ], + "id": "d3bc13596b867137" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "## Training Routine\n", + "\n", + "For example, to run a **non robust** training setting, set:\n", + "\n", + "```python\n", + "train_setting = \"non_robust\"\n" + ], + "id": "7442d6e9fdde2433" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "# Select the training setting manually.\n", + "train_setting = \"non_robust\" # Options: \"non_robust\", \"mildly_robust\", or \"robust\"\n", + "\n", + "# Get the corresponding loss function and number of epochs from the settings.\n", + "current_setting = settings[train_setting]\n", + "\n", + "# Instantiate the classification model and data module.\n", + "classification_module = ClassificationLightningModule(num_classes=10, loss=current_setting[\"loss\"])\n", + "data_module = Cifar10DataModule()\n", + "\n", + "# Optionally, set up a logger or callbacks if needed.\n", + "# For example, if using Wandb:\n", + "# from lightning.pytorch.loggers import WandbLogger\n", + "# wandb_logger = WandbLogger(project=\"lipschitz-robust-cifar10\", log_model=True)\n", + "# checkpoint_callback = pl_callbacks.ModelCheckpoint(\n", + "# monitor=\"loss\",\n", + "# mode=\"min\",\n", + "# save_top_k=1,\n", + "# save_last=True,\n", + "# dirpath=f\"./checkpoints/{wandb_logger.experiment.dir}\",\n", + "# )\n", + "\n", + "trainer = Trainer(\n", + " accelerator=\"gpu\",\n", + " devices=-1, # Use all available GPUs\n", + " num_nodes=1, # Number of nodes\n", + " strategy=\"ddp\", # Distributed strategy\n", + " precision=\"bf16-mixed\", # Mixed precision training\n", + " max_epochs=current_setting[\"epochs\"],\n", + " enable_model_summary=True,\n", + " # logger=[wandb_logger], # Uncomment to enable Wandb logging\n", + " logger=False,\n", + " callbacks=[\n", + " # You can add callbacks here, e.g.:\n", + " # pl_callbacks.LearningRateFinder(max_lr=0.05),\n", + " # checkpoint_callback,\n", + " ],\n", + ")\n", + "\n", + "# Print a summary of the model\n", + "summary(classification_module, input_size=(1, 3, 32, 32))\n", + "\n", + "# Start training\n", + "trainer.fit(classification_module, data_module)\n", + "\n", + "# Optionally, you can save the trained model afterwards:\n", + "# torch.save(classification_module.model.state_dict(), \"single_stage.pth\")\n" + ], + "id": "17b7575b29534e21" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "## Next Steps\n", + "\n", + "- **Model Evaluation:** You can add a new cell to perform model evaluation or predictions.\n", + "- **Logging and Checkpoints:** To enable model logging or checkpoint saving, uncomment the corresponding lines and configure as needed.\n", + "- **Experiment with Settings:** Change the `train_setting` variable to `\"mildly_robust\"` or `\"robust\"` to experiment with other training configurations.\n" + ], + "id": "399725e628d1393a" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From 7cfa924ef5b38566f2dad4c9d101c53cc5471b34 Mon Sep 17 00:00:00 2001 From: Thibaut Boissin Date: Tue, 4 Feb 2025 19:42:17 +0100 Subject: [PATCH 06/13] ran notebook to showcase results --- .../notebooks/demo_cifar_classification.ipynb | 432 +++++++++++++++--- 1 file changed, 373 insertions(+), 59 deletions(-) diff --git a/docs/notebooks/demo_cifar_classification.ipynb b/docs/notebooks/demo_cifar_classification.ipynb index 94046cc..ae90933 100644 --- a/docs/notebooks/demo_cifar_classification.ipynb +++ b/docs/notebooks/demo_cifar_classification.ipynb @@ -1,8 +1,15 @@ { "cells": [ { - "metadata": {}, "cell_type": "markdown", + "id": "430c3d772b2f2bb0", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, "source": [ "# Training a 1-Lipschitz constrained network on CIFAR10 with Orthogonium\n", "\n", @@ -18,11 +25,11 @@ "- **Reparametrization Techniques:** Methods such as AOC (Adaptive OrthoConvolutions) ensure that each layer adheres to the 1-Lipschitz constraint, making the overall bound much tighter compared to a simple product of individual bounds.\n", "\n", "**Certifiable Robustness**\n", - "Certifiable robustness provides a guarantee on the minimal perturbation needed to alter the network's prediction, independent of any specific adversarial attack. For a 1-Lipschitz classification function \\( f \\) with \\( f(x)_l \\) representing the logit for the true class and \\( f(x)_i \\) for any other class, a robustness certificate in the \\( L_2 \\) norm is given by:\n", - "\\[\n", + "Certifiable robustness provides a guarantee on the minimal perturbation needed to alter the network's prediction, independent of any specific adversarial attack. For a 1-Lipschitz classification function $ f $ with $ f(x)_l $ representing the logit for the true class and $ f(x)_i $ for any other class, a robustness certificate in the $ L_2 $ norm is given by:\n", + "$$\n", "\\epsilon \\geq \\frac{f(x)_l - \\max_{i \\neq l} f(x)_i}{\\sqrt{2}}\n", - "\\]\n", - "This means that as long as the perturbation remains below \\( \\epsilon \\), the classification will not change. This certificate is:\n", + "$$\n", + "This means that as long as the perturbation remains below $ \\epsilon $, the classification will not change. This certificate is:\n", "- **Independent of Attacks:** It does not rely on any particular adversarial attack method, ensuring that the guarantee remains valid even as new attack strategies emerge.\n", "- **Computationally Efficient:** The certificate can be computed cheaply and even integrated as a loss term during training, leading to models that are robust by design.\n", "\n", @@ -33,14 +40,24 @@ "- They guarantee existence and uniqueness in classification tasks, making them appealing for reliable machine learning.\n", "\n", "In summary, by combining orthogonal layers with appropriate activations and reparametrization techniques, one can build Lipschitz networks that not only deliver competitive performance but also offer provable robustness guarantees.\n" - ], - "id": "430c3d772b2f2bb0" + ] }, { + "cell_type": "code", + "execution_count": 1, + "id": "18fd1cee-c6cd-473f-b375-1678983991ab", "metadata": {}, + "outputs": [], + "source": [ + "# !pip install orthogonium lightning rich schedulefree" + ] + }, + { "cell_type": "code", + "execution_count": 2, + "id": "54c7e322929ba6a2", + "metadata": {}, "outputs": [], - "execution_count": null, "source": [ "import math\n", "import os\n", @@ -52,10 +69,11 @@ "from lightning.pytorch import callbacks as pl_callbacks\n", "from lightning.pytorch import Trainer\n", "from lightning.pytorch import LightningModule, LightningDataModule\n", + "from lightning.pytorch.callbacks import RichProgressBar\n", + "from lightning.pytorch.callbacks import RichModelSummary\n", "# from lightning.pytorch.loggers import WandbLogger # Uncomment if using Wandb logging\n", "from torch.nn import AvgPool2d\n", "from torch.utils.data import DataLoader\n", - "from torchinfo import summary\n", "from torchvision.datasets import CIFAR10\n", "from torchvision.transforms import Compose, Normalize, RandAugment, RandomHorizontalFlip, RandomResizedCrop, ToTensor\n", "\n", @@ -70,12 +88,12 @@ "# Enable benchmark mode and set matmul precision for performance tuning\n", "torch.backends.cudnn.benchmark = True\n", "torch.set_float32_matmul_precision(\"medium\")\n" - ], - "id": "54c7e322929ba6a2" + ] }, { - "metadata": {}, "cell_type": "markdown", + "id": "2ee069e5694ccd2d", + "metadata": {}, "source": [ "## Training Settings\n", "\n", @@ -93,24 +111,24 @@ "\n", "| Setting | Epochs | Accuracy | Verified Robust Accuracy (VRA) |\n", "|---------------|--------|----------|--------------------------------|\n", - "| **non_robust** | 30 | 88.5% | 0% |\n", + "| **non_robust** | 60 | 88.5% | 0% |\n", "| **mildly_robust** | 150 | 75% | 42% |\n", "| **robust** | 150 | 71% | 47% |\n", "\n", "These configurations are stored in the `settings` dictionary.\n" - ], - "id": "2ee069e5694ccd2d" + ] }, { - "metadata": {}, "cell_type": "code", + "execution_count": 3, + "id": "337a713cc4e13f55", + "metadata": {}, "outputs": [], - "execution_count": null, "source": [ "settings = {\n", " \"non_robust\": {\n", " \"loss\": CosineLoss,\n", - " \"epochs\": 30,\n", + " \"epochs\": 60,\n", " },\n", " \"mildly_robust\": {\n", " \"loss\": ClassParam(\n", @@ -131,12 +149,12 @@ " \"epochs\": 150,\n", " },\n", "}\n" - ], - "id": "337a713cc4e13f55" + ] }, { - "metadata": {}, "cell_type": "markdown", + "id": "eec8549d5fa315b5", + "metadata": {}, "source": [ "## Data Module: CIFAR10\n", "\n", @@ -149,14 +167,14 @@ "\n", "The validation dataloader only applies tensor conversion and normalization.\n", "\n" - ], - "id": "eec8549d5fa315b5" + ] }, { - "metadata": {}, "cell_type": "code", + "execution_count": 4, + "id": "1e578065843cfeb2", + "metadata": {}, "outputs": [], - "execution_count": null, "source": [ "class Cifar10DataModule(LightningDataModule):\n", " # Dataset configuration\n", @@ -228,34 +246,42 @@ "\n", " return DataLoader(\n", " val_dataset,\n", - " batch_size=self._BATCH_SIZE,\n", + " batch_size=self._BATCH_SIZE * 4,\n", " num_workers=self._NUM_WORKERS,\n", " shuffle=False,\n", " )\n" - ], - "id": "1e578065843cfeb2" + ] }, { - "metadata": {}, "cell_type": "markdown", + "id": "617cfc5c06ab2d75", + "metadata": {}, "source": [ "## Classification Model Module\n", "\n", - "We now define a `LightningModule` that wraps our CNN model. The network is built using the `PatchBasedExapandedCNN` factory method from *orthogonium*.\n", + "We now define a `LightningModule` that wraps our CNN model. The network uses the `PatchBasedExapandedCNN` factory method from *orthogonium*.\n", + "\n", + "The architecture consists of 4 main parts:\n", + "- The stem is a patch extractor: a convolution whose kernel size equals the stride.\n", + "- A sequence of residual block: each residual features a learnable factor to ensure its Lipschitzness. In each residual, there is one depthwise convolution, A MaxMin activation, and a pointwise convolution. No pooling is performed in this part of the network.\n", + "- A pooling layer: here, we use a depthwise convolution whose kernel size equals the image size. This allows for the localization of features without using a large amount of weight. (this is not mandatory for accurate training but seems to obtain a slightly better accuracy / robustness tradeoff in robust training).\n", + "- a Fully connected layer for classification.\n", + "\n", + "All convolutional layers use AOC, allowing the construction of complex Lipschitz-constrained architectures.\n", "\n", "Key components include:\n", "- The custom CNN model architecture.\n", "- The loss function (set based on the selected training configuration).\n", "- Training and validation steps that compute and log both accuracy and verified robust accuracy (VRA).\n", "- The `configure_optimizers` method which sets up the Adam optimizer with schedule-free updates.\n" - ], - "id": "617cfc5c06ab2d75" + ] }, { - "metadata": {}, "cell_type": "code", + "execution_count": 5, + "id": "d3bc13596b867137", + "metadata": {}, "outputs": [], - "execution_count": null, "source": [ "class ClassificationLightningModule(LightningModule):\n", " def __init__(self, num_classes=10, loss=None):\n", @@ -342,16 +368,16 @@ "\n", " def configure_optimizers(self):\n", " # Setup the Adam optimizer with schedule-free updates.\n", - " optimizer = schedulefree.AdamWScheduleFree(self.parameters(), lr=5e-3, weight_decay=0)\n", + " optimizer = schedulefree.AdamWScheduleFree(self.parameters(), lr=1e-2, weight_decay=0)\n", " optimizer.train()\n", " self.hparams[\"lr\"] = optimizer.param_groups[0][\"lr\"]\n", " return optimizer\n" - ], - "id": "d3bc13596b867137" + ] }, { - "metadata": {}, "cell_type": "markdown", + "id": "7442d6e9fdde2433", + "metadata": {}, "source": [ "## Training Routine\n", "\n", @@ -359,14 +385,305 @@ "\n", "```python\n", "train_setting = \"non_robust\"\n" - ], - "id": "7442d6e9fdde2433" + ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], - "execution_count": null, + "execution_count": 6, + "id": "17b7575b29534e21", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using bfloat16 Automatic Mixed Precision (AMP)\n", + "Trainer already configured with model summary callbacks: []. Skipping setting a default `ModelSummary` callback.\n", + "GPU available: True (cuda), used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "HPU available: False, using: 0 HPUs\n", + "/mnt/deel/data/thibaut.boissin/envs/ortho/lib/python3.12/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /home/thibaut.boissin/projects/orthogonium/docs/notebooks/checkpoints exists and is not empty.\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n" + ] + }, + { + "data": { + "text/html": [ + "
┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━┓\n",
+       "┃     Name                              Type                         Params  Mode  ┃\n",
+       "┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━┩\n",
+       "│ 0  │ model                            │ Sequential                  │  1.9 M │ train │\n",
+       "│ 1  │ model.0                          │ ParametrizedRKOConv2d       │  3.3 K │ train │\n",
+       "│ 2  │ model.0.parametrizations         │ ModuleDict                  │  3.3 K │ train │\n",
+       "│ 3  │ model.0.parametrizations.weight  │ ParametrizationList         │  3.3 K │ train │\n",
+       "│ 4  │ model.1                          │ PrescaledAdditiveResidual   │  142 K │ train │\n",
+       "│ 5  │ model.1.fn                       │ Sequential                  │  142 K │ train │\n",
+       "│ 6  │ model.1.fn.0                     │ ParametrizedFastBlockConv2d │ 10.8 K │ train │\n",
+       "│ 7  │ model.1.fn.1                     │ MaxMin                      │      0 │ train │\n",
+       "│ 8  │ model.1.fn.2                     │ Identity                    │      0 │ train │\n",
+       "│ 9  │ model.1.fn.3                     │ ParametrizedRKOConv2d       │  131 K │ train │\n",
+       "│ 10 │ model.2                          │ PrescaledAdditiveResidual   │  142 K │ train │\n",
+       "│ 11 │ model.2.fn                       │ Sequential                  │  142 K │ train │\n",
+       "│ 12 │ model.2.fn.0                     │ ParametrizedFastBlockConv2d │ 10.8 K │ train │\n",
+       "│ 13 │ model.2.fn.1                     │ MaxMin                      │      0 │ train │\n",
+       "│ 14 │ model.2.fn.2                     │ Identity                    │      0 │ train │\n",
+       "│ 15 │ model.2.fn.3                     │ ParametrizedRKOConv2d       │  131 K │ train │\n",
+       "│ 16 │ model.3                          │ PrescaledAdditiveResidual   │  142 K │ train │\n",
+       "│ 17 │ model.3.fn                       │ Sequential                  │  142 K │ train │\n",
+       "│ 18 │ model.3.fn.0                     │ ParametrizedFastBlockConv2d │ 10.8 K │ train │\n",
+       "│ 19 │ model.3.fn.1                     │ MaxMin                      │      0 │ train │\n",
+       "│ 20 │ model.3.fn.2                     │ Identity                    │      0 │ train │\n",
+       "│ 21 │ model.3.fn.3                     │ ParametrizedRKOConv2d       │  131 K │ train │\n",
+       "│ 22 │ model.4                          │ PrescaledAdditiveResidual   │  142 K │ train │\n",
+       "│ 23 │ model.4.fn                       │ Sequential                  │  142 K │ train │\n",
+       "│ 24 │ model.4.fn.0                     │ ParametrizedFastBlockConv2d │ 10.8 K │ train │\n",
+       "│ 25 │ model.4.fn.1                     │ MaxMin                      │      0 │ train │\n",
+       "│ 26 │ model.4.fn.2                     │ Identity                    │      0 │ train │\n",
+       "│ 27 │ model.4.fn.3                     │ ParametrizedRKOConv2d       │  131 K │ train │\n",
+       "│ 28 │ model.5                          │ PrescaledAdditiveResidual   │  142 K │ train │\n",
+       "│ 29 │ model.5.fn                       │ Sequential                  │  142 K │ train │\n",
+       "│ 30 │ model.5.fn.0                     │ ParametrizedFastBlockConv2d │ 10.8 K │ train │\n",
+       "│ 31 │ model.5.fn.1                     │ MaxMin                      │      0 │ train │\n",
+       "│ 32 │ model.5.fn.2                     │ Identity                    │      0 │ train │\n",
+       "│ 33 │ model.5.fn.3                     │ ParametrizedRKOConv2d       │  131 K │ train │\n",
+       "│ 34 │ model.6                          │ PrescaledAdditiveResidual   │  142 K │ train │\n",
+       "│ 35 │ model.6.fn                       │ Sequential                  │  142 K │ train │\n",
+       "│ 36 │ model.6.fn.0                     │ ParametrizedFastBlockConv2d │ 10.8 K │ train │\n",
+       "│ 37 │ model.6.fn.1                     │ MaxMin                      │      0 │ train │\n",
+       "│ 38 │ model.6.fn.2                     │ Identity                    │      0 │ train │\n",
+       "│ 39 │ model.6.fn.3                     │ ParametrizedRKOConv2d       │  131 K │ train │\n",
+       "│ 40 │ model.7                          │ PrescaledAdditiveResidual   │  142 K │ train │\n",
+       "│ 41 │ model.7.fn                       │ Sequential                  │  142 K │ train │\n",
+       "│ 42 │ model.7.fn.0                     │ ParametrizedFastBlockConv2d │ 10.8 K │ train │\n",
+       "│ 43 │ model.7.fn.1                     │ MaxMin                      │      0 │ train │\n",
+       "│ 44 │ model.7.fn.2                     │ Identity                    │      0 │ train │\n",
+       "│ 45 │ model.7.fn.3                     │ ParametrizedRKOConv2d       │  131 K │ train │\n",
+       "│ 46 │ model.8                          │ PrescaledAdditiveResidual   │  142 K │ train │\n",
+       "│ 47 │ model.8.fn                       │ Sequential                  │  142 K │ train │\n",
+       "│ 48 │ model.8.fn.0                     │ ParametrizedFastBlockConv2d │ 10.8 K │ train │\n",
+       "│ 49 │ model.8.fn.1                     │ MaxMin                      │      0 │ train │\n",
+       "│ 50 │ model.8.fn.2                     │ Identity                    │      0 │ train │\n",
+       "│ 51 │ model.8.fn.3                     │ ParametrizedRKOConv2d       │  131 K │ train │\n",
+       "│ 52 │ model.9                          │ PrescaledAdditiveResidual   │  142 K │ train │\n",
+       "│ 53 │ model.9.fn                       │ Sequential                  │  142 K │ train │\n",
+       "│ 54 │ model.9.fn.0                     │ ParametrizedFastBlockConv2d │ 10.8 K │ train │\n",
+       "│ 55 │ model.9.fn.1                     │ MaxMin                      │      0 │ train │\n",
+       "│ 56 │ model.9.fn.2                     │ Identity                    │      0 │ train │\n",
+       "│ 57 │ model.9.fn.3                     │ ParametrizedRKOConv2d       │  131 K │ train │\n",
+       "│ 58 │ model.10                         │ PrescaledAdditiveResidual   │  142 K │ train │\n",
+       "│ 59 │ model.10.fn                      │ Sequential                  │  142 K │ train │\n",
+       "│ 60 │ model.10.fn.0                    │ ParametrizedFastBlockConv2d │ 10.8 K │ train │\n",
+       "│ 61 │ model.10.fn.1                    │ MaxMin                      │      0 │ train │\n",
+       "│ 62 │ model.10.fn.2                    │ Identity                    │      0 │ train │\n",
+       "│ 63 │ model.10.fn.3                    │ ParametrizedRKOConv2d       │  131 K │ train │\n",
+       "│ 64 │ model.11                         │ PrescaledAdditiveResidual   │  142 K │ train │\n",
+       "│ 65 │ model.11.fn                      │ Sequential                  │  142 K │ train │\n",
+       "│ 66 │ model.11.fn.0                    │ ParametrizedFastBlockConv2d │ 10.8 K │ train │\n",
+       "│ 67 │ model.11.fn.1                    │ MaxMin                      │      0 │ train │\n",
+       "│ 68 │ model.11.fn.2                    │ Identity                    │      0 │ train │\n",
+       "│ 69 │ model.11.fn.3                    │ ParametrizedRKOConv2d       │  131 K │ train │\n",
+       "│ 70 │ model.12                         │ PrescaledAdditiveResidual   │  142 K │ train │\n",
+       "│ 71 │ model.12.fn                      │ Sequential                  │  142 K │ train │\n",
+       "│ 72 │ model.12.fn.0                    │ ParametrizedFastBlockConv2d │ 10.8 K │ train │\n",
+       "│ 73 │ model.12.fn.1                    │ MaxMin                      │      0 │ train │\n",
+       "│ 74 │ model.12.fn.2                    │ Identity                    │      0 │ train │\n",
+       "│ 75 │ model.12.fn.3                    │ ParametrizedRKOConv2d       │  131 K │ train │\n",
+       "│ 76 │ model.13                         │ ParametrizedRKOConv2d       │  196 K │ train │\n",
+       "│ 77 │ model.13.parametrizations        │ ModuleDict                  │  196 K │ train │\n",
+       "│ 78 │ model.13.parametrizations.weight │ ParametrizationList         │  196 K │ train │\n",
+       "│ 79 │ model.14                         │ Flatten                     │      0 │ train │\n",
+       "│ 80 │ model.15                         │ MaxMin                      │      0 │ train │\n",
+       "│ 81 │ model.16                         │ ParametrizedOrthoLinear     │  2.8 K │ train │\n",
+       "│ 82 │ model.16.parametrizations        │ ModuleDict                  │  2.8 K │ train │\n",
+       "│ 83 │ model.16.parametrizations.weight │ ParametrizationList         │  2.8 K │ train │\n",
+       "│ 84 │ criteria                         │ CosineLoss                  │      0 │ train │\n",
+       "│ 85 │ train_acc                        │ MulticlassAccuracy          │      0 │ train │\n",
+       "│ 86 │ val_acc                          │ MulticlassAccuracy          │      0 │ train │\n",
+       "│ 87 │ train_vra                        │ MeanMetric                  │      0 │ train │\n",
+       "│ 88 │ val_vra                          │ MeanMetric                  │      0 │ train │\n",
+       "└────┴──────────────────────────────────┴─────────────────────────────┴────────┴───────┘\n",
+       "
\n" + ], + "text/plain": [ + "┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━┓\n", + "┃\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mName \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mType \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mParams\u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mMode \u001b[0m\u001b[1;35m \u001b[0m┃\n", + "┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━┩\n", + "│\u001b[2m \u001b[0m\u001b[2m0 \u001b[0m\u001b[2m \u001b[0m│ model │ Sequential │ 1.9 M │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m1 \u001b[0m\u001b[2m \u001b[0m│ model.0 │ ParametrizedRKOConv2d │ 3.3 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m2 \u001b[0m\u001b[2m \u001b[0m│ model.0.parametrizations │ ModuleDict │ 3.3 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m3 \u001b[0m\u001b[2m \u001b[0m│ model.0.parametrizations.weight │ ParametrizationList │ 3.3 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m4 \u001b[0m\u001b[2m \u001b[0m│ model.1 │ PrescaledAdditiveResidual │ 142 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m5 \u001b[0m\u001b[2m \u001b[0m│ model.1.fn │ Sequential │ 142 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m6 \u001b[0m\u001b[2m \u001b[0m│ model.1.fn.0 │ ParametrizedFastBlockConv2d │ 10.8 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m7 \u001b[0m\u001b[2m \u001b[0m│ model.1.fn.1 │ MaxMin │ 0 │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m8 \u001b[0m\u001b[2m \u001b[0m│ model.1.fn.2 │ Identity │ 0 │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m9 \u001b[0m\u001b[2m \u001b[0m│ model.1.fn.3 │ ParametrizedRKOConv2d │ 131 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m10\u001b[0m\u001b[2m \u001b[0m│ model.2 │ PrescaledAdditiveResidual │ 142 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m11\u001b[0m\u001b[2m \u001b[0m│ model.2.fn │ Sequential │ 142 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m12\u001b[0m\u001b[2m \u001b[0m│ model.2.fn.0 │ ParametrizedFastBlockConv2d │ 10.8 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m13\u001b[0m\u001b[2m \u001b[0m│ model.2.fn.1 │ MaxMin │ 0 │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m14\u001b[0m\u001b[2m \u001b[0m│ model.2.fn.2 │ Identity │ 0 │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m15\u001b[0m\u001b[2m \u001b[0m│ model.2.fn.3 │ ParametrizedRKOConv2d │ 131 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m16\u001b[0m\u001b[2m \u001b[0m│ model.3 │ PrescaledAdditiveResidual │ 142 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m17\u001b[0m\u001b[2m \u001b[0m│ model.3.fn │ Sequential │ 142 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m18\u001b[0m\u001b[2m \u001b[0m│ model.3.fn.0 │ ParametrizedFastBlockConv2d │ 10.8 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m19\u001b[0m\u001b[2m \u001b[0m│ model.3.fn.1 │ MaxMin │ 0 │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m20\u001b[0m\u001b[2m \u001b[0m│ model.3.fn.2 │ Identity │ 0 │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m21\u001b[0m\u001b[2m \u001b[0m│ model.3.fn.3 │ ParametrizedRKOConv2d │ 131 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m22\u001b[0m\u001b[2m \u001b[0m│ model.4 │ PrescaledAdditiveResidual │ 142 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m23\u001b[0m\u001b[2m \u001b[0m│ model.4.fn │ Sequential │ 142 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m24\u001b[0m\u001b[2m \u001b[0m│ model.4.fn.0 │ ParametrizedFastBlockConv2d │ 10.8 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m25\u001b[0m\u001b[2m \u001b[0m│ model.4.fn.1 │ MaxMin │ 0 │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m26\u001b[0m\u001b[2m \u001b[0m│ model.4.fn.2 │ Identity │ 0 │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m27\u001b[0m\u001b[2m \u001b[0m│ model.4.fn.3 │ ParametrizedRKOConv2d │ 131 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m28\u001b[0m\u001b[2m \u001b[0m│ model.5 │ PrescaledAdditiveResidual │ 142 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m29\u001b[0m\u001b[2m \u001b[0m│ model.5.fn │ Sequential │ 142 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m30\u001b[0m\u001b[2m \u001b[0m│ model.5.fn.0 │ ParametrizedFastBlockConv2d │ 10.8 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m31\u001b[0m\u001b[2m \u001b[0m│ model.5.fn.1 │ MaxMin │ 0 │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m32\u001b[0m\u001b[2m \u001b[0m│ model.5.fn.2 │ Identity │ 0 │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m33\u001b[0m\u001b[2m \u001b[0m│ model.5.fn.3 │ ParametrizedRKOConv2d │ 131 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m34\u001b[0m\u001b[2m \u001b[0m│ model.6 │ PrescaledAdditiveResidual │ 142 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m35\u001b[0m\u001b[2m \u001b[0m│ model.6.fn │ Sequential │ 142 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m36\u001b[0m\u001b[2m \u001b[0m│ model.6.fn.0 │ ParametrizedFastBlockConv2d │ 10.8 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m37\u001b[0m\u001b[2m \u001b[0m│ model.6.fn.1 │ MaxMin │ 0 │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m38\u001b[0m\u001b[2m \u001b[0m│ model.6.fn.2 │ Identity │ 0 │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m39\u001b[0m\u001b[2m \u001b[0m│ model.6.fn.3 │ ParametrizedRKOConv2d │ 131 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m40\u001b[0m\u001b[2m \u001b[0m│ model.7 │ PrescaledAdditiveResidual │ 142 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m41\u001b[0m\u001b[2m \u001b[0m│ model.7.fn │ Sequential │ 142 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m42\u001b[0m\u001b[2m \u001b[0m│ model.7.fn.0 │ ParametrizedFastBlockConv2d │ 10.8 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m43\u001b[0m\u001b[2m \u001b[0m│ model.7.fn.1 │ MaxMin │ 0 │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m44\u001b[0m\u001b[2m \u001b[0m│ model.7.fn.2 │ Identity │ 0 │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m45\u001b[0m\u001b[2m \u001b[0m│ model.7.fn.3 │ ParametrizedRKOConv2d │ 131 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m46\u001b[0m\u001b[2m \u001b[0m│ model.8 │ PrescaledAdditiveResidual │ 142 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m47\u001b[0m\u001b[2m \u001b[0m│ model.8.fn │ Sequential │ 142 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m48\u001b[0m\u001b[2m \u001b[0m│ model.8.fn.0 │ ParametrizedFastBlockConv2d │ 10.8 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m49\u001b[0m\u001b[2m \u001b[0m│ model.8.fn.1 │ MaxMin │ 0 │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m50\u001b[0m\u001b[2m \u001b[0m│ model.8.fn.2 │ Identity │ 0 │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m51\u001b[0m\u001b[2m \u001b[0m│ model.8.fn.3 │ ParametrizedRKOConv2d │ 131 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m52\u001b[0m\u001b[2m \u001b[0m│ model.9 │ PrescaledAdditiveResidual │ 142 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m53\u001b[0m\u001b[2m \u001b[0m│ model.9.fn │ Sequential │ 142 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m54\u001b[0m\u001b[2m \u001b[0m│ model.9.fn.0 │ ParametrizedFastBlockConv2d │ 10.8 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m55\u001b[0m\u001b[2m \u001b[0m│ model.9.fn.1 │ MaxMin │ 0 │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m56\u001b[0m\u001b[2m \u001b[0m│ model.9.fn.2 │ Identity │ 0 │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m57\u001b[0m\u001b[2m \u001b[0m│ model.9.fn.3 │ ParametrizedRKOConv2d │ 131 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m58\u001b[0m\u001b[2m \u001b[0m│ model.10 │ PrescaledAdditiveResidual │ 142 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m59\u001b[0m\u001b[2m \u001b[0m│ model.10.fn │ Sequential │ 142 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m60\u001b[0m\u001b[2m \u001b[0m│ model.10.fn.0 │ ParametrizedFastBlockConv2d │ 10.8 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m61\u001b[0m\u001b[2m \u001b[0m│ model.10.fn.1 │ MaxMin │ 0 │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m62\u001b[0m\u001b[2m \u001b[0m│ model.10.fn.2 │ Identity │ 0 │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m63\u001b[0m\u001b[2m \u001b[0m│ model.10.fn.3 │ ParametrizedRKOConv2d │ 131 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m64\u001b[0m\u001b[2m \u001b[0m│ model.11 │ PrescaledAdditiveResidual │ 142 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m65\u001b[0m\u001b[2m \u001b[0m│ model.11.fn │ Sequential │ 142 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m66\u001b[0m\u001b[2m \u001b[0m│ model.11.fn.0 │ ParametrizedFastBlockConv2d │ 10.8 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m67\u001b[0m\u001b[2m \u001b[0m│ model.11.fn.1 │ MaxMin │ 0 │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m68\u001b[0m\u001b[2m \u001b[0m│ model.11.fn.2 │ Identity │ 0 │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m69\u001b[0m\u001b[2m \u001b[0m│ model.11.fn.3 │ ParametrizedRKOConv2d │ 131 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m70\u001b[0m\u001b[2m \u001b[0m│ model.12 │ PrescaledAdditiveResidual │ 142 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m71\u001b[0m\u001b[2m \u001b[0m│ model.12.fn │ Sequential │ 142 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m72\u001b[0m\u001b[2m \u001b[0m│ model.12.fn.0 │ ParametrizedFastBlockConv2d │ 10.8 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m73\u001b[0m\u001b[2m \u001b[0m│ model.12.fn.1 │ MaxMin │ 0 │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m74\u001b[0m\u001b[2m \u001b[0m│ model.12.fn.2 │ Identity │ 0 │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m75\u001b[0m\u001b[2m \u001b[0m│ model.12.fn.3 │ ParametrizedRKOConv2d │ 131 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m76\u001b[0m\u001b[2m \u001b[0m│ model.13 │ ParametrizedRKOConv2d │ 196 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m77\u001b[0m\u001b[2m \u001b[0m│ model.13.parametrizations │ ModuleDict │ 196 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m78\u001b[0m\u001b[2m \u001b[0m│ model.13.parametrizations.weight │ ParametrizationList │ 196 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m79\u001b[0m\u001b[2m \u001b[0m│ model.14 │ Flatten │ 0 │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m80\u001b[0m\u001b[2m \u001b[0m│ model.15 │ MaxMin │ 0 │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m81\u001b[0m\u001b[2m \u001b[0m│ model.16 │ ParametrizedOrthoLinear │ 2.8 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m82\u001b[0m\u001b[2m \u001b[0m│ model.16.parametrizations │ ModuleDict │ 2.8 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m83\u001b[0m\u001b[2m \u001b[0m│ model.16.parametrizations.weight │ ParametrizationList │ 2.8 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m84\u001b[0m\u001b[2m \u001b[0m│ criteria │ CosineLoss │ 0 │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m85\u001b[0m\u001b[2m \u001b[0m│ train_acc │ MulticlassAccuracy │ 0 │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m86\u001b[0m\u001b[2m \u001b[0m│ val_acc │ MulticlassAccuracy │ 0 │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m87\u001b[0m\u001b[2m \u001b[0m│ train_vra │ MeanMetric │ 0 │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m88\u001b[0m\u001b[2m \u001b[0m│ val_vra │ MeanMetric │ 0 │ train │\n", + "└────┴──────────────────────────────────┴─────────────────────────────┴────────┴───────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Trainable params: 1.8 M                                                                                            \n",
+       "Non-trainable params: 130 K                                                                                        \n",
+       "Total params: 1.9 M                                                                                                \n",
+       "Total estimated model params size (MB): 7                                                                          \n",
+       "Modules in train mode: 352                                                                                         \n",
+       "Modules in eval mode: 0                                                                                            \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1mTrainable params\u001b[0m: 1.8 M \n", + "\u001b[1mNon-trainable params\u001b[0m: 130 K \n", + "\u001b[1mTotal params\u001b[0m: 1.9 M \n", + "\u001b[1mTotal estimated model params size (MB)\u001b[0m: 7 \n", + "\u001b[1mModules in train mode\u001b[0m: 352 \n", + "\u001b[1mModules in eval mode\u001b[0m: 0 \n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c71c76c88dbc40cdb2b291fed40e914d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Files already downloaded and verified\n",
+       "
\n" + ], + "text/plain": [ + "Files already downloaded and verified\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Files already downloaded and verified\n",
+       "
\n" + ], + "text/plain": [ + "Files already downloaded and verified\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "`Trainer.fit` stopped: `max_epochs=60` reached.\n" + ] + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
    "source": [
     "# Select the training setting manually.\n",
     "train_setting = \"non_robust\"  # Options: \"non_robust\", \"mildly_robust\", or \"robust\"\n",
@@ -392,62 +709,59 @@
     "\n",
     "trainer = Trainer(\n",
     "    accelerator=\"gpu\",\n",
-    "    devices=-1,             # Use all available GPUs\n",
+    "    devices=1,             # Use 1 GPU set to -1 for all GPUs\n",
     "    num_nodes=1,            # Number of nodes\n",
-    "    strategy=\"ddp\",         # Distributed strategy\n",
+    "    # strategy=\"ddp_spawn\",         # Distributed strategy\n",
     "    precision=\"bf16-mixed\", # Mixed precision training\n",
     "    max_epochs=current_setting[\"epochs\"],\n",
-    "    enable_model_summary=True,\n",
+    "    # enable_model_summary=True,\n",
     "    # logger=[wandb_logger],  # Uncomment to enable Wandb logging\n",
     "    logger=False,\n",
     "    callbacks=[\n",
     "        # You can add callbacks here, e.g.:\n",
     "        # pl_callbacks.LearningRateFinder(max_lr=0.05),\n",
     "        # checkpoint_callback,\n",
+    "        RichModelSummary(max_depth=4),\n",
+    "        RichProgressBar(),\n",
     "    ],\n",
     ")\n",
-    "\n",
-    "# Print a summary of the model\n",
-    "summary(classification_module, input_size=(1, 3, 32, 32))\n",
-    "\n",
     "# Start training\n",
     "trainer.fit(classification_module, data_module)\n",
     "\n",
     "# Optionally, you can save the trained model afterwards:\n",
     "# torch.save(classification_module.model.state_dict(), \"single_stage.pth\")\n"
-   ],
-   "id": "17b7575b29534e21"
+   ]
   },
   {
-   "metadata": {},
    "cell_type": "markdown",
+   "id": "399725e628d1393a",
+   "metadata": {},
    "source": [
     "## Next Steps\n",
     "\n",
     "- **Model Evaluation:** You can add a new cell to perform model evaluation or predictions.\n",
     "- **Logging and Checkpoints:** To enable model logging or checkpoint saving, uncomment the corresponding lines and configure as needed.\n",
     "- **Experiment with Settings:** Change the `train_setting` variable to `\"mildly_robust\"` or `\"robust\"` to experiment with other training configurations.\n"
-   ],
-   "id": "399725e628d1393a"
+   ]
   }
  ],
  "metadata": {
   "kernelspec": {
-   "display_name": "Python 3",
+   "display_name": "Python 3 (ipykernel)",
    "language": "python",
    "name": "python3"
   },
   "language_info": {
    "codemirror_mode": {
     "name": "ipython",
-    "version": 2
+    "version": 3
    },
    "file_extension": ".py",
    "mimetype": "text/x-python",
    "name": "python",
    "nbconvert_exporter": "python",
-   "pygments_lexer": "ipython2",
-   "version": "2.7.6"
+   "pygments_lexer": "ipython3",
+   "version": "3.12.7"
   }
  },
  "nbformat": 4,

From 9d83fa5fa31e09100c3e75cbae6579161a8c3595 Mon Sep 17 00:00:00 2001
From: Thibaut Boissin 
Date: Fri, 7 Feb 2025 22:33:25 +0100
Subject: [PATCH 07/13] first attempt to add demo notebook

---
 .../notebooks/demo_cifar_classification.ipynb | 1224 +++++++++++++----
 mkdocs.yml                                    |    4 +-
 2 files changed, 951 insertions(+), 277 deletions(-)

diff --git a/docs/notebooks/demo_cifar_classification.ipynb b/docs/notebooks/demo_cifar_classification.ipynb
index ae90933..b67cbaa 100644
--- a/docs/notebooks/demo_cifar_classification.ipynb
+++ b/docs/notebooks/demo_cifar_classification.ipynb
@@ -69,8 +69,7 @@
     "from lightning.pytorch import callbacks as pl_callbacks\n",
     "from lightning.pytorch import Trainer\n",
     "from lightning.pytorch import LightningModule, LightningDataModule\n",
-    "from lightning.pytorch.callbacks import RichProgressBar\n",
-    "from lightning.pytorch.callbacks import RichModelSummary\n",
+    "from torchinfo import summary\n",
     "# from lightning.pytorch.loggers import WandbLogger  # Uncomment if using Wandb logging\n",
     "from torch.nn import AvgPool2d\n",
     "from torch.utils.data import DataLoader\n",
@@ -135,7 +134,7 @@
     "            LossXent,\n",
     "            n_classes=10,\n",
     "            offset=(math.sqrt(2) / 0.1983) * (36 / 255),  # aiming for 36/255 verified robust accuracy\n",
-    "            temperature=0.25,\n",
+    "            temperature=0.125,\n",
     "        ),\n",
     "        \"epochs\": 150,\n",
     "    },\n",
@@ -368,7 +367,7 @@
     "\n",
     "    def configure_optimizers(self):\n",
     "        # Setup the Adam optimizer with schedule-free updates.\n",
-    "        optimizer = schedulefree.AdamWScheduleFree(self.parameters(), lr=1e-2, weight_decay=0)\n",
+    "        optimizer = schedulefree.AdamWScheduleFree(self.parameters(), lr=5e-3, weight_decay=0)\n",
     "        optimizer.train()\n",
     "        self.hparams[\"lr\"] = optimizer.param_groups[0][\"lr\"]\n",
     "        return optimizer\n"
@@ -398,229 +397,114 @@
      "output_type": "stream",
      "text": [
       "Using bfloat16 Automatic Mixed Precision (AMP)\n",
-      "Trainer already configured with model summary callbacks: []. Skipping setting a default `ModelSummary` callback.\n",
       "GPU available: True (cuda), used: True\n",
       "TPU available: False, using: 0 TPU cores\n",
       "HPU available: False, using: 0 HPUs\n",
       "/mnt/deel/data/thibaut.boissin/envs/ortho/lib/python3.12/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /home/thibaut.boissin/projects/orthogonium/docs/notebooks/checkpoints exists and is not empty.\n",
-      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"
+      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n",
+      "\n",
+      "  | Name      | Type               | Params | Mode \n",
+      "---------------------------------------------------------\n",
+      "0 | model     | Sequential         | 1.9 M  | train\n",
+      "1 | criteria  | CosineLoss         | 0      | train\n",
+      "2 | train_acc | MulticlassAccuracy | 0      | train\n",
+      "3 | val_acc   | MulticlassAccuracy | 0      | train\n",
+      "4 | train_vra | MeanMetric         | 0      | train\n",
+      "5 | val_vra   | MeanMetric         | 0      | train\n",
+      "---------------------------------------------------------\n",
+      "1.8 M     Trainable params\n",
+      "130 K     Non-trainable params\n",
+      "1.9 M     Total params\n",
+      "7.657     Total estimated model params size (MB)\n",
+      "352       Modules in train mode\n",
+      "0         Modules in eval mode\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "==================================================================================================================================\n",
+      "Layer (type:depth-idx)                                                           Output Shape              Param #\n",
+      "==================================================================================================================================\n",
+      "ClassificationLightningModule                                                    [1, 10]                   --\n",
+      "├─Sequential: 1-1                                                                [1, 10]                   --\n",
+      "│    └─ParametrizedRKOConv2d: 2-1                                                [1, 256, 16, 16]          --\n",
+      "│    │    └─ModuleDict: 3-1                                                      --                        3,340\n",
+      "│    └─AdditiveResidual: 2-2                                                     [1, 256, 16, 16]          1\n",
+      "│    │    └─Sequential: 3-2                                                      [1, 256, 16, 16]          142,592\n",
+      "│    └─AdditiveResidual: 2-3                                                     [1, 256, 16, 16]          1\n",
+      "│    │    └─Sequential: 3-3                                                      [1, 256, 16, 16]          142,592\n",
+      "│    └─AdditiveResidual: 2-4                                                     [1, 256, 16, 16]          1\n",
+      "│    │    └─Sequential: 3-4                                                      [1, 256, 16, 16]          142,592\n",
+      "│    └─AdditiveResidual: 2-5                                                     [1, 256, 16, 16]          1\n",
+      "│    │    └─Sequential: 3-5                                                      [1, 256, 16, 16]          142,592\n",
+      "│    └─AdditiveResidual: 2-6                                                     [1, 256, 16, 16]          1\n",
+      "│    │    └─Sequential: 3-6                                                      [1, 256, 16, 16]          142,592\n",
+      "│    └─AdditiveResidual: 2-7                                                     [1, 256, 16, 16]          1\n",
+      "│    │    └─Sequential: 3-7                                                      [1, 256, 16, 16]          142,592\n",
+      "│    └─AdditiveResidual: 2-8                                                     [1, 256, 16, 16]          1\n",
+      "│    │    └─Sequential: 3-8                                                      [1, 256, 16, 16]          142,592\n",
+      "│    └─AdditiveResidual: 2-9                                                     [1, 256, 16, 16]          1\n",
+      "│    │    └─Sequential: 3-9                                                      [1, 256, 16, 16]          142,592\n",
+      "│    └─AdditiveResidual: 2-10                                                    [1, 256, 16, 16]          1\n",
+      "│    │    └─Sequential: 3-10                                                     [1, 256, 16, 16]          142,592\n",
+      "│    └─AdditiveResidual: 2-11                                                    [1, 256, 16, 16]          1\n",
+      "│    │    └─Sequential: 3-11                                                     [1, 256, 16, 16]          142,592\n",
+      "│    └─AdditiveResidual: 2-12                                                    [1, 256, 16, 16]          1\n",
+      "│    │    └─Sequential: 3-12                                                     [1, 256, 16, 16]          142,592\n",
+      "│    └─AdditiveResidual: 2-13                                                    [1, 256, 16, 16]          1\n",
+      "│    │    └─Sequential: 3-13                                                     [1, 256, 16, 16]          142,592\n",
+      "│    └─ParametrizedRKOConv2d: 2-14                                               [1, 256, 1, 1]            --\n",
+      "│    │    └─ModuleDict: 3-14                                                     --                        196,864\n",
+      "│    └─Flatten: 2-15                                                             [1, 256]                  --\n",
+      "│    └─MaxMin: 2-16                                                              [1, 256]                  --\n",
+      "│    └─ParametrizedOrthoLinear: 2-17                                             [1, 10]                   --\n",
+      "│    │    └─ModuleDict: 3-15                                                     --                        2,826\n",
+      "==================================================================================================================================\n",
+      "Total params: 1,914,146\n",
+      "Trainable params: 1,783,308\n",
+      "Non-trainable params: 130,838\n",
+      "Total mult-adds (Units.MEGABYTES): 0\n",
+      "==================================================================================================================================\n",
+      "Input size (MB): 0.01\n",
+      "Forward/backward pass size (MB): 0.00\n",
+      "Params size (MB): 0.00\n",
+      "Estimated Total Size (MB): 0.01\n",
+      "==================================================================================================================================\n"
      ]
     },
     {
      "data": {
-      "text/html": [
-       "
┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━┓\n",
-       "┃     Name                              Type                         Params  Mode  ┃\n",
-       "┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━┩\n",
-       "│ 0  │ model                            │ Sequential                  │  1.9 M │ train │\n",
-       "│ 1  │ model.0                          │ ParametrizedRKOConv2d       │  3.3 K │ train │\n",
-       "│ 2  │ model.0.parametrizations         │ ModuleDict                  │  3.3 K │ train │\n",
-       "│ 3  │ model.0.parametrizations.weight  │ ParametrizationList         │  3.3 K │ train │\n",
-       "│ 4  │ model.1                          │ PrescaledAdditiveResidual   │  142 K │ train │\n",
-       "│ 5  │ model.1.fn                       │ Sequential                  │  142 K │ train │\n",
-       "│ 6  │ model.1.fn.0                     │ ParametrizedFastBlockConv2d │ 10.8 K │ train │\n",
-       "│ 7  │ model.1.fn.1                     │ MaxMin                      │      0 │ train │\n",
-       "│ 8  │ model.1.fn.2                     │ Identity                    │      0 │ train │\n",
-       "│ 9  │ model.1.fn.3                     │ ParametrizedRKOConv2d       │  131 K │ train │\n",
-       "│ 10 │ model.2                          │ PrescaledAdditiveResidual   │  142 K │ train │\n",
-       "│ 11 │ model.2.fn                       │ Sequential                  │  142 K │ train │\n",
-       "│ 12 │ model.2.fn.0                     │ ParametrizedFastBlockConv2d │ 10.8 K │ train │\n",
-       "│ 13 │ model.2.fn.1                     │ MaxMin                      │      0 │ train │\n",
-       "│ 14 │ model.2.fn.2                     │ Identity                    │      0 │ train │\n",
-       "│ 15 │ model.2.fn.3                     │ ParametrizedRKOConv2d       │  131 K │ train │\n",
-       "│ 16 │ model.3                          │ PrescaledAdditiveResidual   │  142 K │ train │\n",
-       "│ 17 │ model.3.fn                       │ Sequential                  │  142 K │ train │\n",
-       "│ 18 │ model.3.fn.0                     │ ParametrizedFastBlockConv2d │ 10.8 K │ train │\n",
-       "│ 19 │ model.3.fn.1                     │ MaxMin                      │      0 │ train │\n",
-       "│ 20 │ model.3.fn.2                     │ Identity                    │      0 │ train │\n",
-       "│ 21 │ model.3.fn.3                     │ ParametrizedRKOConv2d       │  131 K │ train │\n",
-       "│ 22 │ model.4                          │ PrescaledAdditiveResidual   │  142 K │ train │\n",
-       "│ 23 │ model.4.fn                       │ Sequential                  │  142 K │ train │\n",
-       "│ 24 │ model.4.fn.0                     │ ParametrizedFastBlockConv2d │ 10.8 K │ train │\n",
-       "│ 25 │ model.4.fn.1                     │ MaxMin                      │      0 │ train │\n",
-       "│ 26 │ model.4.fn.2                     │ Identity                    │      0 │ train │\n",
-       "│ 27 │ model.4.fn.3                     │ ParametrizedRKOConv2d       │  131 K │ train │\n",
-       "│ 28 │ model.5                          │ PrescaledAdditiveResidual   │  142 K │ train │\n",
-       "│ 29 │ model.5.fn                       │ Sequential                  │  142 K │ train │\n",
-       "│ 30 │ model.5.fn.0                     │ ParametrizedFastBlockConv2d │ 10.8 K │ train │\n",
-       "│ 31 │ model.5.fn.1                     │ MaxMin                      │      0 │ train │\n",
-       "│ 32 │ model.5.fn.2                     │ Identity                    │      0 │ train │\n",
-       "│ 33 │ model.5.fn.3                     │ ParametrizedRKOConv2d       │  131 K │ train │\n",
-       "│ 34 │ model.6                          │ PrescaledAdditiveResidual   │  142 K │ train │\n",
-       "│ 35 │ model.6.fn                       │ Sequential                  │  142 K │ train │\n",
-       "│ 36 │ model.6.fn.0                     │ ParametrizedFastBlockConv2d │ 10.8 K │ train │\n",
-       "│ 37 │ model.6.fn.1                     │ MaxMin                      │      0 │ train │\n",
-       "│ 38 │ model.6.fn.2                     │ Identity                    │      0 │ train │\n",
-       "│ 39 │ model.6.fn.3                     │ ParametrizedRKOConv2d       │  131 K │ train │\n",
-       "│ 40 │ model.7                          │ PrescaledAdditiveResidual   │  142 K │ train │\n",
-       "│ 41 │ model.7.fn                       │ Sequential                  │  142 K │ train │\n",
-       "│ 42 │ model.7.fn.0                     │ ParametrizedFastBlockConv2d │ 10.8 K │ train │\n",
-       "│ 43 │ model.7.fn.1                     │ MaxMin                      │      0 │ train │\n",
-       "│ 44 │ model.7.fn.2                     │ Identity                    │      0 │ train │\n",
-       "│ 45 │ model.7.fn.3                     │ ParametrizedRKOConv2d       │  131 K │ train │\n",
-       "│ 46 │ model.8                          │ PrescaledAdditiveResidual   │  142 K │ train │\n",
-       "│ 47 │ model.8.fn                       │ Sequential                  │  142 K │ train │\n",
-       "│ 48 │ model.8.fn.0                     │ ParametrizedFastBlockConv2d │ 10.8 K │ train │\n",
-       "│ 49 │ model.8.fn.1                     │ MaxMin                      │      0 │ train │\n",
-       "│ 50 │ model.8.fn.2                     │ Identity                    │      0 │ train │\n",
-       "│ 51 │ model.8.fn.3                     │ ParametrizedRKOConv2d       │  131 K │ train │\n",
-       "│ 52 │ model.9                          │ PrescaledAdditiveResidual   │  142 K │ train │\n",
-       "│ 53 │ model.9.fn                       │ Sequential                  │  142 K │ train │\n",
-       "│ 54 │ model.9.fn.0                     │ ParametrizedFastBlockConv2d │ 10.8 K │ train │\n",
-       "│ 55 │ model.9.fn.1                     │ MaxMin                      │      0 │ train │\n",
-       "│ 56 │ model.9.fn.2                     │ Identity                    │      0 │ train │\n",
-       "│ 57 │ model.9.fn.3                     │ ParametrizedRKOConv2d       │  131 K │ train │\n",
-       "│ 58 │ model.10                         │ PrescaledAdditiveResidual   │  142 K │ train │\n",
-       "│ 59 │ model.10.fn                      │ Sequential                  │  142 K │ train │\n",
-       "│ 60 │ model.10.fn.0                    │ ParametrizedFastBlockConv2d │ 10.8 K │ train │\n",
-       "│ 61 │ model.10.fn.1                    │ MaxMin                      │      0 │ train │\n",
-       "│ 62 │ model.10.fn.2                    │ Identity                    │      0 │ train │\n",
-       "│ 63 │ model.10.fn.3                    │ ParametrizedRKOConv2d       │  131 K │ train │\n",
-       "│ 64 │ model.11                         │ PrescaledAdditiveResidual   │  142 K │ train │\n",
-       "│ 65 │ model.11.fn                      │ Sequential                  │  142 K │ train │\n",
-       "│ 66 │ model.11.fn.0                    │ ParametrizedFastBlockConv2d │ 10.8 K │ train │\n",
-       "│ 67 │ model.11.fn.1                    │ MaxMin                      │      0 │ train │\n",
-       "│ 68 │ model.11.fn.2                    │ Identity                    │      0 │ train │\n",
-       "│ 69 │ model.11.fn.3                    │ ParametrizedRKOConv2d       │  131 K │ train │\n",
-       "│ 70 │ model.12                         │ PrescaledAdditiveResidual   │  142 K │ train │\n",
-       "│ 71 │ model.12.fn                      │ Sequential                  │  142 K │ train │\n",
-       "│ 72 │ model.12.fn.0                    │ ParametrizedFastBlockConv2d │ 10.8 K │ train │\n",
-       "│ 73 │ model.12.fn.1                    │ MaxMin                      │      0 │ train │\n",
-       "│ 74 │ model.12.fn.2                    │ Identity                    │      0 │ train │\n",
-       "│ 75 │ model.12.fn.3                    │ ParametrizedRKOConv2d       │  131 K │ train │\n",
-       "│ 76 │ model.13                         │ ParametrizedRKOConv2d       │  196 K │ train │\n",
-       "│ 77 │ model.13.parametrizations        │ ModuleDict                  │  196 K │ train │\n",
-       "│ 78 │ model.13.parametrizations.weight │ ParametrizationList         │  196 K │ train │\n",
-       "│ 79 │ model.14                         │ Flatten                     │      0 │ train │\n",
-       "│ 80 │ model.15                         │ MaxMin                      │      0 │ train │\n",
-       "│ 81 │ model.16                         │ ParametrizedOrthoLinear     │  2.8 K │ train │\n",
-       "│ 82 │ model.16.parametrizations        │ ModuleDict                  │  2.8 K │ train │\n",
-       "│ 83 │ model.16.parametrizations.weight │ ParametrizationList         │  2.8 K │ train │\n",
-       "│ 84 │ criteria                         │ CosineLoss                  │      0 │ train │\n",
-       "│ 85 │ train_acc                        │ MulticlassAccuracy          │      0 │ train │\n",
-       "│ 86 │ val_acc                          │ MulticlassAccuracy          │      0 │ train │\n",
-       "│ 87 │ train_vra                        │ MeanMetric                  │      0 │ train │\n",
-       "│ 88 │ val_vra                          │ MeanMetric                  │      0 │ train │\n",
-       "└────┴──────────────────────────────────┴─────────────────────────────┴────────┴───────┘\n",
-       "
\n" - ], - "text/plain": [ - "┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━┓\n", - "┃\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mName \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mType \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mParams\u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mMode \u001b[0m\u001b[1;35m \u001b[0m┃\n", - "┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━┩\n", - "│\u001b[2m \u001b[0m\u001b[2m0 \u001b[0m\u001b[2m \u001b[0m│ model │ Sequential │ 1.9 M │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m1 \u001b[0m\u001b[2m \u001b[0m│ model.0 │ ParametrizedRKOConv2d │ 3.3 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m2 \u001b[0m\u001b[2m \u001b[0m│ model.0.parametrizations │ ModuleDict │ 3.3 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m3 \u001b[0m\u001b[2m \u001b[0m│ model.0.parametrizations.weight │ ParametrizationList │ 3.3 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m4 \u001b[0m\u001b[2m \u001b[0m│ model.1 │ PrescaledAdditiveResidual │ 142 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m5 \u001b[0m\u001b[2m \u001b[0m│ model.1.fn │ Sequential │ 142 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m6 \u001b[0m\u001b[2m \u001b[0m│ model.1.fn.0 │ ParametrizedFastBlockConv2d │ 10.8 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m7 \u001b[0m\u001b[2m \u001b[0m│ model.1.fn.1 │ MaxMin │ 0 │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m8 \u001b[0m\u001b[2m \u001b[0m│ model.1.fn.2 │ Identity │ 0 │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m9 \u001b[0m\u001b[2m \u001b[0m│ model.1.fn.3 │ ParametrizedRKOConv2d │ 131 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m10\u001b[0m\u001b[2m \u001b[0m│ model.2 │ PrescaledAdditiveResidual │ 142 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m11\u001b[0m\u001b[2m \u001b[0m│ model.2.fn │ Sequential │ 142 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m12\u001b[0m\u001b[2m \u001b[0m│ model.2.fn.0 │ ParametrizedFastBlockConv2d │ 10.8 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m13\u001b[0m\u001b[2m \u001b[0m│ model.2.fn.1 │ MaxMin │ 0 │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m14\u001b[0m\u001b[2m \u001b[0m│ model.2.fn.2 │ Identity │ 0 │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m15\u001b[0m\u001b[2m \u001b[0m│ model.2.fn.3 │ ParametrizedRKOConv2d │ 131 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m16\u001b[0m\u001b[2m \u001b[0m│ model.3 │ PrescaledAdditiveResidual │ 142 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m17\u001b[0m\u001b[2m \u001b[0m│ model.3.fn │ Sequential │ 142 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m18\u001b[0m\u001b[2m \u001b[0m│ model.3.fn.0 │ ParametrizedFastBlockConv2d │ 10.8 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m19\u001b[0m\u001b[2m \u001b[0m│ model.3.fn.1 │ MaxMin │ 0 │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m20\u001b[0m\u001b[2m \u001b[0m│ model.3.fn.2 │ Identity │ 0 │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m21\u001b[0m\u001b[2m \u001b[0m│ model.3.fn.3 │ ParametrizedRKOConv2d │ 131 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m22\u001b[0m\u001b[2m \u001b[0m│ model.4 │ PrescaledAdditiveResidual │ 142 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m23\u001b[0m\u001b[2m \u001b[0m│ model.4.fn │ Sequential │ 142 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m24\u001b[0m\u001b[2m \u001b[0m│ model.4.fn.0 │ ParametrizedFastBlockConv2d │ 10.8 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m25\u001b[0m\u001b[2m \u001b[0m│ model.4.fn.1 │ MaxMin │ 0 │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m26\u001b[0m\u001b[2m \u001b[0m│ model.4.fn.2 │ Identity │ 0 │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m27\u001b[0m\u001b[2m \u001b[0m│ model.4.fn.3 │ ParametrizedRKOConv2d │ 131 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m28\u001b[0m\u001b[2m \u001b[0m│ model.5 │ PrescaledAdditiveResidual │ 142 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m29\u001b[0m\u001b[2m \u001b[0m│ model.5.fn │ Sequential │ 142 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m30\u001b[0m\u001b[2m \u001b[0m│ model.5.fn.0 │ ParametrizedFastBlockConv2d │ 10.8 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m31\u001b[0m\u001b[2m \u001b[0m│ model.5.fn.1 │ MaxMin │ 0 │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m32\u001b[0m\u001b[2m \u001b[0m│ model.5.fn.2 │ Identity │ 0 │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m33\u001b[0m\u001b[2m \u001b[0m│ model.5.fn.3 │ ParametrizedRKOConv2d │ 131 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m34\u001b[0m\u001b[2m \u001b[0m│ model.6 │ PrescaledAdditiveResidual │ 142 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m35\u001b[0m\u001b[2m \u001b[0m│ model.6.fn │ Sequential │ 142 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m36\u001b[0m\u001b[2m \u001b[0m│ model.6.fn.0 │ ParametrizedFastBlockConv2d │ 10.8 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m37\u001b[0m\u001b[2m \u001b[0m│ model.6.fn.1 │ MaxMin │ 0 │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m38\u001b[0m\u001b[2m \u001b[0m│ model.6.fn.2 │ Identity │ 0 │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m39\u001b[0m\u001b[2m \u001b[0m│ model.6.fn.3 │ ParametrizedRKOConv2d │ 131 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m40\u001b[0m\u001b[2m \u001b[0m│ model.7 │ PrescaledAdditiveResidual │ 142 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m41\u001b[0m\u001b[2m \u001b[0m│ model.7.fn │ Sequential │ 142 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m42\u001b[0m\u001b[2m \u001b[0m│ model.7.fn.0 │ ParametrizedFastBlockConv2d │ 10.8 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m43\u001b[0m\u001b[2m \u001b[0m│ model.7.fn.1 │ MaxMin │ 0 │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m44\u001b[0m\u001b[2m \u001b[0m│ model.7.fn.2 │ Identity │ 0 │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m45\u001b[0m\u001b[2m \u001b[0m│ model.7.fn.3 │ ParametrizedRKOConv2d │ 131 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m46\u001b[0m\u001b[2m \u001b[0m│ model.8 │ PrescaledAdditiveResidual │ 142 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m47\u001b[0m\u001b[2m \u001b[0m│ model.8.fn │ Sequential │ 142 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m48\u001b[0m\u001b[2m \u001b[0m│ model.8.fn.0 │ ParametrizedFastBlockConv2d │ 10.8 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m49\u001b[0m\u001b[2m \u001b[0m│ model.8.fn.1 │ MaxMin │ 0 │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m50\u001b[0m\u001b[2m \u001b[0m│ model.8.fn.2 │ Identity │ 0 │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m51\u001b[0m\u001b[2m \u001b[0m│ model.8.fn.3 │ ParametrizedRKOConv2d │ 131 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m52\u001b[0m\u001b[2m \u001b[0m│ model.9 │ PrescaledAdditiveResidual │ 142 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m53\u001b[0m\u001b[2m \u001b[0m│ model.9.fn │ Sequential │ 142 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m54\u001b[0m\u001b[2m \u001b[0m│ model.9.fn.0 │ ParametrizedFastBlockConv2d │ 10.8 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m55\u001b[0m\u001b[2m \u001b[0m│ model.9.fn.1 │ MaxMin │ 0 │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m56\u001b[0m\u001b[2m \u001b[0m│ model.9.fn.2 │ Identity │ 0 │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m57\u001b[0m\u001b[2m \u001b[0m│ model.9.fn.3 │ ParametrizedRKOConv2d │ 131 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m58\u001b[0m\u001b[2m \u001b[0m│ model.10 │ PrescaledAdditiveResidual │ 142 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m59\u001b[0m\u001b[2m \u001b[0m│ model.10.fn │ Sequential │ 142 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m60\u001b[0m\u001b[2m \u001b[0m│ model.10.fn.0 │ ParametrizedFastBlockConv2d │ 10.8 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m61\u001b[0m\u001b[2m \u001b[0m│ model.10.fn.1 │ MaxMin │ 0 │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m62\u001b[0m\u001b[2m \u001b[0m│ model.10.fn.2 │ Identity │ 0 │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m63\u001b[0m\u001b[2m \u001b[0m│ model.10.fn.3 │ ParametrizedRKOConv2d │ 131 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m64\u001b[0m\u001b[2m \u001b[0m│ model.11 │ PrescaledAdditiveResidual │ 142 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m65\u001b[0m\u001b[2m \u001b[0m│ model.11.fn │ Sequential │ 142 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m66\u001b[0m\u001b[2m \u001b[0m│ model.11.fn.0 │ ParametrizedFastBlockConv2d │ 10.8 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m67\u001b[0m\u001b[2m \u001b[0m│ model.11.fn.1 │ MaxMin │ 0 │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m68\u001b[0m\u001b[2m \u001b[0m│ model.11.fn.2 │ Identity │ 0 │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m69\u001b[0m\u001b[2m \u001b[0m│ model.11.fn.3 │ ParametrizedRKOConv2d │ 131 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m70\u001b[0m\u001b[2m \u001b[0m│ model.12 │ PrescaledAdditiveResidual │ 142 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m71\u001b[0m\u001b[2m \u001b[0m│ model.12.fn │ Sequential │ 142 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m72\u001b[0m\u001b[2m \u001b[0m│ model.12.fn.0 │ ParametrizedFastBlockConv2d │ 10.8 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m73\u001b[0m\u001b[2m \u001b[0m│ model.12.fn.1 │ MaxMin │ 0 │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m74\u001b[0m\u001b[2m \u001b[0m│ model.12.fn.2 │ Identity │ 0 │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m75\u001b[0m\u001b[2m \u001b[0m│ model.12.fn.3 │ ParametrizedRKOConv2d │ 131 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m76\u001b[0m\u001b[2m \u001b[0m│ model.13 │ ParametrizedRKOConv2d │ 196 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m77\u001b[0m\u001b[2m \u001b[0m│ model.13.parametrizations │ ModuleDict │ 196 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m78\u001b[0m\u001b[2m \u001b[0m│ model.13.parametrizations.weight │ ParametrizationList │ 196 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m79\u001b[0m\u001b[2m \u001b[0m│ model.14 │ Flatten │ 0 │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m80\u001b[0m\u001b[2m \u001b[0m│ model.15 │ MaxMin │ 0 │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m81\u001b[0m\u001b[2m \u001b[0m│ model.16 │ ParametrizedOrthoLinear │ 2.8 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m82\u001b[0m\u001b[2m \u001b[0m│ model.16.parametrizations │ ModuleDict │ 2.8 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m83\u001b[0m\u001b[2m \u001b[0m│ model.16.parametrizations.weight │ ParametrizationList │ 2.8 K │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m84\u001b[0m\u001b[2m \u001b[0m│ criteria │ CosineLoss │ 0 │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m85\u001b[0m\u001b[2m \u001b[0m│ train_acc │ MulticlassAccuracy │ 0 │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m86\u001b[0m\u001b[2m \u001b[0m│ val_acc │ MulticlassAccuracy │ 0 │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m87\u001b[0m\u001b[2m \u001b[0m│ train_vra │ MeanMetric │ 0 │ train │\n", - "│\u001b[2m \u001b[0m\u001b[2m88\u001b[0m\u001b[2m \u001b[0m│ val_vra │ MeanMetric │ 0 │ train │\n", - "└────┴──────────────────────────────────┴─────────────────────────────┴────────┴───────┘\n" + "application/vnd.jupyter.widget-view+json": { + "model_id": "2226827bf89b409084b8ddffca572e0b", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Sanity Checking: | …" ] }, "metadata": {}, "output_type": "display_data" }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Files already downloaded and verified\n", + "Files already downloaded and verified\n" + ] + }, { "data": { - "text/html": [ - "
Trainable params: 1.8 M                                                                                            \n",
-       "Non-trainable params: 130 K                                                                                        \n",
-       "Total params: 1.9 M                                                                                                \n",
-       "Total estimated model params size (MB): 7                                                                          \n",
-       "Modules in train mode: 352                                                                                         \n",
-       "Modules in eval mode: 0                                                                                            \n",
-       "
\n" - ], + "application/vnd.jupyter.widget-view+json": { + "model_id": "fcf0b6135375483284ee822ed7415541", + "version_major": 2, + "version_minor": 0 + }, "text/plain": [ - "\u001b[1mTrainable params\u001b[0m: 1.8 M \n", - "\u001b[1mNon-trainable params\u001b[0m: 130 K \n", - "\u001b[1mTotal params\u001b[0m: 1.9 M \n", - "\u001b[1mTotal estimated model params size (MB)\u001b[0m: 7 \n", - "\u001b[1mModules in train mode\u001b[0m: 352 \n", - "\u001b[1mModules in eval mode\u001b[0m: 0 \n" + "Training: | …" ] }, "metadata": {}, @@ -629,12 +513,12 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "c71c76c88dbc40cdb2b291fed40e914d", + "model_id": "e69ec3170eb643649405d9c085cc474f", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "Output()" + "Validation: | …" ] }, "metadata": {}, @@ -642,12 +526,13 @@ }, { "data": { - "text/html": [ - "
Files already downloaded and verified\n",
-       "
\n" - ], + "application/vnd.jupyter.widget-view+json": { + "model_id": "d66dfd3ba2444f8caa958060c8e19750", + "version_major": 2, + "version_minor": 0 + }, "text/plain": [ - "Files already downloaded and verified\n" + "Validation: | …" ] }, "metadata": {}, @@ -655,76 +540,865 @@ }, { "data": { - "text/html": [ - "
Files already downloaded and verified\n",
-       "
\n" - ], + "application/vnd.jupyter.widget-view+json": { + "model_id": "d9a67fb7924944d785d075ecf527523e", + "version_major": 2, + "version_minor": 0 + }, "text/plain": [ - "Files already downloaded and verified\n" + "Validation: | …" ] }, "metadata": {}, "output_type": "display_data" }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "`Trainer.fit` stopped: `max_epochs=60` reached.\n" - ] + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "1263a873ec9245ae86ae078ef88fed29", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: | …" + ] + }, + "metadata": {}, + "output_type": "display_data" }, { "data": { - "text/html": [ - "
\n"
-      ],
-      "text/plain": []
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "f7208738bb03464bb1fd66c6f7c094e8",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
      },
      "metadata": {},
      "output_type": "display_data"
-    }
-   ],
-   "source": [
-    "# Select the training setting manually.\n",
-    "train_setting = \"non_robust\"  # Options: \"non_robust\", \"mildly_robust\", or \"robust\"\n",
-    "\n",
-    "# Get the corresponding loss function and number of epochs from the settings.\n",
-    "current_setting = settings[train_setting]\n",
-    "\n",
-    "# Instantiate the classification model and data module.\n",
-    "classification_module = ClassificationLightningModule(num_classes=10, loss=current_setting[\"loss\"])\n",
-    "data_module = Cifar10DataModule()\n",
-    "\n",
-    "# Optionally, set up a logger or callbacks if needed.\n",
-    "# For example, if using Wandb:\n",
-    "# from lightning.pytorch.loggers import WandbLogger\n",
-    "# wandb_logger = WandbLogger(project=\"lipschitz-robust-cifar10\", log_model=True)\n",
-    "# checkpoint_callback = pl_callbacks.ModelCheckpoint(\n",
-    "#     monitor=\"loss\",\n",
-    "#     mode=\"min\",\n",
-    "#     save_top_k=1,\n",
-    "#     save_last=True,\n",
-    "#     dirpath=f\"./checkpoints/{wandb_logger.experiment.dir}\",\n",
-    "# )\n",
-    "\n",
-    "trainer = Trainer(\n",
-    "    accelerator=\"gpu\",\n",
-    "    devices=1,             # Use 1 GPU set to -1 for all GPUs\n",
-    "    num_nodes=1,            # Number of nodes\n",
-    "    # strategy=\"ddp_spawn\",         # Distributed strategy\n",
-    "    precision=\"bf16-mixed\", # Mixed precision training\n",
-    "    max_epochs=current_setting[\"epochs\"],\n",
-    "    # enable_model_summary=True,\n",
-    "    # logger=[wandb_logger],  # Uncomment to enable Wandb logging\n",
-    "    logger=False,\n",
-    "    callbacks=[\n",
-    "        # You can add callbacks here, e.g.:\n",
-    "        # pl_callbacks.LearningRateFinder(max_lr=0.05),\n",
-    "        # checkpoint_callback,\n",
-    "        RichModelSummary(max_depth=4),\n",
-    "        RichProgressBar(),\n",
-    "    ],\n",
-    ")\n",
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "e76a042444de434aa4ea4d21daa9519a",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "7fd21fb9cc74430092ae0d8b460fa0de",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "6bf278b0c03045c5b23b5bba43972a68",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "86fe5ff8be2741418456e1f3a0c72fff",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "999d206a612c4294b894da5aa640617f",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "55b0a51a70ab41d289b659f0c0c42c72",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "fff0fb909c7447f7a1a8bec3faa9f012",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "7111ce7a8b064f6089dc09cc0f250fb9",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "766fa42f0e72472c809eae654d1e5e58",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "2c5abe8a3f6e4196b5b83255a1419739",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "c64cb38a9417429d825744c5013ab987",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "1ae862f2a38a477c84cb30a97ee7acb4",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "e4129e4c37f24df2b5861bc15d8cae23",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "b924d9fcb7e341b6a702e06dd9f60b48",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "c40b8d0e8c61427985b91fc161a138fb",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "20d0eef1ab6c494998a6375f9f27da45",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "836a21a6cc254089a315b6b32e6860dc",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "f2460340f2ee40d1940f5e68f03c8aaf",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "394bc1541b1e45b483cb0f6fc3249fd8",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "02d9185d09264052b1fdacced023a225",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "507158623c2f46809d15db8d587c13bd",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "c2c9a08f36954982b035ff3dcd47a348",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "b85495eabe7f420c9786514d860e5a55",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "dad5e8c0db064762af69e4b86c29621b",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "8c05ef38da6c4d8da8ff02592f4ec3fa",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "d5c543f2ec8141f69c249e94ca5bd548",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "b872909a37c647749afc42040861d2d9",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "8b47094818724bf8bf458660f0c18d7a",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "afb573baec2b47828f4e2aabd98a06fe",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "9c45134f0de04a1a87331011898a92a0",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "f93009ef44c04659812e1ab6b35f652e",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "128f34b1199a47dea368948200a7a576",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "e0ea951a09ef4c7a9f5f04677f9b5d77",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "d0a977768f1840629231625fc4a6c0fd",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "cb27b60727bb424fac80de3e2d09faa3",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "3bf8927c76d740dea7357fd8883810de",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "e2b9c52d53764fdf8583215f0ec0f4c8",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "9b5d9ac27fee420d97ce033fb0f5b932",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "742daccf242143d3bda7e464f25eb899",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "b580c8ec9d324d368f6c211ab940bc66",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "eff45c67c208455a9389cb0117e4f609",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "da5a6854c4e843429fb20eebe0d6dab7",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "4fe9f0ba14d143bab98d64809bd577e7",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "66a22411171a4b3582500253b5634553",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "647bef5c2b044718a9a511f5fdf33814",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "a69650f346894dc4925403eed72fb704",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "8832554888034a9a8b4e8967c71b1c46",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "1b8d35bff60a41ac838b00ed0a42d147",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "8cddca1f1d4241edb26dab4e3a4629f9",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "f7e33c719dd24d1b9e2574e038e044dc",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "f09d28d453d343828d9f89188b5fe285",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "a709badae748425aacf30afab9e7bf03",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "08fcad2b42724844ad68b4715ea183c3",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "fa05739a92f6468883dd0b609715d376",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "d88fb502d8004927b664794ad314752b",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation: |                                                                                                 …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "`Trainer.fit` stopped: `max_epochs=60` reached.\n"
+     ]
+    }
+   ],
+   "source": [
+    "# Select the training setting manually.\n",
+    "train_setting = \"non_robust\"  # Options: \"non_robust\", \"mildly_robust\", or \"robust\"\n",
+    "\n",
+    "# Get the corresponding loss function and number of epochs from the settings.\n",
+    "current_setting = settings[train_setting]\n",
+    "\n",
+    "# Instantiate the classification model and data module.\n",
+    "classification_module = ClassificationLightningModule(num_classes=10, loss=current_setting[\"loss\"])\n",
+    "data_module = Cifar10DataModule()\n",
+    "\n",
+    "# Optionally, set up a logger or callbacks if needed.\n",
+    "# For example, if using Wandb:\n",
+    "# from lightning.pytorch.loggers import WandbLogger\n",
+    "# wandb_logger = WandbLogger(project=\"lipschitz-robust-cifar10\", log_model=True)\n",
+    "# checkpoint_callback = pl_callbacks.ModelCheckpoint(\n",
+    "#     monitor=\"loss\",\n",
+    "#     mode=\"min\",\n",
+    "#     save_top_k=1,\n",
+    "#     save_last=True,\n",
+    "#     dirpath=f\"./checkpoints/{wandb_logger.experiment.dir}\",\n",
+    "# )\n",
+    "\n",
+    "trainer = Trainer(\n",
+    "    accelerator=\"gpu\",\n",
+    "    devices=[1],             # Use 1 GPU set to -1 for all GPUs\n",
+    "    num_nodes=1,            # Number of nodes\n",
+    "    # strategy=\"ddp_spawn\",         # Distributed strategy\n",
+    "    precision=\"bf16-mixed\", # Mixed precision training\n",
+    "    max_epochs=current_setting[\"epochs\"],\n",
+    "    enable_model_summary=False,\n",
+    "    # logger=[wandb_logger],  # Uncomment to enable Wandb logging\n",
+    "    logger=False,\n",
+    "    callbacks=[\n",
+    "        # You can add callbacks here, e.g.:\n",
+    "        # pl_callbacks.LearningRateFinder(max_lr=0.05),\n",
+    "        # checkpoint_callback,\n",
+    "    ],\n",
+    ")\n",
+    "\n",
+    "print(summary(classification_module, input_size=(1, 3, 32, 32)))\n",
     "# Start training\n",
     "trainer.fit(classification_module, data_module)\n",
     "\n",
diff --git a/mkdocs.yml b/mkdocs.yml
index bbf9c1b..ef5e1fb 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -13,8 +13,8 @@ nav:
 #      - layers.conv.AOC module: api/aoc.md
 #      - layers.conv.adaptiveSOC module: api/adaptiveSOC.md
 #      - layers.conv.SLL module: api/sll.md
-#  - Tutorials:
-#    - "Demo 0: How to use notebook in documentation": notebooks/demo_fake.ipynb
+  - Tutorials:
+    - "Demo 1: Certifiable robustness with 1-Lipschitz networks": notebooks/demo_cifar_classification.ipynb
   - Contributing: CONTRIBUTING.md
 
 theme:

From 3befeb5ca0c3be7d035eb12d27650a1f023dd1a8 Mon Sep 17 00:00:00 2001
From: Thibaut Boissin 
Date: Wed, 12 Feb 2025 22:07:17 +0100
Subject: [PATCH 08/13] SLL handle groups now

---
 orthogonium/layers/conv/SLL/sll_layer.py | 80 +++++++++++++++++-------
 tests/test_sll.py                        | 15 +++++
 2 files changed, 71 insertions(+), 24 deletions(-)

diff --git a/orthogonium/layers/conv/SLL/sll_layer.py b/orthogonium/layers/conv/SLL/sll_layer.py
index 8693f70..7b4eba7 100644
--- a/orthogonium/layers/conv/SLL/sll_layer.py
+++ b/orthogonium/layers/conv/SLL/sll_layer.py
@@ -69,7 +69,9 @@ def safe_inv(x):
 
 
 class SLLxAOCLipschitzResBlock(nn.Module):
-    def __init__(self, cin, cout, inner_dim_factor, kernel_size=3, stride=2, **kwargs):
+    def __init__(
+        self, cin, cout, inner_dim_factor, kernel_size=3, stride=2, groups=1, **kwargs
+    ):
         """
         Extended SLL-based convolutional residual block. Supports arbitrary kernel sizes,
         strides, and changes in the number of channels by integrating additional
@@ -111,12 +113,15 @@ def __init__(self, cin, cout, inner_dim_factor, kernel_size=3, stride=2, **kwarg
         inner_dim = int(cout * inner_dim_factor)
         self.activation = nn.ReLU()
         self.stride = stride
+        self.groups = groups
         self.padding = kernel_size // 2
         self.kernel = nn.Parameter(
-            torch.randn(inner_dim, cin, inner_kernel_size, inner_kernel_size)
+            torch.randn(
+                inner_dim, cin // self.groups, inner_kernel_size, inner_kernel_size
+            )
         )
         self.bias = nn.Parameter(torch.empty(1, inner_dim, 1, 1))
-        self.q = nn.Parameter(torch.randn(inner_dim))
+        self.q = nn.Parameter(torch.ones(inner_dim, 1, 1, 1))
 
         nn.init.xavier_normal_(self.kernel)
         fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.kernel)
@@ -124,42 +129,62 @@ def __init__(self, cin, cout, inner_dim_factor, kernel_size=3, stride=2, **kwarg
         nn.init.uniform_(self.bias, -bound, bound)  # bias init
 
         self.pre_conv = AdaptiveOrthoConv2d(
-            cin, cin, kernel_size=stride, stride=1, bias=False, padding=0
+            cin, cin, kernel_size=stride, stride=1, bias=False, padding=0, groups=groups
         )
         self.post_conv = AdaptiveOrthoConv2d(
-            cin, cout, kernel_size=stride, stride=stride, bias=False, padding=0
+            cin,
+            cout,
+            kernel_size=stride,
+            stride=stride,
+            bias=False,
+            padding=0,
+            groups=groups,
         )
 
     def compute_t(self):
-        ktk = F.conv2d(self.kernel, self.kernel, padding=self.kernel.shape[-1] - 1)
+        ktk = fast_matrix_conv(
+            transpose_kernel(self.kernel, self.groups, flip=True),
+            self.kernel,
+            self.groups,
+        )
         ktk = torch.abs(ktk)
-        q = torch.exp(self.q).reshape(-1, 1, 1, 1)
-        q_inv = torch.exp(-self.q).reshape(-1, 1, 1, 1)
+        q = torch.exp(self.q)
+        q_inv = torch.exp(-self.q)
         t = (q_inv * ktk * q).sum((1, 2, 3))
         t = safe_inv(t)
+        t = t.reshape(-1, 1, 1, 1)
         return t
 
     def forward(self, x):
         # compute t
         t = self.compute_t()
-        t = t.reshape(1, -1, 1, 1)
         # print(self.pre_conv.weight.shape, self.kernel.shape, self.post_conv.weight.shape)
-        kernel_1a = fast_matrix_conv(self.pre_conv.weight, self.kernel, groups=1)
+        kernel_1a = fast_matrix_conv(
+            self.pre_conv.weight, self.kernel, groups=self.groups
+        )
         kernel_1b = fast_matrix_conv(
-            transpose_kernel(self.kernel, groups=1), self.post_conv.weight, groups=1
+            transpose_kernel(self.kernel, groups=self.groups),
+            self.post_conv.weight,
+            groups=self.groups,
         )
         kernel_2 = fast_matrix_conv(
-            self.pre_conv.weight, self.post_conv.weight, groups=1
+            self.pre_conv.weight, self.post_conv.weight, groups=self.groups
         )
         # first branch
         # fuse pre conv with kernel
-        res = F.conv2d(x, kernel_1a, padding=self.padding)
+        res = F.conv2d(x, kernel_1a, padding=self.padding, groups=self.groups)
         res = res + self.bias
         res = t * self.activation(res)
-        res = 2 * F.conv2d(res, kernel_1b, padding=self.padding, stride=self.stride)
+        res = 2 * F.conv2d(
+            res, kernel_1b, padding=self.padding, stride=self.stride, groups=self.groups
+        )
         # residual branch
         x = F.conv2d(
-            x, kernel_2, padding=self.skip_kernel_size // 2, stride=self.stride
+            x,
+            kernel_2,
+            padding=self.skip_kernel_size // 2,
+            stride=self.stride,
+            groups=self.groups,
         )
         # skip connection
         out = x - res
@@ -167,7 +192,7 @@ def forward(self, x):
 
 
 class SDPBasedLipschitzResBlock(nn.Module):
-    def __init__(self, cin, inner_dim_factor, kernel_size=3, **kwargs):
+    def __init__(self, cin, inner_dim_factor, kernel_size=3, groups=1, **kwargs):
         """
          Original 1-Lipschitz convolutional residual block, based on the SDP-based Lipschitz
         layer (SLL) approach [1]. It has a structure akin to:
@@ -200,14 +225,15 @@ def __init__(self, cin, inner_dim_factor, kernel_size=3, **kwargs):
 
         inner_dim = int(cin * inner_dim_factor)
         self.activation = nn.ReLU()
+        self.groups = groups
 
         self.padding = kernel_size // 2
 
         self.kernel = nn.Parameter(
-            torch.randn(inner_dim, cin, kernel_size, kernel_size)
+            torch.randn(inner_dim, cin // groups, kernel_size, kernel_size)
         )
         self.bias = nn.Parameter(torch.empty(1, inner_dim, 1, 1))
-        self.q = nn.Parameter(torch.randn(inner_dim))
+        self.q = nn.Parameter(torch.ones(inner_dim, 1, 1, 1))
 
         nn.init.xavier_normal_(self.kernel)
         fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.kernel)
@@ -215,21 +241,27 @@ def __init__(self, cin, inner_dim_factor, kernel_size=3, **kwargs):
         nn.init.uniform_(self.bias, -bound, bound)  # bias init
 
     def compute_t(self):
-        ktk = F.conv2d(self.kernel, self.kernel, padding=self.kernel.shape[-1] - 1)
+        ktk = fast_matrix_conv(
+            transpose_kernel(self.kernel, self.groups, flip=True),
+            self.kernel,
+            self.groups,
+        )
         ktk = torch.abs(ktk)
-        q = torch.exp(self.q).reshape(-1, 1, 1, 1)
-        q_inv = torch.exp(-self.q).reshape(-1, 1, 1, 1)
+        q = torch.exp(self.q)
+        q_inv = torch.exp(-self.q)
         t = (q_inv * ktk * q).sum((1, 2, 3))
         t = safe_inv(t)
+        t = t.reshape(-1, 1, 1, 1)
         return t
 
     def forward(self, x):
         t = self.compute_t()
-        t = t.reshape(1, -1, 1, 1)
-        res = F.conv2d(x, self.kernel, padding=1)
+        res = F.conv2d(x, self.kernel, padding=self.padding, groups=self.groups)
         res = res + self.bias
         res = t * self.activation(res)
-        res = 2 * F.conv_transpose2d(res, self.kernel, padding=1)
+        res = 2 * F.conv_transpose2d(
+            res, self.kernel, padding=self.padding, groups=self.groups
+        )
         out = x - res
         return out
 
diff --git a/tests/test_sll.py b/tests/test_sll.py
index cd78fee..62abd6d 100644
--- a/tests/test_sll.py
+++ b/tests/test_sll.py
@@ -16,11 +16,21 @@
             {"cin": 4, "inner_dim_factor": 2, "kernel_size": 3},
             (8, 4, 8, 8),
         ),
+        (
+            SDPBasedLipschitzResBlock,
+            {"cin": 4, "inner_dim_factor": 2, "kernel_size": 3, "groups": 2},
+            (8, 4, 8, 8),
+        ),
         (
             SLLxAOCLipschitzResBlock,
             {"cin": 4, "cout": 4, "inner_dim_factor": 2, "kernel_size": 3},
             (8, 4, 8, 8),
         ),
+        (
+            SLLxAOCLipschitzResBlock,
+            {"cin": 4, "cout": 4, "inner_dim_factor": 2, "kernel_size": 3, "groups": 2},
+            (8, 4, 8, 8),
+        ),
         (
             SDPBasedLipschitzDense,
             {"in_features": 64, "out_features": 64, "inner_dim": 64},
@@ -31,6 +41,11 @@
             {"in_channels": 4, "inner_dim_factor": 2, "kernel_size": 3},
             (8, 4, 8, 8),
         ),
+        (
+            AOCLipschitzResBlock,
+            {"in_channels": 4, "inner_dim_factor": 2, "kernel_size": 3, "groups": 2},
+            (8, 4, 8, 8),
+        ),
     ],
 )
 def test_lipschitz_layers(layer_class, init_params, batch_shape):

From 87fac0b2e8316b7d18c7088f761c3b9bec3f0c2e Mon Sep 17 00:00:00 2001
From: Thibaut Boissin 
Date: Thu, 13 Feb 2025 14:30:21 +0100
Subject: [PATCH 09/13] bump version

---
 orthogonium/VERSION | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/orthogonium/VERSION b/orthogonium/VERSION
index 4e379d2..bcab45a 100644
--- a/orthogonium/VERSION
+++ b/orthogonium/VERSION
@@ -1 +1 @@
-0.0.2
+0.0.3

From 5a22fd6b2ef5be0b93970014a5b753f741ad4ea6 Mon Sep 17 00:00:00 2001
From: Thibaut Boissin 
Date: Tue, 18 Mar 2025 18:00:02 +0100
Subject: [PATCH 10/13] fixed issue in SLL

---
 orthogonium/layers/conv/AOL/aol.py       |   9 +-
 orthogonium/layers/conv/SLL/sll_layer.py | 121 +++++++++++------------
 2 files changed, 62 insertions(+), 68 deletions(-)

diff --git a/orthogonium/layers/conv/AOL/aol.py b/orthogonium/layers/conv/AOL/aol.py
index ff31f1a..90be21f 100644
--- a/orthogonium/layers/conv/AOL/aol.py
+++ b/orthogonium/layers/conv/AOL/aol.py
@@ -3,11 +3,16 @@
 from torch.nn.utils import parametrize
 
 from orthogonium.layers.conv.AOC.fast_block_ortho_conv import (
-    conv_singular_values_numpy,
     transpose_kernel,
     fast_matrix_conv,
 )
-from orthogonium.layers.conv.SLL.sll_layer import safe_inv
+
+
+def safe_inv(x):
+    mask = x == 0
+    x_inv = x ** (-1)
+    x_inv[mask] = 0
+    return x_inv
 
 
 class AOLReparametrizer(nn.Module):
diff --git a/orthogonium/layers/conv/SLL/sll_layer.py b/orthogonium/layers/conv/SLL/sll_layer.py
index 7b4eba7..0459d28 100644
--- a/orthogonium/layers/conv/SLL/sll_layer.py
+++ b/orthogonium/layers/conv/SLL/sll_layer.py
@@ -54,20 +54,15 @@
 import torch.nn as nn
 import torch.nn.functional as F
 from torch.nn.common_types import _size_2_t
+from torch.nn.utils import parametrize
 
 from orthogonium.layers import AdaptiveOrthoConv2d
 from orthogonium.layers.conv.AOC.fast_block_ortho_conv import fast_matrix_conv
 from orthogonium.layers.conv.AOC.fast_block_ortho_conv import transpose_kernel
+from orthogonium.layers.conv.AOL.aol import AOLReparametrizer, safe_inv
 from orthogonium.reparametrizers import OrthoParams
 
 
-def safe_inv(x):
-    mask = x == 0
-    x_inv = x ** (-1)
-    x_inv[mask] = 0
-    return x_inv
-
-
 class SLLxAOCLipschitzResBlock(nn.Module):
     def __init__(
         self, cin, cout, inner_dim_factor, kernel_size=3, stride=2, groups=1, **kwargs
@@ -98,6 +93,8 @@ def __init__(
           - `cin` (int): Number of input channels.
           - `inner_dim_factor` (float): Multiplier for the internal channel dimension.
           - `kernel_size` (int, optional): Base kernel size for the SLL portion. Default is 3.
+          - `stride` (int, optional): Stride for the skip connection. Default is 2.
+          - `groups` (int, optional): Number of groups for the convolution. Default is 1.
           - `**kwargs`: Additional options (unused).
 
 
@@ -120,6 +117,14 @@ def __init__(
                 inner_dim, cin // self.groups, inner_kernel_size, inner_kernel_size
             )
         )
+        parametrize.register_parametrization(
+            self,
+            "kernel",
+            AOLReparametrizer(
+                inner_dim,
+                groups=groups,
+            ),
+        )
         self.bias = nn.Parameter(torch.empty(1, inner_dim, 1, 1))
         self.q = nn.Parameter(torch.ones(inner_dim, 1, 1, 1))
 
@@ -141,51 +146,41 @@ def __init__(
             groups=groups,
         )
 
-    def compute_t(self):
-        ktk = fast_matrix_conv(
-            transpose_kernel(self.kernel, self.groups, flip=True),
-            self.kernel,
-            self.groups,
-        )
-        ktk = torch.abs(ktk)
-        q = torch.exp(self.q)
-        q_inv = torch.exp(-self.q)
-        t = (q_inv * ktk * q).sum((1, 2, 3))
-        t = safe_inv(t)
-        t = t.reshape(-1, 1, 1, 1)
-        return t
-
     def forward(self, x):
         # compute t
-        t = self.compute_t()
         # print(self.pre_conv.weight.shape, self.kernel.shape, self.post_conv.weight.shape)
         kernel_1a = fast_matrix_conv(
             self.pre_conv.weight, self.kernel, groups=self.groups
         )
-        kernel_1b = fast_matrix_conv(
-            transpose_kernel(self.kernel, groups=self.groups),
-            self.post_conv.weight,
-            groups=self.groups,
-        )
-        kernel_2 = fast_matrix_conv(
-            self.pre_conv.weight, self.post_conv.weight, groups=self.groups
-        )
-        # first branch
-        # fuse pre conv with kernel
-        res = F.conv2d(x, kernel_1a, padding=self.padding, groups=self.groups)
-        res = res + self.bias
-        res = t * self.activation(res)
-        res = 2 * F.conv2d(
-            res, kernel_1b, padding=self.padding, stride=self.stride, groups=self.groups
-        )
-        # residual branch
-        x = F.conv2d(
-            x,
-            kernel_2,
-            padding=self.skip_kernel_size // 2,
-            stride=self.stride,
-            groups=self.groups,
-        )
+        with parametrize.cached():
+            kernel_1b = fast_matrix_conv(
+                transpose_kernel(self.kernel, groups=self.groups),
+                self.post_conv.weight,
+                groups=self.groups,
+            )
+            kernel_2 = fast_matrix_conv(
+                self.pre_conv.weight, self.post_conv.weight, groups=self.groups
+            )
+            # first branch
+            # fuse pre conv with kernel
+            res = F.conv2d(x, kernel_1a, padding=self.padding, groups=self.groups)
+            res = res + self.bias
+            res = self.activation(res)
+            res = 2 * F.conv2d(
+                res,
+                kernel_1b,
+                padding=self.padding,
+                stride=self.stride,
+                groups=self.groups,
+            )
+            # residual branch
+            x = F.conv2d(
+                x,
+                kernel_2,
+                padding=self.skip_kernel_size // 2,
+                stride=self.stride,
+                groups=self.groups,
+            )
         # skip connection
         out = x - res
         return out
@@ -211,7 +206,7 @@ def __init__(self, cin, inner_dim_factor, kernel_size=3, groups=1, **kwargs):
           - `cout` (int): Number of output channels.
           - `inner_dim_factor` (float): Multiplier for the intermediate dimensionality.
           - `kernel_size` (int, optional): Size of the convolution kernel. Default is 3.
-          - `stride` (int, optional): Stride for the skip connection. Default is 2.
+          - `groups` (int, optional): Number of groups for the convolution. Default is 1.
           - `**kwargs`: Additional keyword arguments (unused).
 
 
@@ -232,6 +227,14 @@ def __init__(self, cin, inner_dim_factor, kernel_size=3, groups=1, **kwargs):
         self.kernel = nn.Parameter(
             torch.randn(inner_dim, cin // groups, kernel_size, kernel_size)
         )
+        parametrize.register_parametrization(
+            self,
+            "kernel",
+            AOLReparametrizer(
+                inner_dim,
+                groups=groups,
+            ),
+        )
         self.bias = nn.Parameter(torch.empty(1, inner_dim, 1, 1))
         self.q = nn.Parameter(torch.ones(inner_dim, 1, 1, 1))
 
@@ -240,28 +243,14 @@ def __init__(self, cin, inner_dim_factor, kernel_size=3, groups=1, **kwargs):
         bound = 1 / np.sqrt(fan_in)
         nn.init.uniform_(self.bias, -bound, bound)  # bias init
 
-    def compute_t(self):
-        ktk = fast_matrix_conv(
-            transpose_kernel(self.kernel, self.groups, flip=True),
-            self.kernel,
-            self.groups,
-        )
-        ktk = torch.abs(ktk)
-        q = torch.exp(self.q)
-        q_inv = torch.exp(-self.q)
-        t = (q_inv * ktk * q).sum((1, 2, 3))
-        t = safe_inv(t)
-        t = t.reshape(-1, 1, 1, 1)
-        return t
-
     def forward(self, x):
-        t = self.compute_t()
         res = F.conv2d(x, self.kernel, padding=self.padding, groups=self.groups)
         res = res + self.bias
-        res = t * self.activation(res)
-        res = 2 * F.conv_transpose2d(
-            res, self.kernel, padding=self.padding, groups=self.groups
-        )
+        res = self.activation(res)
+        with parametrize.cached():
+            res = 2 * F.conv_transpose2d(
+                res, self.kernel, padding=self.padding, groups=self.groups
+            )
         out = x - res
         return out
 

From 95db2c5ee1b4b7f3398d92fee2d2cbcb969e02f4 Mon Sep 17 00:00:00 2001
From: Thibaut Boissin 
Date: Tue, 18 Mar 2025 18:00:50 +0100
Subject: [PATCH 11/13] improved testing in SLL only compute largest singular
 value

---
 tests/test_sll.py | 11 ++++++-----
 1 file changed, 6 insertions(+), 5 deletions(-)

diff --git a/tests/test_sll.py b/tests/test_sll.py
index 62abd6d..6d490c6 100644
--- a/tests/test_sll.py
+++ b/tests/test_sll.py
@@ -14,12 +14,12 @@
         (
             SDPBasedLipschitzResBlock,
             {"cin": 4, "inner_dim_factor": 2, "kernel_size": 3},
-            (8, 4, 8, 8),
+            (1, 4, 8, 8),
         ),
         (
             SDPBasedLipschitzResBlock,
-            {"cin": 4, "inner_dim_factor": 2, "kernel_size": 3, "groups": 2},
-            (8, 4, 8, 8),
+            {"cin": 2, "inner_dim_factor": 2, "kernel_size": 3, "groups": 2},
+            (1, 2, 8, 8),
         ),
         (
             SLLxAOCLipschitzResBlock,
@@ -108,5 +108,6 @@ def compute_lipschitz_constant(layer, x):
     jacobian = torch.tensor(jacobian).view(y.numel(), x.numel())  # Construct Jacobian
 
     # Compute singular values and return the maximum value
-    singular_values = torch.linalg.svdvals(jacobian)
-    return singular_values.max().item()
+    # singular_values = torch.linalg.svdvals(jacobian)
+    # return singular_values.max().item()
+    return torch.linalg.matrix_norm(jacobian, ord=2).item()

From 1ab7a0b10d0cf5c6e2e1f32e684e95b057e73cf8 Mon Sep 17 00:00:00 2001
From: Thibaut Boissin 
Date: Tue, 18 Mar 2025 18:01:27 +0100
Subject: [PATCH 12/13] improved test speed for normalization layers

---
 tests/test_normalization_layers.py | 44 ++++++++++++++----------------
 1 file changed, 21 insertions(+), 23 deletions(-)

diff --git a/tests/test_normalization_layers.py b/tests/test_normalization_layers.py
index d7807b0..112be99 100644
--- a/tests/test_normalization_layers.py
+++ b/tests/test_normalization_layers.py
@@ -37,34 +37,32 @@ def test_lipschitz_constant_with_various_distributions(
     x = torch.randn(batch_size, num_features, h, w) * std + mean
     x.requires_grad_(True)  # Enable gradient tracking
 
-    # Forward pass through the layer
     y = layer(x)
+    x.requires_grad_(True)  # Enable gradient tracking
+
+    # Compute the Jacobian using jacrev
+    batch_jacobian = torch.func.jacrev(layer)(x)
+
+    # Reshape the Jacobian to match the desired shape
+    batch_size = x.shape[0]
+    ydim = torch.prod(torch.tensor(y.shape)).item()
+    xdim = torch.prod(torch.tensor(x.shape)).item()
 
-    # Calculate Jacobian
-    jacobian = []
-    for i in range(y.numel()):
-        grad_output = torch.zeros_like(y)
-        grad_output.view(-1)[i] = 1
-        gradients = torch.autograd.grad(
-            outputs=y,
-            inputs=x,
-            grad_outputs=grad_output,
-            retain_graph=True,
-            create_graph=True,
-            allow_unused=True,
-        )[0]
-        jacobian.append(gradients.view(-1).detach().cpu().numpy())
-    jacobian = torch.tensor(jacobian).view(
-        y.numel(), x.numel()
-    )  # Construct Jacobian matrix
+    jacobian = batch_jacobian.view(ydim, xdim)
 
     # Validate Lipschitz constant
-    singular_values = torch.linalg.svdvals(jacobian)
-    assert singular_values.max() <= 1 + 1e-4, (
-        f"Lipschitz constraint violated for input distribution with mean={mean}, std={std}; "
-        f"max singular value: {singular_values.max()}"
-    )
     if orthogonal:
+        singular_values = torch.linalg.svdvals(jacobian)
+        assert singular_values.max() <= 1 + 1e-4, (
+            f"Lipschitz constraint violated for input distribution with mean={mean}, std={std}; "
+            f"max singular value: {singular_values.max()}"
+        )
         assert (
             singular_values.min() >= 1 - 1e-4
         ), f"Orthogonality constraint violated for input distribution with mean={mean}, std={std}; "
+    else:
+        lipschitz_constant = torch.linalg.matrix_norm(jacobian, ord=2).item()
+        assert lipschitz_constant <= 1 + 1e-4, (
+            f"Lipschitz constraint violated for input distribution with mean={mean}, std={std}; "
+            f"Lipschitz constant: {lipschitz_constant}"
+        )

From 9af33570d452afb5af070b0772bee5e3fe3b77f4 Mon Sep 17 00:00:00 2001
From: Thibaut Boissin 
Date: Thu, 20 Mar 2025 16:29:36 +0100
Subject: [PATCH 13/13] updated README.md

---
 README.md | 35 ++++++++++++++++++++++-------------
 1 file changed, 22 insertions(+), 13 deletions(-)

diff --git a/README.md b/README.md
index 5a12dd5..90a3670 100644
--- a/README.md
+++ b/README.md
@@ -88,20 +88,23 @@ an explicit kernel, that support all features like stride, conv transposed,
 grouped convolutions and dilation (and all compositions of these parameters). This approach is highly scalable, and can
 be applied to problems like Imagenet-1K.
 
-## Adaptive-SC-FAC:
+[//]: # (## Adaptive-SC-FAC:)
 
-As AOC is built on top of BCOP method, we can construct an equivalent method constructed on top of SC-Fac instead.
-This will allow to compare performance of the two methods given that they have very similar parametrization. (See our 
-paper for discussions about the similarities and differences between the two methods).
+[//]: # ()
+[//]: # (As AOC is built on top of BCOP method, we can construct an equivalent method constructed on top of SC-Fac instead.)
+
+[//]: # (This will allow to compare performance of the two methods given that they have very similar parametrization. (See our )
+
+[//]: # (paper for discussions about the similarities and differences between the two methods).)
 
 ## Adaptive-SOC:
 
 Adaptive-SOC blend the approach of AOC and SOC. It differs from SOC in the way that it is more memory efficient and 
-sometimes faster. It also allows to handle stride, groups, dilation and transposed convolutions. However, it does not allow to 
-control the kernel size explicitly as the resulting kernel size is larger than the requested kernel size. 
-It is due to the computation to the exponential of a kernel that increases the kernel size at each iteration.
-
-Its development is still in progress, so extra testing is still require to ensure exact orthogonality.
+sometimes faster. It also allows to handle stride, groups, dilation and transposed convolutions. Also, our 
+implementation uses AOL to normalize the kernel, which is more stable, more efficient and allows a convergence with less 
+iterations. However, it does not allow to control the kernel size explicitly as the resulting kernel size is larger 
+than the requested kernel size. It is due to the computation to the exponential of a kernel that increases the kernel 
+size at each iteration.
 
 ## SLL:
 
@@ -109,6 +112,12 @@ SLL is a method that allows to construct small residual blocks with ReLU activat
 implementation, and added `SLLxAOCLipschitzResBlock` that construct a down-sampling residual block by fusing SLL with 
 $AOC.
 
+## AOL:
+
+AOL is a method that constructs "almost orthogonal" layers. It ensures lipschitzness of the layer while pushing toward 
+orthogonality. It is a good alternative when the orthogonality constraint is not necessary, or when the orthogonality
+constraint is too expensive to compute.
+
 ## more layers are coming soon !
 
 # 🏠 Install the library:
@@ -118,7 +127,7 @@ The library is available on pip,so you can install it by running the following c
 pip install orthogonium
 ```
 
-If you wish to deep dive in the code and edit your local versin, you can clone the repository and run the following command 
+If you wish to deep dive in the code and edit your local version, you can clone the repository and run the following command 
 to install it locally:
 ```
 git clone git@github.com:thib-s/orthogonium.git
@@ -183,7 +192,7 @@ in a larger scale setting.
 - LOT: [github](https://github.com/AI-secure/Layerwise-Orthogonal-Training) and [paper](https://arxiv.org/abs/2210.11620)
 - ProjUNN-T: [github](https://github.com/facebookresearch/projUNN) and [paper](https://arxiv.org/abs/2203.05483)
 - SLL: [github](https://github.com/araujoalexandre/Lipschitz-SLL-Networks) and [paper](https://arxiv.org/abs/2303.03169)
-- Sandwish: [github](https://github.com/acfr/LBDN) and [paper](https://arxiv.org/abs/2301.11526)
+- Sandwich: [github](https://github.com/acfr/LBDN) and [paper](https://arxiv.org/abs/2301.11526)
 - AOL: [github](https://github.com/berndprach/AOL) and [paper](https://arxiv.org/abs/2208.03160)
 - SOC: [github](https://github.com/singlasahil14/SOC) and [paper 1](https://arxiv.org/abs/2105.11417), [paper 2](https://arxiv.org/abs/2211.08453)
 
@@ -233,9 +242,9 @@ Layers:
   - enable support for native stride, transposition and dilation
 - AOL:
   - torch implementation of AOL
-- Sandwish:
+- Sandwich:
   - import code
-  - plug AOC into Sandwish conv
+  - plug AOC into Sandwich conv
 
 ZOO:
 - models from the paper