Shortcuts

Example and usage

In order to make things simple the following rules have been followed during development:

  • deel-torchlip follows the torch.nn package structure.

  • When a k-Lipschitz module overrides a standard torch.nn module, it uses the same interface and the same parameters. The only difference is a new parameter to control the Lipschitz constant of a layer.

Which modules are safe to use?

Modules from deel-torchlip are mostly wrappers around initializers and normalization hooks that ensure their 1-Lipschitz property. For instance, the SpectralLinear module is simply a torch.nn.Linear module , with automatic orthogonal initialization and hooks:

# This code is about equivalent to SpectralLinear(16, 32)
m = torch.nn.Linear(16, 32)
torch.nn.init.orthogonal_(m.weight)
m.bias.data.fill_(0.0)

torch.nn.utils.spectral_norm(m, "weight", 3)
torchlip.utils.bjorck_norm(m, "weight", 15)

The following table indicates which module are safe to use in a Lipschitz network, and which are not.

torch.nn

1-Lipschitz?

deel-torchlip equivalent

comments

torch.nn.Linear

no

SpectralLinear
FrobeniusLinear

SpectralLinear and FrobeniusLinear are similar when there is a single output.

torch.nn.Conv2d

no

SpectralConv2d
FrobeniusConv2d

SpectralConv2d also implements Björck normalization.

MaxPooling
GlobalMaxPooling

yes

n/a

torch.nn.AvgPool2d
torch.nn.AdaptiveAvgPool2d

no

ScaledAvgPool2d
ScaledAdaptiveAvgPool2d
ScaledL2NormPool2d

The Lipschitz constant is bounded by sqrt(pool_h * pool_w).

Flatten

yes

n/a

torch.nn.Dropout

no

None

The Lipschitz constant is bounded by the dropout factor.

torch.nn.BatchNorm1d
torch.nn.BatchNorm2d
torch.nn.BatchNorm3d

no

None

We suspect that layer normalization already limits internal covariate shift.

How to use it?

Here is a simple example showing how to build a 1-Lipschitz network:

import torch
from deel import torchlip

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# deel-torchlip layers can be used like any torch.nn layers in
# Sequential or other types of container modules.
model = torch.nn.Sequential(
    torchlip.SpectralConv2d(1, 32, (3, 3), padding=1),
    torchlip.SpectralConv2d(32, 32, (3, 3), padding=1),
    torch.nn.MaxPool2d(kernel_size=(2, 2)),
    torchlip.SpectralConv2d(32, 32, (3, 3), padding=1),
    torchlip.SpectralConv2d(32, 32, (3, 3), padding=1),
    torch.nn.MaxPool2d(kernel_size=(2, 2)),
    torch.nn.Flatten(),
    torchlip.SpectralLinear(1568, 256),
    torchlip.SpectralLinear(256, 1)
).to(device)

# Training can be done as usual, except that we are doing
# binary classification with -1 and +1 labels to the target
# must be fixed from the dataset.
optimizer = torch.optim.Adam(lr=0.01, params=model.parameters())
for data, target in mnist_08:
    data, target = data.to(device), target.to(device)
    optimizer.zero_grad()
    output = model(data)
    loss = torchlip.functional.hkr_loss(output, target, alpha=10, min_margin=1)
    loss.backward()
    optimizer.step()

See deel.torchlip for a complete API description.


© Copyright 2020, IRT Antoine de Saint Exupéry - All rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry, CRIAQ and ANITI..

Built with Sphinx using PyTorch's theme provided originally by Read the Docs.