• Docs >
  • Example 3: HKR classifier on MNIST dataset
Shortcuts

Example 3: HKR classifier on MNIST dataset

Open In Colab

This notebook demonstrates how to learn a binary classifier on the MNIST0-8 dataset (MNIST with only 0 and 8).

# Install the required library deel-torchlip (uncomment line below)
# %pip install -qqq deel-torchlip

1. Data preparation

For this task we will select two classes: 0 and 8. Labels are changed to {-1,1}, which is compatible with the hinge term used in the loss.

import torch
from torchvision import datasets

# First we select the two classes
selected_classes = [0, 8]  # must be two classes as we perform binary classification


def prepare_data(dataset, class_a=0, class_b=8):
    """
    This function converts the MNIST data to make it suitable for our binary
    classification setup.
    """
    x = dataset.data
    y = dataset.targets
    # select items from the two selected classes
    mask = (y == class_a) + (
        y == class_b
    )  # mask to select only items from class_a or class_b
    x = x[mask]
    y = y[mask]

    # convert from range int[0,255] to float32[-1,1]
    x = x.float() / 255
    x = x.reshape((-1, 28, 28, 1))
    # change label to binary classification {-1,1}

    y_ = torch.zeros_like(y).float()
    y_[y == class_a] = 1.0
    y_[y == class_b] = -1.0
    return torch.utils.data.TensorDataset(x, y_)


train = datasets.MNIST("./data", train=True, download=True)
test = datasets.MNIST("./data", train=False, download=True)

# Prepare the data
train = prepare_data(train, selected_classes[0], selected_classes[1])
test = prepare_data(test, selected_classes[0], selected_classes[1])

# Display infos about dataset
print(
    f"Train set size: {len(train)} samples, classes proportions: "
    f"{100 * (train.tensors[1] == 1).numpy().mean():.2f} %"
)
print(
    f"Test set size: {len(test)} samples, classes proportions: "
    f"{100 * (test.tensors[1] == 1).numpy().mean():.2f} %"
)
Train set size: 11774 samples, classes proportions: 50.31 %
Test set size: 1954 samples, classes proportions: 50.15 %

2. Build Lipschitz model

Here, the experiments are done with a model with only fully-connected layers. However, torchlip also provides state-of-the-art 1-Lipschitz convolutional layers.

import torch
from deel import torchlip

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ninputs = 28 * 28
wass = torchlip.Sequential(
    torch.nn.Flatten(),
    torchlip.SpectralLinear(ninputs, 128),
    torchlip.FullSort(),
    torchlip.SpectralLinear(128, 64),
    torchlip.FullSort(),
    torchlip.SpectralLinear(64, 32),
    torchlip.FullSort(),
    torchlip.FrobeniusLinear(32, 1),
).to(device)

wass
Sequential(
  (0): Flatten(start_dim=1, end_dim=-1)
  (1): ParametrizedSpectralLinear(
    in_features=784, out_features=128, bias=True
    (parametrizations): ModuleDict(
      (weight): ParametrizationList(
        (0): _SpectralNorm()
        (1): _BjorckNorm()
      )
    )
  )
  (2): FullSort()
  (3): ParametrizedSpectralLinear(
    in_features=128, out_features=64, bias=True
    (parametrizations): ModuleDict(
      (weight): ParametrizationList(
        (0): _SpectralNorm()
        (1): _BjorckNorm()
      )
    )
  )
  (4): FullSort()
  (5): ParametrizedSpectralLinear(
    in_features=64, out_features=32, bias=True
    (parametrizations): ModuleDict(
      (weight): ParametrizationList(
        (0): _SpectralNorm()
        (1): _BjorckNorm()
      )
    )
  )
  (6): FullSort()
  (7): ParametrizedFrobeniusLinear(
    in_features=32, out_features=1, bias=True
    (parametrizations): ModuleDict(
      (weight): ParametrizationList(
        (0): _FrobeniusNorm()
      )
    )
  )
)

3. Learn classification on MNIST

