Skip to content

TorchFeatureExtractor

TorchFeatureExtractor

Bases: FeatureExtractor

Feature extractor based on "model" to construct a feature space on which OOD detection is performed. The features can be the output activation values of internal model layers, or the output of the model (softmax/logits).

Parameters:

Name Type Description Default
model Module

model to extract the features from

required
feature_layers_id List[Union[int, str]]

list of str or int that identify features to output. If int, the rank of the layer in the layer list If str, the name of the layer. Defaults to [].

[]
input_layer_id Optional[Union[int, str]]

input layer of the feature extractor (to avoid useless forwards when working on the feature space without finetuning the bottom of the model). Defaults to None.

None
react_threshold Optional[float]

if not None, penultimate layer activations are clipped under this threshold value (useful for ReAct). Defaults to None.

None
Source code in oodeel/extractor/torch_feature_extractor.py
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
class TorchFeatureExtractor(FeatureExtractor):
    """
    Feature extractor based on "model" to construct a feature space
    on which OOD detection is performed. The features can be the output
    activation values of internal model layers,
    or the output of the model (softmax/logits).

    Args:
        model: model to extract the features from
        feature_layers_id: list of str or int that identify features to output.
            If int, the rank of the layer in the layer list
            If str, the name of the layer. Defaults to [].
        input_layer_id: input layer of the feature extractor (to avoid useless forwards
            when working on the feature space without finetuning the bottom of
            the model).
            Defaults to None.
        react_threshold: if not None, penultimate layer activations are clipped under
            this threshold value (useful for ReAct). Defaults to None.
    """

    def __init__(
        self,
        model: nn.Module,
        feature_layers_id: List[Union[int, str]] = [],
        input_layer_id: Optional[Union[int, str]] = None,
        react_threshold: Optional[float] = None,
    ):
        model = model.eval()
        super().__init__(
            model=model,
            feature_layers_id=feature_layers_id,
            input_layer_id=input_layer_id,
            react_threshold=react_threshold,
        )
        self._device = next(model.parameters()).device
        self._features = {layer: torch.empty(0) for layer in self._hook_layers_id}
        self._last_logits = None
        self.backend = "torch"

    @property
    def _hook_layers_id(self):
        return self.feature_layers_id + [-1]

    def _get_features_hook(self, layer_id: Union[str, int]) -> Callable:
        """
        Hook that stores features corresponding to a specific layer
        in a class dictionary.

        Args:
            layer_id (Union[str, int]): layer identifier

        Returns:
            Callable: hook function
        """

        def hook(_, __, output):
            if isinstance(output, torch.Tensor):
                self._features[layer_id] = output
            else:
                raise NotImplementedError

        return hook

    @staticmethod
    def find_layer(
        model: nn.Module,
        layer_id: Union[str, int],
        index_offset: int = 0,
        return_id: bool = False,
    ) -> Union[nn.Module, Tuple[nn.Module, str]]:
        """Find a layer in a model either by his name or by his index.

        Args:
            model (nn.Module): model whose identified layer will be returned
            layer_id (Union[str, int]): layer identifier
            index_offset (int): index offset to find layers located before (negative
                offset) or after (positive offset) the identified layer
            return_id (bool): if True, the layer will be returned with its id

        Returns:
            Union[nn.Module, Tuple[nn.Module, str]]: the corresponding layer and its id
                if return_id is True.
        """
        if isinstance(layer_id, int):
            layer_id += index_offset
            if isinstance(model, nn.Sequential):
                layer = model[layer_id]
            else:
                layer = list(model.named_modules())[layer_id][1]
        else:
            layer_id = list(dict(model.named_modules()).keys()).index(layer_id)
            layer_id += index_offset
            layer = list(model.named_modules())[layer_id][1]

        if return_id:
            return layer, layer_id
        else:
            return layer

    def prepare_extractor(self) -> None:
        """Prepare the feature extractor by adding hooks to self.model"""
        # remove forward hooks attached to the model
        self._clean_forward_hooks()

        # === If react method, clip activations from penultimate layer ===
        if self.react_threshold is not None:
            pen_layer = self.find_layer(self.model, -2)
            pen_layer.register_forward_hook(self._get_clip_hook(self.react_threshold))

        # Register a hook to store feature values for each considered layer + last layer
        for layer_id in self._hook_layers_id:
            layer = self.find_layer(self.model, layer_id)
            layer.register_forward_hook(self._get_features_hook(layer_id))

        # Crop model if input layer is provided
        if not (self.input_layer_id) is None:
            if isinstance(self.input_layer_id, int):
                if isinstance(self.model, nn.Sequential):
                    self.model = nn.Sequential(
                        *list(self.model.modules())[self.input_layer_id :]
                    )
                else:
                    raise NotImplementedError
            elif isinstance(self.input_layer_id, str):
                if isinstance(self.model, nn.Sequential):
                    module_names = list(
                        filter(
                            lambda x: x != "",
                            map(lambda x: x[0], self.model.named_modules()),
                        )
                    )
                    input_module_idx = module_names.index(self.input_layer_id)
                    self.model = nn.Sequential(
                        *list(self.model.modules())[(input_module_idx + 1) :]
                    )
                else:
                    raise NotImplementedError
            else:
                raise NotImplementedError

    @sanitize_input
    def predict_tensor(
        self,
        x: TensorType,
        postproc_fns: Optional[List[Callable]] = None,
        detach: bool = True,
    ) -> Tuple[List[torch.Tensor], torch.Tensor]:
        """Get the projection of tensor in the feature space of self.model

        Args:
            x (TensorType): input tensor (or dataset elem)
            postproc_fns (Optional[List[Callable]]): postprocessing function to apply to
                each feature immediately after forward. Default to None.
            detach (bool): if True, return features detached from the computational
                graph. Defaults to True.

        Returns:
            List[torch.Tensor], torch.Tensor: features, logits
        """
        if x.device != self._device:
            x = x.to(self._device)
        _ = self.model(x)

        if detach:
            features = [
                self._features[layer_id].detach() for layer_id in self._hook_layers_id
            ]
        else:
            features = [self._features[layer_id] for layer_id in self._hook_layers_id]

        # split features and logits
        logits = features.pop()

        if postproc_fns is not None:
            features = [
                postproc_fn(feature)
                for feature, postproc_fn in zip(features, postproc_fns)
            ]

        self._last_logits = logits
        return features, logits

    def predict(
        self,
        dataset: Union[DataLoader, ItemType],
        postproc_fns: Optional[List[Callable]] = None,
        detach: bool = True,
        **kwargs,
    ) -> Tuple[List[torch.Tensor], dict]:
        """Get the projection of the dataset in the feature space of self.model

        Args:
            dataset (Union[DataLoader, ItemType]): input dataset
            postproc_fns (Optional[List[Callable]]): postprocessing function to apply to
                each feature immediately after forward. Default to None.
            detach (bool): if True, return features detached from the computational
                graph. Defaults to True.
            kwargs (dict): additional arguments not considered for prediction

        Returns:
            List[torch.Tensor], dict: features and extra information (logits, labels) as
                a dictionary.
        """
        labels = None

        if isinstance(dataset, get_args(ItemType)):
            tensor = TorchDataHandler.get_input_from_dataset_item(dataset)
            features, logits = self.predict_tensor(tensor, postproc_fns, detach=detach)

            # Get labels if dataset is a tuple/list
            if isinstance(dataset, (list, tuple)) and len(dataset) > 1:
                labels = TorchDataHandler.get_label_from_dataset_item(dataset)

        else:
            features = [None for i in range(len(self.feature_layers_id))]
            logits = None
            batch = next(iter(dataset))
            contains_labels = isinstance(batch, (list, tuple)) and len(batch) > 1
            for elem in dataset:
                tensor = TorchDataHandler.get_input_from_dataset_item(elem)
                features_batch, logits_batch = self.predict_tensor(
                    tensor, postproc_fns, detach=detach
                )
                for i, f in enumerate(features_batch):
                    features[i] = (
                        f if features[i] is None else torch.cat([features[i], f], dim=0)
                    )
                # concatenate logits
                logits = (
                    logits_batch
                    if logits is None
                    else torch.cat([logits, logits_batch], axis=0)
                )
                # concatenate labels of current batch with previous batches
                if contains_labels:
                    lbl_batch = TorchDataHandler.get_label_from_dataset_item(elem)

                    if labels is None:
                        labels = lbl_batch
                    else:
                        labels = torch.cat([labels, lbl_batch], dim=0)

        # store extra information in a dict
        info = dict(labels=labels, logits=logits)
        return features, info

    def get_weights(self, layer_id: Union[str, int]) -> List[torch.Tensor]:
        """Get the weights of a layer

        Args:
            layer_id (Union[int, str]): layer identifier

        Returns:
            List[torch.Tensor]: weights and biases matrixes
        """
        layer = self.find_layer(self.model, layer_id)
        return [layer.weight.detach().cpu().numpy(), layer.bias.detach().cpu().numpy()]

    def _get_clip_hook(self, threshold: float) -> Callable:
        """
        Hook that truncate activation features under a threshold value

        Args:
            threshold (float): threshold value

        Returns:
            Callable: hook function
        """

        def hook(_, __, output):
            output = torch.clip(output, max=threshold)
            return output

        return hook

    def _clean_forward_hooks(self) -> None:
        """
        Remove all the forward hook attached to the model's layers. This function should
        be called at the __init__, and prevent from accumulating the hooks when
        defining a new TorchFeatureExtractor for the same model.
        """

        def __clean_hooks(m: nn.Module):
            for _, child in m._modules.items():
                if child is not None:
                    if hasattr(child, "_forward_hooks"):
                        child._forward_hooks = OrderedDict()
                    __clean_hooks(child)

        return __clean_hooks(self.model)

