📞 Callable or Models handle by BlackBox Attribution methods¶
The model can be something else than a tf.keras.Model
if it respects one of the following condition:
model(inputs: np.ndarray)
return either anp.ndarray
or atf.Tensor
- The model has a
scikit-learn
API and has apredict_proba
function - The model is a
xgboost.XGBModel
from the XGBoost python library - The model is a TF Lite model. Note this feature is experimental.
- The model is a PyTorch model (see the dedicated documentation)
In fact, what happens when a custom operator
is not provided (see operator's documentation) and model
(see model's documentation) is not a tf.keras.Model
, a tf.Module
or a tf.keras.layers.Layer
is that the predictions_one_hot_callable
operator is used:
def predictions_one_hot_callable(
model: Callable,
inputs: tf.Tensor,
targets: tf.Tensor) -> tf.Tensor:
"""
Compute predictions scores, only for the label class, for a batch of samples.
Parameters
----------
model
Model used for computing predictions.
inputs
Input samples to be explained.
targets
One-hot encoded labels or regression target (e.g {+1, -1}), one for each sample.
Returns
-------
scores
Predictions scores computed, only for the label class.
"""
if isinstance(model, tf.lite.Interpreter):
model.resize_tensor_input(0, [*inputs.shape], strict=False)
model.allocate_tensors()
model.set_tensor(model.get_input_details()[0]["index"], inputs)
model.invoke()
pred = model.get_tensor(model.get_output_details()[0]["index"])
# can be a sklearn model or xgboost model
elif hasattr(model, 'predict_proba'):
pred = model.predict_proba(inputs.numpy())
# can be another model thus it needs to implement a call function
else:
pred = model(inputs.numpy())
# make sure that the prediction shape is coherent
if inputs.shape[0] != 1:
# a batch of prediction is required
if len(pred.shape) == 1:
# The prediction dimension disappeared
pred = tf.expand_dims(pred, axis=1)
pred = tf.cast(pred, dtype=tf.float32)
scores = tf.reduce_sum(pred * targets, axis=-1)
return scores
Knowing that, you are free to wrap your model to make it work with our API and/or write a more customizable operator
(see operator's documentation)!