Demo 2: HKR Classifier on toy dataset
Demo 2: HKR Classifier on toy dataset¶
In this demo notebook we will show how to build a robust
classifier based on the regularized version of the Kantorovitch-Rubinstein
duality.
We will perform this on the two moons
synthetic dataset.
# pip install deel-lip -qqq
import numpy as np
from sklearn.datasets import make_moons, make_circles # the synthetic dataset
import matplotlib.pyplot as plt
import seaborn as sns
# in order to build our classifier we will use element from tensorflow along with
# layers from deel-lip
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.layers import ReLU, Input
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.metrics import binary_accuracy
from deel.lip.model import Model # use of deel.lip is not mandatory but offers the vanilla_export feature
from deel.lip.layers import SpectralConv2D, SpectralDense, FrobeniusDense
from deel.lip.activations import MaxMin, GroupSort, FullSort, GroupSort2
from deel.lip.losses import HKR, KR, HingeMargin # custom losses for HKR robust classif
Parameters¶
Let's first construct our two moons dataset
circle_or_moons = 1 # 0 for circle , 1 for moons
n_samples=5000 # number of sample in the dataset
noise=0.05 # amount of noise to add in the data. Tested with 0.14 for circles 0.05 for two moons
factor=0.4 # scale factor between the inner and the outer circle
if circle_or_moons == 0:
X,Y=make_circles(n_samples=n_samples,noise=noise,factor=factor)
else:
X,Y=make_moons(n_samples=n_samples,noise=noise)
# When working with the HKR-classifier, using labels {-1, 1} instead of {0, 1} is advised.
# This will be explained further on
Y[Y==1]=-1
Y[Y==0]=1
X1=X[Y==1]
X2=X[Y==-1]
sns.scatterplot(X1[:1000,0],X1[:1000,1])
sns.scatterplot(X2[:1000,0],X2[:1000,1])
Relation with optimal transport¶
In this setup we can solve the optimal transport problem
between the distribution of X[Y==1]
and X[Y==-1]
. This
usually require to match each element of the first distribution
with an element of the second distribution such that this minimize
a global cost. In our setup this cost is the $ l_1 $ distance, which
will allow us to make use of the KR dual formulation. The overall cost
is then the \(W_1\) distance.
Wasserstein distance¶
The wasserstein distance measure the distance between two probability distribution. Wikipedia article gives a more intuitive definition of it:
> Intuitively, if each distribution is viewed as a unit amount of "dirt" piled on {\displaystyle M}M, the metric is the minimum "cost" of turning one pile into the other, which is assumed to be the amount of dirt that needs to be moved times the mean distance it has to be moved. Because of this analogy, the metric is known in computer science as the earth mover's distance.
Mathematically it is defined as:
where \(\Pi(\mu,\nu)\) is the set of all probability measures on \(\Omega\times \Omega\) with marginals \(\mu\) and \(\nu\). In most case this equation is not tractable.
However the \(W_1\) distance is known to be untractable in general.
KR dual formulation¶
In our setup, the KR dual formulation is stated as following: $$ W_1(\mu, \nu) = \sup_{f \in Lip_1(\Omega)} \underset{\textbf{x} \sim \mu}{\mathbb{E}} \left[f(\textbf{x} )\right] -\underset{\textbf{x} \sim \nu}{\mathbb{E}} \left[f(\textbf{x} )\right] $$
This state the problem as an optimization problem over the 1-lipschitz functions. Therefore k-Lipschitz networks allows us to solve this maximization problem.
Hinge-KR classification¶
When dealing with \(W_1\) one may note that many functions maximize the maximization problem described above. Also we want this function to be meaningfull in terms of classification. To do so, we want f to be centered in 0, which can be done without altering the inital problem. By doing so we can use the obtained function for binary classification, by looking at the sign of \(f\).
In order to enforce this, we will add a Hinge term to the loss. It has been shown that this new problem is still a optimal transport problem and that this problem admit a meaningfull optimal solution.
HKR-Classifier¶
Now we will show how to build a binary classifier based on the regularized version of the KR dual problem.
In order to ensure the 1-Lipschitz constraint deel-lip
uses spectral normalization. These layers also can also use Bjork orthonormalization to ensure that the gradient of the layer is 1 almost everywhere. Experiment shows that the optimal solution lie in this sub-class of functions.
batch_size=256
steps_per_epoch=40480
epoch=10
hidden_layers_size = [256,128,64] # stucture of the network
activation = FullSort # other lipschitz activation are ReLU, MaxMin, GroupSort2, GroupSort
min_margin= 0.29 # minimum margin to enforce between the values of f for each class
# build data generator
def otp_generator(batch_size, X, Y):
Y_ix = np.array([i for i in range(Y.shape[0])])
Y0_ix = Y_ix[Y == 1]
Y1_ix = Y_ix[Y == -1]
half = Y.shape[0] // 2
while True:
batch_x = np.zeros(((batch_size,) + (X[0].shape)), dtype=np.float32)
batch_y = np.zeros((batch_size, 1), dtype=np.float32)
ind = np.random.choice(Y0_ix, size=batch_size // 2, replace=False)
batch_x[:batch_size // 2, ] = X[ind]
batch_y[:batch_size // 2, 0] = Y[ind]
ind = np.random.choice(Y1_ix, size=batch_size // 2, replace=False)
batch_x[batch_size // 2:, ] = X[ind]
batch_y[batch_size // 2:, 0] = Y[ind]
yield batch_x, batch_y
gen=otp_generator(batch_size,X,Y)
Build lipschitz Model¶
Let's build our model now.
K.clear_session()
# please note that calling the previous helper function has the exact
# same effect as the following code:
inputs = Input((2,))
x = SpectralDense(256, activation=activation())(inputs)
x = SpectralDense(128, activation=activation())(x)
x = SpectralDense(64, activation=activation())(x)
y = FrobeniusDense(1, activation=None)(x)
wass = Model(inputs=inputs, outputs=y)
wass.summary()
As we can see the network has a gradient equal to 1 almost everywhere as all the layers respect this property.
It is good to note that the last layer is a FrobeniusDense
this is because, when we have a single
output, it become equivalent to normalize the frobenius norm and the spectral norm (as we only have a single singular value)
optimizer = Adam(lr=0.01)
# as the output of our classifier is in the real range [-1, 1], binary accuracy must be redefined
def HKR_binary_accuracy(y_true, y_pred):
S_true= tf.dtypes.cast(tf.greater_equal(y_true[:,0], 0),dtype=tf.float32)
S_pred= tf.dtypes.cast(tf.greater_equal(y_pred[:,0], 0),dtype=tf.float32)
return binary_accuracy(S_true,S_pred)
wass.compile(
loss=HKR(alpha=10,min_margin=min_margin), # HKR stands for the hinge regularized KR loss
metrics=[
KR, # shows the KR term of the loss
HingeMargin(min_margin=min_margin), # shows the hinge term of the loss
HKR_binary_accuracy # shows the classification accuracy
],
optimizer=optimizer
)
Learn classification on toy dataset¶
Now we are ready to learn the classification task on the two moons dataset.
wass.fit_generator(
gen,
steps_per_epoch=steps_per_epoch // batch_size,
epochs=epoch,
verbose=1
)
Plot output countour line¶
As we can see the classifier get a pretty good accuracy. Let's now take a look at the learnt function. As we are in the 2D space, we can draw a countour plot to visualize f.
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
from matplotlib.ticker import LinearLocator, FormatStrFormatter
batch_size=1024
x = np.linspace(X[:,0].min()-0.2, X[:,0].max()+0.2, 120)
y = np.linspace(X[:,1].min()-0.2, X[:,1].max()+0.2,120)
xx, yy = np.meshgrid(x, y, sparse=False)
X_pred=np.stack((xx.ravel(),yy.ravel()),axis=1)
# make predictions of f
pred=wass.predict(X_pred)
Y_pred=pred
Y_pred=Y_pred.reshape(x.shape[0],y.shape[0])
#plot the results
fig = plt.figure(figsize=(10,7))
ax1 = fig.add_subplot(111)
sns.scatterplot(X[Y==1,0],X[Y==1,1],alpha=0.1,ax=ax1)
sns.scatterplot(X[Y==-1,0],X[Y==-1,1],alpha=0.1,ax=ax1)
cset =ax1.contour(xx,yy,Y_pred,cmap='twilight')
ax1.clabel(cset, inline=1, fontsize=10)
Transfer network to a classical MLP and compare outputs¶
As we saw, our networks use custom layers in order to constrain training.
However during inference layers behave exactly as regular Dense
or Conv2d
layers.
Deel-lip has a functionnality to export a model to it's vanilla keras equivalent. Making it more
convenient for inference.
from deel.lip.model import vanillaModel
## this is equivalent to test2 = wass.vanilla_export()
test2 = vanillaModel(wass)
test2.summary()
pred_test=test2.predict(X_pred)
Y_pred=pred_test
Y_pred=Y_pred.reshape(x.shape[0],y.shape[0])
fig = plt.figure(figsize=(10,7))
ax1 = fig.add_subplot(111)
#ax2 = fig.add_subplot(312)
#ax3 = fig.add_subplot(313)
sns.scatterplot(X[Y==1,0],X[Y==1,1],alpha=0.1,ax=ax1)
sns.scatterplot(X[Y==-1,0],X[Y==-1,1],alpha=0.1,ax=ax1)
cset =ax1.contour(xx,yy,Y_pred,cmap='twilight')
ax1.clabel(cset, inline=1, fontsize=10)