Skip to content

Demo 4: HKR Multiclass and fooling

Demo 4: HKR multiclass and fooling

Open In Colab

This notebook will show how to train a lispchitz network in a multiclass setup. The HKR is extended to multiclass using a one-vs all setup. It will go through the process of designing and training the network. It will also show how to create robustness certificates from the output of the network. Finally these certificates will be checked by attacking the network.

installation

First, we install the required libraries. Foolbox will allow to perform adversarial attacks on the trained network.

# pip install deel-lip foolbox -qqq
from deel.lip.layers import (
    SpectralDense,
    SpectralConv2D,
    ScaledL2NormPooling2D,
    ScaledAveragePooling2D,
    FrobeniusDense,
)
from deel.lip.model import Sequential
from deel.lip.activations import GroupSort, FullSort
from deel.lip.losses import MulticlassHKR, MulticlassKR
from deel.lip.callbacks import CondenseCallback
from tensorflow.keras.layers import Input, Flatten
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.datasets import mnist, fashion_mnist, cifar10
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import numpy as np
2021-09-09 14:03:36.448213: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0

For this example, the dataset fashion_mnist will be used. In order to keep things simple, no data augmentation will be performed.

# load data
(x_train, y_train_ord), (x_test, y_test_ord) = fashion_mnist.load_data()
# standardize and reshape the data
x_train = np.expand_dims(x_train, -1) / 255
x_test = np.expand_dims(x_test, -1) / 255
# one hot encode the labels
y_train = to_categorical(y_train_ord)
y_test = to_categorical(y_test_ord)

Let's build the network.

the architecture

The original one vs all setup would require 10 different networks ( 1 per class ), however, in practice we use a network with a common body and 10 1-lipschitz heads. Experiments have shown that this setup don't affect the network performance. In order to ease the creation of such network, FrobeniusDense layer has a parameter for this: whenr disjoint_neurons=True it act as the stacking of 10 single neurons head. Note that, altough each head is a 1-lipschitz function the overall network is not 1-lipschitz (Concatenation is not 1-lipschitz). We will see later how this affects the certficate creation.

the loss

The multiclass loss can be found in HKR_multiclass_loss. The loss has two params: alpha and min_margin. Decreasing alpha and increasing min_margin improve robustness (at the cost of accuracy). note also in the case of lipschitz networks, more robustness require more parameters. For more information see our paper.

In this setup choosing alpha=100, min_margin=.25 provide a good robustness without hurting the accuracy too much.

Finally the KR_multiclass_loss() indicate the robustness of the network ( proxy of the average certificate )

# Sequential (resp Model) from deel.model has the same properties as any lipschitz model.
# It act only as a container, with features specific to lipschitz
# functions (condensation, vanilla_exportation...)
model = Sequential(
    [
        Input(shape=x_train.shape[1:]),
        # Lipschitz layers preserve the API of their superclass ( here Conv2D )
        # an optional param is available: k_coef_lip which control the lipschitz
        # constant of the layer
        SpectralConv2D(
            filters=16,
            kernel_size=(3, 3),
            activation=GroupSort(2),
            use_bias=True,
            kernel_initializer="orthogonal",
        ),
        # usual pooling layer are implemented (avg, max...), but new layers are also available
      ScaledL2NormPooling2D(pool_size=(2, 2), data_format="channels_last"),
        SpectralConv2D(
            filters=32,
            kernel_size=(3, 3),
            activation=GroupSort(2),
            use_bias=True,
            kernel_initializer="orthogonal",
        ),
      ScaledL2NormPooling2D(pool_size=(2, 2), data_format="channels_last"),
        # our layers are fully interoperable with existing keras layers
        Flatten(),
        SpectralDense(
            64,
            activation=GroupSort(2),
            use_bias=True,
            kernel_initializer="orthogonal",
        ),
        FrobeniusDense(
            y_train.shape[-1], activation=None, use_bias=False, kernel_initializer="orthogonal"
        ),
    ],
    # similary model has a parameter to set the lipschitz constant
    # to set automatically the constant of each layer
    k_coef_lip=1.0,
    name="hkr_model",
)

# HKR (Hinge-Krantorovich-Rubinstein) optimize robustness along with accuracy
model.compile(
    # decreasing alpha and increasing min_margin improve robustness (at the cost of accuracy)
    # note also in the case of lipschitz networks, more robustness require more parameters.
    loss=MulticlassHKR(alpha=100, min_margin=.25),
    optimizer=Adam(1e-4),
    metrics=["accuracy", MulticlassKR()],
)

