Skip to content

KerasFeatureExtractor

KerasFeatureExtractor

Bases: FeatureExtractor

Feature extractor based on "model" to construct a feature space on which OOD detection is performed. The features can be the output activation values of internal model layers, or the output of the model (softmax/logits).

Parameters:

Name Type Description Default
model Callable

model to extract the features from

required
feature_layers_id List[Union[int, str]]

list of str or int that identify features to output. If int, the rank of the layer in the layer list If str, the name of the layer. Defaults to [].

[-1]
input_layer_id Optional[Union[int, str]]

input layer of the feature extractor (to avoid useless forwards when working on the feature space without finetuning the bottom of the model). Defaults to None.

None
react_threshold Optional[float]

if not None, penultimate layer activations are clipped under this threshold value (useful for ReAct). Defaults to None.

None
scale_percentile Optional[float]

if not None, the features are scaled following the method of Xu et al., ICLR 2024. Defaults to None.

None
ash_percentile Optional[float]

if not None, the features are scaled following the method of Djurisic et al., ICLR 2023.

None
Source code in oodeel/extractor/keras_feature_extractor.py
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
class KerasFeatureExtractor(FeatureExtractor):
    """
    Feature extractor based on "model" to construct a feature space
    on which OOD detection is performed. The features can be the output
    activation values of internal model layers, or the output of the model
    (softmax/logits).

    Args:
        model: model to extract the features from
        feature_layers_id: list of str or int that identify features to output.
            If int, the rank of the layer in the layer list
            If str, the name of the layer. Defaults to [].
        input_layer_id: input layer of the feature extractor (to avoid useless forwards
            when working on the feature space without finetuning the bottom of the
            model).
            Defaults to None.
        react_threshold: if not None, penultimate layer activations are clipped under
            this threshold value (useful for ReAct). Defaults to None.
        scale_percentile: if not None, the features are scaled
            following the method of Xu et al., ICLR 2024.
            Defaults to None.
        ash_percentile: if not None, the features are scaled following
            the method of Djurisic et al., ICLR 2023.
    """

    def __init__(
        self,
        model: Callable,
        feature_layers_id: List[Union[int, str]] = [-1],
        input_layer_id: Optional[Union[int, str]] = None,
        react_threshold: Optional[float] = None,
        scale_percentile: Optional[float] = None,
        ash_percentile: Optional[float] = None,
    ):
        if input_layer_id is None:
            input_layer_id = 0
        super().__init__(
            model=model,
            feature_layers_id=feature_layers_id,
            input_layer_id=input_layer_id,
            react_threshold=react_threshold,
            scale_percentile=scale_percentile,
            ash_percentile=ash_percentile,
        )

        self.backend = "tensorflow"
        self.model.layers[-1].activation = getattr(tf.keras.activations, "linear")
        self._last_logits = None

    @staticmethod
    def find_layer(
        model: Callable,
        layer_id: Union[str, int],
        index_offset: int = 0,
        return_id: bool = False,
    ) -> Union[tf.keras.layers.Layer, Tuple[tf.keras.layers.Layer, str]]:
        """Find a layer in a model either by his name or by his index.

        Args:
            model (Callable): model whose identified layer will be returned
            layer_id (Union[str, int]): layer identifier
            index_offset (int): index offset to find layers located before (negative
                offset) or after (positive offset) the identified layer
            return_id (bool): if True, the layer will be returned with its id

        Raises:
            ValueError: if the layer is not found

        Returns:
            Union[tf.keras.layers.Layer, Tuple[tf.keras.layers.Layer, str]]:
                the corresponding layer and its id if return_id is True.
        """
        if isinstance(layer_id, str):
            layers_names = [layer.name for layer in model.layers]
            layer_id = layers_names.index(layer_id)
        if isinstance(layer_id, int):
            layer_id += index_offset
            layer = model.get_layer(index=layer_id)
        else:
            raise ValueError(f"Could not find any layer {layer_id}.")

        if return_id:
            return layer, layer_id
        else:
            return layer

    # @tf.function
    # TODO check with Thomas about @tf.function
    def prepare_extractor(self) -> tf.keras.models.Model:
        """Constructs the feature extractor model

        Returns:
            tf.keras.models.Model: truncated model (extractor)
        """
        input_layer = self.find_layer(self.model, self.input_layer_id)
        new_input = tf.keras.layers.Input(tensor=input_layer.input)
        output_tensors = [
            self.find_layer(self.model, id).output for id in self.feature_layers_id
        ]

        # === If react method, clip activations from penultimate layer ===
        if self.react_threshold is not None:
            penultimate_layer = self.find_layer(self.model, -2)
            penult_extractor = tf.keras.models.Model(
                new_input, penultimate_layer.output
            )
            last_layer = self.find_layer(self.model, -1)

            # clip penultimate activations
            x = tf.clip_by_value(
                penult_extractor(new_input),
                clip_value_min=tf.float32.min,
                clip_value_max=self.react_threshold,
            )
            # apply ultimate layer on clipped activations
            output_tensors.append(last_layer(x))

        # === If SCALE method, scale activations from penultimate layer ===
        # === If ASH method, scale and prune activations from penultimate layer ===
        elif (self.scale_percentile is not None) or (self.ash_percentile is not None):
            penultimate_layer = self.find_layer(self.model, -2)
            penult_extractor = tf.keras.models.Model(
                new_input, penultimate_layer.output
            )
            last_layer = self.find_layer(self.model, -1)

            # apply scaling on penultimate activations
            penultimate = penult_extractor(new_input)
            if self.scale_percentile is not None:
                output_percentile = tfp.stats.percentile(
                    penultimate, 100 * self.scale_percentile, axis=1
                )
            else:
                output_percentile = tfp.stats.percentile(
                    penultimate, 100 * self.ash_percentile, axis=1
                )

            mask = penultimate > tf.reshape(output_percentile, (-1, 1))
            filtered_penultimate = tf.where(
                mask, penultimate, tf.zeros_like(penultimate)
            )
            s = tf.math.exp(
                tf.reduce_sum(penultimate, axis=1)
                / tf.reduce_sum(filtered_penultimate, axis=1)
            )

            if self.scale_percentile is not None:
                x = penultimate * tf.expand_dims(s, 1)
            else:
                x = filtered_penultimate * tf.expand_dims(s, 1)
            # apply ultimate layer on scaled activations
            output_tensors.append(last_layer(x))

        else:
            output_tensors.append(self.find_layer(self.model, -1).output)

        extractor = tf.keras.models.Model(new_input, output_tensors)
        return extractor

    @sanitize_input
    def predict_tensor(
        self,
        tensor: TensorType,
        postproc_fns: Optional[List[Callable]] = None,
    ) -> Tuple[List[tf.Tensor], tf.Tensor]:
        """Get the projection of tensor in the feature space of self.model

        Args:
            tensor (TensorType): input tensor (or dataset elem)
            postproc_fns (Optional[List[Callable]]): postprocessing function to apply to
                each feature immediately after forward. Default to None.

        Returns:
            Tuple[List[tf.Tensor], tf.Tensor]: features, logits
        """
        features = self.forward(tensor)

        if type(features) is not list:
            features = [features]

        # split features and logits
        logits = features.pop()

        if postproc_fns is not None:
            features = [
                postproc_fn(feature)
                for feature, postproc_fn in zip(features, postproc_fns)
            ]

        self._last_logits = logits
        return features, logits

    @tf.function
    def forward(self, tensor: TensorType) -> List[tf.Tensor]:
        return self.extractor(tensor, training=False)

    def predict(
        self,
        dataset: Union[ItemType, tf.data.Dataset],
        postproc_fns: Optional[List[Callable]] = None,
        verbose: bool = False,
        **kwargs,
    ) -> Tuple[List[tf.Tensor], dict]:
        """Get the projection of the dataset in the feature space of self.model

        Args:
            dataset (Union[ItemType, tf.data.Dataset]): input dataset
            postproc_fns (Optional[Callable]): postprocessing function to apply to each
                feature immediately after forward. Default to None.
            verbose (bool): if True, display a progress bar. Defaults to False.
            kwargs (dict): additional arguments not considered for prediction

        Returns:
            List[tf.Tensor], dict: features and extra information (logits, labels) as a
                dictionary.
        """
        labels = None

        if isinstance(dataset, get_args(ItemType)):
            tensor = TFDataHandler.get_input_from_dataset_item(dataset)
            features, logits = self.predict_tensor(tensor, postproc_fns)

            # Get labels if dataset is a tuple/list
            if isinstance(dataset, (list, tuple)):
                labels = TFDataHandler.get_label_from_dataset_item(dataset)

        else:  # if dataset is a tf.data.Dataset
            features = [None for i in range(len(self.feature_layers_id))]
            logits = None
            contains_labels = TFDataHandler.get_item_length(dataset) > 1
            for elem in tqdm(dataset, desc="Predicting", disable=not verbose):
                tensor = TFDataHandler.get_input_from_dataset_item(elem)
                features_batch, logits_batch = self.predict_tensor(tensor, postproc_fns)

                for i, f in enumerate(features_batch):
                    features[i] = (
                        f
                        if features[i] is None
                        else tf.concat([features[i], f], axis=0)
                    )
                # concatenate logits
                logits = (
                    logits_batch
                    if logits is None
                    else tf.concat([logits, logits_batch], axis=0)
                )
                # concatenate labels of current batch with previous batches
                if contains_labels:
                    lbl_batch = TFDataHandler.get_label_from_dataset_item(elem)

                    if labels is None:
                        labels = lbl_batch
                    else:
                        labels = tf.concat([labels, lbl_batch], axis=0)

        # store extra information in a dict
        info = dict(labels=labels, logits=logits)
        return features, info

    def get_weights(self, layer_id: Union[int, str]) -> List[tf.Tensor]:
        """Get the weights of a layer

        Args:
            layer_id (Union[int, str]): layer identifier

        Returns:
            List[tf.Tensor]: weights and biases matrixes
        """
        return self.find_layer(self.model, layer_id).get_weights()

