CRAFT¶
View colab Tensorflow tutorial | View colab Pytorch tutorial | View source | 📰 Paper
CRAFT or Concept Recursive Activation FacTorization for Explainability is a method for automatically extracting human-interpretable concepts from deep networks.
This concept activations factorization method aims to explain a trained model's decisions on a per-class and per-image basis by highlighting both "what" the model saw and “where” it saw it. Thus CRAFT generates post-hoc local and global explanations.
It is made up from 3 ingredients:
- a method to recursively decompose concepts into sub-concepts
- a method to better estimate the importance of extracted concepts
- a method to use any attribution method to create concept attribution maps, using implicit differentiation
CRAFT requires splitting the model in two parts: \((g, h)\) such that \(f(x) = (g \cdot h)(x)\). To put it simply, \(g\) is the function that maps our input to the latent space (an inner layer of our model), and \(h\) is the function that maps the latent space to the output. The concepts will be extracted from this latent space.
Info
It is important to note that if the model contains a global average pooling layer, it is strongly recommended to provide CRAFT with the layer before the global average pooling.
Warning
Please keep in mind that the activations must be positives (after relu or any positive activation function)
Example¶
Use Craft to investigate a single class.
from xplique.concepts import CraftTf as Craft
# Cut the model in two parts (as explained in the paper)
# first part is g(.) our 'input_to_latent' model returning positive activations,
# second part is h(.) our 'latent_to_logit' model
g = tf.keras.Model(model.input, model.layers[-3].output)
h = tf.keras.Model(model.layers[-2].input, model.layers[-1].output)
# Create a Craft concept extractor from these 2 models
craft = Craft(input_to_latent_model = g,
latent_to_logit_model = h,
number_of_concepts = 10,
patch_size = 80,
batch_size = 64)
# Use Craft to get the crops (crops), the embedding of the crops (crops_u),
# and the concept bank (w)
crops, crops_u, w = craft.fit(images_preprocessed, class_id=rabbit_class_id)
# Compute Sobol indices to understand which concept matters
importances = craft.estimate_importance()
# Display those concepts by showing the 10 best crops for each concept
craft.plot_concepts_crops(nb_crops=10)
Use CraftManager to investigate multiple classes.
from xplique.concepts import CraftManagerTf as CraftManager
# Cut the model in two parts (as explained in the paper)
# first part is g(.) our 'input_to_latent' model returning positive activations,
# second part is h(.) our 'latent_to_logit' model
g = tf.keras.Model(model.input, model.layers[-3].output)
h = tf.keras.Model(model.layers[-2].input, model.layers[-1].output)
# CraftManager will create one instance of Craft per class of interest
# to investigate
list_of_class_of_interest = [0, 491, 497, 569, 574] # list of class_ids
cm = CraftManager(input_to_latent_model = g,
latent_to_logit_model = h,
inputs = inputs_preprocessed,
labels = y,
list_of_class_of_interest = list_of_class_of_interest,
number_of_concepts = 10,
patch_size = 80,
batch_size = 64)
cm.fit(nb_samples_per_class=50)
# Compute Sobol indices to understand which concept matters
cm.estimate_importance()
# Display those concepts by showing the 10 best crops for each concept,
# for the 1st class
cm.plot_concepts_crops(class_id=0, nb_crops=10)
CraftTf
¶
Class implementing the CRAFT Concept Extraction Mechanism on Tensorflow.
__init__(self,
input_to_latent_model: Callable,
latent_to_logit_model: Callable,
number_of_concepts: int = 20,
batch_size: int = 64,
patch_size: int = 64)
¶
input_to_latent_model: Callable,
latent_to_logit_model: Callable,
number_of_concepts: int = 20,
batch_size: int = 64,
patch_size: int = 64)
Parameters
-
input_to_latent_model : Callable
The first part of the model taking an input and returning positive activations, g(.) in the original paper.
Must be a Tensorflow model (tf.keras.engine.base_layer.Layer) accepting data of shape (n_samples, height, width, channels).
-
latent_to_logit_model : Callable
The second part of the model taking activation and returning logits, h(.) in the original paper.
Must be a Tensorflow model (tf.keras.engine.base_layer.Layer).
-
number_of_concepts : int = 20
The number of concepts to extract. Default is 20.
-
batch_size : int = 64
The batch size to use during training and prediction. Default is 64.
-
patch_size : int = 64
The size of the patches to extract from the input data. Default is 64.
check_if_fitted(self)
¶
Checks if the factorization model has been fitted to input data.
compute_subplots_layout_parameters(images: numpy.ndarray,
cols: int = 5,
img_size: float = 2.0,
margin: float = 0.3,
spacing: float = 0.3)
¶
cols: int = 5,
img_size: float = 2.0,
margin: float = 0.3,
spacing: float = 0.3)
Compute layout parameters for subplots, to be used by the
method fig.subplots_adjust()
Parameters
-
images : numpy.ndarray
The images to display with subplots. Should be data of shape (n_samples, height, width, channels).
-
cols : int = 5
Number of columns to configure for the subplots.
Defaults to 5.
-
img_size : float = 2.0
Size of each subplots (in inch), considering we keep aspect ratio. Defaults to 2.
-
margin : float = 0.3
The margin to use for the subplots. Defaults to 0.3.
-
spacing : float = 0.3
The spacing to use for the subplots. Defaults to 0.3.
Return
-
layout_parameters
A dictionary containing the layout description
-
rows
The number of rows needed to display the images
-
figwidth
The figures width in the subplots
-
figheight
The figures height in the subplots
estimate_importance(self,
inputs: numpy.ndarray = None,
nb_design: int = 32) -> numpy.ndarray
¶
inputs: numpy.ndarray = None,
nb_design: int = 32) -> numpy.ndarray
Estimates the importance of each concept for a given class, either globally
on the whole dataset provided in the fit() method (in this case, inputs shall
be set to None), or locally on a specific input image.
Parameters
-
inputs : numpy array or Tensor
The input data on which to compute the importances.
If None, then the inputs provided in the fit() method will be used (global importance of the whole dataset).
Default is None.
-
nb_design : int = 32
The number of design to use for the importance estimation. Default is 32.
Return
-
importances : numpy.ndarray
The Sobol total index (importance score) for each concept.
fit(self,
inputs: numpy.ndarray,
class_id: int = 0) -> Tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]
¶
inputs: numpy.ndarray,
class_id: int = 0) -> Tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]
Fit the Craft model to the input data.
Parameters
-
inputs : numpy.ndarray
Input data of shape (n_samples, height, width, channels).
(x1, x2, ..., xn) in the paper.
-
class_id : int = 0
The class id of the inputs.
Return
-
crops : Tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]
The crops (X in the paper)
-
crops_u : Tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]
The concepts' values (U in the paper)
-
concept_bank_w : Tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]
The concept's basis (W in the paper)
plot_concept_attribution_legend(self,
nb_most_important_concepts: int = 6,
border_width: int = 5)
¶
nb_most_important_concepts: int = 6,
border_width: int = 5)
Plot a legend for the concepts attribution maps.
Parameters
-
nb_most_important_concepts : int = 6
The number of concepts to focus on. Default is 6.
-
border_width : int = 5
Width of the border around each concept image, in pixels. Defaults to 5.
plot_concept_attribution_map(self,
image: numpy.ndarray,
most_important_concepts: numpy.ndarray,
nb_most_important_concepts: int = 5,
filter_percentile: int = 90,
clip_percentile: Optional[float] = 10,
alpha: float = 0.65,
**plot_kwargs)
¶
image: numpy.ndarray,
most_important_concepts: numpy.ndarray,
nb_most_important_concepts: int = 5,
filter_percentile: int = 90,
clip_percentile: Optional[float] = 10,
alpha: float = 0.65,
**plot_kwargs)
Display the concepts attribution map for a single image given in argument.
Parameters
-
image : numpy.ndarray
The image to display.
-
most_important_concepts : numpy.ndarray
The concepts ids to display.
-
nb_most_important_concepts : int = 5
The number of concepts to display. Default is 5.
-
filter_percentile : int = 90
Percentile used to filter the concept heatmap.
(only show concept if excess N-th percentile). Defaults to 90.
-
clip_percentile : Optional[float] = 10
Percentile value to use if clipping is needed when drawing the concept, e.g a value of 1 will perform a clipping between percentile 1 and 99.
This parameter allows to avoid outliers in case of too extreme values.
It is applied after the filter_percentile operation.
Default to 10.
-
alpha : float = 0.65
The alpha channel value for the heatmaps. Defaults to 0.65.
-
plot_kwargs : **plot_kwargs
Additional parameters passed to
plt.imshow()
.
plot_concept_attribution_maps(self,
images: numpy.ndarray,
importances: numpy.ndarray = None,
nb_most_important_concepts: int = 5,
filter_percentile: int = 90,
clip_percentile: Optional[float] = 10.0,
alpha: float = 0.65,
cols: int = 5,
img_size: float = 2.0,
**plot_kwargs)
¶
images: numpy.ndarray,
importances: numpy.ndarray = None,
nb_most_important_concepts: int = 5,
filter_percentile: int = 90,
clip_percentile: Optional[float] = 10.0,
alpha: float = 0.65,
cols: int = 5,
img_size: float = 2.0,
**plot_kwargs)
Display the concepts attribution maps for the images given in argument.
Parameters
-
images : numpy.ndarray
The images to display.
-
importances : numpy.ndarray = None
The importances computed by the estimate_importance() method.
If None is provided, then the global importances will be used, otherwise the local importances set in this parameter will be used.
-
nb_most_important_concepts : int = 5
The number of concepts to focus on. Default is 5.
-
filter_percentile : int = 90
Percentile used to filter the concept heatmap (only show concept if excess N-th percentile). Defaults to 90.
-
clip_percentile : Optional[float] = 10.0
Percentile value to use if clipping is needed when drawing the concept, e.g a value of 1 will perform a clipping between percentile 1 and 99.
This parameter allows to avoid outliers in case of too extreme values.
It is applied after the filter_percentile operation.
Default to 10.
-
alpha : float = 0.65
The alpha channel value for the heatmaps. Defaults to 0.65.
-
cols : int = 5
Number of columns. Default to 3.
-
img_size : float = 2.0
Size of each subplots (in inch), considering we keep aspect ratio.
-
plot_kwargs : **plot_kwargs
Additional parameters passed to
plt.imshow()
.
plot_concepts_crops(self,
nb_crops: int = 10,
nb_most_important_concepts: int = None,
verbose: bool = False) -> None
¶
nb_crops: int = 10,
nb_most_important_concepts: int = None,
verbose: bool = False) -> None
Display the crops for each concept.
Parameters
-
nb_crops : int = 10
The number of crops (patches) to display per concept. Defaults to 10.
-
nb_most_important_concepts : int = None
The number of concepts to display. If provided, only display nb_most_important_concepts, otherwise display them all.
Default is None.
-
verbose : bool = False
If True, then print the importance value of each concept, otherwise no textual output will be printed.
plot_concepts_importances(self,
importances: numpy.ndarray = None,
display_importance_order: xplique.concepts.craft.DisplayImportancesOrder = ,
nb_most_important_concepts: int = None,
verbose: bool = False)
¶
importances: numpy.ndarray = None,
display_importance_order: xplique.concepts.craft.DisplayImportancesOrder =
nb_most_important_concepts: int = None,
verbose: bool = False)
Plot a bar chart displaying the importance value of each concept.
Parameters
-
importances : numpy.ndarray = None
The importances computed by the estimate_importance() method.
Default is None, in this case the importances computed on the whole dataset will be used.
-
display_importance_order : 0>
Selects the order in which the concepts will be displayed, either following the global importance on the whole dataset (same order for all images) or the local importance of the concepts for a single image sample (local importance).
-
nb_most_important_concepts : int = None
The number of concepts to display. If None is provided, then all the concepts will be displayed unordered, otherwise only nb_most_important_concepts will be displayed, ordered by importance.
Default is None.
-
verbose : bool = False
If True, then print the importance value of each concept, otherwise no textual output will be printed.
plot_image_concepts(self,
img: numpy.ndarray,
display_importance_order: xplique.concepts.craft.DisplayImportancesOrder = ,
nb_most_important_concepts: int = 5,
filter_percentile: int = 90,
clip_percentile: Optional[float] = 10,
alpha: float = 0.65,
filepath: Optional[str] = None,
**plot_kwargs)
¶
img: numpy.ndarray,
display_importance_order: xplique.concepts.craft.DisplayImportancesOrder =
nb_most_important_concepts: int = 5,
filter_percentile: int = 90,
clip_percentile: Optional[float] = 10,
alpha: float = 0.65,
filepath: Optional[str] = None,
**plot_kwargs)
All in one method displaying several plots for the image id
given in argument:
- the concepts attribution map for this image
- the best crops for each concept (displayed around the heatmap)
- the importance of each concept
Parameters
-
img : numpy.ndarray
The image to display.
-
display_importance_order : 0>
Selects the order in which the concepts will be displayed, either following the global importance on the whole dataset (same order for all images) or the local importance of the concepts for a single image sample (local importance).
Default to GLOBAL.
-
nb_most_important_concepts : int = 5
The number of concepts to display. Default is 5.
-
filter_percentile : int = 90
Percentile used to filter the concept heatmap (only show concept if excess N-th percentile). Defaults to 90.
-
clip_percentile : Optional[float] = 10
Percentile value to use if clipping is needed when drawing the concept, e.g a value of 1 will perform a clipping between percentile 1 and 99.
This parameter allows to avoid outliers in case of too extreme values.
Default to 10.
-
alpha : float = 0.65
The alpha channel value for the heatmaps. Defaults to 0.65.
-
filepath : Optional[str] = None
Path the file will be saved at. If None, the function will call plt.show().
-
plot_kwargs : **plot_kwargs
Additional parameters passed to
plt.imshow()
.
transform(self,
inputs: numpy.ndarray,
activations: numpy.ndarray = None) -> numpy.ndarray
¶
inputs: numpy.ndarray,
activations: numpy.ndarray = None) -> numpy.ndarray
Transforms the inputs data into its concept representation.
Parameters
-
inputs : numpy.ndarray
The input data to be transformed.
-
activations : numpy.ndarray = None
Pre-computed activations of the input data. If not provided, the activations will be computed using the input_to_latent_model model on the inputs.
Return
-
coeffs_u : numpy.ndarray
The concepts' values of the inputs (U in the paper).
CraftManagerTf
¶
Class implementing the CraftManager on Tensorflow.
This manager creates one CraftTf instance per class to explain.
__init__(self,
input_to_latent_model: Callable,
latent_to_logit_model: Callable,
inputs: numpy.ndarray,
labels: numpy.ndarray,
list_of_class_of_interest: Optional[list] = None,
number_of_concepts: int = 20,
batch_size: int = 64,
patch_size: int = 64)
¶
input_to_latent_model: Callable,
latent_to_logit_model: Callable,
inputs: numpy.ndarray,
labels: numpy.ndarray,
list_of_class_of_interest: Optional[list] = None,
number_of_concepts: int = 20,
batch_size: int = 64,
patch_size: int = 64)
Parameters
-
input_to_latent_model : Callable
The first part of the model taking an input and returning positive activations, g(.) in the original paper.
Must return positive activations.
-
latent_to_logit_model : Callable
The second part of the model taking activation and returning logits, h(.) in the original paper.
-
inputs : numpy.ndarray
Input data of shape (n_samples, height, width, channels).
(x1, x2, ..., xn) in the paper.
-
labels : numpy.ndarray
Labels of the inputs of shape (n_samples, class_id)
-
list_of_class_of_interest : Optional[list] = None
A list of the classes id to explain. The manager will instanciate one CraftTf object per element of this list.
-
number_of_concepts : int = 20
The number of concepts to extract. Default is 20.
-
batch_size : int = 64
The batch size to use during training and prediction. Default is 64.
-
patch_size : int = 64
The size of the patches (crops) to extract from the input data. Default is 64.
compute_predictions(self)
¶
Compute the predictions for the current dataset, using the 2 models
input_to_latent_model and latent_to_logit_model chained.
Return
-
y_preds
the predictions
estimate_importance(self,
nb_design: int = 32,
verbose: bool = False)
¶
nb_design: int = 32,
verbose: bool = False)
Estimates the importance of each concept for all the classes of interest.
Parameters
-
nb_design : int = 32
The number of design to use for the importance estimation. Default is 32.
-
verbose : bool = False
If True, then print the current class CRAFT is estimating importances for, otherwise no textual output will be printed.
fit(self,
nb_samples_per_class: Optional[int] = None,
verbose: bool = False)
¶
nb_samples_per_class: Optional[int] = None,
verbose: bool = False)
Fit the Craft models on their respective class of interest.
Parameters
-
nb_samples_per_class : Optional[int] = None
Number of samples to use to fit the Craft model.
Default is None, which means that all the samples will be used.
-
verbose : bool = False
If True, then print the current class CRAFT is fitting, otherwise no textual output will be printed.
plot_concepts_crops(self,
class_id: int,
nb_crops: int = 10,
nb_most_important_concepts: int = None)
¶
class_id: int,
nb_crops: int = 10,
nb_most_important_concepts: int = None)
Display the crops for each concept.
Parameters
-
class_id : int
The class to explain.
-
nb_crops : int = 10
The number of crops (patches) to display per concept. Defaults to 10.
-
nb_most_important_concepts : int = None
The number of concepts to display. If provided, only display nb_most_important_concepts, otherwise display them all.
Default is None.
plot_concepts_importances(self,
class_id: int,
nb_most_important_concepts: int = 5,
verbose: bool = False)
¶
class_id: int,
nb_most_important_concepts: int = 5,
verbose: bool = False)
Plot a bar chart displaying the importance value of each concept.
Parameters
-
class_id : int
The class to explain.
-
nb_most_important_concepts : int = 5
The number of concepts to focus on. Default is 5.
-
verbose : bool = False
If True, then print the importance value of each concept, otherwise no textual output will be printed.
plot_image_concepts(self,
img: numpy.ndarray,
class_id: int,
display_importance_order: xplique.concepts.craft.DisplayImportancesOrder = ,
nb_most_important_concepts: int = 5,
filter_percentile: int = 90,
clip_percentile: Optional[float] = 10,
alpha: float = 0.65,
filepath: Optional[str] = None)
¶
img: numpy.ndarray,
class_id: int,
display_importance_order: xplique.concepts.craft.DisplayImportancesOrder =
nb_most_important_concepts: int = 5,
filter_percentile: int = 90,
clip_percentile: Optional[float] = 10,
alpha: float = 0.65,
filepath: Optional[str] = None)
All in one method displaying several plots for the image id
given in argument:
- the concepts attribution map for this image
- the best crops for each concept (displayed around the heatmap)
- the importance of each concept
Parameters
-
img : numpy.ndarray
The image to explain.
-
class_id : int
The class to explain.
-
display_importance_order : 0>
Selects the order in which the concepts will be displayed, either following the global importance on the whole dataset (same order for all images) or the local importance of the concepts for a single image sample (local importance).
Default to GLOBAL.
-
nb_most_important_concepts : int = 5
The number of concepts to focus on. Default is 5.
-
filter_percentile : int = 90
Percentile used to filter the concept heatmap (only show concept if excess N-th percentile). Defaults to 90.
-
clip_percentile : Optional[float] = 10
Percentile value to use if clipping is needed when drawing the concept, e.g a value of 1 will perform a clipping between percentile 1 and 99.
This parameter allows to avoid outliers in case of too extreme values.
Default to 10.
-
alpha : float = 0.65
The alpha channel value for the heatmaps. Defaults to 0.65.
-
filepath : Optional[str] = None
Path the file will be saved at. If None, the function will call plt.show().