Skip to content

CRAFT

View colab Tensorflow tutorial | View colab Pytorch tutorial | View source | 📰 Paper

CRAFT or Concept Recursive Activation FacTorization for Explainability is a method for automatically extracting human-interpretable concepts from deep networks.

This concept activations factorization method aims to explain a trained model's decisions on a per-class and per-image basis by highlighting both "what" the model saw and “where” it saw it. Thus CRAFT generates post-hoc local and global explanations.

It is made up from 3 ingredients:

  1. a method to recursively decompose concepts into sub-concepts
  2. a method to better estimate the importance of extracted concepts
  3. a method to use any attribution method to create concept attribution maps, using implicit differentiation

CRAFT requires splitting the model in two parts: \((g, h)\) such that \(f(x) = (g \cdot h)(x)\). To put it simply, \(g\) is the function that maps our input to the latent space (an inner layer of our model), and \(h\) is the function that maps the latent space to the output. The concepts will be extracted from this latent space.

Info

It is important to note that if the model contains a global average pooling layer, it is strongly recommended to provide CRAFT with the layer before the global average pooling.

Warning

Please keep in mind that the activations must be positives (after relu or any positive activation function)

Example

Use Craft to investigate a single class.

from xplique.concepts import CraftTf as Craft

# Cut the model in two parts (as explained in the paper)
# first part is g(.) our 'input_to_latent' model returning positive activations,
# second part is h(.) our 'latent_to_logit' model

g = tf.keras.Model(model.input, model.layers[-3].output)
h = tf.keras.Model(model.layers[-2].input, model.layers[-1].output)

# Create a Craft concept extractor from these 2 models
craft = Craft(input_to_latent_model = g,
              latent_to_logit_model = h,
              number_of_concepts = 10,
              patch_size = 80,
              batch_size = 64)

# Use Craft to get the crops (crops), the embedding of the crops (crops_u),
# and the concept bank (w)
crops, crops_u, w = craft.fit(images_preprocessed, class_id=rabbit_class_id)

# Compute Sobol indices to understand which concept matters
importances = craft.estimate_importance()

# Display those concepts by showing the 10 best crops for each concept
craft.plot_concepts_crops(nb_crops=10)

Use CraftManager to investigate multiple classes.

from xplique.concepts import CraftManagerTf as CraftManager


# Cut the model in two parts (as explained in the paper)
# first part is g(.) our 'input_to_latent' model returning positive activations,
# second part is h(.) our 'latent_to_logit' model

g = tf.keras.Model(model.input, model.layers[-3].output)
h = tf.keras.Model(model.layers[-2].input, model.layers[-1].output)

# CraftManager will create one instance of Craft per class of interest
# to investigate
list_of_class_of_interest = [0, 491, 497, 569, 574] # list of class_ids
cm = CraftManager(input_to_latent_model = g,
                 latent_to_logit_model = h,
                 inputs = inputs_preprocessed,
                 labels = y,
                 list_of_class_of_interest = list_of_class_of_interest,
                 number_of_concepts = 10,
                 patch_size = 80,
                 batch_size = 64)

cm.fit(nb_samples_per_class=50)

# Compute Sobol indices to understand which concept matters
cm.estimate_importance()

# Display those concepts by showing the 10 best crops for each concept,
# for the 1st class
cm.plot_concepts_crops(class_id=0, nb_crops=10)

CraftTf

Class implementing the CRAFT Concept Extraction Mechanism on Tensorflow.

__init__(self,
         input_to_latent_model: Callable,
         latent_to_logit_model: Callable,
         number_of_concepts: int = 20,
         batch_size: int = 64,
         patch_size: int = 64)

Parameters

  • input_to_latent_model : Callable

    • The first part of the model taking an input and returning positive activations, g(.) in the original paper.

      Must be a Tensorflow model (tf.keras.engine.base_layer.Layer) accepting data of shape (n_samples, height, width, channels).

  • latent_to_logit_model : Callable

    • The second part of the model taking activation and returning logits, h(.) in the original paper.

      Must be a Tensorflow model (tf.keras.engine.base_layer.Layer).

  • number_of_concepts : int = 20

    • The number of concepts to extract. Default is 20.

  • batch_size : int = 64

    • The batch size to use during training and prediction. Default is 64.

  • patch_size : int = 64

    • The size of the patches to extract from the input data. Default is 64.

check_if_fitted(self)

Checks if the factorization model has been fitted to input data.


compute_subplots_layout_parameters(images: numpy.ndarray,
                                   cols: int = 5,
                                   img_size: float = 2.0,
                                   margin: float = 0.3,
                                   spacing: float = 0.3)