find_layer(model, layer_id, index_offset=0, return_id=False) staticmethod

Find a layer in a model either by his name or by his index.

Parameters:

Name Type Description Default
model Callable

model whose identified layer will be returned

required
layer_id Union[str, int]

layer identifier

required
index_offset int

index offset to find layers located before (negative offset) or after (positive offset) the identified layer

0
return_id bool

if True, the layer will be returned with its id

False

Raises:

Type Description
ValueError

if the layer is not found

Returns:

Type Description
Union[tf.keras.layers.Layer, Tuple[tf.keras.layers.Layer, str]]

Union[tf.keras.layers.Layer, Tuple[tf.keras.layers.Layer, str]]: the corresponding layer and its id if return_id is True.

Source code in oodeel/extractor/keras_feature_extractor.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
@staticmethod
def find_layer(
    model: Callable,
    layer_id: Union[str, int],
    index_offset: int = 0,
    return_id: bool = False,
) -> Union[tf.keras.layers.Layer, Tuple[tf.keras.layers.Layer, str]]:
    """Find a layer in a model either by his name or by his index.

    Args:
        model (Callable): model whose identified layer will be returned
        layer_id (Union[str, int]): layer identifier
        index_offset (int): index offset to find layers located before (negative
            offset) or after (positive offset) the identified layer
        return_id (bool): if True, the layer will be returned with its id

    Raises:
        ValueError: if the layer is not found

    Returns:
        Union[tf.keras.layers.Layer, Tuple[tf.keras.layers.Layer, str]]:
            the corresponding layer and its id if return_id is True.
    """
    if isinstance(layer_id, str):
        layers_names = [layer.name for layer in model.layers]
        layer_id = layers_names.index(layer_id)
    if isinstance(layer_id, int):
        layer_id += index_offset
        layer = model.get_layer(index=layer_id)
    else:
        raise ValueError(f"Could not find any layer {layer_id}.")

    if return_id:
        return layer, layer_id
    else:
        return layer

