📞 Callable or Models handle by BlackBox Attribution methods

The model can be something else than a tf.keras.Model if it respects one of the following condition:

  • model(inputs: np.ndarray) return either a np.ndarray or a tf.Tensor
  • The model has a scikit-learn API and has a predict_proba function
  • The model is a xgboost.XGBModel from the XGBoost python library
  • The model is a TF Lite model. Note this feature is experimental.
  • The model is a PyTorch model (see the dedicated documentation)

In fact, what happens when a custom operator is not provided (see operator's documentation) and model (see model's documentation) is not a tf.keras.Model, a tf.Module or a tf.keras.layers.Layer is that the predictions_one_hot_callable operator is used:

def predictions_one_hot_callable(
    model: Callable,
    inputs: tf.Tensor,
    targets: tf.Tensor) -> tf.Tensor:
    Compute predictions scores, only for the label class, for a batch of samples.

        Model used for computing predictions.
        Input samples to be explained.
        One-hot encoded labels or regression target (e.g {+1, -1}), one for each sample.

        Predictions scores computed, only for the label class.
    if isinstance(model, tf.lite.Interpreter):

        model.resize_tensor_input(0, [*inputs.shape], strict=False)
        model.set_tensor(model.get_input_details()[0]["index"], inputs)
        pred = model.get_tensor(model.get_output_details()[0]["index"])

    # can be a sklearn model or xgboost model
    elif hasattr(model, 'predict_proba'):
        pred = model.predict_proba(inputs.numpy())

    # can be another model thus it needs to implement a call function
        pred = model(inputs.numpy())

    # make sure that the prediction shape is coherent
    if inputs.shape[0] != 1:
        # a batch of prediction is required
        if len(pred.shape) == 1:
            # The prediction dimension disappeared
            pred = tf.expand_dims(pred, axis=1)

    pred = tf.cast(pred, dtype=tf.float32)
    scores = tf.reduce_sum(pred * targets, axis=-1)

    return scores

Knowing that, you are free to wrap your model to make it work with our API and/or write a more customizable operator(see operator's documentation)!