ASH method¶
This notebook aims at evaluating the ASH method.
ASH method basically consists in re-using existing logit-based OOD methods, but with penultimate layer activations scaled and pruned. Let $a$ be the activation vector, and $P_p(a)$ the $p$-th percentile of $a$'s values. The scaling is computed using the formula $$ s = \exp(\frac{\sum_{i} a_i}{\sum_{a_i > P_p(a)} a_i}) $$ The activation is pruned for values $a_i \leq P_p(a)$.
Here, we focus on a Resnet trained on CIFAR10, challenged on SVHN.
Reference
Extremely Simple Activation Shaping for Out-of-Distribution Detection, ICLR 2023
http://arxiv.org/abs/2209.09858
Imports¶
%load_ext autoreload
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
from IPython.display import clear_output
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from oodeel.methods import MLS, Energy, GEN, ODIN, Entropy
from oodeel.eval.metrics import bench_metrics
from oodeel.eval.plots import plot_ood_scores, plot_roc_curve, plot_2D_features
from oodeel.datasets import OODDataset
from oodeel.utils.tf_training_tools import train_tf_model
Note that models are saved at ~/.oodeel/saved_models and data is supposed to be found at ~/.oodeel/datasets by default. Change the following cell for a custom path.
model_path = os.path.expanduser("~/") + ".oodeel/saved_models"
data_path = os.path.expanduser("~/") + ".oodeel/datasets"
os.makedirs(model_path, exist_ok=True)
os.makedirs(data_path, exist_ok=True)
Data loading¶
- In-distribution data: CIFAR-10
- Out-of-distribution data: SVHN
Note: We denote In-Distribution (ID) data with
_in
and Out-Of-Distribution (OOD) data with_out
to avoid confusion with OOD detection which is the name of the task, and is therefore used to denote core classes such asOODDataset
andOODBaseDetector
.
# === Load ID and OOD data ===
batch_size = 128
# 1a- Load in-distribution dataset: CIFAR-10
ds_fit = OODDataset("cifar10", load_kwargs={"split": "train"}, input_key="image")
ds_in = OODDataset("cifar10", load_kwargs={"split": "test"}, input_key="image")
# 1b- Load out-of-distribution dataset: SVHN
ds_out = OODDataset("svhn_cropped", load_kwargs={"split": "test"})
# 2- prepare data (preprocess, shuffle, batch)
def preprocess_fn(*inputs):
x = inputs[0] / 255
return tuple([x] + list(inputs[1:]))
ds_fit = ds_fit.prepare(batch_size, preprocess_fn)
ds_in = ds_in.prepare(batch_size, preprocess_fn)
ds_out = ds_out.prepare(batch_size, preprocess_fn)
clear_output()
Model training¶
Now let's train a simple model on MNIST[0-4] using train_tf_model
function.
# === Load model ===
# ResNet pretrained on CIFAR-10
model_path_resnet_cifar10 = tf.keras.utils.get_file(
"cifar10_resnet256.h5",
origin="https://share.deel.ai/s/kram9kLpx6JwRX4/download/cifar10_resnet256.h5",
cache_dir=model_path,
cache_subdir="",
)
model = tf.keras.models.load_model(model_path_resnet_cifar10)
# Evaluate model
model.compile(metrics=["accuracy"])
_, accuracy = model.evaluate(ds_in)
print(f"Test accuracy:\t{accuracy:.4f}")
# penultimate features 2d visualization
print("\n=== Penultimate features viz ===")
plt.figure(figsize=(4.5, 3))
plot_2D_features(
model=model,
in_dataset=ds_in,
out_dataset=ds_out,
output_layer_id=-2,
)
plt.tight_layout()
plt.show()
79/79 [==============================] - 4s 14ms/step - loss: 0.1268 - accuracy: 0.9278 Test accuracy: 0.9278 === Penultimate features viz ===
ASH scores¶
We now fit some OOD detectors using ASH + [MLS, Energy, ODIN] with MNIST[0-4] train dataset, and compare OOD scores returned for MNIST[0-4] (ID) and MNIST[5-9] (OOD) test datasets.
%autoreload 2
detectors = {
"odin": {
"class": ODIN,
"kwargs": dict(temperature=1000),
},
"mls": {
"class": MLS,
"kwargs": dict(),
},
"energy": {
"class": Energy,
"kwargs": dict(),
},
}
for d in detectors.keys():
print(f"=== {d.upper()} ===")
for use_ash in [False, True]:
print(["~ Without", "~ With"][int(use_ash)] + " ASH ~")
# === ood scores ===
d_kwargs = detectors[d]["kwargs"]
d_kwargs.update(dict(
use_ash=use_ash,
ash_percentile=0.90,
))
detector = detectors[d]["class"](**d_kwargs)
detector.fit(model)
scores_in, _ = detector.score(ds_in)
scores_out, _ = detector.score(ds_out)
# === metrics ===
# auroc / fpr95
metrics = bench_metrics(
(scores_in, scores_out),
metrics=["auroc", "fpr95tpr"],
)
for k, v in metrics.items():
print(f"{k:<10} {v:.6f}")
# hists / roc
plt.figure(figsize=(9, 3))
plt.subplot(121)
plot_ood_scores(scores_in, scores_out)
plt.subplot(122)
plot_roc_curve(scores_in, scores_out)
plt.tight_layout()
plt.show()
=== ODIN === ~ Without ASH ~ auroc 0.950493 fpr95tpr 0.151200
~ With ASH ~ auroc 0.976817 fpr95tpr 0.087200
=== MLS === ~ Without ASH ~ auroc 0.915857 fpr95tpr 0.205700
~ With ASH ~ auroc 0.982385 fpr95tpr 0.084700
=== ENERGY === ~ Without ASH ~ auroc 0.909186 fpr95tpr 0.206900
~ With ASH ~ auroc 0.982398 fpr95tpr 0.084600