API: Example-based¶
Context¶
Quote
While saliency maps have stolen the show for the last few years in the XAI field, their ability to reflect models' internal processes has been questioned. Although less in the spotlight, example-based XAI methods have continued to improve. It encompasses methods that use samples as explanations for a machine learning model's predictions. This aligns with the psychological mechanisms of human reasoning and makes example-based explanations natural and intuitive for users to understand. Indeed, humans learn and reason by forming mental representations of concepts based on examples.
As mentioned by our team members in the quote above, example-based methods are an alternative to saliency maps and can be more aligned with some users' expectations. Thus, we have been working on implementing some of those methods in Xplique that have been put aside in the previous developments.
While not being exhaustive we tried to cover a range of methods that are representative of the field and that belong to different families: similar examples, contrastive (counter-factuals and semi-factuals) examples, and prototypes (as concepts based methods have a dedicated sections).
At present, we made the following choices: - Focus on methods that are natural example methods (post-hoc and non-generative, see the paper above for more details). - Try to unify the four families of approaches with a common API.
Info
We are in the early stages of development and are looking for feedback on the API design and the methods we have chosen to implement. Also, we are counting on the community to furnish the collection of methods available. If you are willing to contribute reach us on the GitHub repository (with an issue, pull request, ...).
Common API¶
projection = ProjectionMethod(model)
explainer = ExampleMethod(
cases_dataset=cases_dataset,
k=k,
projection=projection,
case_returns=case_returns,
distance=distance,
)
outputs_dict = explainer.explain(inputs, targets)
We tried to keep the API as close as possible to the one of the attribution methods to keep a consistent experience for the users.
The BaseExampleMethod
is an abstract base class designed for example-based methods used to explain classification models. It provides examples from a dataset (usually the training dataset) to help understand a model's predictions. Examples are projected from the input space to a search space using a projection function. The projection function defines the search space. Then, examples are selected using a search method within the search space. For all example-based methods, one can define the distance
that will be used by the search method.
We can broadly categorize example-based methods into four families: similar examples, counter-factuals, semi-factuals, and prototypes.
- Similar Examples: This method involves finding instances in the dataset that are similar to a given instance. The similarity is often determined based on the feature space, and these examples can help in understanding the model's decision by showing what other data points resemble the instance in question.
- Counter Factuals: Counterfactual explanations identify the minimal changes needed to an instance's features to change the model's prediction to a different, specified outcome. They help answer "what-if" scenarios by showing how altering certain aspects of the input would lead to a different decision.
- Semi Factuals: Semifactual explanations describe hypothetical situations where most features of an instance remain the same except for one or a few features, without changing the overall outcome. They highlight which features could vary without altering the prediction.
- Prototypes: Prototypes are representative examples from the dataset that summarize typical cases within a certain category or cluster. They act as archetypal instances that the model uses to make predictions, providing a reference point for understanding model behavior. Additional documentation can be found in the Prototypes API documentation.
Table of example-based methods available
Method | Family | Documentation | Tutorial |
---|---|---|---|
SimilarExamples |
Similar Examples | SimilarExamples | |
Cole |
Similar Examples | Cole | |
NaiveCounterFactuals |
Counter Factuals | NaiveCounterFactuals | |
LabelAwareCounterFactuals |
Counter Factuals | LabelAwareCounterFactuals | |
KLEORSimMiss |
Semi Factuals | KLEOR | |
KLEORGlobalSim |
Semi Factuals | KLEOR | |
ProtoGreedy |
Prototypes | ProtoGreedy | |
ProtoDash |
Prototypes | ProtoDash | |
MMDCritic |
Prototypes | MMDCritic |
Parameters¶
DatasetOrTensor = Union[tf.Tensor, np.ndarray, "torch.Tensor", tf.data.Dataset, "torch.utils.data.DataLoader"]
- cases_dataset (
DatasetOrTensor
): The dataset used to train the model, examples are extracted from this dataset. All datasets (cases, labels, and targets) should be of the same type. Supported types are:tf.data.Dataset
,torch.utils.data.DataLoader
,tf.Tensor
,np.ndarray
,torch.Tensor
. For datasets with multiple columns, the first column is assumed to be the cases. While the second column is assumed to be the labels, and the third the targets. Warning: datasets tend to reshuffle at each iteration, ensure the datasets are not reshuffle as we use index in the dataset. - labels_dataset (
Optional[DatasetOrTensor]
): Labels associated with the examples in the cases dataset. It should have the same type ascases_dataset
. - targets_dataset (
Optional[DatasetOrTensor]
): Targets associated with thecases_dataset
for dataset projection, often the one-hot encoding of a model's predictions. Seeprojection
for detail. It should have the same type ascases_dataset
. It is not be necessary for all projections. Furthermore, projections which requires it compute it internally by default. - k (
int
): The number of examples to retrieve per input. - projection (
Union[Projection, Callable]
): A projection or callable function that projects samples from the input space to the search space. The search space should be relevant for the model. (see Projections) - case_returns (
Union[List[str], str]
): Elements to return inself.explain()
. Default is"examples"
."all"
indicates that every possible output should be returned. - batch_size (
Optional[int]
): Number of samples processed simultaneously for projection and search. Ignored ifcases_dataset
is a batchedtf.data.Dataset
or a batchedtorch.utils.data.DataLoader
is provided.
Tips
If the elements of your dataset are tuples (cases, labels), you can pass this dataset directly to the cases_dataset
.
Tips
Apart from contrastive explanations, in the case of classification, the built-in Projections compute targets
online and the targets_dataset
is not necessary.
Properties¶
- search_method_class (
Type[BaseSearchMethod]
): Abstract property to define the search method class to use. Must be implemented in subclasses. (see Search Methods) - k (
int
): Getter and setter for thek
parameter. - returns (
Union[List[str], str]
): Getter and setter for thereturns
parameter. Defines the elements to return inself.explain()
.
explain(self, inputs, targets)
¶
Returns the relevant examples to explain the (inputs, targets). Projects inputs using self.projection
and finds examples using the self.search_method
.
- inputs (
Union[tf.Tensor, np.ndarray]
): Input samples to be explained. Shape: (n, ...) where n is the number of samples. - targets (
Optional[Union[tf.Tensor, np.ndarray]]
): Targets associated with theinputs
for projection. Shape: (n, nb_classes) where n is the number of samples and nb_classes is the number of classes. Not used in all projection. Used in contrastive methods to know the predicted classes of the provided samples.
Returns: Dictionary with elements listed in self.returns
.
Info
The __call__
method is an alias for the explain
method.
Projections¶
Projections are functions that map input samples to a search space where examples are retrieved with a search_method
. The search space should be relevant for the model (e.g. projecting the inputs into the latent space of the model).
Info
If one decides to use the identity function as a projection, the search space will be the input space, thus rather explaining the dataset than the model.
The Projection
class is a base class for projections. It involves two parts: space_projection
and weights
. The samples are first projected to a new space and then weighted.
Warning
If both parts are None
, the projection acts as an identity function. In general, we advise that one part should involve the model to ensure meaningful distance calculations with respect to the model.
To know more about projections and their importance, you can refer to the Projections section.
Search Methods¶
Info
The search methods are hidden to the user and only used internally. However, they help to understand how the API works.
Search methods are used to retrieve examples from the cases_dataset
that are relevant to the input samples.
Warning
In an search method, the cases_dataset
is the dataset that has been projected with a Projection
object (see the previous section). The search methods are used to find examples in this projected space.
Each example-based method has its own search method. The search method is defined in the search_method_class
property of the ExampleMethod
class.