Skip to content

Demo 0: Example and usage

In order to make things simple the following rules have been followed during development:

  • deel-lip follows the keras package structure.
  • All elements (layers, activations, initializers, ...) are compatible with standard the keras elements.
  • When a k-Lipschitz layer overrides a standard keras layer, it uses the same interface and the same parameters. The only difference is a new parameter to control the Lipschitz constant of a layer.

Which layers are safe to use?

The following table indicates which layers are safe to use in a Lipshitz network, and which are not.

layer 1-lip? deel-lip equivalent comments
Dense no SpectralDense
FrobeniusDense
SpectralDense and FrobeniusDense are similar when there is a single output.
Conv2D no SpectralConv2D
FrobeniusConv2D
SpectralConv2D also implements Björck normalization.
MaxPooling
GlobalMaxPooling
yes n/a
AveragePooling2D
GlobalAveragePooling2D
no ScaledAveragePooling2D
ScaledGlobalAveragePooling2D
The lipschitz constant is bounded by sqrt(pool_h * pool_h).
Flatten yes n/a
Dropout no None The lipschitz constant is bounded by the dropout factor.
BatchNormalization no None We suspect that layer normalization already limits internal covariate shift.

Design tips

Designing lipschitz networks requires a careful design in order to avoid vanishing/exploding gradient problems.

Choosing pooling layers:

layer advantages disadvantages
ScaledAveragePooling2D and MaxPooling2D very similar to original implementation (just add a scaling factor for avg). not norm preserving nor gradient norm preserving.
InvertibleDownSampling norm preserving and gradient norm preserving. increases the number of channels (and the number of parameters of the next layer).
ScaledL2NormPooling2D (sqrt(avgpool(x**2))) norm preserving. lower numerical stability of the gradient when inputs are close to zero.

Choosing activations:

layer advantages disadvantages
ReLU create a strong vanishing gradient effect. If you manage to learn with it, please call 911.
MaxMin (stack([ReLU(x), ReLU(-x)])) have similar properties to ReLU, but is norm and gradient norm preserving double the number of outputs
GroupSort Input and GradientNorm preserving. Also limit the need of biases (as it is shift invariant). more computationally expensive, (when its parameter n is large)

Please note that when learning with the HKR_loss and HKR_multiclass_loss, no activation is required on the last layer.

How to use it ?

Open In Colab

Here is an example of 1-lipschitz network trained on MNIST:

from deel.lip.layers import (
    SpectralDense,
    SpectralConv2D,
    ScaledL2NormPooling2D,
    FrobeniusDense,
)
from deel.lip.model import Sequential
from deel.lip.activations import GroupSort
from deel.lip.losses import MulticlassHKR, MulticlassKR
from tensorflow.keras.layers import Input, Flatten
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
import numpy as np

# Sequential (resp Model) from deel.model has the same properties as any lipschitz model.
# It act only as a container, with features specific to lipschitz
# functions (condensation, vanilla_exportation...) but The layers are fully compatible
# with the tf.keras.model.Sequential/Model
model = Sequential(
    [
        Input(shape=(28, 28, 1)),
        # Lipschitz layers preserve the API of their superclass ( here Conv2D )
        # an optional param is available: k_coef_lip which control the lipschitz
        # constant of the layer
        SpectralConv2D(
            filters=16,
            kernel_size=(3, 3),
            activation=GroupSort(2),
            use_bias=True,
            kernel_initializer="orthogonal",
        ),
        # usual pooling layer are implemented (avg, max...), but new layers are also available
        ScaledL2NormPooling2D(pool_size=(2, 2), data_format="channels_last"),
        SpectralConv2D(
            filters=16,
            kernel_size=(3, 3),
            activation=GroupSort(2),
            use_bias=True,
            kernel_initializer="orthogonal",
        ),
        ScaledL2NormPooling2D(pool_size=(2, 2), data_format="channels_last"),
        # our layers are fully interoperable with existing keras layers
        Flatten(),
        SpectralDense(
            32,
            activation=GroupSort(2),
            use_bias=True,
            kernel_initializer="orthogonal",
        ),
        FrobeniusDense(
            10, activation=None, use_bias=False, kernel_initializer="orthogonal"
        ),
    ],
    # similary model has a parameter to set the lipschitz constant
    # to set automatically the constant of each layer
    k_coef_lip=1.0,
    name="hkr_model",
)

# HKR (Hinge-Krantorovich-Rubinstein) optimize robustness along with accuracy
model.compile(
    # decreasing alpha and increasing min_margin improve robustness (at the cost of accuracy)
    # note also in the case of lipschitz networks, more robustness require more parameters.
    loss=MulticlassHKR(alpha=50, min_margin=0.05),
    optimizer=Adam(1e-3),
    metrics=["accuracy", MulticlassKR()],
)

model.summary()

