Skip to content

Demo 3: HKR Classifier on MNIST dataset

Demo 3: HKR classifier on MNIST dataset

Open In Colab

This notebook will demonstrate learning a binary task on the MNIST0-8 dataset.

# pip install deel-lip -qqq
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.python.keras.layers import Input, Flatten
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.metrics import binary_accuracy
from tensorflow.keras.models import Sequential

from deel.lip.layers import (
    SpectralConv2D,
    SpectralDense,
    FrobeniusDense,
    ScaledL2NormPooling2D,
)
from deel.lip.activations import MaxMin, GroupSort, GroupSort2, FullSort
from deel.lip.losses import HKR, KR, HingeMargin
2021-09-08 18:34:34.803681: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0

data preparation

For this task we will select two classes: 0 and 8. Labels are changed to {-1,1}, wich is compatible with the Hinge term used in the loss.

from tensorflow.keras.datasets import mnist

# first we select the two classes
selected_classes = [0, 8]  # must be two classes as we perform binary classification


def prepare_data(x, y, class_a=0, class_b=8):
    """
    This function convert the MNIST data to make it suitable for our binary classification
    setup.
    """
    # select items from the two selected classes
    mask = (y == class_a) + (
        y == class_b
    )  # mask to select only items from class_a or class_b
    x = x[mask]
    y = y[mask]
    x = x.astype("float32")
    y = y.astype("float32")
    # convert from range int[0,255] to float32[-1,1]
    x /= 255
    x = x.reshape((-1, 28, 28, 1))
    # change label to binary classification {-1,1}
    y[y == class_a] = 1.0
    y[y == class_b] = -1.0
    return x, y


# now we load the dataset
(x_train, y_train_ord), (x_test, y_test_ord) = mnist.load_data()

# prepare the data
x_train, y_train = prepare_data(
    x_train, y_train_ord, selected_classes[0], selected_classes[1]
)
x_test, y_test = prepare_data(
    x_test, y_test_ord, selected_classes[0], selected_classes[1]
)

# display infos about dataset
print(
    "train set size: %i samples, classes proportions: %.3f percent"
    % (y_train.shape[0], 100 * y_train[y_train == 1].sum() / y_train.shape[0])
)
print(
    "test set size: %i samples, classes proportions: %.3f percent"
    % (y_test.shape[0], 100 * y_test[y_test == 1].sum() / y_test.shape[0])
)
train set size: 11774 samples, classes proportions: 50.306 percent
test set size: 1954 samples, classes proportions: 50.154 percent

Build lipschitz Model

Let's first explicit the paremeters of this experiment

# training parameters
epochs = 10
batch_size = 128

# network parameters
activation = GroupSort  # ReLU, MaxMin, GroupSort2

# loss parameters
min_margin = 1.0
alpha = 10.0

Now we can build the network. Here the experiment is done with a MLP. But Deel-lip also provide state of the art 1-Lipschitz convolutions.

