Example and usage
=================
In order to make things simple the following rules have been followed during development:
* ``deel-torchlip`` follows the ``torch.nn`` package structure.
* When a k-Lipschitz module overrides a standard ``torch.nn`` module, it uses the same
interface and the same parameters.
The only difference is a new parameter to control the Lipschitz constant of a layer.
Which modules are safe to use?
------------------------------
Modules from ``deel-torchlip`` are mostly wrappers around initializers and normalization hooks
that ensure their 1-Lipschitz property.
For instance, the :class:`SpectralLinear` module is simply a :class:`torch.nn.Linear` module
, with automatic orthogonal initialization and hooks:
.. code-block:: python
# This code is about equivalent to SpectralLinear(16, 32)
m = torch.nn.Linear(16, 32)
torch.nn.init.orthogonal_(m.weight)
m.bias.data.fill_(0.0)
torch.nn.utils.spectral_norm(m, "weight", eps=1e-3)
torchlip.utils.bjorck_norm(m, "weight", eps=1e-3)
The following table indicates which module are safe to use in a Lipschitz network, and which are not.
.. role:: raw-html-m2r(raw)
:format: html
.. list-table::
:header-rows: 1
* - ``torch.nn``
- 1-Lipschitz?
- ``deel-torchlip`` equivalent
- comments
* - :class:`torch.nn.Linear`
- no
- :class:`.SpectralLinear` \ :raw-html-m2r:`
`\ :class:`.FrobeniusLinear`
- :class:`.SpectralLinear` and :class:`.FrobeniusLinear` are similar when there is a single output.
* - :class:`torch.nn.Conv2d`
- no
- :class:`.SpectralConv2d` \ :raw-html-m2r:`
`\ :class:`.FrobeniusConv2d`
- :class:`.SpectralConv2d` also implements Björck normalization.
* - :class:`torch.nn.Conv1d`
- no
- :class:`.SpectralConv1d`
- :class:`.SpectralConv1d` also implements Björck normalization.
* - :class:`MaxPooling`\ :raw-html-m2r:`
`\ :class:`GlobalMaxPooling`
- yes
- n/a
-
* - :class:`torch.nn.AvgPool2d`\ :raw-html-m2r:`
`\ :class:`torch.nn.AdaptiveAvgPool2d`
- no
- :class:`.ScaledAvgPool2d`\ :raw-html-m2r:`
`\ :class:`.ScaledAdaptiveAvgPool2d` \ :raw-html-m2r:`
` \ :class:`.ScaledL2NormPool2d` \ :raw-html-m2r:`
` \ :class:`.ScaledAdaptativeL2NormPool2d`
- The Lipschitz constant is bounded by ``sqrt(pool_h * pool_w)``.
* - :class:`Flatten`
- yes
- n/a
-
* - :class:`torch.nn.ConvTranspose2d`
- no
- :class:`.SpectralConvTranspose2d`
- :class:`.SpectralConvTranspose2d` also implements Björck normalization.
* - :class:`torch.nn.BatchNorm1d` \ :raw-html-m2r:`
` \ :class:`torch.nn.BatchNorm2d` \ :raw-html-m2r:`
` \ :class:`torch.nn.BatchNorm3d`
- no
- :class:`.BatchCentering`
- This layer apply a bias based on statistics on batch, but no normalization factor (1-Lipschitz).
* - :class:`torch.nn.LayerNorm`
- no
- :class:`.LayerCentering`
- This layer apply a bias based on statistics on each sample, but no normalization factor (1-Lipschitz).
* - Residual connections
- no
- :class:`.LipResidual`
- Learn a factor for mixing residual and a 1-Lipschitz branch .
* - :class:`torch.nn.Dropout`
- no
- None
- The Lipschitz constant is bounded by the dropout factor.
How to use it?
--------------
Here is a simple example showing how to build a 1-Lipschitz network:
.. code-block:: python
import torch
from deel import torchlip
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# deel-torchlip layers can be used like any torch.nn layers in
# Sequential or other types of container modules.
model = torch.nn.Sequential(
torchlip.SpectralConv2d(1, 32, (3, 3), padding=1),
torchlip.SpectralConv2d(32, 32, (3, 3), padding=1),
torch.nn.MaxPool2d(kernel_size=(2, 2)),
torchlip.SpectralConv2d(32, 32, (3, 3), padding=1),
torchlip.SpectralConv2d(32, 32, (3, 3), padding=1),
torch.nn.MaxPool2d(kernel_size=(2, 2)),
torch.nn.Flatten(),
torchlip.SpectralLinear(1568, 256),
torchlip.SpectralLinear(256, 1)
).to(device)
# Training can be done as usual, except that we are doing
# binary classification with -1 and +1 labels to the target
# must be fixed from the dataset.
optimizer = torch.optim.Adam(lr=0.01, params=model.parameters())
hkr_loss = HKRLoss(alpha=10, min_margin=1)
for data, target in mnist_08:
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = hkr_loss(output, target)
loss.backward()
optimizer.step()
See :ref:`deel-torchlip-api` for a complete API description.