Skip to content

MMDCritic

View colab tutorial | View source | 📰 Paper

MMDCritic finds prototypes and criticisms by maximizing two separate objectives based on the Maximum Mean Discrepancy (MMD).

Quote

MMD-critic uses the MMD statistic as a measure of similarity between points and potential prototypes, and efficiently selects prototypes that maximize the statistic. In addition to prototypes, MMD-critic selects criticism samples i.e. samples that are not well-explained by the prototypes using a regularized witness function score.

-- Efficient Data Representation by Selecting Prototypes with Importance Weights (2019).

First, to find prototypes \(\mathcal{P}\), a greedy algorithm is used to maximize \(F(\mathcal{P})\) s.t. \(|\mathcal{P}| \le m_p\) where \(F(\mathcal{P})\) is defined as: \begin{equation} F(\mathcal{P})=\frac{2}{|\mathcal{P}|\cdot n}\sum_{i,j=1}^{|\mathcal{P}|,n}\kappa(p_i,x_j)-\frac{1}{|\mathcal{P}|^2}\sum_{i,j=1}^{|\mathcal{P}|}\kappa(p_i,p_j), \end{equation} where \(m_p\) the number of prototypes to be found. They used diagonal dominance conditions on the kernel to ensure monotonocity and submodularity of \(F(\mathcal{P})\).

Second, to find criticisms \(\mathcal{C}\), the same greedy algorithm is used to select points that maximize another objective function \(J(\mathcal{C})\).

Warning

For MMDCritic, the kernel must satisfy a condition that ensures the submodularity of the set function. The Gaussian kernel meets this requirement and it is recommended. If you wish to choose a different kernel, it must satisfy the condition described by Kim et al., 2016.

Example

from xplique.example_based import MMDCritic
from xplique.example_based.projections import LatentSpaceProjection

# load the training dataset and the model
cases_dataset = ... # load the training dataset
model = ...

# load the test samples
test_samples = ... # load the test samples to search for

# parameters
case_returns = "all"  # elements returned by the explain function
latent_layer = "last_conv"  # where to split your model for the projection
nb_global_prototypes = 5
nb_local_prototypes = 1
kernel_fn = None  # the default rbf kernel will be used, the distance will be based on this

# construct a projection with your model
projection = LatentSpaceProjection(model, latent_layer=latent_layer)

mmd = MMDCritic(
    cases_dataset=cases_dataset,
    nb_global_prototypes=nb_global_prototypes,
    nb_local_prototypes=nb_local_prototypes,
    projection=projection,
    case_returns=case_returns,
)

# compute global explanation
global_prototypes = mmd.get_global_prototypes()

# compute local explanation
local_prototypes = mmd.explain(test_samples)

Notebooks

MMDCritic


__init__(self,
         cases_dataset: ~DatasetOrTensor,
         labels_dataset: Optional[~DatasetOrTensor] = None,
         targets_dataset: Optional[~DatasetOrTensor] = None,
         nb_global_prototypes: int = 1,
         nb_local_prototypes: int = 1,
         projection: Union[xplique.example_based.projections.base.Projection, Callable] = None,
         case_returns: Union[List[str], str] = 'examples',
         batch_size: Optional[int] = None,
         distance: Union[int, str, Callable, None] = None,
         kernel_fn: = None,
         gamma: float = None)

explain(self,
        inputs: Union[tensorflow.python.framework.tensor.Tensor, numpy.ndarray],
        targets: Union[tensorflow.python.framework.tensor.Tensor, numpy.ndarray, None] = None)

Return the relevant examples to explain the (inputs, targets). It projects inputs with self.projection in the search space and find examples with the self.search_method.

Parameters

  • inputs : Union[tensorflow.python.framework.tensor.Tensor, numpy.ndarray]

    • Tensor or Array. Input samples to be explained.

      Expected shape among (N, W), (N, T, W), (N, W, H, C).

      More information in the documentation.

  • targets : Union[tensorflow.python.framework.tensor.Tensor, numpy.ndarray, None] = None

    • Targets associated to the inputs for projection.

      Shape: (n, nb_classes) where n is the number of samples and nb_classes is the number of classes.

      It is used in the projection. But projection can compute it internally.

Return

  • return_dict

    • Dictionary with listed elements in self.returns.

      The elements that can be returned are defined with the _returns_possibilities static attribute of the class.


format_search_output(self,
                     search_output: Dict[str, tensorflow.python.framework.tensor.Tensor],
                     inputs: Union[tensorflow.python.framework.tensor.Tensor, numpy.ndarray])

Format the output of the search_method to match the expected returns in self.returns.

Parameters

  • search_output : Dict[str, tensorflow.python.framework.tensor.Tensor]

    • Dictionary with the required outputs from the search_method.

  • inputs : Union[tensorflow.python.framework.tensor.Tensor, numpy.ndarray]

    • Tensor or Array. Input samples to be explained.

      Expected shape among (N, W), (N, T, W), (N, W, H, C).

Return

  • return_dict

    • Dictionary with listed elements in self.returns.

      The elements that can be returned are defined with the _returns_possibilities static attribute of the class.


get_global_prototypes(self) -> Dict[str, tensorflow.python.framework.tensor.Tensor]

Provide the global prototypes computed at the initialization. Prototypes and their labels are extracted from the indices. The weights of the prototypes and their indices are also returned.

Return

  • prototypes_dict : Dict[str, tf.Tensor]

    • A dictionary with the following - 'prototypes': The prototypes found by the method.

      - 'prototype_labels': The labels of the prototypes.

      - 'prototype_weights': The weights of the prototypes.

      - 'prototype_indices': The indices of the prototypes.