Example 4: HKR multiclass and fooling ===================================== |Open in Colab| This notebook will show how to train a Lispchitz network in a multiclass configuration. The HKR (hinge-Kantorovich-Rubinstein) loss is extended to multiclass using a one-vs all setup. The notebook will go through the process of designing and training the network. It will also show how to compute robustness certificates from the outputs of the network. Finally the guarantee of these certificates will be checked by attacking the network. .. |Open in Colab| image:: https://colab.research.google.com/assets/colab-badge.svg :target: https://colab.research.google.com/github/deel-ai/deel-torchlip/blob/master/docs/notebooks/wasserstein_classification_fashionMNIST.ipynb .. code:: ipython3 # Install the required libraries deel-torchlip and foolbox (uncomment below if needed) # %pip install -qqq deel-torchlip foolbox 1. Data preparation ------------------- For this example, the ``fashion_mnist`` dataset is used. In order to keep things simple, no data augmentation is performed. .. code:: ipython3 import torch from torchvision import datasets, transforms train_set = datasets.FashionMNIST( root="./data", download=True, train=True, transform=transforms.ToTensor(), ) test_set = datasets.FashionMNIST( root="./data", download=True, train=False, transform=transforms.ToTensor(), ) batch_size = 100 train_loader = torch.utils.data.DataLoader(train_set, batch_size, shuffle=True) test_loader = torch.utils.data.DataLoader(test_set, batch_size) 2. Model architecture --------------------- The original one-vs-all setup would require 10 different networks (1 per class). However, we use in practice a network with a common body and a Lipschitz head (linear layer) containing 10 output neurons, like any standard network for multiclass classification. Note that we use torchlip.FrobeniusLinear disjoint_neurons=True to enforce each head neuron to be a 1-Lipschitz function; Notes about constraint enforcement ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ There are currently 3 ways to enforce the Lipschitz constraint in a network: 1. weight regularization 2. weight reparametrization 3. weight projection Weight regularization doesn’t provide the required guarantees as it is only a regularization term. Weight reparametrization is available in ``torchlip`` and is done directly in the layers (parameter ``niter_bjorck``). This trick allows to perform arbitrary gradient updates without breaking the constraint. However this is done in the graph, increasing resources consumption. Weight projection is not implemented in ``torchlip``. .. code:: ipython3 from deel import torchlip # Sequential has the same properties as any Lipschitz layer. It only acts as a # container, with features specific to Lipschitz functions (condensation, # vanilla_exportation, ...) model = torchlip.Sequential( # Lipschitz layers preserve the API of their superclass (here Conv2d). An optional # argument is available, k_coef_lip, which controls the Lipschitz constant of the # layer torchlip.SpectralConv2d( in_channels=1, out_channels=16, kernel_size=(3, 3), padding="same" ), torchlip.GroupSort2(), # Usual pooling layer are implemented (avg, max), but new pooling layers are also # available torchlip.ScaledL2NormPool2d(kernel_size=(2, 2)), torchlip.SpectralConv2d( in_channels=16, out_channels=32, kernel_size=(3, 3), padding="same" ), torchlip.GroupSort2(), torchlip.ScaledL2NormPool2d(kernel_size=(2, 2)), # Our layers are fully interoperable with existing PyTorch layers torch.nn.Flatten(), torchlip.SpectralLinear(1568, 64), torchlip.GroupSort2(), torchlip.FrobeniusLinear(64, 10, bias=True, disjoint_neurons=True), # Similarly, model has a parameter to set the Lipschitz constant that automatically # sets the constant of each layer. k_coef_lip=1.0, ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) .. parsed-literal:: Sequential( (0): ParametrizedSpectralConv2d( 1, 16, kernel_size=(3, 3), stride=(1, 1), padding=same (parametrizations): ModuleDict( (weight): ParametrizationList( (0): _SpectralNorm() (1): _BjorckNorm() (2): _LConvNorm() ) ) ) (1): GroupSort2() (2): ScaledL2NormPool2d(norm_type=2, kernel_size=(2, 2), stride=None, ceil_mode=False) (3): ParametrizedSpectralConv2d( 16, 32, kernel_size=(3, 3), stride=(1, 1), padding=same (parametrizations): ModuleDict( (weight): ParametrizationList( (0): _SpectralNorm() (1): _BjorckNorm() (2): _LConvNorm() ) ) ) (4): GroupSort2() (5): ScaledL2NormPool2d(norm_type=2, kernel_size=(2, 2), stride=None, ceil_mode=False) (6): Flatten(start_dim=1, end_dim=-1) (7): ParametrizedSpectralLinear( in_features=1568, out_features=64, bias=True (parametrizations): ModuleDict( (weight): ParametrizationList( (0): _SpectralNorm() (1): _BjorckNorm() ) ) ) (8): GroupSort2() (9): ParametrizedFrobeniusLinear( in_features=64, out_features=10, bias=True (parametrizations): ModuleDict( (weight): ParametrizationList( (0): _FrobeniusNorm() ) ) ) ) 3. HKR loss and training ------------------------ The multiclass HKR loss can be found in the\ ``HKRMulticlassLoss`` class. The loss has two parameters: ``alpha`` and ``min_margin``. Decreasing ``alpha`` and increasing ``min_margin`` improve robustness (at the cost of accuracy). Note also in the case of Lipschitz networks, more robustness requires more parameters. For more information, see `our paper <https://arxiv.org/abs/2006.06520>`__. In this setup, choosing ``alpha=0.99`` and ``min_margin=.25`` provides good robustness without hurting the accuracy too much. An accurate network can be obtained using ``alpha=0.999`` and ``min_margin=.1`` We also propose the ``SoftHKRMulticlassLoss`` proposed in `this paper <https://arxiv.org/abs/2206.06854>`__ that can achieve equivalent performance to unconstrianed networks (92% validation accuracy with ``alpha=0.995``, ``min_margin=0.10``, ``temperature=50.0``). Finally the ``KRMulticlassLoss`` gives an indication on the robustness of the network (proxy of the average certificate). .. code:: ipython3 loss_choice = "HKRMulticlassLoss" # "HKRMulticlassLoss" or "SoftHKRMulticlassLoss" epochs = 50 optimizer = torch.optim.Adam(lr=1e-3, params=model.parameters()) hkr_loss = None if loss_choice == "HKRMulticlassLoss": hkr_loss = torchlip.HKRMulticlassLoss(alpha=0.99, min_margin=0.25) #Robust #hkr_loss = torchlip.HKRMulticlassLoss(alpha=0.999, min_margin=0.10) #Accurate if loss_choice == "SoftHKRMulticlassLoss": hkr_loss = torchlip.SoftHKRMulticlassLoss(alpha=0.995, min_margin=0.10, temperature=50.0) assert hkr_loss is not None, "Please choose a valid loss function" kr_multiclass_loss = torchlip.KRMulticlassLoss() for epoch in range(epochs): m_kr, m_acc = 0, 0 for step, (data, target) in enumerate(train_loader): # For multiclass HKR loss, the targets must be one-hot encoded target = torch.nn.functional.one_hot(target, num_classes=10) data, target = data.to(device), target.to(device) # Forward + backward pass optimizer.zero_grad() output = model(data) loss = hkr_loss(output, target) loss.backward() optimizer.step() # Compute metrics on batch m_kr += kr_multiclass_loss(output, target) m_acc += (output.argmax(dim=1) == target.argmax(dim=1)).sum() / len(target) # Train metrics for the current epoch metrics = [ f"{k}: {v:.04f}" for k, v in { "loss": loss, "acc": m_acc / (step + 1), "KR": m_kr / (step + 1), }.items() ] # Compute validation loss for the current epoch test_output, test_targets = [], [] for data, target in test_loader: data, target = data.to(device), target.to(device) test_output.append(model(data).detach().cpu()) test_targets.append( torch.nn.functional.one_hot(target, num_classes=10).detach().cpu() ) test_output = torch.cat(test_output) test_targets = torch.cat(test_targets) val_loss = hkr_loss(test_output, test_targets) val_kr = kr_multiclass_loss(test_output, test_targets) val_acc = (test_output.argmax(dim=1) == test_targets.argmax(dim=1)).float().mean() # Validation metrics for the current epoch metrics += [ f"val_{k}: {v:.04f}" for k, v in { "loss": hkr_loss(test_output, test_targets), "acc": (test_output.argmax(dim=1) == test_targets.argmax(dim=1)) .float() .mean(), "KR": kr_multiclass_loss(test_output, test_targets), }.items() ] print(f"Epoch {epoch + 1}/{epochs}") print(" - ".join(metrics)) .. parsed-literal:: Epoch 1/50 loss: 0.0161 - acc: 0.7948 - KR: 0.8425 - val_loss: 0.0234 - val_acc: 0.8237 - val_KR: 1.1728 .. parsed-literal:: Epoch 2/50 loss: 0.0144 - acc: 0.8425 - KR: 1.3253 - val_loss: 0.0181 - val_acc: 0.8474 - val_KR: 1.4679 .. parsed-literal:: Epoch 3/50 loss: 0.0040 - acc: 0.8522 - KR: 1.6386 - val_loss: 0.0191 - val_acc: 0.8214 - val_KR: 1.7816 .. parsed-literal:: Epoch 4/50 loss: 0.0046 - acc: 0.8574 - KR: 1.9427 - val_loss: 0.0098 - val_acc: 0.8596 - val_KR: 2.0056 .. parsed-literal:: Epoch 5/50 loss: 0.0000 - acc: 0.8605 - KR: 2.1595 - val_loss: 0.0079 - val_acc: 0.8680 - val_KR: 2.1441 .. parsed-literal:: Epoch 6/50 loss: 0.0049 - acc: 0.8642 - KR: 2.2765 - val_loss: 0.0063 - val_acc: 0.8634 - val_KR: 2.3429 .. parsed-literal:: Epoch 7/50 loss: -0.0053 - acc: 0.8670 - KR: 2.3516 - val_loss: 0.0051 - val_acc: 0.8664 - val_KR: 2.3691 .. parsed-literal:: Epoch 8/50 loss: -0.0021 - acc: 0.8708 - KR: 2.4078 - val_loss: 0.0031 - val_acc: 0.8698 - val_KR: 2.4568 .. parsed-literal:: Epoch 9/50 loss: -0.0072 - acc: 0.8731 - KR: 2.4747 - val_loss: 0.0031 - val_acc: 0.8688 - val_KR: 2.5106 .. parsed-literal:: Epoch 10/50 loss: 0.0009 - acc: 0.8726 - KR: 2.5210 - val_loss: 0.0026 - val_acc: 0.8685 - val_KR: 2.5051 .. parsed-literal:: Epoch 11/50 loss: -0.0028 - acc: 0.8751 - KR: 2.5462 - val_loss: 0.0022 - val_acc: 0.8730 - val_KR: 2.5741 .. parsed-literal:: Epoch 12/50 loss: -0.0035 - acc: 0.8751 - KR: 2.5864 - val_loss: 0.0025 - val_acc: 0.8707 - val_KR: 2.5648 .. parsed-literal:: Epoch 13/50 loss: -0.0027 - acc: 0.8764 - KR: 2.5977 - val_loss: 0.0019 - val_acc: 0.8718 - val_KR: 2.6368 .. parsed-literal:: Epoch 14/50 loss: -0.0047 - acc: 0.8789 - KR: 2.6347 - val_loss: 0.0044 - val_acc: 0.8539 - val_KR: 2.6234 .. parsed-literal:: Epoch 15/50 loss: 0.0189 - acc: 0.8788 - KR: 2.6543 - val_loss: 0.0003 - val_acc: 0.8723 - val_KR: 2.5902 .. parsed-literal:: Epoch 16/50 loss: 0.0142 - acc: 0.8793 - KR: 2.6534 - val_loss: 0.0006 - val_acc: 0.8673 - val_KR: 2.6843 .. parsed-literal:: Epoch 17/50 loss: -0.0018 - acc: 0.8809 - KR: 2.6729 - val_loss: 0.0014 - val_acc: 0.8670 - val_KR: 2.7061 .. parsed-literal:: Epoch 18/50 loss: 0.0005 - acc: 0.8805 - KR: 2.6892 - val_loss: 0.0002 - val_acc: 0.8692 - val_KR: 2.6683 .. parsed-literal:: Epoch 19/50 loss: 0.0144 - acc: 0.8814 - KR: 2.7032 - val_loss: 0.0006 - val_acc: 0.8754 - val_KR: 2.6909 .. parsed-literal:: Epoch 20/50 loss: 0.0095 - acc: 0.8827 - KR: 2.7164 - val_loss: 0.0001 - val_acc: 0.8707 - val_KR: 2.7713 .. parsed-literal:: Epoch 21/50 loss: -0.0062 - acc: 0.8815 - KR: 2.7312 - val_loss: -0.0008 - val_acc: 0.8776 - val_KR: 2.7397 .. parsed-literal:: Epoch 22/50 loss: -0.0057 - acc: 0.8834 - KR: 2.7449 - val_loss: -0.0002 - val_acc: 0.8638 - val_KR: 2.7346 .. parsed-literal:: Epoch 23/50 loss: -0.0109 - acc: 0.8844 - KR: 2.7543 - val_loss: -0.0016 - val_acc: 0.8781 - val_KR: 2.7080 .. parsed-literal:: Epoch 24/50 loss: -0.0091 - acc: 0.8844 - KR: 2.7597 - val_loss: -0.0006 - val_acc: 0.8731 - val_KR: 2.7509 .. parsed-literal:: Epoch 25/50 loss: 0.0054 - acc: 0.8839 - KR: 2.7827 - val_loss: -0.0021 - val_acc: 0.8789 - val_KR: 2.7414 .. parsed-literal:: Epoch 26/50 loss: -0.0093 - acc: 0.8865 - KR: 2.7827 - val_loss: -0.0024 - val_acc: 0.8815 - val_KR: 2.7571 .. parsed-literal:: Epoch 27/50 loss: -0.0028 - acc: 0.8854 - KR: 2.7891 - val_loss: -0.0007 - val_acc: 0.8671 - val_KR: 2.8054 .. parsed-literal:: Epoch 28/50 loss: 0.0045 - acc: 0.8848 - KR: 2.8087 - val_loss: -0.0005 - val_acc: 0.8765 - val_KR: 2.7992 .. parsed-literal:: Epoch 29/50 loss: -0.0050 - acc: 0.8855 - KR: 2.8126 - val_loss: -0.0003 - val_acc: 0.8716 - val_KR: 2.7960 .. parsed-literal:: Epoch 30/50 loss: -0.0090 - acc: 0.8858 - KR: 2.8186 - val_loss: -0.0015 - val_acc: 0.8727 - val_KR: 2.7698 .. parsed-literal:: Epoch 31/50 loss: -0.0086 - acc: 0.8882 - KR: 2.8209 - val_loss: -0.0029 - val_acc: 0.8752 - val_KR: 2.8335 .. parsed-literal:: Epoch 32/50 loss: -0.0064 - acc: 0.8871 - KR: 2.8258 - val_loss: -0.0030 - val_acc: 0.8820 - val_KR: 2.8266 .. parsed-literal:: Epoch 33/50 loss: -0.0086 - acc: 0.8882 - KR: 2.8410 - val_loss: -0.0025 - val_acc: 0.8742 - val_KR: 2.8252 .. parsed-literal:: Epoch 34/50 loss: -0.0157 - acc: 0.8873 - KR: 2.8518 - val_loss: -0.0021 - val_acc: 0.8736 - val_KR: 2.7995 .. parsed-literal:: Epoch 35/50 loss: 0.0009 - acc: 0.8877 - KR: 2.8418 - val_loss: -0.0028 - val_acc: 0.8739 - val_KR: 2.8467 .. parsed-literal:: Epoch 36/50 loss: -0.0137 - acc: 0.8882 - KR: 2.8552 - val_loss: -0.0023 - val_acc: 0.8778 - val_KR: 2.8063 .. parsed-literal:: Epoch 37/50 loss: -0.0103 - acc: 0.8881 - KR: 2.8597 - val_loss: -0.0023 - val_acc: 0.8720 - val_KR: 2.8331 .. parsed-literal:: Epoch 38/50 loss: -0.0100 - acc: 0.8897 - KR: 2.8594 - val_loss: -0.0033 - val_acc: 0.8811 - val_KR: 2.8638 .. parsed-literal:: Epoch 39/50 loss: -0.0047 - acc: 0.8887 - KR: 2.8630 - val_loss: -0.0035 - val_acc: 0.8801 - val_KR: 2.8755 .. parsed-literal:: Epoch 40/50 loss: -0.0047 - acc: 0.8902 - KR: 2.8691 - val_loss: -0.0023 - val_acc: 0.8752 - val_KR: 2.8752 .. parsed-literal:: Epoch 41/50 loss: -0.0085 - acc: 0.8897 - KR: 2.8753 - val_loss: -0.0018 - val_acc: 0.8756 - val_KR: 2.8190 .. parsed-literal:: Epoch 42/50 loss: -0.0170 - acc: 0.8892 - KR: 2.8745 - val_loss: -0.0034 - val_acc: 0.8807 - val_KR: 2.8524 .. parsed-literal:: Epoch 43/50 loss: -0.0025 - acc: 0.8909 - KR: 2.8805 - val_loss: -0.0030 - val_acc: 0.8811 - val_KR: 2.8388 .. parsed-literal:: Epoch 44/50 loss: -0.0093 - acc: 0.8922 - KR: 2.8824 - val_loss: -0.0034 - val_acc: 0.8805 - val_KR: 2.8573 .. parsed-literal:: Epoch 45/50 loss: -0.0065 - acc: 0.8898 - KR: 2.8861 - val_loss: -0.0027 - val_acc: 0.8763 - val_KR: 2.8508 .. parsed-literal:: Epoch 46/50 loss: -0.0046 - acc: 0.8908 - KR: 2.8799 - val_loss: -0.0038 - val_acc: 0.8808 - val_KR: 2.8540 .. parsed-literal:: Epoch 47/50 loss: -0.0141 - acc: 0.8902 - KR: 2.8932 - val_loss: -0.0037 - val_acc: 0.8794 - val_KR: 2.8714 .. parsed-literal:: Epoch 48/50 loss: -0.0101 - acc: 0.8912 - KR: 2.8959 - val_loss: -0.0033 - val_acc: 0.8789 - val_KR: 2.8827 .. parsed-literal:: Epoch 49/50 loss: -0.0111 - acc: 0.8918 - KR: 2.8873 - val_loss: -0.0040 - val_acc: 0.8859 - val_KR: 2.9193 .. parsed-literal:: Epoch 50/50 loss: -0.0008 - acc: 0.8933 - KR: 2.9104 - val_loss: -0.0041 - val_acc: 0.8818 - val_KR: 2.8705 4. Model export --------------- Once training is finished, the model can be optimized for inference by using the ``vanilla_export()`` method. The ``torchlip`` layers are converted to their PyTorch counterparts, e.g. ``SpectralConv2d`` layers will be converted into ``torch.nn.Conv2d`` layers. Warnings: ~~~~~~~~~ vanilla_export method modifies the model in-place. In order to build and export a new model while keeping the reference one, it is required to follow these steps: # Build e new mode for instance with torchlip.Sequential( torchlip.SpectralConv2d(…), …) ``vanilla_model = <your_function_to_build_the_model>()`` # Copy the parameters from the reference t the new model ``vanilla_model.load_state_dict(model.state_dict())`` # one forward required to initialize pamatrizations ``vanilla_model(one_input)`` # vanilla_export the new model ``vanilla_model = vanilla_model.vanilla_export()`` .. code:: ipython3 vanilla_model = model.vanilla_export() vanilla_model.eval() vanilla_model.to(device) .. parsed-literal:: Sequential( (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=same) (1): GroupSort2() (2): LPPool2d(norm_type=2, kernel_size=(2, 2), stride=None, ceil_mode=False) (3): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=same) (4): GroupSort2() (5): LPPool2d(norm_type=2, kernel_size=(2, 2), stride=None, ceil_mode=False) (6): Flatten(start_dim=1, end_dim=-1) (7): Linear(in_features=1568, out_features=64, bias=True) (8): GroupSort2() (9): Linear(in_features=64, out_features=10, bias=True) ) 5. Robustness evaluation: certificate generation and adversarial attacks ------------------------------------------------------------------------ A Lipschitz network provides certificates guaranteeing that there is no adversarial attack smaller than the certificates. We will show how to compute a certificate for a given image sample. We will also run attacks on 10 images (one per class) and show that the distance between the obtained adversarial images and the original images is greater than the certificates. The ``foolbox`` library is used to perform adversarial attacks. .. code:: ipython3 import numpy as np # Select only the first batch from the test set sub_data, sub_targets = next(iter(test_loader)) sub_data, sub_targets = sub_data.to(device), sub_targets.to(device) # Drop misclassified elements output = vanilla_model(sub_data) well_classified_mask = output.argmax(dim=-1) == sub_targets sub_data = sub_data[well_classified_mask] sub_targets = sub_targets[well_classified_mask] # Retrieve one image per class images_list, targets_list = [], [] for i in range(10): # Select the elements of the i-th label and keep the first one label_mask = sub_targets == i x = sub_data[label_mask][0] y = sub_targets[label_mask][0] images_list.append(x) targets_list.append(y) images = torch.stack(images_list) targets = torch.stack(targets_list) In order to build a certificate :math:`\mathcal{M}` for a given sample, we take the top-2 output and apply the following formula: .. math:: \mathcal{M} = \frac{\text{top}_1 - \text{top}_2}{2} This certificate is a guarantee that no L2 attack can defeat the given image sample with a robustness radius :math:`\epsilon` lower than the certificate, i.e. .. math:: \epsilon \geq \mathcal{M} In the following cell, we attack the model on the ten selected images and compare the obtained radius :math:`\epsilon` with the certificates :math:`\mathcal{M}`. In this setup, ``L2CarliniWagnerAttack`` from ``foolbox`` is used but in practice as these kind of networks are gradient norm preserving, other attacks gives very similar results. .. code:: ipython3 import foolbox as fb # Compute certificates values, _ = vanilla_model(images).topk(k=2) #The factor is 2.0 when using disjoint_neurons==True certificates = (values[:, 0] - values[:, 1]) / 2. # Run Carlini & Wagner attack fmodel = fb.PyTorchModel(vanilla_model, bounds=(0.0, 1.0), device=device) attack = fb.attacks.L2CarliniWagnerAttack(binary_search_steps=6, steps=8000) _, advs, success = attack(fmodel, images, targets, epsilons=None) dist_to_adv = (images - advs).square().sum(dim=(1, 2, 3)).sqrt() # Print results print("Image # Certificate Distance to adversarial") print("---------------------------------------------------") for i in range(len(certificates)): print(f"Image {i} {certificates[i]:.3f} {dist_to_adv[i]:.2f}") .. parsed-literal:: Image # Certificate Distance to adversarial --------------------------------------------------- Image 0 0.309 1.29 Image 1 1.864 4.65 Image 2 0.397 1.56 Image 3 0.527 2.81 Image 4 0.105 0.44 Image 5 0.188 0.82 Image 6 0.053 0.26 Image 7 0.450 1.62 Image 8 1.488 3.91 Image 9 0.161 0.69 Finally, we can take a visual look at the obtained images. When looking at the adversarial examples, we can see that the network has interesting properties: - **Predictability**: by looking at the certificates, we can predict if the adversarial example will be close or not to the original image. - **Disparity among classes**: as we can see, the attacks are very efficent on similar classes (e.g. T-shirt/top, and Shirt). This denotes that all classes are not made equal regarding robustness. - **Explainability**: the network is more explainable as attacks can be used as counterfactuals. We can tell that removing the inscription on a T-shirt turns it into a shirt makes sense. Non-robust examples reveal that the network relies on textures rather on shapes to make its decision. .. code:: ipython3 import matplotlib.pyplot as plt def adversarial_viz(model, images, advs, class_names): """ This functions shows for each image sample: - the original image - the adversarial image - the difference map - the certificate and the observed distance to adversarial """ scale = 1.5 nb_imgs = images.shape[0] # Compute certificates values, _ = model(images).topk(k=2) certificates = (values[:, 0] - values[:, 1]) / np.sqrt(2) # Compute distance between image and its adversarial dist_to_adv = (images - advs).square().sum(dim=(1, 2, 3)).sqrt() # Find predicted classes for images and their adversarials orig_classes = [class_names[i] for i in model(images).argmax(dim=-1)] advs_classes = [class_names[i] for i in model(advs).argmax(dim=-1)] # Compute difference maps advs = advs.detach().cpu() images = images.detach().cpu() diff_pos = np.clip(advs - images, 0, 1.0) diff_neg = np.clip(images - advs, 0, 1.0) diff_map = np.concatenate( [diff_neg, diff_pos, np.zeros_like(diff_neg)], axis=1 ).transpose((0, 2, 3, 1)) # Create plot def _set_ax(ax, title): ax.set_title(title) ax.set_xticks([]) ax.set_yticks([]) ax.axis("off") figsize = (3 * scale, nb_imgs * scale) _, axes = plt.subplots( ncols=3, nrows=nb_imgs, figsize=figsize, squeeze=False, constrained_layout=True ) for i in range(nb_imgs): _set_ax(axes[i][0], orig_classes[i]) axes[i][0].imshow(images[i].squeeze(), cmap="gray") _set_ax(axes[i][1], advs_classes[i]) axes[i][1].imshow(advs[i].squeeze(), cmap="gray") _set_ax(axes[i][2], f"certif: {certificates[i]:.2f}, obs: {dist_to_adv[i]:.2f}") axes[i][2].imshow(diff_map[i] / diff_map[i].max()) adversarial_viz(vanilla_model, images, advs, test_set.classes) .. image:: wasserstein_classification_fashionMNIST_files/wasserstein_classification_fashionMNIST_16_0.png