_clean_forward_hooks()

Remove all the forward hook attached to the model's layers. This function should be called at the init, and prevent from accumulating the hooks when defining a new TorchFeatureExtractor for the same model.

Source code in oodeel/extractor/torch_feature_extractor.py
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
def _clean_forward_hooks(self) -> None:
    """
    Remove all the forward hook attached to the model's layers. This function should
    be called at the __init__, and prevent from accumulating the hooks when
    defining a new TorchFeatureExtractor for the same model.
    """

    def __clean_hooks(m: nn.Module):
        for _, child in m._modules.items():
            if child is not None:
                if hasattr(child, "_forward_hooks"):
                    child._forward_hooks = OrderedDict()
                __clean_hooks(child)

    return __clean_hooks(self.model)

_get_clip_hook(threshold)

Hook that truncate activation features under a threshold value

Parameters:

Name Type Description Default
threshold float

threshold value

required

Returns:

Name Type Description
Callable Callable

hook function

Source code in oodeel/extractor/torch_feature_extractor.py
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
def _get_clip_hook(self, threshold: float) -> Callable:
    """
    Hook that truncate activation features under a threshold value

    Args:
        threshold (float): threshold value

    Returns:
        Callable: hook function
    """

    def hook(_, __, output):
        output = torch.clip(output, max=threshold)
        return output

    return hook