Compute layout parameters for subplots, to be used by the method fig.subplots_adjust()

Parameters

  • images : numpy.ndarray

    • The images to display with subplots. Should be data of shape (n_samples, height, width, channels).

  • cols : int = 5

    • Number of columns to configure for the subplots.

      Defaults to 5.

  • img_size : float = 2.0

    • Size of each subplots (in inch), considering we keep aspect ratio. Defaults to 2.

  • margin : float = 0.3

    • The margin to use for the subplots. Defaults to 0.3.

  • spacing : float = 0.3

    • The spacing to use for the subplots. Defaults to 0.3.

Return

  • layout_parameters

    • A dictionary containing the layout description

  • rows

    • The number of rows needed to display the images

  • figwidth

    • The figures width in the subplots

  • figheight

    • The figures height in the subplots


estimate_importance(self,
                    inputs: numpy.ndarray = None,
                    nb_design: int = 32) -> numpy.ndarray

Estimates the importance of each concept for a given class, either globally on the whole dataset provided in the fit() method (in this case, inputs shall be set to None), or locally on a specific input image.

Parameters

  • inputs : numpy array or Tensor

    • The input data on which to compute the importances.

      If None, then the inputs provided in the fit() method will be used (global importance of the whole dataset).

      Default is None.

  • nb_design : int = 32

    • The number of design to use for the importance estimation. Default is 32.

Return

  • importances : numpy.ndarray

    • The Sobol total index (importance score) for each concept.


fit(self,
    inputs: numpy.ndarray,
    class_id: int = 0) -> Tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]

Fit the Craft model to the input data.

Parameters

  • inputs : numpy.ndarray

    • Input data of shape (n_samples, height, width, channels).

      (x1, x2, ..., xn) in the paper.

  • class_id : int = 0

    • The class id of the inputs.

Return

  • crops : Tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]

    • The crops (X in the paper)

  • crops_u : Tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]

    • The concepts' values (U in the paper)

  • concept_bank_w : Tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]

    • The concept's basis (W in the paper)


plot_concept_attribution_legend(self,
                                nb_most_important_concepts: int = 6,
                                border_width: int = 5)

Plot a legend for the concepts attribution maps.

Parameters

  • nb_most_important_concepts : int = 6

    • The number of concepts to focus on. Default is 6.

  • border_width : int = 5

    • Width of the border around each concept image, in pixels. Defaults to 5.


plot_concept_attribution_map(self,
                             image: numpy.ndarray,
                             most_important_concepts: numpy.ndarray,
                             nb_most_important_concepts: int = 5,
                             filter_percentile: int = 90,
                             clip_percentile: Optional[float] = 10,
                             alpha: float = 0.65,
                             **plot_kwargs)

Display the concepts attribution map for a single image given in argument.

Parameters

  • image : numpy.ndarray

    • The image to display.

  • most_important_concepts : numpy.ndarray

    • The concepts ids to display.

  • nb_most_important_concepts : int = 5

    • The number of concepts to display. Default is 5.

  • filter_percentile : int = 90

    • Percentile used to filter the concept heatmap.

      (only show concept if excess N-th percentile). Defaults to 90.

  • clip_percentile : Optional[float] = 10

    • Percentile value to use if clipping is needed when drawing the concept, e.g a value of 1 will perform a clipping between percentile 1 and 99.

      This parameter allows to avoid outliers in case of too extreme values.

      It is applied after the filter_percentile operation.

      Default to 10.

  • alpha : float = 0.65

    • The alpha channel value for the heatmaps. Defaults to 0.65.

  • plot_kwargs : **plot_kwargs

    • Additional parameters passed to plt.imshow().


plot_concept_attribution_maps(self,
                              images: numpy.ndarray,
                              importances: numpy.ndarray = None,
                              nb_most_important_concepts: int = 5,
                              filter_percentile: int = 90,
                              clip_percentile: Optional[float] = 10.0,
                              alpha: float = 0.65,
                              cols: int = 5,
                              img_size: float = 2.0,
                              **plot_kwargs)

Display the concepts attribution maps for the images given in argument.