model.summary()
2021-09-09 14:03:38.719310: I tensorflow/compiler/jit/xla_cpu_device.cc:41] Not creating XLA devices, tf_xla_enable_xla_devices not set
2021-09-09 14:03:38.719800: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcuda.so.1
2021-09-09 14:03:38.750242: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:941] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-09-09 14:03:38.750491: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1720] Found device 0 with properties: 
pciBusID: 0000:01:00.0 name: GeForce RTX 2070 SUPER computeCapability: 7.5
coreClock: 1.785GHz coreCount: 40 deviceMemorySize: 7.79GiB deviceMemoryBandwidth: 417.29GiB/s
2021-09-09 14:03:38.750504: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
2021-09-09 14:03:38.751559: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublas.so.11
2021-09-09 14:03:38.751584: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublasLt.so.11
2021-09-09 14:03:38.752047: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcufft.so.10
2021-09-09 14:03:38.752161: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcurand.so.10
2021-09-09 14:03:38.753239: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcusolver.so.10
2021-09-09 14:03:38.753476: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcusparse.so.11
2021-09-09 14:03:38.753540: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudnn.so.8
2021-09-09 14:03:38.753583: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:941] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-09-09 14:03:38.753826: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:941] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-09-09 14:03:38.754040: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1862] Adding visible gpu devices: 0
2021-09-09 14:03:38.754479: I tensorflow/compiler/jit/xla_gpu_device.cc:99] Not creating XLA devices, tf_xla_enable_xla_devices not set
2021-09-09 14:03:38.754559: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:941] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-09-09 14:03:38.754781: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1720] Found device 0 with properties: 
pciBusID: 0000:01:00.0 name: GeForce RTX 2070 SUPER computeCapability: 7.5
coreClock: 1.785GHz coreCount: 40 deviceMemorySize: 7.79GiB deviceMemoryBandwidth: 417.29GiB/s
2021-09-09 14:03:38.754792: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
2021-09-09 14:03:38.754799: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublas.so.11
2021-09-09 14:03:38.754806: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublasLt.so.11
2021-09-09 14:03:38.754812: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcufft.so.10
2021-09-09 14:03:38.754818: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcurand.so.10
2021-09-09 14:03:38.754824: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcusolver.so.10
2021-09-09 14:03:38.754831: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcusparse.so.11
2021-09-09 14:03:38.754837: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudnn.so.8
2021-09-09 14:03:38.754865: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:941] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-09-09 14:03:38.755095: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:941] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-09-09 14:03:38.755303: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1862] Adding visible gpu devices: 0
2021-09-09 14:03:38.755319: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
2021-09-09 14:03:39.211037: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1261] Device interconnect StreamExecutor with strength 1 edge matrix:
2021-09-09 14:03:39.211059: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1267]      0 
2021-09-09 14:03:39.211064: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1280] 0:   N 
2021-09-09 14:03:39.211182: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:941] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-09-09 14:03:39.211426: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:941] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-09-09 14:03:39.211643: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:941] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-09-09 14:03:39.211849: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1406] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 7250 MB memory) -> physical GPU (device: 0, name: GeForce RTX 2070 SUPER, pci bus id: 0000:01:00.0, compute capability: 7.5)