_get_features_hook(layer_id)

Hook that stores features corresponding to a specific layer in a class dictionary.

Parameters:

Name Type Description Default
layer_id Union[str, int]

layer identifier

required

Returns:

Name Type Description
Callable Callable

hook function

Source code in oodeel/extractor/torch_feature_extractor.py
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
def _get_features_hook(self, layer_id: Union[str, int]) -> Callable:
    """
    Hook that stores features corresponding to a specific layer
    in a class dictionary.

    Args:
        layer_id (Union[str, int]): layer identifier

    Returns:
        Callable: hook function
    """

    def hook(_, __, output):
        if isinstance(output, torch.Tensor):
            self._features[layer_id] = output
        else:
            raise NotImplementedError

    return hook

find_layer(model, layer_id, index_offset=0, return_id=False) staticmethod

Find a layer in a model either by his name or by his index.

Parameters:

Name Type Description Default
model Module

model whose identified layer will be returned

required
layer_id Union[str, int]

layer identifier

required
index_offset int

index offset to find layers located before (negative offset) or after (positive offset) the identified layer

0
return_id bool

if True, the layer will be returned with its id

False

Returns:

Type Description
Union[Module, Tuple[Module, str]]

Union[nn.Module, Tuple[nn.Module, str]]: the corresponding layer and its id if return_id is True.