get_weights(layer_id)

Get the weights of a layer

Parameters:

Name Type Description Default
layer_id Union[int, str]

layer identifier

required

Returns:

Type Description
List[tf.Tensor]

List[tf.Tensor]: weights and biases matrixes

Source code in oodeel/extractor/keras_feature_extractor.py
300
301
302
303
304
305
306
307
308
309
def get_weights(self, layer_id: Union[int, str]) -> List[tf.Tensor]:
    """Get the weights of a layer

    Args:
        layer_id (Union[int, str]): layer identifier

    Returns:
        List[tf.Tensor]: weights and biases matrixes
    """
    return self.find_layer(self.model, layer_id).get_weights()

predict(dataset, postproc_fns=None, verbose=False, **kwargs)

Get the projection of the dataset in the feature space of self.model

Parameters:

Name Type Description Default
dataset Union[ItemType, tf.data.Dataset]

input dataset

required
postproc_fns Optional[Callable]

postprocessing function to apply to each feature immediately after forward. Default to None.

None
verbose bool

if True, display a progress bar. Defaults to False.

False
kwargs dict

additional arguments not considered for prediction

{}

Returns:

Type Description
Tuple[List[tf.Tensor], dict]

List[tf.Tensor], dict: features and extra information (logits, labels) as a dictionary.