Parameters

  • images : numpy.ndarray

    • The images to display.

  • importances : numpy.ndarray = None

    • The importances computed by the estimate_importance() method.

      If None is provided, then the global importances will be used, otherwise the local importances set in this parameter will be used.

  • nb_most_important_concepts : int = 5

    • The number of concepts to focus on. Default is 5.

  • filter_percentile : int = 90

    • Percentile used to filter the concept heatmap (only show concept if excess N-th percentile). Defaults to 90.

  • clip_percentile : Optional[float] = 10.0

    • Percentile value to use if clipping is needed when drawing the concept, e.g a value of 1 will perform a clipping between percentile 1 and 99.

      This parameter allows to avoid outliers in case of too extreme values.

      It is applied after the filter_percentile operation.

      Default to 10.

  • alpha : float = 0.65

    • The alpha channel value for the heatmaps. Defaults to 0.65.

  • cols : int = 5

    • Number of columns. Default to 3.

  • img_size : float = 2.0

    • Size of each subplots (in inch), considering we keep aspect ratio.

  • plot_kwargs : **plot_kwargs

    • Additional parameters passed to plt.imshow().


plot_concepts_crops(self,
                    nb_crops: int = 10,
                    nb_most_important_concepts: int = None,
                    verbose: bool = False) -> None

Display the crops for each concept.

Parameters

  • nb_crops : int = 10

    • The number of crops (patches) to display per concept. Defaults to 10.

  • nb_most_important_concepts : int = None

    • The number of concepts to display. If provided, only display nb_most_important_concepts, otherwise display them all.

      Default is None.

  • verbose : bool = False

    • If True, then print the importance value of each concept, otherwise no textual output will be printed.


plot_concepts_importances(self,
                          importances: numpy.ndarray = None,
                          display_importance_order: xplique.concepts.craft.DisplayImportancesOrder = ,
                          nb_most_important_concepts: int = None,
                          verbose: bool = False)

Plot a bar chart displaying the importance value of each concept.

Parameters

  • importances : numpy.ndarray = None

    • The importances computed by the estimate_importance() method.

      Default is None, in this case the importances computed on the whole dataset will be used.

  • display_importance_order : 0>

    • Selects the order in which the concepts will be displayed, either following the global importance on the whole dataset (same order for all images) or the local importance of the concepts for a single image sample (local importance).

  • nb_most_important_concepts : int = None

    • The number of concepts to display. If None is provided, then all the concepts will be displayed unordered, otherwise only nb_most_important_concepts will be displayed, ordered by importance.

      Default is None.

  • verbose : bool = False

    • If True, then print the importance value of each concept, otherwise no textual output will be printed.


plot_image_concepts(self,
                    img: numpy.ndarray,
                    display_importance_order: xplique.concepts.craft.DisplayImportancesOrder = ,
                    nb_most_important_concepts: int = 5,
                    filter_percentile: int = 90,
                    clip_percentile: Optional[float] = 10,
                    alpha: float = 0.65,
                    filepath: Optional[str] = None,
                    **plot_kwargs)

All in one method displaying several plots for the image id given in argument: - the concepts attribution map for this image - the best crops for each concept (displayed around the heatmap) - the importance of each concept

Parameters

  • img : numpy.ndarray

    • The image to display.

  • display_importance_order : 0>

    • Selects the order in which the concepts will be displayed, either following the global importance on the whole dataset (same order for all images) or the local importance of the concepts for a single image sample (local importance).

      Default to GLOBAL.

  • nb_most_important_concepts : int = 5

    • The number of concepts to display. Default is 5.

  • filter_percentile : int = 90

    • Percentile used to filter the concept heatmap (only show concept if excess N-th percentile). Defaults to 90.

  • clip_percentile : Optional[float] = 10

    • Percentile value to use if clipping is needed when drawing the concept, e.g a value of 1 will perform a clipping between percentile 1 and 99.

      This parameter allows to avoid outliers in case of too extreme values.

      Default to 10.

  • alpha : float = 0.65

    • The alpha channel value for the heatmaps. Defaults to 0.65.

  • filepath : Optional[str] = None

    • Path the file will be saved at. If None, the function will call plt.show().

  • plot_kwargs : **plot_kwargs

    • Additional parameters passed to plt.imshow().


transform(self,
          inputs: numpy.ndarray,
          activations: numpy.ndarray = None) -> numpy.ndarray

Transforms the inputs data into its concept representation.

Parameters

  • inputs : numpy.ndarray

    • The input data to be transformed.

  • activations : numpy.ndarray = None

    • Pre-computed activations of the input data. If not provided, the activations will be computed using the input_to_latent_model model on the inputs.

Return

  • coeffs_u : numpy.ndarray

    • The concepts' values of the inputs (U in the paper).


CraftManagerTf

Class implementing the CraftManager on Tensorflow. This manager creates one CraftTf instance per class to explain.