from deel.torchlip import KRLoss, HKRLoss, HingeMarginLoss

# training parameters
epochs = 10
batch_size = 128

# loss parameters
min_margin = 1
alpha = 0.98

kr_loss = KRLoss()
hkr_loss = HKRLoss(alpha=alpha, min_margin=min_margin)
hinge_margin_loss =HingeMarginLoss(min_margin=min_margin)
optimizer = torch.optim.Adam(lr=0.001, params=wass.parameters())

train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test, batch_size=32, shuffle=False)

for epoch in range(epochs):

    m_kr, m_hm, m_acc = 0, 0, 0
    wass.train()

    for step, (data, target) in enumerate(train_loader):

        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = wass(data)
        loss = hkr_loss(output, target)
        loss.backward()
        optimizer.step()

        # Compute metrics on batch
        m_kr += kr_loss(output, target)
        m_hm += hinge_margin_loss(output, target)
        m_acc += (torch.sign(output).flatten() == torch.sign(target)).sum() / len(
            target
        )

    # Train metrics for the current epoch
    metrics = [
        f"{k}: {v:.04f}"
        for k, v in {
            "loss": loss,
            "KR": m_kr / (step + 1),
            "acc": m_acc / (step + 1),
        }.items()
    ]

    # Compute test loss for the current epoch
    wass.eval()
    testo = []
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        testo.append(wass(data).detach().cpu())
    testo = torch.cat(testo).flatten()

    # Validation metrics for the current epoch
    metrics += [
        f"val_{k}: {v:.04f}"
        for k, v in {
            "loss": hkr_loss(
                testo, test.tensors[1]
            ),
            "KR": kr_loss(testo.flatten(), test.tensors[1]),
            "acc": (torch.sign(testo).flatten() == torch.sign(test.tensors[1]))
            .float()
            .mean(),
        }.items()
    ]

    print(f"Epoch {epoch + 1}/{epochs}")
    print(" - ".join(metrics))
Epoch 1/10
loss: -0.0341 - KR: 1.2932 - acc: 0.8849 - val_loss: -0.0332 - val_KR: 2.2678 - val_acc: 0.9918
Epoch 2/10
loss: -0.0578 - KR: 2.7714 - acc: 0.9918 - val_loss: -0.0588 - val_KR: 3.3804 - val_acc: 0.9928
Epoch 3/10
loss: -0.0787 - KR: 3.8889 - acc: 0.9930 - val_loss: -0.0777 - val_KR: 4.5303 - val_acc: 0.9903
Epoch 4/10
loss: -0.0989 - KR: 4.8591 - acc: 0.9946 - val_loss: -0.0949 - val_KR: 5.2321 - val_acc: 0.9928
Epoch 5/10
loss: -0.1063 - KR: 5.4602 - acc: 0.9948 - val_loss: -0.1064 - val_KR: 5.7756 - val_acc: 0.9923
Epoch 6/10
loss: -0.1025 - KR: 5.8678 - acc: 0.9946 - val_loss: -0.1123 - val_KR: 6.0558 - val_acc: 0.9928
Epoch 7/10
loss: -0.1155 - KR: 6.0850 - acc: 0.9958 - val_loss: -0.1178 - val_KR: 6.2237 - val_acc: 0.9939
Epoch 8/10
loss: -0.0989 - KR: 6.2505 - acc: 0.9957 - val_loss: -0.1193 - val_KR: 6.3435 - val_acc: 0.9933
Epoch 9/10
loss: -0.1330 - KR: 6.3675 - acc: 0.9956 - val_loss: -0.1218 - val_KR: 6.4925 - val_acc: 0.9939
Epoch 10/10
loss: -0.1272 - KR: 6.4819 - acc: 0.9966 - val_loss: -0.1148 - val_KR: 6.5502 - val_acc: 0.9862

4. Evaluate the Lipschitz constant of our networks

4.1. Empirical evaluation

We can estimate the Lipschitz constant by evaluating