Source code in oodeel/extractor/torch_feature_extractor.py
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
@staticmethod
def find_layer(
    model: nn.Module,
    layer_id: Union[str, int],
    index_offset: int = 0,
    return_id: bool = False,
) -> Union[nn.Module, Tuple[nn.Module, str]]:
    """Find a layer in a model either by his name or by his index.

    Args:
        model (nn.Module): model whose identified layer will be returned
        layer_id (Union[str, int]): layer identifier
        index_offset (int): index offset to find layers located before (negative
            offset) or after (positive offset) the identified layer
        return_id (bool): if True, the layer will be returned with its id

    Returns:
        Union[nn.Module, Tuple[nn.Module, str]]: the corresponding layer and its id
            if return_id is True.
    """
    if isinstance(layer_id, int):
        layer_id += index_offset
        if isinstance(model, nn.Sequential):
            layer = model[layer_id]
        else:
            layer = list(model.named_modules())[layer_id][1]
    else:
        layer_id = list(dict(model.named_modules()).keys()).index(layer_id)
        layer_id += index_offset
        layer = list(model.named_modules())[layer_id][1]

    if return_id:
        return layer, layer_id
    else:
        return layer

get_weights(layer_id)

Get the weights of a layer

Parameters:

Name Type Description Default
layer_id Union[int, str]

layer identifier

required

Returns:

Type Description
List[Tensor]

List[torch.Tensor]: weights and biases matrixes

Source code in oodeel/extractor/torch_feature_extractor.py
288
289
290
291
292
293
294
295
296
297
298
def get_weights(self, layer_id: Union[str, int]) -> List[torch.Tensor]:
    """Get the weights of a layer

    Args:
        layer_id (Union[int, str]): layer identifier

    Returns:
        List[torch.Tensor]: weights and biases matrixes
    """
    layer = self.find_layer(self.model, layer_id)
    return [layer.weight.detach().cpu().numpy(), layer.bias.detach().cpu().numpy()]

predict(dataset, postproc_fns=None, detach=True, **kwargs)

Get the projection of the dataset in the feature space of self.model

Parameters:

Name Type Description Default
dataset Union[DataLoader, ItemType]

input dataset

required
postproc_fns Optional[List[Callable]]

postprocessing function to apply to each feature immediately after forward. Default to None.

None
detach bool

if True, return features detached from the computational graph. Defaults to True.

True
kwargs dict

additional arguments not considered for prediction

{}

Returns:

Type Description
Tuple[List[Tensor], dict]

List[torch.Tensor], dict: features and extra information (logits, labels) as a dictionary.

Source code in oodeel/extractor/torch_feature_extractor.py
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
def predict(
    self,
    dataset: Union[DataLoader, ItemType],
    postproc_fns: Optional[List[Callable]] = None,
    detach: bool = True,
    **kwargs,
) -> Tuple[List[torch.Tensor], dict]:
    """Get the projection of the dataset in the feature space of self.model

    Args:
        dataset (Union[DataLoader, ItemType]): input dataset
        postproc_fns (Optional[List[Callable]]): postprocessing function to apply to
            each feature immediately after forward. Default to None.
        detach (bool): if True, return features detached from the computational
            graph. Defaults to True.
        kwargs (dict): additional arguments not considered for prediction

    Returns:
        List[torch.Tensor], dict: features and extra information (logits, labels) as
            a dictionary.
    """
    labels = None

    if isinstance(dataset, get_args(ItemType)):
        tensor = TorchDataHandler.get_input_from_dataset_item(dataset)
        features, logits = self.predict_tensor(tensor, postproc_fns, detach=detach)

        # Get labels if dataset is a tuple/list
        if isinstance(dataset, (list, tuple)) and len(dataset) > 1:
            labels = TorchDataHandler.get_label_from_dataset_item(dataset)

    else:
        features = [None for i in range(len(self.feature_layers_id))]
        logits = None
        batch = next(iter(dataset))
        contains_labels = isinstance(batch, (list, tuple)) and len(batch) > 1
        for elem in dataset:
            tensor = TorchDataHandler.get_input_from_dataset_item(elem)
            features_batch, logits_batch = self.predict_tensor(
                tensor, postproc_fns, detach=detach
            )
            for i, f in enumerate(features_batch):
                features[i] = (
                    f if features[i] is None else torch.cat([features[i], f], dim=0)
                )
            # concatenate logits
            logits = (
                logits_batch
                if logits is None
                else torch.cat([logits, logits_batch], axis=0)
            )
            # concatenate labels of current batch with previous batches
            if contains_labels:
                lbl_batch = TorchDataHandler.get_label_from_dataset_item(elem)

                if labels is None:
                    labels = lbl_batch
                else:
                    labels = torch.cat([labels, lbl_batch], dim=0)

    # store extra information in a dict
    info = dict(labels=labels, logits=logits)
    return features, info

