Generalized Entropy¶
This notebook aims at evaluating the GEN method (Generalized ENtropy).
This method consists in computing a generalized entropy score based on the softmax probabilities. Considering the softmax output values $p_i$ (one per class), the OOD score is defined as
$$ S(p) = \sum \_{j=1}^k p_i^\gamma (1-p_i)^\gamma. $$The two parameters the method are:
- $\gamma$, corresponding to the order of the generalized entropy form, between 0 and 1. The authors of the original paper propose to set $\gamma=0.1$.
- $k$, corresponding to the top-k largest softmax values to keep in the entropy computation. Removing the smallest values makes the method more robust to small variations.
Here, we focus on a toy convolutional network trained on MNIST[0-4] and a ResNet20 model trained on CIFAR-10, respectively challenged on MNIST[5-9] and SVHN OOD datasets.
GEN: Pushing the Limits of Softmax-Based Out-of-Distribution Detection, CVPR 2023
%load_ext autoreload
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
from IPython.display import clear_output
import tensorflow as tf
import matplotlib.pyplot as plt
from oodeel.methods import GEN
from oodeel.eval.metrics import bench_metrics
from oodeel.eval.plots import plot_ood_scores, plot_roc_curve, plot_2D_features
from oodeel.datasets import OODDataset
from oodeel.utils.tf_training_tools import train_tf_model
/home/corentin.friedrich/dev/oodeel/venv/lib/python3.9/site-packages/tqdm/ TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See from .autonotebook import tqdm as notebook_tqdm
Note that models are saved at ~/.oodeel/saved_models and data is supposed to be found at ~/.oodeel/datasets by default. Change the following cell for a custom path.
model_path = os.path.expanduser("~/") + ".oodeel/saved_models"
data_path = os.path.expanduser("~/") + ".oodeel/datasets"
os.makedirs(model_path, exist_ok=True)
os.makedirs(data_path, exist_ok=True)
First experiment: MNIST[0-4] vs MNIST[5-9]¶
For this first experiment, we train a toy convolutional network on the MNIST dataset restricted to digits 0 to 4.
Data loading¶
- In-distribution data: MNIST[0-4]
- Out-of-distribution data: MNIST[5-9]
Note: We denote In-Distribution (ID) data with
and Out-Of-Distribution (OOD) data with_out
to avoid confusion with OOD detection which is the name of the task, and is therefore used to denote core classes such asOODDataset
# === Load ID and OOD data ===
batch_size = 128
in_labels = [0, 1, 2, 3, 4]
# 1- Load train/test MNIST dataset
ds_train = OODDataset("mnist", load_kwargs=dict(split="train"))
data_test = OODDataset("mnist", load_kwargs=dict(split="test"))
# 2- Split ID / OOD data depending on label value:
# in-distribution: MNIST[0-4] / out-of-distribution: MNIST[5-9]
ds_train, _ = ds_train.split_by_class(in_labels)
oods_in, oods_out = data_test.split_by_class(in_labels)
# 3- Prepare data (preprocess, shuffle, batch)
def preprocess_fn(*inputs):
x = inputs[0] / 255
return tuple([x] + list(inputs[1:]))
ds_train = ds_train.prepare(batch_size, preprocess_fn, shuffle=True)
ds_in = oods_in.prepare(batch_size, preprocess_fn, with_ood_labels=False)
ds_out = oods_out.prepare(batch_size, preprocess_fn, with_ood_labels=False)
Model training¶
Now let's train a simple model on MNIST[0-4] using train_tf_model
# === Train / Load model ===
model_path_mnist_04 = os.path.join(model_path, "mnist_model_0-4.h5")
# if the model exists, load it
model = tf.keras.models.load_model(model_path_mnist_04)
except OSError:
# else, train a new model
train_config = {
"model": "toy_convnet",
"input_shape": (28, 28, 1),
"num_classes": 10,
"batch_size": 128,
"epochs": 5,
"save_dir": model_path_mnist_04,
"validation_data": ds_in,
model = train_tf_model(ds_train, **train_config)
_, accuracy = model.evaluate(ds_in)
print(f"Test accuracy:\t{accuracy:.4f}")
# penultimate features 2d visualization
print("\n=== Penultimate features viz ===")
plt.figure(figsize=(4.5, 3))
41/41 [==============================] - 1s 25ms/step - loss: 0.0059 - accuracy: 0.9979 Test accuracy: 0.9979 === Penultimate features viz ===
GEN score¶
We now fit a GEN OOD detector with MNIST[0-4] train dataset, and compare OOD scores returned for MNIST[0-4] (ID) and MNIST[5-9] (OOD) test datasets.
# === GEN scores ===
gen = GEN()
scores_in, _ = gen.score(ds_in)
scores_out, _ = gen.score(ds_out)
# === metrics ===
# auroc / fpr95
metrics = bench_metrics(
(scores_in, scores_out),
metrics=["auroc", "fpr95tpr"],
print("=== Metrics ===")
for k, v in metrics.items():
print(f"{k:<10} {v:.6f}")
print("\n=== Plots ===")
# hists / roc
plt.figure(figsize=(9, 3))
plot_ood_scores(scores_in, scores_out, log_scale=False)
plot_roc_curve(scores_in, scores_out)
=== Metrics === auroc 0.914862 fpr95tpr 0.529675 === Plots ===
Second experiment: CIFAR-10 vs SVHN¶
For this second experiment, we oppose CIFAR-10 (in-distribution dataset) to SVHN (out-of-distribution dataset).
Data loading¶
- In-distribution data: CIFAR-10
- Out-of-distribution data: SVHN
# === Load ID and OOD data ===
batch_size = 128
# 1a- Load in-distribution dataset: CIFAR-10
ds_in = OODDataset("cifar10", load_kwargs={"split": "test"}, input_key="image")
# 1b- Load out-of-distribution dataset: SVHN
ds_out = OODDataset("svhn_cropped", load_kwargs={"split": "test"})
# 2- prepare data (preprocess, shuffle, batch)
def preprocess_fn(*inputs):
x = inputs[0] / 255
return tuple([x] + list(inputs[1:]))
ds_in = ds_in.prepare(batch_size, preprocess_fn)
ds_out = ds_out.prepare(batch_size, preprocess_fn)
Model loading¶
The model is a ResNet pretrained on CIFAR-10 and getting an accuracy score of 92.75%.
# === Load model ===
# ResNet pretrained on CIFAR-10
model_path_resnet_cifar10 = tf.keras.utils.get_file(
model = tf.keras.models.load_model(model_path_resnet_cifar10)
# Evaluate model
_, accuracy = model.evaluate(ds_in)
print(f"Test accuracy:\t{accuracy:.4f}")
# penultimate features 2d visualization
print("\n=== Penultimate features viz ===")
plt.figure(figsize=(4.5, 3))
79/79 [==============================] - 2s 9ms/step - loss: 0.1268 - accuracy: 0.9275 Test accuracy: 0.9275 === Penultimate features viz ===
GEN score¶
We now fit a GEN OOD detector with CIFAR-10 train dataset, and compare OOD scores returned for CIFAR-10 (ID) and SVHN (OOD) test datasets.
# === GEN scores ===
gen = GEN()
scores_in, _ = gen.score(ds_in)
scores_out, _ = gen.score(ds_out)
# === metrics ===
# auroc / fpr95
metrics = bench_metrics(
(scores_in, scores_out),
metrics=["auroc", "fpr95tpr"],
print("=== Metrics ===")
for k, v in metrics.items():
print(f"{k:<10} {v:.6f}")
print("\n=== Plots ===")
# hists / roc
plt.figure(figsize=(9, 3))
plot_ood_scores(scores_in, scores_out, log_scale=False)
plot_roc_curve(scores_in, scores_out)
=== Metrics === auroc 0.969413 fpr95tpr 0.133700 === Plots ===