Maximum Logit Score / Maximum Softmax Probability¶
This notebook aims at evaluating the MLS and MSP methods.
These methods return an OOD score based on the maximum value of the output logits or softmax activations.
Here, we focus on a toy convolutional network trained on MNIST[0-4] and a ResNet model trained on CIFAR-10, respectively challenged on MNIST[5-9] and SVHN OOD datasets.
References
Imports¶
%load_ext autoreload
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
from IPython.display import clear_output
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from oodeel.methods import MLS
from oodeel.eval.metrics import bench_metrics
from oodeel.eval.plots import plot_ood_scores, plot_roc_curve, plot_2D_features
from oodeel.datasets import OODDataset
from oodeel.utils.tf_training_tools import train_tf_model
Note that models are saved at ~/.oodeel/saved_models and data is supposed to be found at ~/.oodeel/datasets by default. Change the following cell for a custom path.
model_path = os.path.expanduser("~/") + ".oodeel/saved_models"
data_path = os.path.expanduser("~/") + ".oodeel/datasets"
os.makedirs(model_path, exist_ok=True)
os.makedirs(data_path, exist_ok=True)
First experiment: MNIST[0-4] vs MNIST[5-9]¶
For this first experiment, we train a toy convolutional network on the MNIST dataset restricted to digits 0 to 4. After fitting the train subset of this dataset to the MLS method, we will compare the scores returned for MNIST[0-4] (in-distribution) and MNIST[5-9] (out-of-distribution) test subsets.
Data loading¶
- In-distribution data: MNIST[0-4]
- Out-of-distribution data: MNIST[5-9]
Note: We denote In-Distribution (ID) data with
_in
and Out-Of-Distribution (OOD) data with_out
to avoid confusion with OOD detection which is the name of the task, and is therefore used to denote core classes such asOODDataset
andOODBaseDetector
.
# === Load ID and OOD data ===
batch_size = 128
in_labels = [0, 1, 2, 3, 4]
# 1- Load train/test MNIST dataset
ds_train = OODDataset("mnist", load_kwargs=dict(split="train"))
data_test = OODDataset("mnist", load_kwargs=dict(split="test"))
# 2- Split ID / OOD data depending on label value:
# in-distribution: MNIST[0-4] / out-of-distribution: MNIST[5-9]
ds_train, _ = ds_train.split_by_class(in_labels)
oods_in, oods_out = data_test.split_by_class(in_labels)
# 3- Prepare data (preprocess, shuffle, batch)
def preprocess_fn(*inputs):
x = inputs[0] / 255
return tuple([x] + list(inputs[1:]))
ds_train = ds_train.prepare(batch_size, preprocess_fn, shuffle=True)
ds_in = oods_in.prepare(batch_size, preprocess_fn, with_ood_labels=False)
ds_out = oods_out.prepare(batch_size, preprocess_fn, with_ood_labels=False)
clear_output()
Model training¶
Now let's train a simple model on MNIST[0-4] using train_tf_model
function.
# === Train / Load model ===
model_path_mnist_04 = os.path.join(model_path, "mnist_model_0-4.h5")
try:
# if the model exists, load it
model = tf.keras.models.load_model(model_path_mnist_04)
except OSError:
# else, train a new model
train_config = {
"model": "toy_convnet",
"input_shape": (28, 28, 1),
"num_classes": 10,
"batch_size": 128,
"epochs": 5,
"save_dir": model_path_mnist_04,
"validation_data": ds_in,
}
model = train_tf_model(ds_train, **train_config)
_, accuracy = model.evaluate(ds_in)
print(f"Test accuracy:\t{accuracy:.4f}")
# penultimate features 2d visualization
print("\n=== Penultimate features viz ===")
plt.figure(figsize=(4.5, 3))
plot_2D_features(
model=model,
in_dataset=ds_in,
out_dataset=ds_out,
output_layer_id=-2,
)
plt.tight_layout()
plt.show()
41/41 [==============================] - 2s 12ms/step - loss: 0.0118 - accuracy: 0.9965 Test accuracy: 0.9965 === Penultimate features viz ===
MLS score¶
We now fit a MLS OOD detector with MNIST[0-4] train dataset, and compare OOD scores returned for MNIST[0-4] (ID) and MNIST[5-9] (OOD) test datasets.
# === MLS scores ===
mls = MLS()
mls.fit(model)
scores_in, _ = mls.score(ds_in)
scores_out, _ = mls.score(ds_out)
# === metrics ===
# auroc / fpr95
metrics = bench_metrics(
(scores_in, scores_out),
metrics=["auroc", "fpr95tpr"],
)
print("=== Metrics ===")
for k, v in metrics.items():
print(f"{k:<10} {v:.6f}")
print("\n=== Plots ===")
# hists / roc
plt.figure(figsize=(9, 3))
plt.subplot(121)
plot_ood_scores(scores_in, scores_out, log_scale=False)
plt.subplot(122)
plot_roc_curve(scores_in, scores_out)
plt.tight_layout()
plt.show()
=== Metrics === auroc 0.934842 fpr95tpr 0.392878 === Plots ===
MSP score¶
Using the softmax activations instead, we get the MSP scores for MNIST[0-4] (ID) and MNIST[5-9] (OOD) test datasets.
# === MSP scores ===
msp = MLS(output_activation="softmax")
msp.fit(model)
scores_in, _ = msp.score(ds_in)
scores_out, _ = msp.score(ds_out)
# === metrics ===
# auroc / fpr95
metrics = bench_metrics(
(scores_in, scores_out),
metrics=["auroc", "fpr95tpr"],
)
print("=== Metrics ===")
for k, v in metrics.items():
print(f"{k:<10} {v:.6f}")
print("\n=== Plots ===")
# Normalize scores for a better hist visualization
minim = np.min([np.min(scores_in), np.min(scores_out)])
scores_in_ = scores_in - 2 * minim + np.min(scores_in[np.where(scores_in != minim)])
scores_out_ = scores_out - 2 * minim + np.min(scores_in[np.where(scores_in != minim)])
# hists / roc
plt.figure(figsize=(9, 3))
plt.subplot(121)
plot_ood_scores(scores_in_, scores_out_, log_scale=True)
plt.subplot(122)
plot_roc_curve(scores_in, scores_out)
plt.tight_layout()
plt.show()
=== Metrics === auroc 0.937479 fpr95tpr 0.416229 === Plots ===
Second experiment: CIFAR-10 vs SVHN¶
For this second experiment, we oppose CIFAR-10 (in-distribution dataset) to SVHN (out-of-distribution dataset).
Data loading¶
- In-distribution data: CIFAR-10
- Out-of-distribution data: SVHN
# === Load ID and OOD data ===
batch_size = 128
# 1a- Load in-distribution dataset: CIFAR-10
ds_in = OODDataset("cifar10", load_kwargs={"split": "test"}, input_key="image")
# 1b- Load out-of-distribution dataset: SVHN
ds_out = OODDataset("svhn_cropped", load_kwargs={"split": "test"})
# 2- prepare data (preprocess, shuffle, batch)
def preprocess_fn(*inputs):
x = inputs[0] / 255
return tuple([x] + list(inputs[1:]))
ds_in = ds_in.prepare(batch_size, preprocess_fn)
ds_out = ds_out.prepare(batch_size, preprocess_fn)
Model loading¶
The model is a ResNet pretrained on CIFAR-10 and getting an accuracy score of 92.75%.
# === Load model ===
# ResNet pretrained on CIFAR-10
model_path_resnet_cifar10 = tf.keras.utils.get_file(
"cifar10_resnet256.h5",
origin="https://share.deel.ai/s/kram9kLpx6JwRX4/download/cifar10_resnet256.h5",
cache_dir=model_path,
cache_subdir="",
)
model = tf.keras.models.load_model(model_path_resnet_cifar10)
# Evaluate model
model.compile(metrics=["accuracy"])
_, accuracy = model.evaluate(ds_in)
print(f"Test accuracy:\t{accuracy:.4f}")
# penultimate features 2d visualization
print("\n=== Penultimate features viz ===")
plt.figure(figsize=(4.5, 3))
plot_2D_features(
model=model,
in_dataset=ds_in,
out_dataset=ds_out,
output_layer_id=-2,
)
plt.tight_layout()
plt.show()
79/79 [==============================] - 3s 25ms/step - loss: 0.1268 - accuracy: 0.9276 Test accuracy: 0.9276 === Penultimate features viz ===
MLS score¶
We now fit a MLS OOD detector with CIFAR-10 train dataset, and compare OOD scores returned for CIFAR-10 (ID) and SVHN (OOD) test datasets.
# === MLS scores ===
mls = MLS()
mls.fit(model)
scores_in, _ = mls.score(ds_in)
scores_out, _ = mls.score(ds_out)
# === metrics ===
# auroc / fpr95
metrics = bench_metrics(
(scores_in, scores_out),
metrics=["auroc", "fpr95tpr"],
)
print("=== Metrics ===")
for k, v in metrics.items():
print(f"{k:<10} {v:.6f}")
print("\n=== Plots ===")
# hists / roc
plt.figure(figsize=(9, 3))
plt.subplot(121)
plot_ood_scores(scores_in, scores_out, log_scale=False)
plt.subplot(122)
plot_roc_curve(scores_in, scores_out)
plt.tight_layout()
plt.show()
=== Metrics === auroc 0.915858 fpr95tpr 0.205700 === Plots ===
MSP score¶
Using the softmax activations instead, we get the MSP scores for CIFAR-10 (ID) and SVHN (OOD) test datasets.
# === MSP scores ===
msp = MLS(output_activation="softmax")
msp.fit(model)
scores_in, _ = msp.score(ds_in)
scores_out, _ = msp.score(ds_out)
# === metrics ===
# auroc / fpr95
metrics = bench_metrics(
(scores_in, scores_out),
metrics=["auroc", "fpr95tpr"],
)
print("=== Metrics ===")
for k, v in metrics.items():
print(f"{k:<10} {v:.6f}")
print("\n=== Plots ===")
# Normalize scores for a better hist visualization
minim = np.min([np.min(scores_in), np.min(scores_out)])
scores_in_ = scores_in - 2 * minim + np.min(scores_in[np.where(scores_in != minim)])
scores_out_ = scores_out - 2 * minim + np.min(scores_in[np.where(scores_in != minim)])
# hists / roc
plt.figure(figsize=(9, 3))
plt.subplot(121)
plot_ood_scores(scores_in_, scores_out_, log_scale=True)
plt.subplot(122)
plot_roc_curve(scores_in, scores_out)
plt.tight_layout()
plt.show()
=== Metrics === auroc 0.932143 fpr95tpr 0.209000 === Plots ===