ProtoDash¶
View colab tutorial | View source | 📰 Paper
ProtoDash
associated non-negative weights to prototypes which are indicative of their importance. This approach allows for identifying both prototypes and criticisms (the least weighted examples among prototypes) by maximmizing the same weighted objective function.
Quote
Our work notably generalizes the recent work by Kim et al. (2016) where in addition to selecting prototypes, we also associate non-negative weights which are indicative of their importance. This extension provides a single coherent framework under which both prototypes and criticisms (i.e. outliers) can be found. Furthermore, our framework works for any symmetric positive definite kernel thus addressing one of the key open questions laid out in Kim et al. (2016).
-- Efficient Data Representation by Selecting Prototypes with Importance Weights (2019).
More precisely, the weighted objective \(F(\mathcal{P},w)\) is defined as:
\begin{equation}
F(\mathcal{P},w)=\frac{2}{n}\sum_{i,j=1}^{|\mathcal{P}|,n}w_i\kappa(p_i,x_j)-\sum_{i,j=1}^{|\mathcal{P}|}w_iw_j\kappa(p_i,p_j),
\end{equation}
where \(w\) are non-negative weights for each prototype. The problem then consist on finding a subset \(\mathcal{P}\) with a corresponding \(w\) that maximizes \(J(\mathcal{P}) \equiv \max_{w:supp(w)\in \mathcal{P},w\ge 0} J(\mathcal{P},w)\) s.t. \(|\mathcal{P}| \leq m=m_p+m_c\).
Info
For ProtoDash, any kernel can be used, as these methods rely on weak submodularity instead of full submodularity.
Example¶
from xplique.example_based import ProtoDash
from xplique.example_based.projections import LatentSpaceProjection
# load the training dataset and the model
cases_dataset = ... # load the training dataset
model = ...
# load the test samples
test_samples = ... # load the test samples to search for
# parameters
case_returns = "all" # elements returned by the explain function
latent_layer = "last_conv" # where to split your model for the projection
nb_global_prototypes = 5
nb_local_prototypes = 1
kernel_fn = None # the default rbf kernel will be used, the distance will be based on this
# construct a projection with your model
projection = LatentSpaceProjection(model, latent_layer=latent_layer)
protodash = ProtoDash(
cases_dataset=cases_dataset,
nb_global_prototypes=nb_global_prototypes,
nb_local_prototypes=nb_local_prototypes,
projection=projection,
case_returns=case_returns,
)
# compute global explanation
global_prototypes = protodash.get_global_prototypes()
# compute local explanation
local_prototypes = protodash.explain(test_samples)
Notebooks¶
ProtoDash
¶
__init__(self,
cases_dataset: ~DatasetOrTensor,
labels_dataset: Optional[~DatasetOrTensor] = None,
targets_dataset: Optional[~DatasetOrTensor] = None,
nb_global_prototypes: int = 1,
nb_local_prototypes: int = 1,
projection: Union[xplique.example_based.projections.base.Projection, Callable] = None,
case_returns: Union[List[str], str] = 'examples',
batch_size: Optional[int] = None,
distance: Union[int, str, Callable, None] = None,
kernel_fn: = None ,
gamma: float = None)
¶
cases_dataset: ~DatasetOrTensor,
labels_dataset: Optional[~DatasetOrTensor] = None,
targets_dataset: Optional[~DatasetOrTensor] = None,
nb_global_prototypes: int = 1,
nb_local_prototypes: int = 1,
projection: Union[xplique.example_based.projections.base.Projection, Callable] = None,
case_returns: Union[List[str], str] = 'examples',
batch_size: Optional[int] = None,
distance: Union[int, str, Callable, None] = None,
kernel_fn:
gamma: float = None)
explain(self,
inputs: Union[tensorflow.python.framework.tensor.Tensor, numpy.ndarray],
targets: Union[tensorflow.python.framework.tensor.Tensor, numpy.ndarray, None] = None)
¶
inputs: Union[tensorflow.python.framework.tensor.Tensor, numpy.ndarray],
targets: Union[tensorflow.python.framework.tensor.Tensor, numpy.ndarray, None] = None)
Return the relevant examples to explain the (inputs, targets).
It projects inputs with self.projection
in the search space
and find examples with the self.search_method
.
Parameters
-
inputs : Union[tensorflow.python.framework.tensor.Tensor, numpy.ndarray]
Tensor or Array. Input samples to be explained.
Expected shape among (N, W), (N, T, W), (N, W, H, C).
More information in the documentation.
-
targets : Union[tensorflow.python.framework.tensor.Tensor, numpy.ndarray, None] = None
Targets associated to the
inputs
for projection.Shape: (n, nb_classes) where n is the number of samples and nb_classes is the number of classes.
It is used in the
projection
. Butprojection
can compute it internally.
Return
-
return_dict
Dictionary with listed elements in
self.returns
.The elements that can be returned are defined with the
_returns_possibilities
static attribute of the class.
format_search_output(self,
search_output: Dict[str, tensorflow.python.framework.tensor.Tensor],
inputs: Union[tensorflow.python.framework.tensor.Tensor, numpy.ndarray])
¶
search_output: Dict[str, tensorflow.python.framework.tensor.Tensor],
inputs: Union[tensorflow.python.framework.tensor.Tensor, numpy.ndarray])
Format the output of the search_method
to match the expected returns in self.returns
.
Parameters
-
search_output : Dict[str, tensorflow.python.framework.tensor.Tensor]
Dictionary with the required outputs from the
search_method
.
-
inputs : Union[tensorflow.python.framework.tensor.Tensor, numpy.ndarray]
Tensor or Array. Input samples to be explained.
Expected shape among (N, W), (N, T, W), (N, W, H, C).
Return
-
return_dict
Dictionary with listed elements in
self.returns
.The elements that can be returned are defined with the
_returns_possibilities
static attribute of the class.
get_global_prototypes(self) -> Dict[str, tensorflow.python.framework.tensor.Tensor]
¶
Provide the global prototypes computed at the initialization.
Prototypes and their labels are extracted from the indices.
The weights of the prototypes and their indices are also returned.
Return
-
prototypes_dict : Dict[str, tf.Tensor]
A dictionary with the following - 'prototypes': The prototypes found by the method.
- 'prototype_labels': The labels of the prototypes.
- 'prototype_weights': The weights of the prototypes.
- 'prototype_indices': The indices of the prototypes.