Model: "hkr_model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
spectral_conv2d (SpectralCon (None, 28, 28, 16)        321       
_________________________________________________________________
scaled_l2norm_pooling2d (Sca (None, 14, 14, 16)        0         
_________________________________________________________________
spectral_conv2d_1 (SpectralC (None, 14, 14, 32)        9281      
_________________________________________________________________
scaled_l2norm_pooling2d_1 (S (None, 7, 7, 32)          0         
_________________________________________________________________
flatten (Flatten)            (None, 1568)              0         
_________________________________________________________________
spectral_dense (SpectralDens (None, 64)                200833    
_________________________________________________________________
frobenius_dense (FrobeniusDe (None, 10)                1280      
=================================================================
Total params: 211,715
Trainable params: 105,856
Non-trainable params: 105,859
_________________________________________________________________

/home/thibaut.boissin/projects/repo_github/deel-lip/deel/lip/model.py:56: UserWarning: Sequential model contains a layer wich is not a Lipschitz layer: flatten
  layer.name

notes about constraint enforcement

There are currently 3 way to enforce a constraint in a network: 1. regularization 2. weight reparametrization 3. weight projection

The first one don't provide the required garanties, this is why deel-lip focuses on the later two. Weight reparametrization is done directly in the layers (parameter niter_bjorck) this trick allow to perform arbitrary gradient updates without breaking the constraint. However this is done in the graph, increasing ressources consumption. The last method project the weights between each batch, ensuring the constraint at an more affordable computational cost. It can be done in deel-lip using the CondenseCallback. The main problem with this method is a reduced efficiency of each update.

As a rule of thumb, when reparametrization is used alone, setting niter_bjorck to at least 15 is advised. However when combined with weight projection, this setting can be lowered greatly.

# fit the model
model.fit(
    x_train,
    y_train,
    batch_size=4096,
    epochs=100,
    validation_data=(x_test, y_test),
    shuffle=True,
    verbose=1,
)
2021-09-09 14:03:40.083840: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:116] None of the MLIR optimization passes are enabled (registered 2)
2021-09-09 14:03:40.100871: I tensorflow/core/platform/profile_utils/cpu_utils.cc:112] CPU Frequency: 3600000000 Hz

Epoch 1/100

2021-09-09 14:03:42.102055: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublas.so.11
2021-09-09 14:03:42.320388: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublasLt.so.11
2021-09-09 14:03:42.331382: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudnn.so.8

15/15 [==============================] - 5s 117ms/step - loss: 41.2174 - accuracy: 0.1382 - MulticlassKR: 0.0467 - val_loss: 29.5743 - val_accuracy: 0.2798 - val_MulticlassKR: 0.1810
Epoch 2/100
15/15 [==============================] - 1s 81ms/step - loss: 25.3826 - accuracy: 0.4441 - MulticlassKR: 0.2389 - val_loss: 19.8280 - val_accuracy: 0.5547 - val_MulticlassKR: 0.3549
Epoch 3/100
15/15 [==============================] - 1s 81ms/step - loss: 18.3231 - accuracy: 0.5899 - MulticlassKR: 0.4017 - val_loss: 16.0346 - val_accuracy: 0.6183 - val_MulticlassKR: 0.4835
Epoch 4/100
15/15 [==============================] - 1s 81ms/step - loss: 15.0896 - accuracy: 0.6402 - MulticlassKR: 0.5135 - val_loss: 13.9297 - val_accuracy: 0.6470 - val_MulticlassKR: 0.5607
Epoch 5/100
15/15 [==============================] - 1s 81ms/step - loss: 13.2237 - accuracy: 0.6814 - MulticlassKR: 0.5821 - val_loss: 12.5531 - val_accuracy: 0.6814 - val_MulticlassKR: 0.6186
Epoch 6/100
15/15 [==============================] - 1s 81ms/step - loss: 12.0225 - accuracy: 0.7057 - MulticlassKR: 0.6364 - val_loss: 11.6916 - val_accuracy: 0.6964 - val_MulticlassKR: 0.6655
Epoch 7/100
15/15 [==============================] - 1s 81ms/step - loss: 11.2456 - accuracy: 0.7178 - MulticlassKR: 0.6803 - val_loss: 11.0661 - val_accuracy: 0.7131 - val_MulticlassKR: 0.7020
Epoch 8/100
15/15 [==============================] - 1s 81ms/step - loss: 10.7023 - accuracy: 0.7343 - MulticlassKR: 0.7144 - val_loss: 10.6094 - val_accuracy: 0.7190 - val_MulticlassKR: 0.7339
Epoch 9/100
15/15 [==============================] - 1s 81ms/step - loss: 10.2158 - accuracy: 0.7353 - MulticlassKR: 0.7471 - val_loss: 10.2140 - val_accuracy: 0.7255 - val_MulticlassKR: 0.7639
Epoch 10/100
15/15 [==============================] - 1s 80ms/step - loss: 9.9306 - accuracy: 0.7444 - MulticlassKR: 0.7743 - val_loss: 9.8911 - val_accuracy: 0.7341 - val_MulticlassKR: 0.7875
Epoch 11/100
15/15 [==============================] - 1s 80ms/step - loss: 9.4766 - accuracy: 0.7500 - MulticlassKR: 0.8008 - val_loss: 9.5676 - val_accuracy: 0.7397 - val_MulticlassKR: 0.8139
Epoch 12/100
15/15 [==============================] - 1s 80ms/step - loss: 9.2583 - accuracy: 0.7547 - MulticlassKR: 0.8227 - val_loss: 9.3108 - val_accuracy: 0.7445 - val_MulticlassKR: 0.8375
Epoch 13/100
15/15 [==============================] - 1s 81ms/step - loss: 9.0268 - accuracy: 0.7571 - MulticlassKR: 0.8463 - val_loss: 9.0594 - val_accuracy: 0.7461 - val_MulticlassKR: 0.8565
Epoch 14/100
15/15 [==============================] - 1s 80ms/step - loss: 8.7289 - accuracy: 0.7631 - MulticlassKR: 0.8653 - val_loss: 8.8221 - val_accuracy: 0.7563 - val_MulticlassKR: 0.8798
Epoch 15/100
15/15 [==============================] - 1s 81ms/step - loss: 8.5468 - accuracy: 0.7660 - MulticlassKR: 0.8856 - val_loss: 8.6213 - val_accuracy: 0.7566 - val_MulticlassKR: 0.8976
Epoch 16/100
15/15 [==============================] - 1s 80ms/step - loss: 8.3208 - accuracy: 0.7699 - MulticlassKR: 0.9078 - val_loss: 8.4393 - val_accuracy: 0.7672 - val_MulticlassKR: 0.9187
Epoch 17/100
15/15 [==============================] - 1s 80ms/step - loss: 8.1348 - accuracy: 0.7747 - MulticlassKR: 0.9288 - val_loss: 8.2421 - val_accuracy: 0.7644 - val_MulticlassKR: 0.9369
Epoch 18/100
15/15 [==============================] - 1s 80ms/step - loss: 7.8150 - accuracy: 0.7807 - MulticlassKR: 0.9479 - val_loss: 8.0528 - val_accuracy: 0.7741 - val_MulticlassKR: 0.9598
Epoch 19/100
15/15 [==============================] - 1s 80ms/step - loss: 7.7277 - accuracy: 0.7813 - MulticlassKR: 0.9697 - val_loss: 7.8976 - val_accuracy: 0.7749 - val_MulticlassKR: 0.9754
Epoch 20/100
15/15 [==============================] - 1s 80ms/step - loss: 7.5802 - accuracy: 0.7822 - MulticlassKR: 0.9866 - val_loss: 7.7375 - val_accuracy: 0.7784 - val_MulticlassKR: 0.9936
Epoch 21/100
15/15 [==============================] - 1s 80ms/step - loss: 7.3151 - accuracy: 0.7893 - MulticlassKR: 1.0068 - val_loss: 7.5871 - val_accuracy: 0.7818 - val_MulticlassKR: 1.0131
Epoch 22/100
15/15 [==============================] - 1s 81ms/step - loss: 7.2699 - accuracy: 0.7901 - MulticlassKR: 1.0211 - val_loss: 7.4710 - val_accuracy: 0.7807 - val_MulticlassKR: 1.0305
Epoch 23/100
15/15 [==============================] - 1s 83ms/step - loss: 7.1052 - accuracy: 0.7939 - MulticlassKR: 1.0391 - val_loss: 7.3397 - val_accuracy: 0.7854 - val_MulticlassKR: 1.0450
Epoch 24/100
15/15 [==============================] - 1s 80ms/step - loss: 7.0167 - accuracy: 0.7962 - MulticlassKR: 1.0562 - val_loss: 7.2212 - val_accuracy: 0.7870 - val_MulticlassKR: 1.0637
Epoch 25/100
15/15 [==============================] - 1s 80ms/step - loss: 6.8205 - accuracy: 0.8002 - MulticlassKR: 1.0749 - val_loss: 7.1256 - val_accuracy: 0.7895 - val_MulticlassKR: 1.0808
Epoch 26/100
15/15 [==============================] - 1s 80ms/step - loss: 6.7542 - accuracy: 0.8013 - MulticlassKR: 1.0923 - val_loss: 7.0068 - val_accuracy: 0.7897 - val_MulticlassKR: 1.0966
Epoch 27/100
15/15 [==============================] - 1s 81ms/step - loss: 6.6025 - accuracy: 0.8022 - MulticlassKR: 1.1069 - val_loss: 6.8967 - val_accuracy: 0.7924 - val_MulticlassKR: 1.1105
Epoch 28/100
15/15 [==============================] - 1s 80ms/step - loss: 6.5729 - accuracy: 0.8033 - MulticlassKR: 1.1220 - val_loss: 6.8168 - val_accuracy: 0.7951 - val_MulticlassKR: 1.1275
Epoch 29/100
15/15 [==============================] - 1s 80ms/step - loss: 6.5147 - accuracy: 0.8074 - MulticlassKR: 1.1347 - val_loss: 6.7141 - val_accuracy: 0.7971 - val_MulticlassKR: 1.1425
Epoch 30/100
15/15 [==============================] - 1s 80ms/step - loss: 6.4094 - accuracy: 0.8059 - MulticlassKR: 1.1528 - val_loss: 6.6193 - val_accuracy: 0.7998 - val_MulticlassKR: 1.1605
Epoch 31/100
15/15 [==============================] - 1s 82ms/step - loss: 6.3102 - accuracy: 0.8090 - MulticlassKR: 1.1664 - val_loss: 6.5371 - val_accuracy: 0.8005 - val_MulticlassKR: 1.1746
Epoch 32/100
15/15 [==============================] - 1s 80ms/step - loss: 6.1902 - accuracy: 0.8078 - MulticlassKR: 1.1889 - val_loss: 6.4705 - val_accuracy: 0.8004 - val_MulticlassKR: 1.1924
Epoch 33/100
15/15 [==============================] - 1s 80ms/step - loss: 6.1780 - accuracy: 0.8127 - MulticlassKR: 1.1991 - val_loss: 6.3850 - val_accuracy: 0.8033 - val_MulticlassKR: 1.2076
Epoch 34/100
15/15 [==============================] - 1s 80ms/step - loss: 6.1156 - accuracy: 0.8123 - MulticlassKR: 1.2147 - val_loss: 6.3106 - val_accuracy: 0.8091 - val_MulticlassKR: 1.2191
Epoch 35/100
15/15 [==============================] - 1s 81ms/step - loss: 6.0083 - accuracy: 0.8143 - MulticlassKR: 1.2322 - val_loss: 6.2621 - val_accuracy: 0.8086 - val_MulticlassKR: 1.2360
Epoch 36/100
15/15 [==============================] - 1s 80ms/step - loss: 5.9177 - accuracy: 0.8158 - MulticlassKR: 1.2462 - val_loss: 6.1842 - val_accuracy: 0.8101 - val_MulticlassKR: 1.2483
Epoch 37/100
15/15 [==============================] - 1s 80ms/step - loss: 5.7953 - accuracy: 0.8186 - MulticlassKR: 1.2662 - val_loss: 6.1092 - val_accuracy: 0.8119 - val_MulticlassKR: 1.2654
Epoch 38/100
15/15 [==============================] - 1s 80ms/step - loss: 5.7620 - accuracy: 0.8179 - MulticlassKR: 1.2781 - val_loss: 6.0499 - val_accuracy: 0.8126 - val_MulticlassKR: 1.2815
Epoch 39/100
15/15 [==============================] - 1s 80ms/step - loss: 5.7588 - accuracy: 0.8187 - MulticlassKR: 1.2897 - val_loss: 5.9959 - val_accuracy: 0.8131 - val_MulticlassKR: 1.2936
Epoch 40/100
15/15 [==============================] - 1s 80ms/step - loss: 5.7005 - accuracy: 0.8208 - MulticlassKR: 1.3042 - val_loss: 5.9460 - val_accuracy: 0.8152 - val_MulticlassKR: 1.3039
Epoch 41/100
15/15 [==============================] - 1s 80ms/step - loss: 5.6319 - accuracy: 0.8232 - MulticlassKR: 1.3146 - val_loss: 5.8816 - val_accuracy: 0.8148 - val_MulticlassKR: 1.3217
Epoch 42/100
15/15 [==============================] - 1s 81ms/step - loss: 5.6429 - accuracy: 0.8232 - MulticlassKR: 1.3291 - val_loss: 5.8772 - val_accuracy: 0.8151 - val_MulticlassKR: 1.3317
Epoch 43/100
15/15 [==============================] - 1s 80ms/step - loss: 5.5395 - accuracy: 0.8245 - MulticlassKR: 1.3460 - val_loss: 5.8039 - val_accuracy: 0.8189 - val_MulticlassKR: 1.3538
Epoch 44/100
15/15 [==============================] - 1s 81ms/step - loss: 5.4303 - accuracy: 0.8249 - MulticlassKR: 1.3593 - val_loss: 5.7421 - val_accuracy: 0.8189 - val_MulticlassKR: 1.3669
Epoch 45/100
15/15 [==============================] - 1s 80ms/step - loss: 5.3844 - accuracy: 0.8268 - MulticlassKR: 1.3762 - val_loss: 5.6846 - val_accuracy: 0.8217 - val_MulticlassKR: 1.3765
Epoch 46/100
15/15 [==============================] - 1s 80ms/step - loss: 5.3307 - accuracy: 0.8281 - MulticlassKR: 1.3873 - val_loss: 5.6413 - val_accuracy: 0.8234 - val_MulticlassKR: 1.3881
Epoch 47/100
15/15 [==============================] - 1s 80ms/step - loss: 5.3788 - accuracy: 0.8258 - MulticlassKR: 1.3938 - val_loss: 5.6087 - val_accuracy: 0.8214 - val_MulticlassKR: 1.3971
Epoch 48/100
15/15 [==============================] - 1s 80ms/step - loss: 5.2561 - accuracy: 0.8314 - MulticlassKR: 1.4119 - val_loss: 5.5684 - val_accuracy: 0.8215 - val_MulticlassKR: 1.4106
Epoch 49/100
15/15 [==============================] - 1s 81ms/step - loss: 5.2374 - accuracy: 0.8276 - MulticlassKR: 1.4266 - val_loss: 5.5116 - val_accuracy: 0.8255 - val_MulticlassKR: 1.4254
Epoch 50/100
15/15 [==============================] - 1s 81ms/step - loss: 5.2404 - accuracy: 0.8299 - MulticlassKR: 1.4328 - val_loss: 5.4923 - val_accuracy: 0.8248 - val_MulticlassKR: 1.4351
Epoch 51/100
15/15 [==============================] - 1s 81ms/step - loss: 5.2273 - accuracy: 0.8302 - MulticlassKR: 1.4446 - val_loss: 5.4473 - val_accuracy: 0.8252 - val_MulticlassKR: 1.4494
Epoch 52/100
15/15 [==============================] - 1s 80ms/step - loss: 5.1193 - accuracy: 0.8302 - MulticlassKR: 1.4615 - val_loss: 5.4205 - val_accuracy: 0.8219 - val_MulticlassKR: 1.4643
Epoch 53/100
15/15 [==============================] - 1s 81ms/step - loss: 5.1053 - accuracy: 0.8338 - MulticlassKR: 1.4739 - val_loss: 5.3770 - val_accuracy: 0.8238 - val_MulticlassKR: 1.4766
Epoch 54/100
15/15 [==============================] - 1s 80ms/step - loss: 4.9836 - accuracy: 0.8338 - MulticlassKR: 1.4889 - val_loss: 5.3285 - val_accuracy: 0.8259 - val_MulticlassKR: 1.4896
Epoch 55/100
15/15 [==============================] - 1s 80ms/step - loss: 4.9996 - accuracy: 0.8337 - MulticlassKR: 1.4994 - val_loss: 5.3168 - val_accuracy: 0.8272 - val_MulticlassKR: 1.4970
Epoch 56/100
15/15 [==============================] - 1s 81ms/step - loss: 4.9064 - accuracy: 0.8372 - MulticlassKR: 1.5095 - val_loss: 5.2652 - val_accuracy: 0.8284 - val_MulticlassKR: 1.5102
Epoch 57/100
15/15 [==============================] - 1s 80ms/step - loss: 4.9659 - accuracy: 0.8335 - MulticlassKR: 1.5204 - val_loss: 5.2111 - val_accuracy: 0.8284 - val_MulticlassKR: 1.5191
Epoch 58/100
15/15 [==============================] - 1s 81ms/step - loss: 4.9272 - accuracy: 0.8351 - MulticlassKR: 1.5316 - val_loss: 5.1873 - val_accuracy: 0.8310 - val_MulticlassKR: 1.5290
Epoch 59/100
15/15 [==============================] - 1s 80ms/step - loss: 4.8504 - accuracy: 0.8367 - MulticlassKR: 1.5386 - val_loss: 5.1892 - val_accuracy: 0.8263 - val_MulticlassKR: 1.5440
Epoch 60/100
15/15 [==============================] - 1s 82ms/step - loss: 4.7810 - accuracy: 0.8399 - MulticlassKR: 1.5500 - val_loss: 5.1203 - val_accuracy: 0.8298 - val_MulticlassKR: 1.5517
Epoch 61/100
15/15 [==============================] - 1s 80ms/step - loss: 4.7313 - accuracy: 0.8394 - MulticlassKR: 1.5630 - val_loss: 5.1206 - val_accuracy: 0.8292 - val_MulticlassKR: 1.5662
Epoch 62/100
15/15 [==============================] - 1s 80ms/step - loss: 4.7666 - accuracy: 0.8406 - MulticlassKR: 1.5742 - val_loss: 5.0925 - val_accuracy: 0.8295 - val_MulticlassKR: 1.5692
Epoch 63/100
15/15 [==============================] - 1s 80ms/step - loss: 4.6527 - accuracy: 0.8418 - MulticlassKR: 1.5808 - val_loss: 5.0593 - val_accuracy: 0.8302 - val_MulticlassKR: 1.5836
Epoch 64/100
15/15 [==============================] - 1s 81ms/step - loss: 4.7434 - accuracy: 0.8410 - MulticlassKR: 1.5952 - val_loss: 5.0201 - val_accuracy: 0.8329 - val_MulticlassKR: 1.5966
Epoch 65/100
15/15 [==============================] - 1s 81ms/step - loss: 4.7347 - accuracy: 0.8386 - MulticlassKR: 1.6056 - val_loss: 5.0073 - val_accuracy: 0.8337 - val_MulticlassKR: 1.6002
Epoch 66/100
15/15 [==============================] - 1s 80ms/step - loss: 4.6701 - accuracy: 0.8414 - MulticlassKR: 1.6104 - val_loss: 4.9744 - val_accuracy: 0.8345 - val_MulticlassKR: 1.6125
Epoch 67/100
15/15 [==============================] - 1s 80ms/step - loss: 4.5813 - accuracy: 0.8430 - MulticlassKR: 1.6230 - val_loss: 4.9599 - val_accuracy: 0.8336 - val_MulticlassKR: 1.6252
Epoch 68/100
15/15 [==============================] - 1s 81ms/step - loss: 4.6265 - accuracy: 0.8420 - MulticlassKR: 1.6316 - val_loss: 4.9260 - val_accuracy: 0.8310 - val_MulticlassKR: 1.6337
Epoch 69/100
15/15 [==============================] - 1s 81ms/step - loss: 4.6232 - accuracy: 0.8426 - MulticlassKR: 1.6420 - val_loss: 4.8940 - val_accuracy: 0.8365 - val_MulticlassKR: 1.6376
Epoch 70/100
15/15 [==============================] - 1s 81ms/step - loss: 4.5432 - accuracy: 0.8430 - MulticlassKR: 1.6507 - val_loss: 4.8714 - val_accuracy: 0.8355 - val_MulticlassKR: 1.6471
Epoch 71/100
15/15 [==============================] - 1s 80ms/step - loss: 4.4822 - accuracy: 0.8438 - MulticlassKR: 1.6584 - val_loss: 4.8362 - val_accuracy: 0.8358 - val_MulticlassKR: 1.6575
Epoch 72/100
15/15 [==============================] - 1s 80ms/step - loss: 4.4781 - accuracy: 0.8444 - MulticlassKR: 1.6695 - val_loss: 4.8306 - val_accuracy: 0.8372 - val_MulticlassKR: 1.6670
Epoch 73/100
15/15 [==============================] - 1s 81ms/step - loss: 4.5386 - accuracy: 0.8424 - MulticlassKR: 1.6777 - val_loss: 4.8021 - val_accuracy: 0.8364 - val_MulticlassKR: 1.6715
Epoch 74/100
15/15 [==============================] - 1s 80ms/step - loss: 4.4138 - accuracy: 0.8447 - MulticlassKR: 1.6880 - val_loss: 4.7918 - val_accuracy: 0.8377 - val_MulticlassKR: 1.6845
Epoch 75/100
15/15 [==============================] - 1s 81ms/step - loss: 4.4090 - accuracy: 0.8476 - MulticlassKR: 1.6962 - val_loss: 4.7612 - val_accuracy: 0.8368 - val_MulticlassKR: 1.6925
Epoch 76/100
15/15 [==============================] - 1s 81ms/step - loss: 4.4482 - accuracy: 0.8459 - MulticlassKR: 1.6987 - val_loss: 4.7491 - val_accuracy: 0.8363 - val_MulticlassKR: 1.7041
Epoch 77/100
15/15 [==============================] - 1s 80ms/step - loss: 4.3394 - accuracy: 0.8462 - MulticlassKR: 1.7108 - val_loss: 4.7155 - val_accuracy: 0.8387 - val_MulticlassKR: 1.7075
Epoch 78/100
15/15 [==============================] - 1s 80ms/step - loss: 4.3768 - accuracy: 0.8482 - MulticlassKR: 1.7117 - val_loss: 4.6795 - val_accuracy: 0.8396 - val_MulticlassKR: 1.7135
Epoch 79/100
15/15 [==============================] - 1s 80ms/step - loss: 4.3540 - accuracy: 0.8476 - MulticlassKR: 1.7259 - val_loss: 4.6666 - val_accuracy: 0.8388 - val_MulticlassKR: 1.7266
Epoch 80/100
15/15 [==============================] - 1s 80ms/step - loss: 4.2509 - accuracy: 0.8469 - MulticlassKR: 1.7359 - val_loss: 4.6558 - val_accuracy: 0.8357 - val_MulticlassKR: 1.7321
Epoch 81/100
15/15 [==============================] - 1s 81ms/step - loss: 4.2792 - accuracy: 0.8461 - MulticlassKR: 1.7397 - val_loss: 4.6639 - val_accuracy: 0.8364 - val_MulticlassKR: 1.7419
Epoch 82/100
15/15 [==============================] - 1s 80ms/step - loss: 4.2849 - accuracy: 0.8465 - MulticlassKR: 1.7502 - val_loss: 4.6150 - val_accuracy: 0.8389 - val_MulticlassKR: 1.7488
Epoch 83/100
15/15 [==============================] - 1s 81ms/step - loss: 4.2858 - accuracy: 0.8466 - MulticlassKR: 1.7563 - val_loss: 4.6256 - val_accuracy: 0.8382 - val_MulticlassKR: 1.7551
Epoch 84/100
15/15 [==============================] - 1s 81ms/step - loss: 4.1836 - accuracy: 0.8491 - MulticlassKR: 1.7594 - val_loss: 4.5682 - val_accuracy: 0.8401 - val_MulticlassKR: 1.7607
Epoch 85/100
15/15 [==============================] - 1s 80ms/step - loss: 4.1970 - accuracy: 0.8497 - MulticlassKR: 1.7701 - val_loss: 4.5760 - val_accuracy: 0.8405 - val_MulticlassKR: 1.7660
Epoch 86/100
15/15 [==============================] - 1s 80ms/step - loss: 4.1455 - accuracy: 0.8507 - MulticlassKR: 1.7759 - val_loss: 4.5417 - val_accuracy: 0.8425 - val_MulticlassKR: 1.7734
Epoch 87/100
15/15 [==============================] - 1s 80ms/step - loss: 4.1810 - accuracy: 0.8506 - MulticlassKR: 1.7823 - val_loss: 4.5125 - val_accuracy: 0.8417 - val_MulticlassKR: 1.7786
Epoch 88/100
15/15 [==============================] - 1s 80ms/step - loss: 4.1159 - accuracy: 0.8518 - MulticlassKR: 1.7922 - val_loss: 4.5125 - val_accuracy: 0.8391 - val_MulticlassKR: 1.7913
Epoch 89/100
15/15 [==============================] - 1s 81ms/step - loss: 4.1807 - accuracy: 0.8500 - MulticlassKR: 1.7990 - val_loss: 4.4882 - val_accuracy: 0.8402 - val_MulticlassKR: 1.7938
Epoch 90/100
15/15 [==============================] - 1s 80ms/step - loss: 4.1548 - accuracy: 0.8504 - MulticlassKR: 1.8031 - val_loss: 4.5046 - val_accuracy: 0.8421 - val_MulticlassKR: 1.8073
Epoch 91/100
15/15 [==============================] - 1s 80ms/step - loss: 4.1227 - accuracy: 0.8501 - MulticlassKR: 1.8102 - val_loss: 4.4483 - val_accuracy: 0.8408 - val_MulticlassKR: 1.8036
Epoch 92/100
15/15 [==============================] - 1s 81ms/step - loss: 4.1302 - accuracy: 0.8512 - MulticlassKR: 1.8124 - val_loss: 4.4501 - val_accuracy: 0.8435 - val_MulticlassKR: 1.8101
Epoch 93/100
15/15 [==============================] - 1s 81ms/step - loss: 4.0846 - accuracy: 0.8502 - MulticlassKR: 1.8184 - val_loss: 4.4205 - val_accuracy: 0.8425 - val_MulticlassKR: 1.8175
Epoch 94/100
15/15 [==============================] - 1s 80ms/step - loss: 3.9720 - accuracy: 0.8539 - MulticlassKR: 1.8275 - val_loss: 4.4813 - val_accuracy: 0.8381 - val_MulticlassKR: 1.8186
Epoch 95/100
15/15 [==============================] - 1s 81ms/step - loss: 3.9978 - accuracy: 0.8542 - MulticlassKR: 1.8309 - val_loss: 4.3855 - val_accuracy: 0.8440 - val_MulticlassKR: 1.8287
Epoch 96/100
15/15 [==============================] - 1s 80ms/step - loss: 4.0764 - accuracy: 0.8506 - MulticlassKR: 1.8369 - val_loss: 4.3828 - val_accuracy: 0.8443 - val_MulticlassKR: 1.8371
Epoch 97/100
15/15 [==============================] - 1s 81ms/step - loss: 4.0436 - accuracy: 0.8517 - MulticlassKR: 1.8470 - val_loss: 4.3730 - val_accuracy: 0.8457 - val_MulticlassKR: 1.8344
Epoch 98/100
15/15 [==============================] - 1s 81ms/step - loss: 3.9989 - accuracy: 0.8532 - MulticlassKR: 1.8491 - val_loss: 4.3596 - val_accuracy: 0.8445 - val_MulticlassKR: 1.8446
Epoch 99/100
15/15 [==============================] - 1s 80ms/step - loss: 3.9820 - accuracy: 0.8541 - MulticlassKR: 1.8539 - val_loss: 4.3444 - val_accuracy: 0.8442 - val_MulticlassKR: 1.8477
Epoch 100/100
15/15 [==============================] - 1s 80ms/step - loss: 3.9592 - accuracy: 0.8523 - MulticlassKR: 1.8626 - val_loss: 4.3177 - val_accuracy: 0.8448 - val_MulticlassKR: 1.8529

<tensorflow.python.keras.callbacks.History at 0x7f9e52213c10>

model exportation

Once training is finished, the model can be optimized for inference by using the vanilla_export() method.

# once training is finished you can convert
# SpectralDense layers into Dense layers and SpectralConv2D into Conv2D
# which optimize performance for inference
vanilla_model = model.vanilla_export()

certificates generation and adversarial attacks

import foolbox as fb
from tensorflow import convert_to_tensor
import matplotlib.pyplot as plt
import tensorflow as tf
Matplotlib created a temporary config/cache directory at /tmp/matplotlib-an1t4aqt because the default path (/home/thibaut.boissin/.config/matplotlib) is not a writable directory; it is highly recommended to set the MPLCONFIGDIR environment variable to a writable directory, in particular to speed up the import of Matplotlib and to better support multiprocessing.

# we will test it on 10 samples one of each class
nb_adv = 10

hkr_fmodel = fb.TensorFlowModel(vanilla_model, bounds=(0., 1.), device="/GPU:0")

In order to test the robustness of the model, the first correctly classified element of each class are selected.

# strategy: first
# we select a sample from each class.
images_list = []
labels_list = []
# select only a few element from the test set
selected=np.random.choice(len(y_test_ord), 500)
sub_y_test_ord = y_test_ord[:300]
sub_x_test = x_test[:300]
# drop misclassified elements
misclassified_mask = tf.equal(tf.argmax(vanilla_model.predict(sub_x_test), axis=-1), sub_y_test_ord)
sub_x_test = sub_x_test[misclassified_mask]
sub_y_test_ord = sub_y_test_ord[misclassified_mask]
# now we will build a list with input image for each element of the matrix
for i in range(10):
  # select the first element of the ith label
  label_mask = [sub_y_test_ord==i]
  x = sub_x_test[label_mask][0]
  y = sub_y_test_ord[label_mask][0]
  # convert it to tensor for use with foolbox
  images = convert_to_tensor(x.astype("float32"), dtype="float32")
  labels = convert_to_tensor(y, dtype="int64")
  # repeat the input 10 times, one per misclassification target
  images_list.append(images)
  labels_list.append(labels)
images = convert_to_tensor(images_list)
labels = convert_to_tensor(labels_list)
/home/thibaut.boissin/envs/deel-lip_github/lib/python3.7/site-packages/ipykernel_launcher.py:17: FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this will be interpreted as an array index, `arr[np.array(seq)]`, which will result either in an error or a different result.
/home/thibaut.boissin/envs/deel-lip_github/lib/python3.7/site-packages/ipykernel_launcher.py:18: FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this will be interpreted as an array index, `arr[np.array(seq)]`, which will result either in an error or a different result.

In order to build a certficate, we take for each sample the top 2 output and apply this formula: $$ \epsilon \geq \frac{\text{top}_1 - \text{top}_2}{2} $$ Where epsilon is the robustness radius for the considered sample.

values, classes = tf.math.top_k(hkr_fmodel(images), k=2)
certificates = (values[:, 0] - values[:, 1]) / 2
certificates
<tf.Tensor: shape=(10,), dtype=float32, numpy=
array([0.25511226, 1.0321686 , 0.34624586, 0.5743104 , 0.12979731,
       0.19581676, 0.08184442, 0.34386343, 0.68743587, 0.12055641],
      dtype=float32)>

now we will attack the model to check if the certificates are respected. In this setup L2CarliniWagnerAttack is used but in practice as these kind of networks are gradient norm preserving, other attacks gives very similar results.

attack = fb.attacks.L2CarliniWagnerAttack(binary_search_steps=6, steps=8000)
imgs, advs, success = attack(hkr_fmodel, images, labels, epsilons=None)
dist_to_adv = np.sqrt(np.sum(np.square(images - advs), axis=(1,2,3)))
dist_to_adv
array([1.3944995 , 3.5208094 , 1.6824133 , 1.9192038 , 0.5746496 ,
       0.7780392 , 0.39687884, 1.1619285 , 2.367604  , 0.48984095],
      dtype=float32)

As we can see the certificate are respected.

tf.assert_less(certificates, dist_to_adv)

Finally we can take a visual look at the obtained examples. We first start with utility functions for display.

class_mapping = {
  0: "T-shirt/top",
  1: "Trouser",
  2: "Pullover",
  3: "Dress",
  4: "Coat",
  5: "Sandal",
  6: "Shirt",
  7: "Sneaker",
  8: "Bag",
  9: "Ankle boot",
}
def adversarial_viz(model, images, advs, class_mapping):
  """
  This functions shows for each sample: 
  - the original image
  - the adversarial image
  - the difference map
  - the certificate and the observed distance to adversarial 
  """
  scale = 1.5
  kwargs={}
  nb_imgs = images.shape[0]
  # compute certificates
  values, classes = tf.math.top_k(model(images), k=2)
  certificates = (values[:, 0] - values[:, 1]) / 2
  # compute difference distance to adversarial
  dist_to_adv = np.sqrt(np.sum(np.square(images - advs), axis=(1,2,3)))
  # find classes labels for imgs and advs
  orig_classes = [class_mapping[i] for i in tf.argmax(model(images), axis=-1).numpy()]
  advs_classes = [class_mapping[i] for i in tf.argmax(model(advs), axis=-1).numpy()]
  # compute differences maps
  if images.shape[-1] != 3:
    diff_pos = np.clip(advs - images, 0, 1.)
    diff_neg = np.clip(images - advs, 0, 1.)
    diff_map = np.concatenate([diff_neg, diff_pos, np.zeros_like(diff_neg)], axis=-1)
  else:
    diff_map = np.abs(advs - images)
  # expands image to be displayed
  if images.shape[-1] != 3:
    images = np.repeat(images, 3, -1)
  if advs.shape[-1] != 3:
    advs = np.repeat(advs, 3, -1)
  # create plot
  figsize = (3 * scale, nb_imgs * scale)
  fig, axes = plt.subplots(
    ncols=3,
    nrows=nb_imgs,
    figsize=figsize,
    squeeze=False,
    constrained_layout=True,
    **kwargs,
  )
  for i in range(nb_imgs):
    ax = axes[i][0]
    ax.set_title(orig_classes[i])
    ax.set_xticks([])
    ax.set_yticks([])
    ax.axis("off")
    ax.imshow(images[i])
    ax = axes[i][1]
    ax.set_title(advs_classes[i])
    ax.set_xticks([])
    ax.set_yticks([])
    ax.axis("off")
    ax.imshow(advs[i])
    ax = axes[i][2]
    ax.set_title(f"certif: {certificates[i]:.2f}, obs: {dist_to_adv[i]:.2f}")
    ax.set_xticks([])
    ax.set_yticks([])
    ax.axis("off")
    ax.imshow(diff_map[i]/diff_map[i].max())

When looking at the adversarial examples we can see that the network has interresting properties:

predictability

by looking at the certificates, we can predict if the adversarial example will be close of not

disparity among classes

As we can see, the attacks are very efficent on similar classes (eg. T-shirt/top, and Shirt ). This denote that all classes are not made equal regarding robustness.

explainability

The network is more explainable: attacks can be used as counterfactuals. We can tell that removing the inscription on a T-shirt turns it into a shirt makes sense. Non robust examples reveals that the network rely on textures rather on shapes to make it's decision.

adversarial_viz(hkr_fmodel, images, advs, class_mapping)
No description has been provided for this image