# load data
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# standardize and reshape the data
x_train = np.expand_dims(x_train, -1)
mean = x_train.mean()
std = x_train.std()
x_train = (x_train - mean) / std
x_test = np.expand_dims(x_test, -1)
x_test = (x_test - mean) / std
# one hot encode the labels
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)

# fit the model
model.fit(
    x_train,
    y_train,
    batch_size=2048,
    epochs=30,
    validation_data=(x_test, y_test),
    shuffle=True,
)

# once training is finished you can convert
# SpectralDense layers into Dense layers and SpectralConv2D into Conv2D
# which optimize performance for inference
vanilla_model = model.vanilla_export()
/home/thibaut.boissin/projects/deel-lip/deel/lip/model.py:56: UserWarning: Sequential model contains a layer wich is not a Lipschitz layer: flatten_2
  layer.name

Model: "hkr_model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
spectral_conv2d_4 (SpectralC (None, 28, 28, 16)        321       
_________________________________________________________________
scaled_l2norm_pooling2d_4 (S (None, 14, 14, 16)        0         
_________________________________________________________________
spectral_conv2d_5 (SpectralC (None, 14, 14, 16)        4641      
_________________________________________________________________
scaled_l2norm_pooling2d_5 (S (None, 7, 7, 16)          0         
_________________________________________________________________
flatten_2 (Flatten)          (None, 784)               0         
_________________________________________________________________
spectral_dense_2 (SpectralDe (None, 32)                50241     
_________________________________________________________________
frobenius_dense_2 (Frobenius (None, 10)                640       
=================================================================
Total params: 55,843
Trainable params: 27,920
Non-trainable params: 27,923
_________________________________________________________________
Epoch 1/30
30/30 [==============================] - 3s 43ms/step - loss: 6.5323 - accuracy: 0.1522 - MulticlassKR: 0.0183 - val_loss: 2.3933 - val_accuracy: 0.4873 - val_MulticlassKR: 0.0942
Epoch 2/30
30/30 [==============================] - 1s 39ms/step - loss: 2.0856 - accuracy: 0.5528 - MulticlassKR: 0.1149 - val_loss: 1.3480 - val_accuracy: 0.7091 - val_MulticlassKR: 0.1653
Epoch 3/30
30/30 [==============================] - 1s 35ms/step - loss: 1.2743 - accuracy: 0.7298 - MulticlassKR: 0.1743 - val_loss: 0.9228 - val_accuracy: 0.7942 - val_MulticlassKR: 0.2097
Epoch 4/30
30/30 [==============================] - 1s 36ms/step - loss: 0.9001 - accuracy: 0.7975 - MulticlassKR: 0.2168 - val_loss: 0.6864 - val_accuracy: 0.8368 - val_MulticlassKR: 0.2486
Epoch 5/30
30/30 [==============================] - 1s 35ms/step - loss: 0.6889 - accuracy: 0.8338 - MulticlassKR: 0.2546 - val_loss: 0.5352 - val_accuracy: 0.8650 - val_MulticlassKR: 0.2835
Epoch 6/30
30/30 [==============================] - 1s 37ms/step - loss: 0.5256 - accuracy: 0.8609 - MulticlassKR: 0.2879 - val_loss: 0.4442 - val_accuracy: 0.8805 - val_MulticlassKR: 0.3166
Epoch 7/30
30/30 [==============================] - 1s 34ms/step - loss: 0.4469 - accuracy: 0.8735 - MulticlassKR: 0.3186 - val_loss: 0.3349 - val_accuracy: 0.8911 - val_MulticlassKR: 0.3470
Epoch 8/30
30/30 [==============================] - 1s 34ms/step - loss: 0.3493 - accuracy: 0.8835 - MulticlassKR: 0.3480 - val_loss: 0.2641 - val_accuracy: 0.8961 - val_MulticlassKR: 0.3787
Epoch 9/30
30/30 [==============================] - 1s 34ms/step - loss: 0.2722 - accuracy: 0.8938 - MulticlassKR: 0.3818 - val_loss: 0.2122 - val_accuracy: 0.8993 - val_MulticlassKR: 0.4127
Epoch 10/30
30/30 [==============================] - 1s 34ms/step - loss: 0.2036 - accuracy: 0.9013 - MulticlassKR: 0.4153 - val_loss: 0.1330 - val_accuracy: 0.9079 - val_MulticlassKR: 0.4487
Epoch 11/30
30/30 [==============================] - 1s 35ms/step - loss: 0.1472 - accuracy: 0.9059 - MulticlassKR: 0.4505 - val_loss: 0.0799 - val_accuracy: 0.9126 - val_MulticlassKR: 0.4861
Epoch 12/30
30/30 [==============================] - 1s 35ms/step - loss: 0.0939 - accuracy: 0.9103 - MulticlassKR: 0.4915 - val_loss: 0.0371 - val_accuracy: 0.9142 - val_MulticlassKR: 0.5313
Epoch 13/30
30/30 [==============================] - 1s 40ms/step - loss: 0.0499 - accuracy: 0.9100 - MulticlassKR: 0.5346 - val_loss: -0.0211 - val_accuracy: 0.9206 - val_MulticlassKR: 0.5729
Epoch 14/30
30/30 [==============================] - 1s 39ms/step - loss: -0.0216 - accuracy: 0.9162 - MulticlassKR: 0.5760 - val_loss: -0.0682 - val_accuracy: 0.9200 - val_MulticlassKR: 0.6219
Epoch 15/30
30/30 [==============================] - 1s 35ms/step - loss: -0.0666 - accuracy: 0.9168 - MulticlassKR: 0.6248 - val_loss: -0.1301 - val_accuracy: 0.9236 - val_MulticlassKR: 0.6742
Epoch 16/30
30/30 [==============================] - 1s 35ms/step - loss: -0.1223 - accuracy: 0.9197 - MulticlassKR: 0.6778 - val_loss: -0.1777 - val_accuracy: 0.9270 - val_MulticlassKR: 0.7275
Epoch 17/30
30/30 [==============================] - 1s 36ms/step - loss: -0.1605 - accuracy: 0.9199 - MulticlassKR: 0.7291 - val_loss: -0.2426 - val_accuracy: 0.9272 - val_MulticlassKR: 0.7900
Epoch 18/30
30/30 [==============================] - 1s 36ms/step - loss: -0.2278 - accuracy: 0.9218 - MulticlassKR: 0.7886 - val_loss: -0.2883 - val_accuracy: 0.9305 - val_MulticlassKR: 0.8471
Epoch 19/30
30/30 [==============================] - 1s 40ms/step - loss: -0.2246 - accuracy: 0.9170 - MulticlassKR: 0.8478 - val_loss: -0.3104 - val_accuracy: 0.9183 - val_MulticlassKR: 0.9070
Epoch 20/30
30/30 [==============================] - 1s 34ms/step - loss: -0.3066 - accuracy: 0.9213 - MulticlassKR: 0.9085 - val_loss: -0.3778 - val_accuracy: 0.9284 - val_MulticlassKR: 0.9754
Epoch 21/30
30/30 [==============================] - 1s 39ms/step - loss: -0.3736 - accuracy: 0.9241 - MulticlassKR: 0.9739 - val_loss: -0.4258 - val_accuracy: 0.9280 - val_MulticlassKR: 1.0388
Epoch 22/30
30/30 [==============================] - 1s 35ms/step - loss: -0.4180 - accuracy: 0.9229 - MulticlassKR: 1.0337 - val_loss: -0.4805 - val_accuracy: 0.9302 - val_MulticlassKR: 1.1069
Epoch 23/30
30/30 [==============================] - 1s 38ms/step - loss: -0.4624 - accuracy: 0.9234 - MulticlassKR: 1.1055 - val_loss: -0.5607 - val_accuracy: 0.9312 - val_MulticlassKR: 1.1803
Epoch 24/30
30/30 [==============================] - 1s 35ms/step - loss: -0.5279 - accuracy: 0.9257 - MulticlassKR: 1.1797 - val_loss: -0.5866 - val_accuracy: 0.9275 - val_MulticlassKR: 1.2456
Epoch 25/30
30/30 [==============================] - 1s 38ms/step - loss: -0.5482 - accuracy: 0.9218 - MulticlassKR: 1.2388 - val_loss: -0.6441 - val_accuracy: 0.9310 - val_MulticlassKR: 1.3125
Epoch 26/30
30/30 [==============================] - 1s 35ms/step - loss: -0.6375 - accuracy: 0.9263 - MulticlassKR: 1.3103 - val_loss: -0.6890 - val_accuracy: 0.9295 - val_MulticlassKR: 1.3795
Epoch 27/30
30/30 [==============================] - 1s 42ms/step - loss: -0.6668 - accuracy: 0.9230 - MulticlassKR: 1.3719 - val_loss: -0.7413 - val_accuracy: 0.9271 - val_MulticlassKR: 1.4496
Epoch 28/30
30/30 [==============================] - 1s 35ms/step - loss: -0.7483 - accuracy: 0.9264 - MulticlassKR: 1.4371 - val_loss: -0.7748 - val_accuracy: 0.9296 - val_MulticlassKR: 1.5096
Epoch 29/30
30/30 [==============================] - 1s 49ms/step - loss: -0.7495 - accuracy: 0.9229 - MulticlassKR: 1.4900 - val_loss: -0.8622 - val_accuracy: 0.9332 - val_MulticlassKR: 1.5644
Epoch 30/30
30/30 [==============================] - 1s 35ms/step - loss: -0.8047 - accuracy: 0.9246 - MulticlassKR: 1.5530 - val_loss: -0.8732 - val_accuracy: 0.9297 - val_MulticlassKR: 1.6220