Demo 1: Wasserstein distance estimation on toy example
Demo 1: Wasserstein distance estimation on toy example¶
In this notebook we will see how to estimate the wasserstein distance with a Neural net by using the Kantorovich-Rubinestein dual representation.
Wasserstein distance¶
The wasserstein distance measure the distance between two probability distribution. Wikipedia article gives a more intuitive definition of it:
> Intuitively, if each distribution is viewed as a unit amount of "dirt" piled on M, the metric is the minimum "cost" of turning one pile into the other, which is assumed to be the amount of dirt that needs to be moved times the mean distance it has to be moved. Because of this analogy, the metric is known in computer science as the earth mover's distance.
Mathematically it is defined as:
where \(\Pi(\mu,\nu)\) is the set of all probability measures on \(\Omega\times \Omega\) with marginals \(\mu\) and \(\nu\). In most case this equation is not tractable.
KR dual formulation¶
In our setup, the KR dual formulation is stated as following: $$ W_1(\mu, \nu) = \sup_{f \in Lip_1(\Omega)} \underset{\textbf{x} \sim \mu}{\mathbb{E}} \left[f(\textbf{x} )\right] -\underset{\textbf{x} \sim \nu}{\mathbb{E}} \left[f(\textbf{x} )\right] $$
This state the problem as an optimization problem over the 1-lipschitz functions. Therefore k-Lipschitz networks allows us to solve this maximization problem.
[1] C. Anil, J. Lucas, et R. Grosse, « Sorting out Lipschitz function approximation », arXiv:1811.05381 [cs, stat], nov. 2018.
We will illustrate this on a synthetic image dataset where \(W_1\) is known.
# pip install deel-lip -qqq
from datetime import datetime
import os
import numpy as np
import math
import matplotlib.pyplot as plt
from tensorflow.keras import backend as K
from tensorflow.keras.layers import Input, Flatten, ReLU
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import load_model
from deel.lip.layers import SpectralConv2D, SpectralDense, FrobeniusDense
from deel.lip.activations import MaxMin, GroupSort, FullSort
from deel.lip.losses import KR, HKR
from deel.lip.model import Model
Parameters input images¶
The synthetic dataset will be composed image with black or white squares allowing us to check if the computed wasserstein distance is correct. One distribution will be the set of black images, while the other will be the set of images with a square on it. these two distribution are diracs, and the wasserstein distance can be analyticaly computed:
In the case to the two diracs the wasserstein distance is then the L1 distance between the two images.
img_size = 64
frac_value = 0.3 # proportion of the center square
Generate images¶
def generate_toy_images(shape,frac=0,v=1):
"""
function that generate a single image.
Args:
shape: shape of the output image
frac: proportion of the center square
value: value assigned to the center square
"""
img = np.zeros(shape)
if frac==0:
return img
frac=frac**0.5
#print(frac)
l=int(shape[0]*frac)
ldec=(shape[0]-l)//2
#print(l)
w=int(shape[1]*frac)
wdec=(shape[1]-w)//2
img[ldec:ldec+l,wdec:wdec+w,:]=v
return img
def binary_generator(batch_size,shape,frac=0):
"""
generate a batch with half of black images, hald of images with a white square.
"""
batch_x = np.zeros(((batch_size,)+(shape)), dtype=np.float16)
batch_y=np.zeros((batch_size,1), dtype=np.float16)
batch_x[batch_size//2:,]=generate_toy_images(shape,frac=frac,v=1)
batch_y[batch_size//2:]=1
while True:
yield batch_x, batch_y
def ternary_generator(batch_size,shape,frac=0):
"""
Same as binary generator, but images can have a white square of value 1, or value -1
"""
batch_x = np.zeros(((batch_size,)+(shape)), dtype=np.float16)
batch_y=np.zeros((batch_size,1), dtype=np.float16)
batch_x[3*batch_size//4:,]=generate_toy_images(shape,frac=frac,v=1)
batch_x[batch_size//2:3*batch_size//4,]=generate_toy_images(shape,frac=frac,v=-1)
batch_y[batch_size//2:]=1
#indexes_shuffle = np.arange(batch_size)
while True:
#np.random.shuffle(indexes_shuffle)
#yield batch_x[indexes_shuffle,], batch_y[indexes_shuffle,]
yield batch_x, batch_y
def display_img(img):
"""
Display an image
"""
if img.shape[-1] == 1:
img = np.tile(img,(3,))
fig, ax = plt.subplots()
imgplot = ax.imshow((img*255).astype(np.uint))
Now let's take a look at the generated batches
for binary generator¶
test=binary_generator(2,(img_size,img_size,1),frac=frac_value)
imgs, y=next(test)
display_img(imgs[0])
display_img(imgs[1])
print("Norm L2 "+str(np.linalg.norm(imgs[1])))
print("Norm L2(count pixels) "+str(math.sqrt(np.size(imgs[1][imgs[1]==1]))))
for ternary generator¶
test=ternary_generator(4,(img_size,img_size,1),frac=frac_value)
imgs, y=next(test)
for i in range(4):
display_img(0.5*(imgs[i]+1.0)) # we ensure that there is no negative value wehn displaying images
print("Norm L2(imgs[2]-imgs[0])"+str(np.linalg.norm(imgs[2]-imgs[0])))
print("Norm L2(imgs[2]) "+str(np.linalg.norm(imgs[2])))
print("Norm L2(count pixels) "+str(math.sqrt(np.size(imgs[2][imgs[2]==-1]))))
Expe parameters¶
Now we know the wasserstein distance between the black image and the images with a square on it. For both binary generator and ternary generator this distance is 35.
We will then compute this distance using a neural network.
batch_size=64
epochs=5
steps_per_epoch=6400
generator = ternary_generator #binary_generator, ternary_generator
activation = FullSort #ReLU, MaxMin, GroupSort
Build lipschitz Model¶
K.clear_session()
## please note that the previous helper function has the same behavior as the following code:
inputs = Input((img_size, img_size, 1))
x = Flatten()(inputs)
x = SpectralDense(128, activation=FullSort())(x)
x = SpectralDense(64, activation=FullSort())(x)
x = SpectralDense(32, activation=FullSort())(x)
y = FrobeniusDense(1, activation=None)(x)
wass = Model(inputs=inputs, outputs=y)
wass.summary()
optimizer = Adam(lr=0.01)
wass.compile(loss=HKR(alpha=0), optimizer=optimizer, metrics=[KR])
Learn on toy dataset¶
wass.fit_generator( generator(batch_size,(img_size,img_size,1),frac=frac_value),
steps_per_epoch=steps_per_epoch// batch_size,
epochs=epochs,verbose=1)
As we can see the loss converge to the value 35 which is the wasserstein distance between the two distributions (square and non-square).