Shortcuts

Example and usage

In order to make things simple the following rules have been followed during development:

  • deel-torchlip follows the torch.nn package structure.

  • When a k-Lipschitz module overrides a standard torch.nn module, it uses the same interface and the same parameters. The only difference is a new parameter to control the Lipschitz constant of a layer.

Which modules are safe to use?

Modules from deel-torchlip are mostly wrappers around initializers and normalization hooks that ensure their 1-Lipschitz property. For instance, the SpectralLinear module is simply a torch.nn.Linear module , with automatic orthogonal initialization and hooks:

# This code is about equivalent to SpectralLinear(16, 32)
m = torch.nn.Linear(16, 32)
torch.nn.init.orthogonal_(m.weight)
m.bias.data.fill_(0.0)

torch.nn.utils.spectral_norm(m, "weight", eps=1e-3)
torchlip.utils.bjorck_norm(m, "weight", eps=1e-3)

The following table indicates which module are safe to use in a Lipschitz network, and which are not.

torch.nn

1-Lipschitz?

deel-torchlip equivalent

comments

torch.nn.Linear

no

SpectralLinear
FrobeniusLinear

SpectralLinear and FrobeniusLinear are similar when there is a single output.

torch.nn.Conv2d

no

SpectralConv2d
FrobeniusConv2d

SpectralConv2d also implements Björck normalization.

torch.nn.Conv1d

no

SpectralConv1d

SpectralConv1d also implements Björck normalization.

MaxPooling
GlobalMaxPooling

yes

n/a

torch.nn.AvgPool2d
torch.nn.AdaptiveAvgPool2d

no

ScaledAvgPool2d
ScaledAdaptiveAvgPool2d
ScaledL2NormPool2d
ScaledAdaptativeL2NormPool2d

The Lipschitz constant is bounded by sqrt(pool_h * pool_w).

Flatten

yes

n/a

torch.nn.ConvTranspose2d

no

SpectralConvTranspose2d

SpectralConvTranspose2d also implements Björck normalization.

torch.nn.BatchNorm1d
torch.nn.BatchNorm2d
torch.nn.BatchNorm3d

no

BatchCentering

This layer apply a bias based on statistics on batch, but no normalization factor (1-Lipschitz).

torch.nn.LayerNorm

no

LayerCentering

This layer apply a bias based on statistics on each sample, but no normalization factor (1-Lipschitz).

Residual connections

no

LipResidual

Learn a factor for mixing residual and a 1-Lipschitz branch .

torch.nn.Dropout

no

None

The Lipschitz constant is bounded by the dropout factor.

How to use it?

Here is a simple example showing how to build a 1-Lipschitz network:

import torch
from deel import torchlip

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# deel-torchlip layers can be used like any torch.nn layers in
# Sequential or other types of container modules.
model = torch.nn.Sequential(
    torchlip.SpectralConv2d(1, 32, (3, 3), padding=1),
    torchlip.SpectralConv2d(32, 32, (3, 3), padding=1),
    torch.nn.MaxPool2d(kernel_size=(2, 2)),
    torchlip.SpectralConv2d(32, 32, (3, 3), padding=1),
    torchlip.SpectralConv2d(32, 32, (3, 3), padding=1),
    torch.nn.MaxPool2d(kernel_size=(2, 2)),
    torch.nn.Flatten(),
    torchlip.SpectralLinear(1568, 256),
    torchlip.SpectralLinear(256, 1)
).to(device)

# Training can be done as usual, except that we are doing
# binary classification with -1 and +1 labels to the target
# must be fixed from the dataset.
optimizer = torch.optim.Adam(lr=0.01, params=model.parameters())
hkr_loss = HKRLoss(alpha=10, min_margin=1)
for data, target in mnist_08:
    data, target = data.to(device), target.to(device)
    optimizer.zero_grad()
    output = model(data)
    loss = hkr_loss(output, target)
    loss.backward()
    optimizer.step()

See deel.torchlip for a complete API description.


© Copyright 2020, IRT Antoine de Saint Exupéry - All rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry, CRIAQ and ANITI..

Built with Sphinx using PyTorch's theme provided originally by Read the Docs.