F(x2)F(x1)x2x1orF(x+ϵ)F(x)ϵ\frac{\Vert{}F(x_2) - F(x_1)\Vert{}}{\Vert{}x_2 - x_1\Vert{}} \quad\text{or}\quad \frac{\Vert{}F(x + \epsilon) - F(x)\Vert{}}{\Vert{}\epsilon\Vert{}}

for various inputs.

The deel.torchlip.utils.evaluate_lip_const implements several methods to evaluate this constant, either by adding random noise x+ϵx+\epsilon and evaluating F(x+ϵ)F(x)ϵ\frac{\Vert{}F(x + \epsilon) - F(x)\Vert{}}{\Vert{}\epsilon\Vert{}}, or by an adversarial attack on ϵ\epsilon to increase this value, or by computing the jacobian norm xF(x)||\nabla_x F(x)||

from deel.torchlip.utils import evaluate_lip_const
x,y = next(iter(test_loader))
evaluate_lip_const(wass, x.to(device), evaluation_type="all", disjoint_neurons=False, double_attack=True)
Empirical lipschitz constant is 1.0000001192092896 with method jacobian_norm
Empirical lipschitz constant is 0.16324962675571442 with method noise_norm
Warning : double_attack is set to True,                 the computation time will be doubled
Empirical lipschitz constant is 0.9863337874412537 with method attack
1.0000001192092896

As we can see, each method provide an underestimation of the Lipschitz constant. The adversarial attack method and jacobian norm find values very close to one.

4.1. Singular-Value Decomposition

Since our network is only made of linear layers and FullSort activation, we can compute Singular-Value Decomposition (SVD) of our weight matrix and check that, for each linear layer, all singular values are 1.

print("=== Before export ===")
layers = list(wass.children())
for layer in layers:
    if hasattr(layer, "weight"):
        w = layer.weight
        u, s, v = torch.svd(w)
        print(f"{type(layer)}, min={s.min()}, max={s.max()}")
=== Before export ===
<class 'abc.ParametrizedSpectralLinear'>, min=0.9999998807907104, max=1.0
<class 'abc.ParametrizedSpectralLinear'>, min=1.000000238418579, max=1.000009536743164
<class 'abc.ParametrizedSpectralLinear'>, min=0.9999998807907104, max=1.0
<class 'abc.ParametrizedFrobeniusLinear'>, min=0.9999998211860657, max=0.9999998211860657

4.2 Model export

Once training is finished, the model can be optimized for inference by using the vanilla_export() method. The torchlip layers are converted to their PyTorch counterparts, e.g. SpectralConv2d layers will be converted into torch.nn.Conv2d layers.

Warnings:

vanilla_export method modifies the model in-place.

In order to build and export a new model while keeping the reference one, it is required to follow these steps:

# Build e new mode for instance with torchlip.Sequential( torchlip.SpectralConv2d(…), …)

wexport = <your_function_to_build_the_model>()

# Copy the parameters from the reference t the new model

wexport.load_state_dict(wass.state_dict())

# one forward required to initialize pamatrizations

vanilla_model(one_input)

# vanilla_export the new model

wexport = wexport.vanilla_export()

wexport = wass.vanilla_export()

print("=== After export ===")
layers = list(wexport.children())
for layer in layers:
    if hasattr(layer, "weight"):
        w = layer.weight
        u, s, v = torch.svd(w)
        print(f"{type(layer)}, min={s.min()}, max={s.max()}")
=== After export ===
<class 'torch.nn.modules.linear.Linear'>, min=0.9999998807907104, max=1.0
<class 'torch.nn.modules.linear.Linear'>, min=1.000000238418579, max=1.000009536743164
<class 'torch.nn.modules.linear.Linear'>, min=0.9999998807907104, max=1.0
<class 'torch.nn.modules.linear.Linear'>, min=0.9999998211860657, max=0.9999998211860657

As we can see, all our singular values are very close to one.


© Copyright 2020, IRT Antoine de Saint Exupéry - All rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry, CRIAQ and ANITI..

Built with Sphinx using PyTorch's theme provided originally by Read the Docs.