Source code in oodeel/extractor/keras_feature_extractor.py
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
def predict(
    self,
    dataset: Union[ItemType, tf.data.Dataset],
    postproc_fns: Optional[List[Callable]] = None,
    verbose: bool = False,
    **kwargs,
) -> Tuple[List[tf.Tensor], dict]:
    """Get the projection of the dataset in the feature space of self.model

    Args:
        dataset (Union[ItemType, tf.data.Dataset]): input dataset
        postproc_fns (Optional[Callable]): postprocessing function to apply to each
            feature immediately after forward. Default to None.
        verbose (bool): if True, display a progress bar. Defaults to False.
        kwargs (dict): additional arguments not considered for prediction

    Returns:
        List[tf.Tensor], dict: features and extra information (logits, labels) as a
            dictionary.
    """
    labels = None

    if isinstance(dataset, get_args(ItemType)):
        tensor = TFDataHandler.get_input_from_dataset_item(dataset)
        features, logits = self.predict_tensor(tensor, postproc_fns)

        # Get labels if dataset is a tuple/list
        if isinstance(dataset, (list, tuple)):
            labels = TFDataHandler.get_label_from_dataset_item(dataset)

    else:  # if dataset is a tf.data.Dataset
        features = [None for i in range(len(self.feature_layers_id))]
        logits = None
        contains_labels = TFDataHandler.get_item_length(dataset) > 1
        for elem in tqdm(dataset, desc="Predicting", disable=not verbose):
            tensor = TFDataHandler.get_input_from_dataset_item(elem)
            features_batch, logits_batch = self.predict_tensor(tensor, postproc_fns)

            for i, f in enumerate(features_batch):
                features[i] = (
                    f
                    if features[i] is None
                    else tf.concat([features[i], f], axis=0)
                )
            # concatenate logits
            logits = (
                logits_batch
                if logits is None
                else tf.concat([logits, logits_batch], axis=0)
            )
            # concatenate labels of current batch with previous batches
            if contains_labels:
                lbl_batch = TFDataHandler.get_label_from_dataset_item(elem)

                if labels is None:
                    labels = lbl_batch
                else:
                    labels = tf.concat([labels, lbl_batch], axis=0)

    # store extra information in a dict
    info = dict(labels=labels, logits=logits)
    return features, info

predict_tensor(tensor, postproc_fns=None)

Get the projection of tensor in the feature space of self.model

Parameters:

Name Type Description Default
tensor TensorType

input tensor (or dataset elem)

