Gram method¶
This notebook aims at evaluating the gram method.
The method consists in computing the Gram matrices of several layers to evaluate the channel-wise correlation of internal features. A score is built by computing these correlations for a new input and evaluating their deviation from regular correlations values computed on the in-distribution dataset.
Important Disclaimer: Taking the statistics of min/max deviation, as in the paper raises some problems.
The method often yields a score of zero for some tasks. This is expected since the min/max among the samples of a random variable becomes more and more extreme with the sample size. As a result, computing the min/max over the training set is likely to produce min/max values that are so extreme that none of the in distribution correlations of the validation set goes beyond these threshold. The worst is that a significant part of ood data does not exceed the thresholds either. This can be aleviated by computing the min/max over a limited number of sample. However, it is counter-intuitive and, in our opinion, not desirable: adding some more information should only improve a method.
Hence, we decided to replace the min/max by the q / 1-q quantile, with q a new parameter of the method. Specifically, instead of the deviation as defined in eq. 3 of the paper, we use the definition $$ \delta(t_q, t_{1-q}, value) = \begin{cases} 0 & \text{if} \; t_q \leq value \leq t_{1-q} \\ \frac{t_q - value}{|t_q|} & \text{if } value < t_q, \\ \frac{value - t_{1-q}}{|t_q|} & \text{if } value > t_{1-q} \end{cases} $$ With this new deviation, the more point we add, the more accurate the quantile becomes. In addition, the method can be made more or less discriminative by toggling the value of q.
Finally, we found that this approach improved the performance of the baseline in our experiments.
Here, we focus on a toy convolutional network trained on MNIST[0-4] and a ResNet model trained on CIFAR-10, respectively challenged on MNIST[5-9] and SVHN OOD datasets.
Reference Detecting Out-of-Distribution Examples with Gram Matrices, ICML 2020
Imports¶
%load_ext autoreload
%autoreload 2
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
from IPython.display import clear_output
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
from oodeel.methods import Gram
from oodeel.eval.metrics import bench_metrics
from oodeel.eval.plots import plot_ood_scores, plot_roc_curve, plot_2D_features
from oodeel.datasets import OODDataset
from oodeel.utils.tf_training_tools import train_tf_model
Note that models are saved at ~/.oodeel/saved_models and data is supposed to be found at ~/.oodeel/datasets by default. Change the following cell for a custom path.
model_path = os.path.expanduser("~/") + ".oodeel/saved_models"
data_path = os.path.expanduser("~/") + ".oodeel/datasets"
os.makedirs(model_path, exist_ok=True)
os.makedirs(data_path, exist_ok=True)
First experiment: MNIST[0-4] vs MNIST[5-9]¶
For this first experiment, we train a toy convolutional network on the MNIST dataset restricted to digits 0 to 4. After fitting the train subset of this dataset to the gram method, we will compare the scores returned for MNIST[0-4] (in-distribution) and MNIST[5-9] (out-of-distribution) test subsets.
Data loading¶
- In-distribution data: MNIST[0-4]
- Out-of-distribution data: MNIST[5-9]
Note: We denote In-Distribution (ID) data with
_in
and Out-Of-Distribution (OOD) data with_out
to avoid confusion with OOD detection which is the name of the task, and is therefore used to denote core classes such asOODDataset
andOODBaseDetector
.
# === Load ID and OOD data ===
batch_size = 128
in_labels = [0, 1, 2, 3, 4]
# 1- Load train/test MNIST dataset
ds_train = OODDataset("mnist", load_kwargs=dict(split="train"))
data_test = OODDataset("mnist", load_kwargs=dict(split="test"))
# 2- Split ID / OOD data depending on label value:
# in-distribution: MNIST[0-4] / out-of-distribution: MNIST[5-9]
ds_train, _ = ds_train.split_by_class(in_labels)
oods_in, oods_out = data_test.split_by_class(in_labels)
# 3- Prepare data (preprocess, shuffle, batch)
def preprocess_fn(*inputs):
x = inputs[0] / 255
return tuple([x] + list(inputs[1:]))
ds_train = ds_train.prepare(batch_size, preprocess_fn, shuffle=True)
ds_in = oods_in.prepare(batch_size, preprocess_fn, with_ood_labels=False)
ds_out = oods_out.prepare(batch_size, preprocess_fn, with_ood_labels=False)
clear_output()
Model training¶
Now let's train a simple model on MNIST[0-4] on a simple MLP using train_tf_model
function.
# === Train / Load model ===
model_path_mnist_04 = os.path.join(model_path, "mnist_mlp_0-4.h5")
try:
# if the model exists, load it
model = tf.keras.models.load_model(model_path_mnist_04)
except OSError:
# else, train a new model
train_config = {
"model": "toy_mlp",
"input_shape": (28, 28, 1),
"num_classes": 10,
"batch_size": 128,
"epochs": 5,
"save_dir": model_path_mnist_04,
"validation_data": ds_in,
}
model = train_tf_model(ds_train, **train_config)
Epoch 1/5 240/240 [==============================] - 2s 6ms/step - loss: 0.1480 - accuracy: 0.9572 - val_loss: 0.0459 - val_accuracy: 0.9879 Epoch 2/5 240/240 [==============================] - 1s 2ms/step - loss: 0.0437 - accuracy: 0.9871 - val_loss: 0.0310 - val_accuracy: 0.9895 Epoch 3/5 240/240 [==============================] - 1s 2ms/step - loss: 0.0301 - accuracy: 0.9904 - val_loss: 0.0239 - val_accuracy: 0.9918 Epoch 4/5 240/240 [==============================] - 1s 2ms/step - loss: 0.0156 - accuracy: 0.9957 - val_loss: 0.0239 - val_accuracy: 0.9930 Epoch 5/5 240/240 [==============================] - 1s 2ms/step - loss: 0.0105 - accuracy: 0.9970 - val_loss: 0.0188 - val_accuracy: 0.9928
_, accuracy = model.evaluate(ds_in)
print(f"Test accuracy:\t{accuracy:.4f}")
# penultimate features 2d visualization
print("\n=== Penultimate features viz ===")
plt.figure(figsize=(4.5, 3))
plot_2D_features(
model=model,
in_dataset=ds_in,
out_dataset=ds_out,
output_layer_id=-2,
)
plt.tight_layout()
plt.show()
41/41 [==============================] - 0s 1ms/step - loss: 0.0239 - accuracy: 0.9930 Test accuracy: 0.9930 === Penultimate features viz ===
gram score¶
We now fit a gram detector with MNIST[0-4] train dataset, and compare OOD scores returned for MNIST[0-4] (ID) and MNIST[5-9] (OOD) test datasets.
# === gram scores ===
gram = Gram(quantile=0.2)
gram.fit(model, ds_train, feature_layers_id=["dense", "dense_1"])
scores_in, _ = gram.score(ds_in)
scores_out, _ = gram.score(ds_out)
# Since many scores are equal to 0, we add a random noise to avoid bugs
# in Auroc and TPR computation.
scores_in += np.random.random_sample(size=scores_in.shape) * 10e-6
scores_out += np.random.random_sample(size=scores_out.shape) * 10e-6
# === metrics ===
# auroc / fpr95
metrics = bench_metrics(
(scores_in, scores_out),
metrics=["auroc", "fpr95tpr"],
)
print("=== Metrics ===")
for k, v in metrics.items():
print(f"{k:<10} {v:.6f}")
=== Metrics === auroc 0.889509 fpr95tpr 0.412921
print("\n=== Plots ===")
# hists / roc
plt.figure(figsize=(9, 3))
plt.subplot(121)
plot_ood_scores(scores_in, scores_out, log_scale=True)
plt.subplot(122)
plot_roc_curve(scores_in, scores_out)
plt.tight_layout()
plt.show()
=== Plots ===
Second experiment: CIFAR-10 vs SVHN¶
For this second experiment, we oppose CIFAR-10 (in-distribution dataset) to SVHN (out-of-distribution dataset).
Data loading¶
- In-distribution data: CIFAR-10
- Out-of-distribution data: SVHN
# === Load ID and OOD data ===
batch_size = 128
# 1a- Load in-distribution dataset: CIFAR-10
ds_fit = OODDataset("cifar10", load_kwargs={"split": "train"}, input_key="image")
ds_in = OODDataset("cifar10", load_kwargs={"split": "test"}, input_key="image")
# 1b- Load out-of-distribution dataset: SVHN
ds_out = OODDataset("svhn_cropped", load_kwargs={"split": "test"})
# 2- prepare data (preprocess, shuffle, batch)
def preprocess_fn(*inputs):
x = inputs[0] / 255
return tuple([x] + list(inputs[1:]))
ds_fit = ds_fit.prepare(batch_size, preprocess_fn)
ds_in = ds_in.prepare(batch_size, preprocess_fn)
ds_out = ds_out.prepare(batch_size, preprocess_fn)
clear_output()
Model loading¶
The model is a ResNet pretrained on CIFAR-10 and getting an accuracy score of 92.75%.
# === Load model ===
# ResNet pretrained on CIFAR-10
model_path_resnet_cifar10 = tf.keras.utils.get_file(
"cifar10_resnet256.h5",
origin="https://share.deel.ai/s/kram9kLpx6JwRX4/download/cifar10_resnet256.h5",
cache_dir=model_path,
cache_subdir="",
)
model = tf.keras.models.load_model(model_path_resnet_cifar10)
# Evaluate model
model.compile(metrics=["accuracy"])
_, accuracy = model.evaluate(ds_in)
print(f"Test accuracy:\t{accuracy:.4f}")
# penultimate features 2d visualization
print("\n=== Penultimate features viz ===")
plt.figure(figsize=(4.5, 3))
plot_2D_features(
model=model,
in_dataset=ds_in,
out_dataset=ds_out,
output_layer_id=-2,
)
plt.tight_layout()
plt.show()
79/79 [==============================] - 3s 14ms/step - loss: 0.1268 - accuracy: 0.9276 Test accuracy: 0.9276 === Penultimate features viz ===
gram score¶
We now fit a gram detector with CIFAR-10 train dataset, and compare OOD scores returned for CIFAR-10 (ID) and SVHN (OOD) test datasets.
# === gram scores ===
%autoreload 2
gram = Gram(orders = [i for i in range(1, 11)], quantile=0.01)
gram.fit(
model,
ds_fit,
feature_layers_id=['conv2d_18', 'activation_17', 'conv2d_37', 'activation_35', 'conv2d_56', 'activation_53']
)
scores_in, _ = gram.score(ds_in)
scores_out, _ = gram.score(ds_out)
scores_in += np.random.random_sample(size=scores_in.shape) * 10e-6
scores_out += np.random.random_sample(size=scores_out.shape) * 10e-6
# === metrics ===
# auroc / fpr95
metrics = bench_metrics(
(scores_in, scores_out),
metrics=["auroc", "fpr95tpr"],
)
print("=== Metrics ===")
for k, v in metrics.items():
print(f"{k:<10} {v:.6f}")
=== Metrics === auroc 0.982648 fpr95tpr 0.067800
print("\n=== Plots ===")
# hists / roc
plt.figure(figsize=(9, 3))
plt.subplot(121)
plot_ood_scores(scores_in, scores_out, log_scale=True)
plt.subplot(122)
plot_roc_curve(scores_in, scores_out)
plt.tight_layout()
plt.show()
=== Plots ===