Skip to content

Conversation

@franckma31
Copy link
Collaborator

Add a special layer BatchLipNorm to replace non-Lipschitz BatchNorm.

This layer computes the running_mean and running_variance similarly to standard BatchNorm (i.e., using batch statistics during training), but applies a normalization factor common to all channels:

$$ y_{b,c} = \frac{x_{b,c} - \mu_c}{\alpha} $$

where $\mu_c$ is the per-channel mean over the batch (or the running mean in evaluation mode), and $\alpha = 1 / \max(\sqrt{\text{var}_c})$, based on the maximum per-channel variance (or running variance in eval mode).

This layer optionally supports disabling centering (i.e., applying only the normalization factor without subtracting the mean).

It is compatible with multi-GPU training via torchrun (torch.distributed).


Since each BatchLipNorm layer introduces a scaling factor, it must be used within a Sequential model. A SharedLipFactory is provided to track the product of all scaling factors, along with a final LipFactor layer that compensates for this scaling to ensure the network remains globally 1-Lipschitz.

BnLipSequential (a subclass of torch.nn.Sequential) offers a convenient way to build such a model using BatchLipNorm while preserving the 1-Lipschitz property throughout the network.

An example of script that can learn robust Cifar10 classifer with BatchLipNorm layer is provided in scripts\train_cifar_bn.py

@franckma31 franckma31 force-pushed the feat/lipbatchnorm branch 2 times, most recently from a6f3c3f to 0a1568a Compare June 10, 2025 16:07
@franckma31 franckma31 requested a review from thib-s June 11, 2025 08:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants