MMDCritic¶
View colab tutorial | View source | 📰 Paper
MMDCritic
finds prototypes and criticisms by maximizing two separate objectives based on the Maximum Mean Discrepancy (MMD).
Quote
MMD-critic uses the MMD statistic as a measure of similarity between points and potential prototypes, and efficiently selects prototypes that maximize the statistic. In addition to prototypes, MMD-critic selects criticism samples i.e. samples that are not well-explained by the prototypes using a regularized witness function score.
-- Efficient Data Representation by Selecting Prototypes with Importance Weights (2019).
First, to find prototypes \(\mathcal{P}\), a greedy algorithm is used to maximize \(F(\mathcal{P})\) s.t. \(|\mathcal{P}| \le m_p\) where \(F(\mathcal{P})\) is defined as: \begin{equation} F(\mathcal{P})=\frac{2}{|\mathcal{P}|\cdot n}\sum_{i,j=1}^{|\mathcal{P}|,n}\kappa(p_i,x_j)-\frac{1}{|\mathcal{P}|^2}\sum_{i,j=1}^{|\mathcal{P}|}\kappa(p_i,p_j), \end{equation} where \(m_p\) the number of prototypes to be found. They used diagonal dominance conditions on the kernel to ensure monotonocity and submodularity of \(F(\mathcal{P})\).
Second, to find criticisms \(\mathcal{C}\), the same greedy algorithm is used to select points that maximize another objective function \(J(\mathcal{C})\).
Warning
For MMDCritic
, the kernel must satisfy a condition that ensures the submodularity of the set function. The Gaussian kernel meets this requirement and it is recommended. If you wish to choose a different kernel, it must satisfy the condition described by Kim et al., 2016.
Example¶
from xplique.example_based import MMDCritic
from xplique.example_based.projections import LatentSpaceProjection
# load the training dataset and the model
cases_dataset = ... # load the training dataset
model = ...
# load the test samples
test_samples = ... # load the test samples to search for
# parameters
case_returns = "all" # elements returned by the explain function
latent_layer = "last_conv" # where to split your model for the projection
nb_global_prototypes = 5
nb_local_prototypes = 1
kernel_fn = None # the default rbf kernel will be used, the distance will be based on this
# construct a projection with your model
projection = LatentSpaceProjection(model, latent_layer=latent_layer)
mmd = MMDCritic(
cases_dataset=cases_dataset,
nb_global_prototypes=nb_global_prototypes,
nb_local_prototypes=nb_local_prototypes,
projection=projection,
case_returns=case_returns,
)
# compute global explanation
global_prototypes = mmd.get_global_prototypes()
# compute local explanation
local_prototypes = mmd.explain(test_samples)
Notebooks¶
MMDCritic
¶
__init__(self,
cases_dataset: ~DatasetOrTensor,
labels_dataset: Optional[~DatasetOrTensor] = None,
targets_dataset: Optional[~DatasetOrTensor] = None,
nb_global_prototypes: int = 1,
nb_local_prototypes: int = 1,
projection: Union[xplique.example_based.projections.base.Projection, Callable] = None,
case_returns: Union[List[str], str] = 'examples',
batch_size: Optional[int] = None,
distance: Union[int, str, Callable, None] = None,
kernel_fn: = None ,
gamma: float = None)
¶
cases_dataset: ~DatasetOrTensor,
labels_dataset: Optional[~DatasetOrTensor] = None,
targets_dataset: Optional[~DatasetOrTensor] = None,
nb_global_prototypes: int = 1,
nb_local_prototypes: int = 1,
projection: Union[xplique.example_based.projections.base.Projection, Callable] = None,
case_returns: Union[List[str], str] = 'examples',
batch_size: Optional[int] = None,
distance: Union[int, str, Callable, None] = None,
kernel_fn:
gamma: float = None)
explain(self,
inputs: Union[tensorflow.python.framework.tensor.Tensor, numpy.ndarray],
targets: Union[tensorflow.python.framework.tensor.Tensor, numpy.ndarray, None] = None)
¶
inputs: Union[tensorflow.python.framework.tensor.Tensor, numpy.ndarray],
targets: Union[tensorflow.python.framework.tensor.Tensor, numpy.ndarray, None] = None)
Return the relevant examples to explain the (inputs, targets).
It projects inputs with self.projection
in the search space
and find examples with the self.search_method
.
Parameters
-
inputs : Union[tensorflow.python.framework.tensor.Tensor, numpy.ndarray]
Tensor or Array. Input samples to be explained.
Expected shape among (N, W), (N, T, W), (N, W, H, C).
More information in the documentation.
-
targets : Union[tensorflow.python.framework.tensor.Tensor, numpy.ndarray, None] = None
Targets associated to the
inputs
for projection.Shape: (n, nb_classes) where n is the number of samples and nb_classes is the number of classes.
It is used in the
projection
. Butprojection
can compute it internally.
Return
-
return_dict
Dictionary with listed elements in
self.returns
.The elements that can be returned are defined with the
_returns_possibilities
static attribute of the class.
format_search_output(self,
search_output: Dict[str, tensorflow.python.framework.tensor.Tensor],
inputs: Union[tensorflow.python.framework.tensor.Tensor, numpy.ndarray])
¶
search_output: Dict[str, tensorflow.python.framework.tensor.Tensor],
inputs: Union[tensorflow.python.framework.tensor.Tensor, numpy.ndarray])
Format the output of the search_method
to match the expected returns in self.returns
.
Parameters
-
search_output : Dict[str, tensorflow.python.framework.tensor.Tensor]
Dictionary with the required outputs from the
search_method
.
-
inputs : Union[tensorflow.python.framework.tensor.Tensor, numpy.ndarray]
Tensor or Array. Input samples to be explained.
Expected shape among (N, W), (N, T, W), (N, W, H, C).
Return
-
return_dict
Dictionary with listed elements in
self.returns
.The elements that can be returned are defined with the
_returns_possibilities
static attribute of the class.
get_global_prototypes(self) -> Dict[str, tensorflow.python.framework.tensor.Tensor]
¶
Provide the global prototypes computed at the initialization.
Prototypes and their labels are extracted from the indices.
The weights of the prototypes and their indices are also returned.
Return
-
prototypes_dict : Dict[str, tf.Tensor]
A dictionary with the following - 'prototypes': The prototypes found by the method.
- 'prototype_labels': The labels of the prototypes.
- 'prototype_weights': The weights of the prototypes.
- 'prototype_indices': The indices of the prototypes.