predict_tensor(x, postproc_fns=None, detach=True)

Get the projection of tensor in the feature space of self.model

Parameters:

Name Type Description Default
x TensorType

input tensor (or dataset elem)

required
postproc_fns Optional[List[Callable]]

postprocessing function to apply to each feature immediately after forward. Default to None.

None
detach bool

if True, return features detached from the computational graph. Defaults to True.

True

Returns:

Type Description
Tuple[List[Tensor], Tensor]

List[torch.Tensor], torch.Tensor: features, logits

Source code in oodeel/extractor/torch_feature_extractor.py
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
@sanitize_input
def predict_tensor(
    self,
    x: TensorType,
    postproc_fns: Optional[List[Callable]] = None,
    detach: bool = True,
) -> Tuple[List[torch.Tensor], torch.Tensor]:
    """Get the projection of tensor in the feature space of self.model

    Args:
        x (TensorType): input tensor (or dataset elem)
        postproc_fns (Optional[List[Callable]]): postprocessing function to apply to
            each feature immediately after forward. Default to None.
        detach (bool): if True, return features detached from the computational
            graph. Defaults to True.

    Returns:
        List[torch.Tensor], torch.Tensor: features, logits
    """
    if x.device != self._device:
        x = x.to(self._device)
    _ = self.model(x)

    if detach:
        features = [
            self._features[layer_id].detach() for layer_id in self._hook_layers_id
        ]
    else:
        features = [self._features[layer_id] for layer_id in self._hook_layers_id]

    # split features and logits
    logits = features.pop()

    if postproc_fns is not None:
        features = [
            postproc_fn(feature)
            for feature, postproc_fn in zip(features, postproc_fns)
        ]

    self._last_logits = logits
    return features, logits

prepare_extractor()

Prepare the feature extractor by adding hooks to self.model

Source code in oodeel/extractor/torch_feature_extractor.py
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
def prepare_extractor(self) -> None:
    """Prepare the feature extractor by adding hooks to self.model"""
    # remove forward hooks attached to the model
    self._clean_forward_hooks()

    # === If react method, clip activations from penultimate layer ===
    if self.react_threshold is not None:
        pen_layer = self.find_layer(self.model, -2)
        pen_layer.register_forward_hook(self._get_clip_hook(self.react_threshold))

    # Register a hook to store feature values for each considered layer + last layer
    for layer_id in self._hook_layers_id:
        layer = self.find_layer(self.model, layer_id)
        layer.register_forward_hook(self._get_features_hook(layer_id))

    # Crop model if input layer is provided
    if not (self.input_layer_id) is None:
        if isinstance(self.input_layer_id, int):
            if isinstance(self.model, nn.Sequential):
                self.model = nn.Sequential(
                    *list(self.model.modules())[self.input_layer_id :]
                )
            else:
                raise NotImplementedError
        elif isinstance(self.input_layer_id, str):
            if isinstance(self.model, nn.Sequential):
                module_names = list(
                    filter(
                        lambda x: x != "",
                        map(lambda x: x[0], self.model.named_modules()),
                    )
                )
                input_module_idx = module_names.index(self.input_layer_id)
                self.model = nn.Sequential(
                    *list(self.model.modules())[(input_module_idx + 1) :]
                )
            else:
                raise NotImplementedError
        else:
            raise NotImplementedError