Example 3: HKR classifier on MNIST dataset¶
This notebook demonstrates how to learn a binary classifier on the MNIST0-8 dataset (MNIST with only 0 and 8).
# Install the required library deel-torchlip (uncomment line below)
# %pip install -qqq deel-torchlip
1. Data preparation¶
For this task we will select two classes: 0 and 8. Labels are changed to {-1,1}, which is compatible with the hinge term used in the loss.
import torch
from torchvision import datasets
# First we select the two classes
selected_classes = [0, 8] # must be two classes as we perform binary classification
def prepare_data(dataset, class_a=0, class_b=8):
"""
This function converts the MNIST data to make it suitable for our binary
classification setup.
"""
x = dataset.data
y = dataset.targets
# select items from the two selected classes
mask = (y == class_a) + (
y == class_b
) # mask to select only items from class_a or class_b
x = x[mask]
y = y[mask]
# convert from range int[0,255] to float32[-1,1]
x = x.float() / 255
x = x.reshape((-1, 28, 28, 1))
# change label to binary classification {-1,1}
y_ = torch.zeros_like(y).float()
y_[y == class_a] = 1.0
y_[y == class_b] = -1.0
return torch.utils.data.TensorDataset(x, y_)
train = datasets.MNIST("./data", train=True, download=True)
test = datasets.MNIST("./data", train=False, download=True)
# Prepare the data
train = prepare_data(train, selected_classes[0], selected_classes[1])
test = prepare_data(test, selected_classes[0], selected_classes[1])
# Display infos about dataset
print(
f"Train set size: {len(train)} samples, classes proportions: "
f"{100 * (train.tensors[1] == 1).numpy().mean():.2f} %"
)
print(
f"Test set size: {len(test)} samples, classes proportions: "
f"{100 * (test.tensors[1] == 1).numpy().mean():.2f} %"
)
Train set size: 11774 samples, classes proportions: 50.31 %
Test set size: 1954 samples, classes proportions: 50.15 %
2. Build Lipschitz model¶
Here, the experiments are done with a model with only fully-connected
layers. However, torchlip
also provides state-of-the-art 1-Lipschitz
convolutional layers.
import torch
from deel import torchlip
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ninputs = 28 * 28
wass = torchlip.Sequential(
torch.nn.Flatten(),
torchlip.SpectralLinear(ninputs, 128),
torchlip.FullSort(),
torchlip.SpectralLinear(128, 64),
torchlip.FullSort(),
torchlip.SpectralLinear(64, 32),
torchlip.FullSort(),
torchlip.FrobeniusLinear(32, 1),
).to(device)
wass
Sequential(
(0): Flatten(start_dim=1, end_dim=-1)
(1): ParametrizedSpectralLinear(
in_features=784, out_features=128, bias=True
(parametrizations): ModuleDict(
(weight): ParametrizationList(
(0): _SpectralNorm()
(1): _BjorckNorm()
)
)
)
(2): FullSort()
(3): ParametrizedSpectralLinear(
in_features=128, out_features=64, bias=True
(parametrizations): ModuleDict(
(weight): ParametrizationList(
(0): _SpectralNorm()
(1): _BjorckNorm()
)
)
)
(4): FullSort()
(5): ParametrizedSpectralLinear(
in_features=64, out_features=32, bias=True
(parametrizations): ModuleDict(
(weight): ParametrizationList(
(0): _SpectralNorm()
(1): _BjorckNorm()
)
)
)
(6): FullSort()
(7): ParametrizedFrobeniusLinear(
in_features=32, out_features=1, bias=True
(parametrizations): ModuleDict(
(weight): ParametrizationList(
(0): _FrobeniusNorm()
)
)
)
)
3. Learn classification on MNIST¶
from deel.torchlip import KRLoss, HKRLoss, HingeMarginLoss
# training parameters
epochs = 10
batch_size = 128
# loss parameters
min_margin = 1
alpha = 0.98
kr_loss = KRLoss()
hkr_loss = HKRLoss(alpha=alpha, min_margin=min_margin)
hinge_margin_loss =HingeMarginLoss(min_margin=min_margin)
optimizer = torch.optim.Adam(lr=0.001, params=wass.parameters())
train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test, batch_size=32, shuffle=False)
for epoch in range(epochs):
m_kr, m_hm, m_acc = 0, 0, 0
wass.train()
for step, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = wass(data)
loss = hkr_loss(output, target)
loss.backward()
optimizer.step()
# Compute metrics on batch
m_kr += kr_loss(output, target)
m_hm += hinge_margin_loss(output, target)
m_acc += (torch.sign(output).flatten() == torch.sign(target)).sum() / len(
target
)
# Train metrics for the current epoch
metrics = [
f"{k}: {v:.04f}"
for k, v in {
"loss": loss,
"KR": m_kr / (step + 1),
"acc": m_acc / (step + 1),
}.items()
]
# Compute test loss for the current epoch
wass.eval()
testo = []
for data, target in test_loader:
data, target = data.to(device), target.to(device)
testo.append(wass(data).detach().cpu())
testo = torch.cat(testo).flatten()
# Validation metrics for the current epoch
metrics += [
f"val_{k}: {v:.04f}"
for k, v in {
"loss": hkr_loss(
testo, test.tensors[1]
),
"KR": kr_loss(testo.flatten(), test.tensors[1]),
"acc": (torch.sign(testo).flatten() == torch.sign(test.tensors[1]))
.float()
.mean(),
}.items()
]
print(f"Epoch {epoch + 1}/{epochs}")
print(" - ".join(metrics))
Epoch 1/10
loss: -0.0272 - KR: 1.4492 - acc: 0.9184 - val_loss: -0.0367 - val_KR: 2.3308 - val_acc: 0.9939
Epoch 2/10
loss: -0.0518 - KR: 2.7784 - acc: 0.9926 - val_loss: -0.0574 - val_KR: 3.3190 - val_acc: 0.9939
Epoch 3/10
loss: -0.0782 - KR: 3.6303 - acc: 0.9938 - val_loss: -0.0751 - val_KR: 4.1403 - val_acc: 0.9939
Epoch 4/10
loss: -0.0978 - KR: 4.5607 - acc: 0.9952 - val_loss: -0.0927 - val_KR: 4.9920 - val_acc: 0.9933
Epoch 5/10
loss: -0.0873 - KR: 5.2546 - acc: 0.9958 - val_loss: -0.1037 - val_KR: 5.5868 - val_acc: 0.9944
Epoch 6/10
loss: -0.1186 - KR: 5.7066 - acc: 0.9960 - val_loss: -0.1081 - val_KR: 5.9397 - val_acc: 0.9913
Epoch 7/10
loss: -0.1189 - KR: 6.0129 - acc: 0.9955 - val_loss: -0.1161 - val_KR: 6.1834 - val_acc: 0.9933
Epoch 8/10
loss: -0.1281 - KR: 6.2577 - acc: 0.9958 - val_loss: -0.1151 - val_KR: 6.3653 - val_acc: 0.9923
Epoch 9/10
loss: -0.1292 - KR: 6.4227 - acc: 0.9967 - val_loss: -0.1216 - val_KR: 6.5185 - val_acc: 0.9933
Epoch 10/10
loss: -0.1375 - KR: 6.5687 - acc: 0.9965 - val_loss: -0.1253 - val_KR: 6.6100 - val_acc: 0.9939
4. Evaluate the Lipschitz constant of our networks¶
4.1. Empirical evaluation¶
We can estimate the Lipschitz constant by evaluating
for various inputs.
from scipy.spatial.distance import pdist
wass.eval()
p = []
for _ in range(64):
eps = 1e-3
batch, _ = next(iter(train_loader))
dist = torch.distributions.Uniform(-eps, +eps).sample(batch.shape)
y1 = wass(batch.to(device)).detach().cpu()
y2 = wass((batch + dist).to(device)).detach().cpu()
p.append(
torch.max(
torch.norm(y2 - y1, dim=1)
/ torch.norm(dist.reshape(dist.shape[0], -1), dim=1)
)
)
print(torch.tensor(p).max())
tensor(0.1420)
p = []
for batch, _ in train_loader:
x = batch.numpy()
y = wass(batch.to(device)).detach().cpu().numpy()
xd = pdist(x.reshape(batch.shape[0], -1))
yd = pdist(y.reshape(batch.shape[0], -1))
p.append((yd / xd).max())
print(torch.tensor(p).max())
tensor(0.8841, dtype=torch.float64)
As we can see, using the -version, we greatly under-estimate the Lipschitz constant. Using the train dataset, we find a Lipschitz constant close to 0.9, which is better, but our network should be 1-Lipschitz.
4.1. Singular-Value Decomposition¶
Since our network is only made of linear layers and FullSort
activation, we can compute Singular-Value Decomposition (SVD) of our
weight matrix and check that, for each linear layer, all singular values
are 1.
print("=== Before export ===")
layers = list(wass.children())
for layer in layers:
if hasattr(layer, "weight"):
w = layer.weight
u, s, v = torch.svd(w)
print(f"{layer}, min={s.min()}, max={s.max()}")
=== Before export ===
ParametrizedSpectralLinear(
in_features=784, out_features=128, bias=True
(parametrizations): ModuleDict(
(weight): ParametrizationList(
(0): _SpectralNorm()
(1): _BjorckNorm()
)
)
), min=0.9999998211860657, max=1.0
ParametrizedSpectralLinear(
in_features=128, out_features=64, bias=True
(parametrizations): ModuleDict(
(weight): ParametrizationList(
(0): _SpectralNorm()
(1): _BjorckNorm()
)
)
), min=1.000001072883606, max=1.000012755393982
ParametrizedSpectralLinear(
in_features=64, out_features=32, bias=True
(parametrizations): ModuleDict(
(weight): ParametrizationList(
(0): _SpectralNorm()
(1): _BjorckNorm()
)
)
), min=0.9999998807907104, max=1.0
ParametrizedFrobeniusLinear(
in_features=32, out_features=1, bias=True
(parametrizations): ModuleDict(
(weight): ParametrizationList(
(0): _FrobeniusNorm()
)
)
), min=1.0000001192092896, max=1.0000001192092896
4.2 Model export¶
Once training is finished, the model can be optimized for inference by
using the vanilla_export()
method. The torchlip
layers are
converted to their PyTorch counterparts, e.g. SpectralConv2d
layers will be converted into torch.nn.Conv2d
layers.
Warnings:¶
vanilla_export method modifies the model in-place.
In order to build and export a new model while keeping the reference one, it is required to follow these steps:
# Build e new mode for instance with torchlip.Sequential( torchlip.SpectralConv2d(…), …)
wexport = <your_function_to_build_the_model>()
# Copy the parameters from the reference t the new model
wexport.load_state_dict(wass.state_dict())
# one forward required to initialize pamatrizations
vanilla_model(one_input)
# vanilla_export the new model
wexport = wexport.vanilla_export()
wexport = wass.vanilla_export()
print("=== After export ===")
layers = list(wexport.children())
for layer in layers:
if hasattr(layer, "weight"):
w = layer.weight
u, s, v = torch.svd(w)
print(f"{layer}, min={s.min()}, max={s.max()}")
=== After export ===
Linear(in_features=784, out_features=128, bias=True), min=0.9999998211860657, max=1.0
Linear(in_features=128, out_features=64, bias=True), min=1.000001072883606, max=1.000012755393982
Linear(in_features=64, out_features=32, bias=True), min=0.9999998807907104, max=1.0
Linear(in_features=32, out_features=1, bias=True), min=1.0000001192092896, max=1.0000001192092896
As we can see, all our singular values are very close to one.