Training a 1-Lipschitz constrained network on CIFAR10 with Orthogonium¶
Lipschitz-Constrained Networks and Certifiable Robustness¶
What is a Lipschitz Network? A Lipschitz network is a neural network in which each layer is constrained to be a 1-Lipschitz function. This means that small changes in the input lead to only small changes in the output, ensuring controlled sensitivity throughout the network. The overall Lipschitz constant of the network is usually estimated as the product of the Lipschitz constants of its individual layers. However, this bound is often loose and difficult to compute exactly.
How to Build Lipschitz Networks? To construct such networks: - Orthogonal Layers: Use layers that enforce orthogonality constraints (e.g., Adaptive OrthoConvolutions). These layers are designed to strictly represent 1-Lipschitz functions. - Special Activations: Incorporate activations like MaxMin which, when combined with orthogonal layers, help in obtaining a tight estimation of the network's Lipschitz constant. - Reparametrization Techniques: Methods such as AOC (Adaptive OrthoConvolutions) ensure that each layer adheres to the 1-Lipschitz constraint, making the overall bound much tighter compared to a simple product of individual bounds.
Certifiable Robustness Certifiable robustness provides a guarantee on the minimal perturbation needed to alter the network's prediction, independent of any specific adversarial attack. For a 1-Lipschitz classification function $ f $ with $ f(x)l $ representing the logit for the true class and $ f(x)_i $ for any other class, a robustness certificate in the $ L_2 $ norm is given by: $$ \epsilon \geq \frac{f(x)_l - \max $$ This means that as long as the perturbation remains below $ \epsilon $, the classification will not change. This certificate is: - } f(x)_i}{\sqrt{2}Independent of Attacks: It does not rely on any particular adversarial attack method, ensuring that the guarantee remains valid even as new attack strategies emerge. - Computationally Efficient: The certificate can be computed cheaply and even integrated as a loss term during training, leading to models that are robust by design.
Applications and Benefits Lipschitz-constrained networks are not only crucial for certifiable robustness but also have broader applications: - They are tightly linked with generative models like WGANs and concepts in optimal transport. - They enable scalable differential privacy and help avoid singularities in models such as diffusion networks. - They guarantee existence and uniqueness in classification tasks, making them appealing for reliable machine learning.
In summary, by combining orthogonal layers with appropriate activations and reparametrization techniques, one can build Lipschitz networks that not only deliver competitive performance but also offer provable robustness guarantees.
# !pip install orthogonium lightning rich schedulefree
import math
import os
import schedulefree
import torch
import torch.utils.data
import torchmetrics
from lightning.pytorch import callbacks as pl_callbacks
from lightning.pytorch import Trainer
from lightning.pytorch import LightningModule, LightningDataModule
from torchinfo import summary
# from lightning.pytorch.loggers import WandbLogger # Uncomment if using Wandb logging
from torch.nn import AvgPool2d
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, Normalize, RandAugment, RandomHorizontalFlip, RandomResizedCrop, ToTensor
from orthogonium.model_factory.classparam import ClassParam
from orthogonium.layers.conv.AOC import AdaptiveOrthoConv2d
from orthogonium.layers.linear import OrthoLinear
from orthogonium.layers.custom_activations import MaxMin
from orthogonium.losses import LossXent, CosineLoss
from orthogonium.losses import VRA
from orthogonium.model_factory.models_factory import StagedCNN, PatchBasedExapandedCNN
# Enable benchmark mode and set matmul precision for performance tuning
torch.backends.cudnn.benchmark = True
torch.set_float32_matmul_precision("medium")
Training Settings¶
You can play with the training settings to explore different configurations and compare their performance. The settings include:
Training settings include: - non_robust: Cosine Similarity loss training. - mildly_robust: Cross Entropy Loss includes a high margin targeting a VRA of 36/255, resulting in 42% VRA. - robust: Similar to mildly robust, but with settings that push towards 72/255 verified robust accuracy, resulting in 47% VRA.
Note: The aim here is to show the training flow rather than reach state-of-the-art performance.
Training Settings Performance¶
Setting | Epochs | Accuracy | Verified Robust Accuracy (VRA) |
---|---|---|---|
non_robust | 60 | 88.5% | 0% |
mildly_robust | 150 | 75% | 42% |
robust | 150 | 71% | 47% |
These configurations are stored in the settings
dictionary.
settings = {
"non_robust": {
"loss": CosineLoss,
"epochs": 60,
},
"mildly_robust": {
"loss": ClassParam(
LossXent,
n_classes=10,
offset=(math.sqrt(2) / 0.1983) * (36 / 255), # aiming for 36/255 verified robust accuracy
temperature=0.125,
),
"epochs": 150,
},
"robust": {
"loss": ClassParam(
LossXent,
n_classes=10,
offset=(math.sqrt(2) / 0.1983) * (72 / 255), # aiming for 72/255 verified robust accuracy
temperature=0.25,
),
"epochs": 150,
},
}
Data Module: CIFAR10¶
We create a LightningDataModule
to load and preprocess the CIFAR10 training and validation datasets.
The training dataloader applies several transforms: - Random resized cropping - Random horizontal flip - Normalization using precomputed mean and standard deviation
The validation dataloader only applies tensor conversion and normalization.
class Cifar10DataModule(LightningDataModule):
# Dataset configuration
_BATCH_SIZE = 256
_NUM_WORKERS = 8 # Number of parallel processes for data loading
_PREPROCESSING_PARAMS = {
"img_mean": (0.41757566, 0.26098573, 0.25888634),
"img_std": (0.21938758, 0.1983, 0.19342837),
"crop_size": 32,
"horizontal_flip_prob": 0.5,
"random_resized_crop_params": {
"scale": (0.25, 1.0),
"ratio": (3.0 / 4.0, 4.0 / 3.0),
},
}
def train_dataloader(self):
# Define the transformations for training data
transform = Compose(
[
RandomResizedCrop(
self._PREPROCESSING_PARAMS["crop_size"],
**self._PREPROCESSING_PARAMS["random_resized_crop_params"],
),
RandomHorizontalFlip(self._PREPROCESSING_PARAMS["horizontal_flip_prob"]),
# Uncomment the following line to use RandAugment
# RandAugment(**self._PREPROCESSING_PARAMS["randaug_params"]),
ToTensor(),
Normalize(
mean=self._PREPROCESSING_PARAMS["img_mean"],
std=self._PREPROCESSING_PARAMS["img_std"],
),
]
)
train_dataset = CIFAR10(
root="./data",
train=True,
download=True,
transform=transform,
)
return DataLoader(
train_dataset,
batch_size=self._BATCH_SIZE,
num_workers=self._NUM_WORKERS,
prefetch_factor=2,
shuffle=True,
)
def val_dataloader(self):
# Define the transformations for validation data
transform = Compose(
[
ToTensor(),
Normalize(
mean=self._PREPROCESSING_PARAMS["img_mean"],
std=self._PREPROCESSING_PARAMS["img_std"],
),
]
)
val_dataset = CIFAR10(
root="./data",
train=False,
download=True,
transform=transform,
)
return DataLoader(
val_dataset,
batch_size=self._BATCH_SIZE * 4,
num_workers=self._NUM_WORKERS,
shuffle=False,
)
Classification Model Module¶
We now define a LightningModule
that wraps our CNN model. The network uses the PatchBasedExapandedCNN
factory method from orthogonium.
The architecture consists of 4 main parts: - The stem is a patch extractor: a convolution whose kernel size equals the stride. - A sequence of residual block: each residual features a learnable factor to ensure its Lipschitzness. In each residual, there is one depthwise convolution, A MaxMin activation, and a pointwise convolution. No pooling is performed in this part of the network. - A pooling layer: here, we use a depthwise convolution whose kernel size equals the image size. This allows for the localization of features without using a large amount of weight. (this is not mandatory for accurate training but seems to obtain a slightly better accuracy / robustness tradeoff in robust training). - a Fully connected layer for classification.
All convolutional layers use AOC, allowing the construction of complex Lipschitz-constrained architectures.
Key components include:
- The custom CNN model architecture.
- The loss function (set based on the selected training configuration).
- Training and validation steps that compute and log both accuracy and verified robust accuracy (VRA).
- The configure_optimizers
method which sets up the Adam optimizer with schedule-free updates.
class ClassificationLightningModule(LightningModule):
def __init__(self, num_classes=10, loss=None):
super().__init__()
self.num_classes = num_classes
self.model = PatchBasedExapandedCNN(
img_shape=(3, 32, 32),
dim=256,
depth=12,
kernel_size=3,
patch_size=2,
expand_factor=2,
groups=None,
n_classes=10,
skip=True,
conv=ClassParam(
AdaptiveOrthoConv2d,
bias=False,
padding="same",
padding_mode="zeros",
),
act=ClassParam(MaxMin),
pool=ClassParam(
AdaptiveOrthoConv2d,
in_channels=256,
out_channels=256,
groups=128,
bias=False,
padding=0,
kernel_size=16,
stride=16,
),
lin=ClassParam(OrthoLinear, bias=False),
norm=None,
)
self.criteria = loss() if loss is not None else torch.nn.CrossEntropyLoss()
self.train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
self.val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
self.train_vra = torchmetrics.MeanMetric()
self.val_vra = torchmetrics.MeanMetric()
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
self.model.train()
img, label = batch
y_hat = self.model(img)
loss = self.criteria(y_hat, label)
self.train_acc(y_hat, label)
self.train_vra(
VRA(
y_hat,
label,
L=1 / min(Cifar10DataModule._PREPROCESSING_PARAMS["img_std"]),
eps=36 / 255,
last_layer_type="global",
)
)
self.log("loss", loss, on_epoch=True, on_step=True, prog_bar=True, sync_dist=True)
self.log("accuracy", self.train_acc, on_epoch=True, on_step=True, prog_bar=True, sync_dist=True)
self.log("vra", self.train_vra, on_epoch=True, on_step=True, prog_bar=True, sync_dist=False)
return loss
def validation_step(self, batch, batch_idx):
self.model.eval()
img, label = batch
y_hat = self.model(img)
loss = self.criteria(y_hat, label)
self.val_acc(y_hat, label)
self.val_vra(
VRA(
y_hat,
label,
L=1 / min(Cifar10DataModule._PREPROCESSING_PARAMS["img_std"]),
eps=36 / 255,
last_layer_type="global",
)
)
self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True, sync_dist=True)
self.log("val_accuracy", self.val_acc, on_epoch=True, on_step=False, prog_bar=True, sync_dist=True)
self.log("val_vra", self.val_vra, on_epoch=True, on_step=False, prog_bar=True, sync_dist=True)
return loss
def configure_optimizers(self):
# Setup the Adam optimizer with schedule-free updates.
optimizer = schedulefree.AdamWScheduleFree(self.parameters(), lr=5e-3, weight_decay=0)
optimizer.train()
self.hparams["lr"] = optimizer.param_groups[0]["lr"]
return optimizer
Training Routine¶
For example, to run a non robust training setting, set:
train_setting = "non_robust"
# Select the training setting manually.
train_setting = "non_robust" # Options: "non_robust", "mildly_robust", or "robust"
# Get the corresponding loss function and number of epochs from the settings.
current_setting = settings[train_setting]
# Instantiate the classification model and data module.
classification_module = ClassificationLightningModule(num_classes=10, loss=current_setting["loss"])
data_module = Cifar10DataModule()
# Optionally, set up a logger or callbacks if needed.
# For example, if using Wandb:
# from lightning.pytorch.loggers import WandbLogger
# wandb_logger = WandbLogger(project="lipschitz-robust-cifar10", log_model=True)
# checkpoint_callback = pl_callbacks.ModelCheckpoint(
# monitor="loss",
# mode="min",
# save_top_k=1,
# save_last=True,
# dirpath=f"./checkpoints/{wandb_logger.experiment.dir}",
# )
trainer = Trainer(
accelerator="gpu",
devices=[1], # Use 1 GPU set to -1 for all GPUs
num_nodes=1, # Number of nodes
# strategy="ddp_spawn", # Distributed strategy
precision="bf16-mixed", # Mixed precision training
max_epochs=current_setting["epochs"],
enable_model_summary=False,
# logger=[wandb_logger], # Uncomment to enable Wandb logging
logger=False,
enable_progress_bar=False,
callbacks=[
# You can add callbacks here, e.g.:
# pl_callbacks.LearningRateFinder(max_lr=0.05),
# checkpoint_callback,
],
)
print(summary(classification_module, input_size=(1, 3, 32, 32)))
# Start training
trainer.fit(classification_module, data_module)
Optionally, you can save the trained model afterwards:
torch.save(classification_module.model.state_dict(), "single_stage.pth")
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/mnt/deel/data/thibaut.boissin/envs/ortho/lib/python3.12/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /home/thibaut.boissin/projects/orthogonium/scripts/pareto/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
==================================================================================================================================
Layer (type:depth-idx) Output Shape Param #
==================================================================================================================================
ClassificationLightningModule [1, 10] --
├─Sequential: 1-1 [1, 10] --
│ └─ParametrizedRKOConv2d: 2-1 [1, 256, 16, 16] --
│ │ └─ModuleDict: 3-1 -- 3,340
│ └─AdditiveResidual: 2-2 [1, 256, 16, 16] 1
│ │ └─Sequential: 3-2 [1, 256, 16, 16] 142,592
│ └─AdditiveResidual: 2-3 [1, 256, 16, 16] 1
│ │ └─Sequential: 3-3 [1, 256, 16, 16] 142,592
│ └─AdditiveResidual: 2-4 [1, 256, 16, 16] 1
│ │ └─Sequential: 3-4 [1, 256, 16, 16] 142,592
│ └─AdditiveResidual: 2-5 [1, 256, 16, 16] 1
│ │ └─Sequential: 3-5 [1, 256, 16, 16] 142,592
│ └─AdditiveResidual: 2-6 [1, 256, 16, 16] 1
│ │ └─Sequential: 3-6 [1, 256, 16, 16] 142,592
│ └─AdditiveResidual: 2-7 [1, 256, 16, 16] 1
│ │ └─Sequential: 3-7 [1, 256, 16, 16] 142,592
│ └─AdditiveResidual: 2-8 [1, 256, 16, 16] 1
│ │ └─Sequential: 3-8 [1, 256, 16, 16] 142,592
│ └─AdditiveResidual: 2-9 [1, 256, 16, 16] 1
│ │ └─Sequential: 3-9 [1, 256, 16, 16] 142,592
│ └─AdditiveResidual: 2-10 [1, 256, 16, 16] 1
│ │ └─Sequential: 3-10 [1, 256, 16, 16] 142,592
│ └─AdditiveResidual: 2-11 [1, 256, 16, 16] 1
│ │ └─Sequential: 3-11 [1, 256, 16, 16] 142,592
│ └─AdditiveResidual: 2-12 [1, 256, 16, 16] 1
│ │ └─Sequential: 3-12 [1, 256, 16, 16] 142,592
│ └─AdditiveResidual: 2-13 [1, 256, 16, 16] 1
│ │ └─Sequential: 3-13 [1, 256, 16, 16] 142,592
│ └─ParametrizedRKOConv2d: 2-14 [1, 256, 1, 1] --
│ │ └─ModuleDict: 3-14 -- 196,864
│ └─Flatten: 2-15 [1, 256] --
│ └─MaxMin: 2-16 [1, 256] --
│ └─ParametrizedOrthoLinear: 2-17 [1, 10] --
│ │ └─ModuleDict: 3-15 -- 2,826
==================================================================================================================================
Total params: 1,914,146
Trainable params: 1,783,308
Non-trainable params: 130,838
Total mult-adds (Units.MEGABYTES): 0
==================================================================================================================================
Input size (MB): 0.01
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.01
==================================================================================================================================
Files already downloaded and verified
Files already downloaded and verified
`Trainer.fit` stopped: `max_epochs=60` reached.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Files already downloaded and verified
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Validate metric ┃ DataLoader 0 ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ val_accuracy │ 0.8831999897956848 │
│ val_loss │ -0.8909505605697632 │
│ val_vra │ 0.0 │
└───────────────────────────┴───────────────────────────┘
[{'val_loss': -0.8909505605697632,
'val_accuracy': 0.8831999897956848,
'val_vra': 0.0}]
Next Steps¶
- Model Evaluation: You can add a new cell to perform model evaluation or predictions.
- Logging and Checkpoints: To enable model logging or checkpoint saving, uncomment the corresponding lines and configure as needed.
- Experiment with Settings: Change the
train_setting
variable to"mildly_robust"
or"robust"
to experiment with other training configurations.