required
postproc_fns Optional[List[Callable]]

postprocessing function to apply to each feature immediately after forward. Default to None.

None

Returns:

Type Description
Tuple[List[tf.Tensor], tf.Tensor]

Tuple[List[tf.Tensor], tf.Tensor]: features, logits

Source code in oodeel/extractor/keras_feature_extractor.py
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
@sanitize_input
def predict_tensor(
    self,
    tensor: TensorType,
    postproc_fns: Optional[List[Callable]] = None,
) -> Tuple[List[tf.Tensor], tf.Tensor]:
    """Get the projection of tensor in the feature space of self.model

    Args:
        tensor (TensorType): input tensor (or dataset elem)
        postproc_fns (Optional[List[Callable]]): postprocessing function to apply to
            each feature immediately after forward. Default to None.

    Returns:
        Tuple[List[tf.Tensor], tf.Tensor]: features, logits
    """
    features = self.forward(tensor)

    if type(features) is not list:
        features = [features]

    # split features and logits
    logits = features.pop()

    if postproc_fns is not None:
        features = [
            postproc_fn(feature)
            for feature, postproc_fn in zip(features, postproc_fns)
        ]

    self._last_logits = logits
    return features, logits

prepare_extractor()

Constructs the feature extractor model

Returns:

Type Description
tf.keras.models.Model

tf.keras.models.Model: truncated model (extractor)

Source code in oodeel/extractor/keras_feature_extractor.py
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
def prepare_extractor(self) -> tf.keras.models.Model:
    """Constructs the feature extractor model

    Returns:
        tf.keras.models.Model: truncated model (extractor)
    """
    input_layer = self.find_layer(self.model, self.input_layer_id)
    new_input = tf.keras.layers.Input(tensor=input_layer.input)
    output_tensors = [
        self.find_layer(self.model, id).output for id in self.feature_layers_id
    ]

    # === If react method, clip activations from penultimate layer ===
    if self.react_threshold is not None:
        penultimate_layer = self.find_layer(self.model, -2)
        penult_extractor = tf.keras.models.Model(
            new_input, penultimate_layer.output
        )
        last_layer = self.find_layer(self.model, -1)

        # clip penultimate activations
        x = tf.clip_by_value(
            penult_extractor(new_input),
            clip_value_min=tf.float32.min,
            clip_value_max=self.react_threshold,
        )
        # apply ultimate layer on clipped activations
        output_tensors.append(last_layer(x))

    # === If SCALE method, scale activations from penultimate layer ===
    # === If ASH method, scale and prune activations from penultimate layer ===
    elif (self.scale_percentile is not None) or (self.ash_percentile is not None):
        penultimate_layer = self.find_layer(self.model, -2)
        penult_extractor = tf.keras.models.Model(
            new_input, penultimate_layer.output
        )
        last_layer = self.find_layer(self.model, -1)

        # apply scaling on penultimate activations
        penultimate = penult_extractor(new_input)
        if self.scale_percentile is not None:
            output_percentile = tfp.stats.percentile(
                penultimate, 100 * self.scale_percentile, axis=1
            )
        else:
            output_percentile = tfp.stats.percentile(
                penultimate, 100 * self.ash_percentile, axis=1
            )

        mask = penultimate > tf.reshape(output_percentile, (-1, 1))
        filtered_penultimate = tf.where(
            mask, penultimate, tf.zeros_like(penultimate)
        )
        s = tf.math.exp(
            tf.reduce_sum(penultimate, axis=1)
            / tf.reduce_sum(filtered_penultimate, axis=1)
        )

        if self.scale_percentile is not None:
            x = penultimate * tf.expand_dims(s, 1)
        else:
            x = filtered_penultimate * tf.expand_dims(s, 1)
        # apply ultimate layer on scaled activations
        output_tensors.append(last_layer(x))

    else:
        output_tensors.append(self.find_layer(self.model, -1).output)

    extractor = tf.keras.models.Model(new_input, output_tensors)
    return extractor