Hsic Attribution Method¶
View colab tutorial | View source | 📰 Paper
The Hsic attribution method from Novello, Fel, Vigouroux1 explains a neural network's prediction for a given input image by assessing the dependence between the output and patches of the input. Thanks to the sample efficiency of HSIC Estimator, this black box method requires fewer forward passes to produce relevant explanations.
Let's consider two random variables which are the perturbation associated with each patch of the input image, \(X_i, i \in \{1,...d\}\) with \(d= \text{grid_size}^2\) image patches and the output \(Y\). Let \(X^1_i,...,X^p_i\) and \(Y^1,...,Y^p\) be \(p\) samples of \(X_i\) and \(Y\). HSIC attribution method requires selecting a kernel for the input and the output to construct an RKHS on which is computed the Maximum Mean Discrepancy, a dissimilarity metric between distributions. Let \(k:\mathbb{R}^2 \rightarrow \mathbb{R}\) and \(l:\mathbb{R}^2 \rightarrow \mathbb{R}\) the kernels selected for \(X_i\) and \(Y\), HSIC is estimated with an error \(\mathcal{O}(1/\sqrt{p})\) using the estimator $$ \mathcal{H}^p_{X_i, Y} = \frac{1}{(p-1)^2} \operatorname{tr} (KHLH), $$ where \(H, L, K \in \mathbb{R}^{p \times p}\) and \(K_{ij} = k(x_i, x_j), L_{i,j} = l(y_i, y_j)\) and \(H_{ij} = \delta(i=j) - p^{-1}\) where \(\delta(i=j) = 1\) if \(i=j\) and \(0\) otherwise.
In the paper Making Sense of Dependence: Efficient Black-box Explanations Using Dependence Measure, the sampler LatinHypercube
is used to sample the perturbations. Note however that the present implementation uses TFSobolSequence
as default sampler because LatinHypercube
requires scipy \(\geq\) 1.7.0
. you can nevertheless use this sampler -- which is included in the library -- by specifying it during the init of your explainer.
For the kernel \(k\) applied on \(X_i\), a modified Dirac kernel is used to enable an ANOVA-like decomposition property that allows assessing pairwise patch interactions (see the paper for more details). For the kernel \(l\) of output \(Y\), a Radial Basis Function (RBF) is used.
Tip
We recommend using a grid size of \(7 \times 7\) to define the image patches. The paper uses a number of forwards of \(1500\) to obtain the most faithful explanations and \(750\) for a more budget - but still faithful - version.
Info
To explain small objects in images, it may be necessary to increase the grid_size
, which also requires an increase in nb_design
. However, increasing both may impact the memory usage and result in out of memory errors, hence, setting estimator_batch_size
parameter enables a limited usage of the memory. Note that the classical batch_size
correspond to the batch_size used in the model call, here estimator_batch_size
is intern to the method estimator.
Example¶
Low budget version
from xplique.attributions import HsicAttributionMethod
# load images, labels and model
# ...
explainer = HsicAttributionMethod(model, grid_size=7, nb_design=750)
explanations = explainer(images, labels)
High budget version
from xplique.attributions import HsicAttributionMethod
# load images, labels and model
# ...
explainer = HsicAttributionMethod(model, grid_size=7, nb_design=1500)
explanations = explainer(images, labels)
Recommended version, (you need scipy \(\geq\) 1.7.0
)
from xplique.attributions import HsicAttributionMethod
from xplique.attributions.global_sensitivity_analysis import LatinHypercube
# load images, labels and model
# ...
explainer = HsicAttributionMethod(model,
grid_size=7, nb_design=1500,
sampler = LatinHypercube(binary=True))
explanations = explainer(images, labels)
Notebooks¶
HsicAttributionMethod
¶
HSIC Attribution Method.
Compute the dependance of each input dimension wrt the output using Hilbert-Schmidt Independance
Criterion, a perturbation function on a grid and an adapted sampling as described in
the original paper.
__init__(self,
model,
grid_size: int = 8,
nb_design: int = 500,
sampler: Optional[xplique.attributions.global_sensitivity_analysis.samplers.Sampler] = None,
estimator: Optional[xplique.attributions.global_sensitivity_analysis.hsic_estimators.HsicEstimator] = None,
perturbation_function: Union[Callable, str, None] = 'inpainting',
batch_size: int = 256,
estimator_batch_size: int = None,
operator: Union[xplique.commons.operators_operations.Tasks, str,
Callable[[keras.src.engine.training.Model, tensorflow.python.framework.tensor.Tensor, tensorflow.python.framework.tensor.Tensor], float], None] = None)
¶
model,
grid_size: int = 8,
nb_design: int = 500,
sampler: Optional[xplique.attributions.global_sensitivity_analysis.samplers.Sampler] = None,
estimator: Optional[xplique.attributions.global_sensitivity_analysis.hsic_estimators.HsicEstimator] = None,
perturbation_function: Union[Callable, str, None] = 'inpainting',
batch_size: int = 256,
estimator_batch_size: int = None,
operator: Union[xplique.commons.operators_operations.Tasks, str,
Callable[[keras.src.engine.training.Model, tensorflow.python.framework.tensor.Tensor, tensorflow.python.framework.tensor.Tensor], float], None] = None)
Parameters
-
model : model
Model used for computing explanations.
-
grid_size : int = 8
Cut the image in a grid of (grid_size, grid_size) to estimate an indice per cell.
-
nb_design : int = 500
Number of design for the sampler.
-
sampler : Optional[xplique.attributions.global_sensitivity_analysis.samplers.Sampler] = None
Sampler used to generate the (quasi-)monte carlo samples, LHS or QMC.
For more option, see the sampler module. Note that the original paper uses LHS but here the default sampler is TFSobolSequence as LHS requires scipy 1.7.0.
-
estimator : Optional[xplique.attributions.global_sensitivity_analysis.hsic_estimators.HsicEstimator] = None
Estimator used to compute the HSIC score.
-
perturbation_function : Union[Callable, str, None] = 'inpainting'
Function to call to apply the perturbation on the input. Can also be string: 'inpainting', 'blurring', or 'amplitude'.
-
batch_size : int = 256
Batch size to use for the forwards.
-
estimator_batch_size : int = None
Batch size to use in the estimator. It should only be set if HSIC exceeds the memory.
By default, a tensor of
grid_size
² *nb_design
² is created.estimator_batch_size
is used over thenb_design
² dimension.
-
operator : Union[xplique.commons.operators_operations.Tasks, str, Callable[[keras.src.engine.training.Model, tensorflow.python.framework.tensor.Tensor, tensorflow.python.framework.tensor.Tensor], float], None] = None
Function g to explain, g take 3 parameters (f, x, y) and should return a scalar, with f the model, x the inputs and y the targets. If None, use the standard operator g(f, x, y) = f(x)[y].
explain(self,
inputs: Union[tf.Dataset, tensorflow.python.framework.tensor.Tensor, numpy.ndarray],
targets: Union[tensorflow.python.framework.tensor.Tensor, numpy.ndarray, None] = None) -> tensorflow.python.framework.tensor.Tensor
¶
inputs: Union[tf.Dataset, tensorflow.python.framework.tensor.Tensor, numpy.ndarray],
targets: Union[tensorflow.python.framework.tensor.Tensor, numpy.ndarray, None] = None) -> tensorflow.python.framework.tensor.Tensor
Compute the total Sobol' indices according to the explainer parameter (perturbation
function, grid size...). Accept Tensor, numpy array or tf.data.Dataset (in that case
targets is None).
Parameters
-
inputs : Union[tf.Dataset, tensorflow.python.framework.tensor.Tensor, numpy.ndarray]
Images to be explained, either tf.dataset, Tensor or numpy array.
If Dataset, targets should not be provided (included in Dataset).
Expected shape (N, W, H, C) or (N, W, H).
-
targets : Union[tensorflow.python.framework.tensor.Tensor, numpy.ndarray, None] = None
One-hot encoding for classification or direction {-1, +1} for regression.
Tensor or numpy array.
Expected shape (N, C) or (N).
Return
-
attributions_maps : tensorflow.python.framework.tensor.Tensor
GSA Attribution Method explanations, same shape as the inputs except for the channels.