__init__(self,
         input_to_latent_model: Callable,
         latent_to_logit_model: Callable,
         inputs: numpy.ndarray,
         labels: numpy.ndarray,
         list_of_class_of_interest: Optional[list] = None,
         number_of_concepts: int = 20,
         batch_size: int = 64,
         patch_size: int = 64)

Parameters

  • input_to_latent_model : Callable

    • The first part of the model taking an input and returning positive activations, g(.) in the original paper.

      Must return positive activations.

  • latent_to_logit_model : Callable

    • The second part of the model taking activation and returning logits, h(.) in the original paper.

  • inputs : numpy.ndarray

    • Input data of shape (n_samples, height, width, channels).

      (x1, x2, ..., xn) in the paper.

  • labels : numpy.ndarray

    • Labels of the inputs of shape (n_samples, class_id)

  • list_of_class_of_interest : Optional[list] = None

    • A list of the classes id to explain. The manager will instanciate one CraftTf object per element of this list.

  • number_of_concepts : int = 20

    • The number of concepts to extract. Default is 20.

  • batch_size : int = 64

    • The batch size to use during training and prediction. Default is 64.

  • patch_size : int = 64

    • The size of the patches (crops) to extract from the input data. Default is 64.

compute_predictions(self)

Compute the predictions for the current dataset, using the 2 models input_to_latent_model and latent_to_logit_model chained.

Return

  • y_preds

    • the predictions


estimate_importance(self,
                    nb_design: int = 32,
                    verbose: bool = False)

Estimates the importance of each concept for all the classes of interest.

Parameters

  • nb_design : int = 32

    • The number of design to use for the importance estimation. Default is 32.

  • verbose : bool = False

    • If True, then print the current class CRAFT is estimating importances for, otherwise no textual output will be printed.


fit(self,
    nb_samples_per_class: Optional[int] = None,
    verbose: bool = False)

Fit the Craft models on their respective class of interest.

Parameters

  • nb_samples_per_class : Optional[int] = None

    • Number of samples to use to fit the Craft model.

      Default is None, which means that all the samples will be used.

  • verbose : bool = False

    • If True, then print the current class CRAFT is fitting, otherwise no textual output will be printed.


plot_concepts_crops(self,
                    class_id: int,
                    nb_crops: int = 10,
                    nb_most_important_concepts: int = None)

Display the crops for each concept.

Parameters

  • class_id : int

    • The class to explain.

  • nb_crops : int = 10

    • The number of crops (patches) to display per concept. Defaults to 10.

  • nb_most_important_concepts : int = None

    • The number of concepts to display. If provided, only display nb_most_important_concepts, otherwise display them all.

      Default is None.


plot_concepts_importances(self,
                          class_id: int,
                          nb_most_important_concepts: int = 5,
                          verbose: bool = False)

Plot a bar chart displaying the importance value of each concept.

Parameters

  • class_id : int

    • The class to explain.

  • nb_most_important_concepts : int = 5

    • The number of concepts to focus on. Default is 5.

  • verbose : bool = False

    • If True, then print the importance value of each concept, otherwise no textual output will be printed.


plot_image_concepts(self,
                    img: numpy.ndarray,
                    class_id: int,
                    display_importance_order: xplique.concepts.craft.DisplayImportancesOrder = ,
                    nb_most_important_concepts: int = 5,
                    filter_percentile: int = 90,
                    clip_percentile: Optional[float] = 10,
                    alpha: float = 0.65,
                    filepath: Optional[str] = None)

All in one method displaying several plots for the image id given in argument: - the concepts attribution map for this image - the best crops for each concept (displayed around the heatmap) - the importance of each concept

Parameters

  • img : numpy.ndarray

    • The image to explain.

  • class_id : int

    • The class to explain.

  • display_importance_order : 0>

    • Selects the order in which the concepts will be displayed, either following the global importance on the whole dataset (same order for all images) or the local importance of the concepts for a single image sample (local importance).

      Default to GLOBAL.

  • nb_most_important_concepts : int = 5

    • The number of concepts to focus on. Default is 5.

  • filter_percentile : int = 90

    • Percentile used to filter the concept heatmap (only show concept if excess N-th percentile). Defaults to 90.

  • clip_percentile : Optional[float] = 10

    • Percentile value to use if clipping is needed when drawing the concept, e.g a value of 1 will perform a clipping between percentile 1 and 99.

      This parameter allows to avoid outliers in case of too extreme values.

      Default to 10.

  • alpha : float = 0.65

    • The alpha channel value for the heatmaps. Defaults to 0.65.

  • filepath : Optional[str] = None

    • Path the file will be saved at. If None, the function will call plt.show().