K.clear_session()
# helper function to build the 1-lipschitz MLP
wass = Sequential(
    layers=[
        Input((28, 28, 1)),
        Flatten(),
        SpectralDense(32, GroupSort2(), use_bias=True),
        SpectralDense(16, GroupSort2(), use_bias=True),
        FrobeniusDense(1, activation=None, use_bias=False),
    ],
    name="lipModel",
)
wass.summary()
Model: "lipModel"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
flatten (Flatten)            (None, 784)               0         
_________________________________________________________________
spectral_dense (SpectralDens (None, 32)                50241     
_________________________________________________________________
spectral_dense_1 (SpectralDe (None, 16)                1057      
_________________________________________________________________
frobenius_dense (FrobeniusDe (None, 1)                 32        
=================================================================
Total params: 51,330
Trainable params: 25,664
Non-trainable params: 25,666
_________________________________________________________________

optimizer = Adam(lr=0.001)
# as the output of our classifier is in the real range [-1, 1], binary accuracy must be redefined
def HKR_binary_accuracy(y_true, y_pred):
    S_true = tf.dtypes.cast(tf.greater_equal(y_true[:, 0], 0), dtype=tf.float32)
    S_pred = tf.dtypes.cast(tf.greater_equal(y_pred[:, 0], 0), dtype=tf.float32)
    return binary_accuracy(S_true, S_pred)
wass.compile(
    loss=HKR(
        alpha=alpha, min_margin=min_margin
    ),  # HKR stands for the hinge regularized KR loss
    metrics=[
        KR,  # shows the KR term of the loss
        HingeMargin(min_margin=min_margin),  # shows the hinge term of the loss
        HKR_binary_accuracy,  # shows the classification accuracy
    ],
    optimizer=optimizer,
)

Learn classification on MNIST

Now the model is build, we can learn the task.

wass.fit(
    x=x_train,
    y=y_train,
    validation_data=(x_test, y_test),
    batch_size=batch_size,
    shuffle=True,
    epochs=epochs,
    verbose=1,
)
Epoch 1/10
92/92 [==============================] - 2s 10ms/step - loss: -1.6675 - KR: 3.7144 - HingeMargin: 0.2047 - HKR_binary_accuracy: 0.9382 - val_loss: -5.0961 - val_KR: 5.5990 - val_HingeMargin: 0.0519 - val_HKR_binary_accuracy: 0.9786
Epoch 2/10
92/92 [==============================] - 1s 7ms/step - loss: -5.0297 - KR: 5.5716 - HingeMargin: 0.0542 - HKR_binary_accuracy: 0.9793 - val_loss: -5.4469 - val_KR: 5.7710 - val_HingeMargin: 0.0354 - val_HKR_binary_accuracy: 0.9879
Epoch 3/10
92/92 [==============================] - 1s 7ms/step - loss: -5.3788 - KR: 5.7838 - HingeMargin: 0.0405 - HKR_binary_accuracy: 0.9858 - val_loss: -5.6435 - val_KR: 5.9555 - val_HingeMargin: 0.0334 - val_HKR_binary_accuracy: 0.9860
Epoch 4/10
92/92 [==============================] - 1s 8ms/step - loss: -5.6172 - KR: 5.9671 - HingeMargin: 0.0350 - HKR_binary_accuracy: 0.9874 - val_loss: -5.7918 - val_KR: 6.0764 - val_HingeMargin: 0.0308 - val_HKR_binary_accuracy: 0.9879
Epoch 5/10
92/92 [==============================] - 1s 7ms/step - loss: -5.7598 - KR: 6.0676 - HingeMargin: 0.0308 - HKR_binary_accuracy: 0.9891 - val_loss: -5.8711 - val_KR: 6.1062 - val_HingeMargin: 0.0264 - val_HKR_binary_accuracy: 0.9899
Epoch 6/10
92/92 [==============================] - 1s 7ms/step - loss: -5.7647 - KR: 6.0829 - HingeMargin: 0.0318 - HKR_binary_accuracy: 0.9879 - val_loss: -5.8503 - val_KR: 6.1463 - val_HingeMargin: 0.0315 - val_HKR_binary_accuracy: 0.9879
Epoch 7/10
92/92 [==============================] - 1s 7ms/step - loss: -5.8007 - KR: 6.1082 - HingeMargin: 0.0307 - HKR_binary_accuracy: 0.9884 - val_loss: -5.8470 - val_KR: 6.1179 - val_HingeMargin: 0.0296 - val_HKR_binary_accuracy: 0.9879
Epoch 8/10
92/92 [==============================] - 1s 7ms/step - loss: -5.8268 - KR: 6.1185 - HingeMargin: 0.0292 - HKR_binary_accuracy: 0.9897 - val_loss: -5.8439 - val_KR: 6.1153 - val_HingeMargin: 0.0294 - val_HKR_binary_accuracy: 0.9889
Epoch 9/10
92/92 [==============================] - 1s 7ms/step - loss: -5.8865 - KR: 6.1548 - HingeMargin: 0.0268 - HKR_binary_accuracy: 0.9910 - val_loss: -5.8800 - val_KR: 6.1668 - val_HingeMargin: 0.0312 - val_HKR_binary_accuracy: 0.9874
Epoch 10/10
92/92 [==============================] - 1s 7ms/step - loss: -5.8578 - KR: 6.1453 - HingeMargin: 0.0288 - HKR_binary_accuracy: 0.9892 - val_loss: -5.9233 - val_KR: 6.1783 - val_HingeMargin: 0.0282 - val_HKR_binary_accuracy: 0.9889

<tensorflow.python.keras.callbacks.History at 0x7fce2c6635d0>

As we can see the model reach a very decent accuracy on this task.