Skip to content

Label Aware Counterfactuals

View colab tutorial | View source | 📰 Paper

Note

The paper referenced here is not exactly the one we implemented. However, it is probably the closest in essence of what we implemented.

In contrast to the Naive Counterfactuals approach, the Label Aware CounterFactuals leverage an a priori knowledge of the Counterfactuals' (CFs) targets to guide the search for the CFs (e.g. one is looking for a CF of the digit 8 in MNIST dataset within the digit 0 instances).

Warning

Consequently, for this class, when a user call the explain method, the user is expected to provide both the targets corresponding to the input samples and cf_expected_classes a one-hot encoding of the label expected for the CFs. But in most cases, the targets can be set to None as they are computed internally by projections.

Info

One can use the Projection object to compute the distances between the samples (e.g. search for the CF in the latent space of a model).

Example

from xplique.example_based import LabelAwareCounterFactuals
from xplique.example_based.projections import LatentSpaceProjection

# load the training dataset and the model
cases_dataset = ... # load the training dataset
targets_dataset = ... # load the one-hot encoding of predicted labels of the training dataset
model = ...

# load the test samples
test_samples = ... # load the test samples to search for
test_cf_expacted_classes = ... # WARNING: provide the one-hot encoding of the expected label of the CFs

# parameters
k = 5  # number of example for each input
case_returns = "all"  # elements returned by the explain function
distance = "euclidean"
latent_layer = "last_conv"  # where to split your model for the projection

# construct a projection with your model
projection = LatentSpaceProjection(model, latent_layer=latent_layer)

# instantiate the LabelAwareCounterfactuals object
lacf = LabelAwareCounterFactuals(
    cases_dataset=cases_dataset,
    targets_dataset=targets_dataset,
    k=k,
    projection=projection,
    case_returns=case_returns,
    distance=distance,
)

# search the CFs for the test samples
output_dict = lacf.explain(
    inputs=test_samples,
    targets=None,  # not necessary for this projection
    cf_expected_classes=test_cf_expacted_classes,
)

Notebooks

LabelAwareCounterFactuals

This method will search the counterfactuals of a query within an expected class. This class should be provided with the query when calling the explain method.

__init__(self,
         cases_dataset: ~DatasetOrTensor,
         targets_dataset: ~DatasetOrTensor,
         labels_dataset: Optional[~DatasetOrTensor] = None,
         k: int = 1,
         projection: Union[xplique.example_based.projections.base.Projection, Callable] = None,
         case_returns: Union[List[str], str] = 'examples',
         batch_size: Optional[int] = None,
         distance: Union[int, str, Callable] = 'euclidean')

Parameters

  • cases_dataset : ~DatasetOrTensor

    • The dataset used to train the model, examples are extracted from this dataset.

      All datasets (cases, labels, and targets) should be of the same type.

      Supported types are: tf.data.Dataset, torch.utils.data.DataLoader, tf.Tensor, np.ndarray, torch.Tensor.

      For datasets with multiple columns, the first column is assumed to be the cases.

      While the second column is assumed to be the labels, and the third the targets.

      Warning: datasets tend to reshuffle at each iteration, ensure the datasets are not reshuffle as we use index in the dataset.

  • targets_dataset : ~DatasetOrTensor

    • Targets associated with the cases_dataset for dataset projection, oftentimes the one-hot encoding of a model's predictions. See projection for detail.

      They are also used to know the prediction of the model on the dataset.

      It should have the same type as cases_dataset.

  • labels_dataset : Optional[~DatasetOrTensor] = None

    • Labels associated with the examples in the cases_dataset.

      It should have the same type as cases_dataset.

  • k : int = 1

    • The number of examples to retrieve per input.

  • projection : Union[xplique.example_based.projections.base.Projection, Callable] = None

    • Projection or Callable that project samples from the input space to the search space.

      The search space should be a space where distances are relevant for the model.

      It should not be None, otherwise, the model is not involved thus not explained.

      Example of Callable: def custom_projection(inputs: tf.Tensor, np.ndarray): ''' Example of projection, inputs are the elements to project.</p><p> ''' projected_inputs = # do some magic on inputs, it should use the model.</p><p> return projected_inputs

  • case_returns : Union[List[str], str] = 'examples'

    • String or list of string with the elements to return in self.explain().

      See the base class returns property for more details.

  • batch_size : Optional[int] = None

    • Number of samples treated simultaneously for projection and search.

      Ignored if cases_dataset is a batched tf.data.Dataset or a batched torch.utils.data.DataLoader is provided.

  • distance : Union[int, str, Callable] = 'euclidean'

    • Distance for the FilterKNN search method.

      Distance function for examples search. It can be an integer, a string in {"manhattan", "euclidean", "cosine", "chebyshev", "inf"}, or a Callable, by default "euclidean".

explain(self,
        inputs: Union[tensorflow.python.framework.tensor.Tensor, numpy.ndarray],
        targets: Union[tensorflow.python.framework.tensor.Tensor, numpy.ndarray, None] = None,
        cf_expected_classes: Union[tensorflow.python.framework.tensor.Tensor, numpy.ndarray] = None)

Return the relevant CF examples to explain the inputs. The CF examples are searched within cases for which the target is the one provided in cf_targets. It projects inputs with self.projection in the search space and find examples with the self.search_method.

Parameters

  • inputs : Union[tensorflow.python.framework.tensor.Tensor, numpy.ndarray]

    • Tensor or Array. Input samples to be explained.

      Expected shape among (N, W), (N, T, W), (N, W, H, C).

      More information in the documentation.

  • targets : Union[tensorflow.python.framework.tensor.Tensor, numpy.ndarray, None] = None

    • Tensor or Array. One-hot encoded labels or regression target (e.g {+1, -1}), one for each sample. If not provided, the model's predictions are used.

      Targets associated to the inputs for projection. Shape: (n, nb_classes) where n is the number of samples and nb_classes is the number of classes.

      It is used in the projection. But projection can compute it internally.

  • cf_expected_classes : Union[tensorflow.python.framework.tensor.Tensor, numpy.ndarray] = None

    • Tensor or Array. One-hot encoding of the target class for the counterfactuals.

Return

  • return_dict

    • Dictionary with listed elements in self.returns.

      The elements that can be returned are defined with the _returns_possibilities static attribute of the class.


filter_fn(self,
          _,
          __,
          cf_expected_classes,
          cases_targets) -> tensorflow.python.framework.tensor.Tensor

Filter function to mask the cases for which the target is different from the target(s) expected for the counterfactuals.

Parameters

  • cf_expected_classes : cf_expected_classes

    • The one-hot encoding of the target class for the counterfactuals.

  • cases_targets : cases_targets

    • The one-hot encoding of the target class for the cases.