MMDCritic¶
    
View colab tutorial |
    
View source |
📰 Paper
MMDCritic finds prototypes and criticisms by maximizing two separate objectives based on the Maximum Mean Discrepancy (MMD).
Quote
MMD-critic uses the MMD statistic as a measure of similarity between points and potential prototypes, and efficiently selects prototypes that maximize the statistic. In addition to prototypes, MMD-critic selects criticism samples i.e. samples that are not well-explained by the prototypes using a regularized witness function score.
-- Efficient Data Representation by Selecting Prototypes with Importance Weights (2019).
First, to find prototypes \(\mathcal{P}\), a greedy algorithm is used to maximize \(F(\mathcal{P})\) s.t. \(|\mathcal{P}| \le m_p\) where \(F(\mathcal{P})\) is defined as: \begin{equation} F(\mathcal{P})=\frac{2}{|\mathcal{P}|\cdot n}\sum_{i,j=1}^{|\mathcal{P}|,n}\kappa(p_i,x_j)-\frac{1}{|\mathcal{P}|^2}\sum_{i,j=1}^{|\mathcal{P}|}\kappa(p_i,p_j), \end{equation} where \(m_p\) the number of prototypes to be found. They used diagonal dominance conditions on the kernel to ensure monotonocity and submodularity of \(F(\mathcal{P})\).
Second, to find criticisms \(\mathcal{C}\), the same greedy algorithm is used to select points that maximize another objective function \(J(\mathcal{C})\).
Warning
For MMDCritic, the kernel must satisfy a condition that ensures the submodularity of the set function. The Gaussian kernel meets this requirement and it is recommended. If you wish to choose a different kernel, it must satisfy the condition described by Kim et al., 2016.
Example¶
from xplique.example_based import MMDCritic
from xplique.example_based.projections import LatentSpaceProjection
# load the training dataset and the model
cases_dataset = ... # load the training dataset
model = ...
# load the test samples
test_samples = ... # load the test samples to search for
# parameters
case_returns = "all"  # elements returned by the explain function
latent_layer = "last_conv"  # where to split your model for the projection
nb_global_prototypes = 5
nb_local_prototypes = 1
kernel_fn = None  # the default rbf kernel will be used, the distance will be based on this
# construct a projection with your model
projection = LatentSpaceProjection(model, latent_layer=latent_layer)
mmd = MMDCritic(
    cases_dataset=cases_dataset,
    nb_global_prototypes=nb_global_prototypes,
    nb_local_prototypes=nb_local_prototypes,
    projection=projection,
    case_returns=case_returns,
)
# compute global explanation
global_prototypes = mmd.get_global_prototypes()
# compute local explanation
local_prototypes = mmd.explain(test_samples)
Notebooks¶
MMDCritic¶
__init__(self,
          cases_dataset:  ~DatasetOrTensor,
          labels_dataset:  Optional[~DatasetOrTensor] = None,
          targets_dataset:  Optional[~DatasetOrTensor] = None,
          nb_global_prototypes:  int = 1,
          nb_local_prototypes:  int = 1,
          projection:  Union[xplique.example_based.projections.base.Projection, Callable] = None,
          case_returns:  Union[List[str], str] = 'examples',
          batch_size:  Optional[int] = None,
          distance:  Union[int, str, Callable, None] = None,
          kernel_fn:   = None ,
          gamma:  float = None)¶
cases_dataset: ~DatasetOrTensor,
labels_dataset: Optional[~DatasetOrTensor] = None,
targets_dataset: Optional[~DatasetOrTensor] = None,
nb_global_prototypes: int = 1,
nb_local_prototypes: int = 1,
projection: Union[xplique.example_based.projections.base.Projection, Callable] = None,
case_returns: Union[List[str], str] = 'examples',
batch_size: Optional[int] = None,
distance: Union[int, str, Callable, None] = None,
kernel_fn:
gamma: float = None)
explain(self,
         inputs:  Union[tensorflow.python.framework.tensor.Tensor, numpy.ndarray],
         targets:  Union[tensorflow.python.framework.tensor.Tensor, numpy.ndarray, None] = None)¶
inputs: Union[tensorflow.python.framework.tensor.Tensor, numpy.ndarray],
targets: Union[tensorflow.python.framework.tensor.Tensor, numpy.ndarray, None] = None)
Return the relevant examples to explain the (inputs, targets).
It projects inputs with self.projection in the search space
and find examples with the self.search_method.
Parameters
- 
inputs : Union[tensorflow.python.framework.tensor.Tensor, numpy.ndarray] - Tensor or Array. Input samples to be explained. - Expected shape among (N, W), (N, T, W), (N, W, H, C). - More information in the documentation. 
 
- 
targets : Union[tensorflow.python.framework.tensor.Tensor, numpy.ndarray, None] = None - Targets associated to the - inputsfor projection.- Shape: (n, nb_classes) where n is the number of samples and nb_classes is the number of classes. - It is used in the - projection. But- projectioncan compute it internally.
 
Return
- 
return_dict - Dictionary with listed elements in - self.returns.- The elements that can be returned are defined with the - _returns_possibilitiesstatic attribute of the class.
 
format_search_output(self,
                      search_output:  Dict[str, tensorflow.python.framework.tensor.Tensor],
                      inputs:  Union[tensorflow.python.framework.tensor.Tensor, numpy.ndarray])¶
search_output: Dict[str, tensorflow.python.framework.tensor.Tensor],
inputs: Union[tensorflow.python.framework.tensor.Tensor, numpy.ndarray])
Format the output of the search_method to match the expected returns in self.returns.
Parameters
- 
search_output : Dict[str, tensorflow.python.framework.tensor.Tensor] - Dictionary with the required outputs from the - search_method.
 
- 
inputs : Union[tensorflow.python.framework.tensor.Tensor, numpy.ndarray] - Tensor or Array. Input samples to be explained. - Expected shape among (N, W), (N, T, W), (N, W, H, C). 
 
Return
- 
return_dict - Dictionary with listed elements in - self.returns.- The elements that can be returned are defined with the - _returns_possibilitiesstatic attribute of the class.
 
get_global_prototypes(self) -> Dict[str, tensorflow.python.framework.tensor.Tensor]¶
Provide the global prototypes computed at the initialization.
Prototypes and their labels are extracted from the indices.
The weights of the prototypes and their indices are also returned. 
Return
- 
prototypes_dict : Dict[str, tf.Tensor] - A dictionary with the following - 'prototypes': The prototypes found by the method. - - 'prototype_labels': The labels of the prototypes. - - 'prototype_weights': The weights of the prototypes. - - 'prototype_indices': The indices of the prototypes.