Skip to content

TFDataHandler

TFDataHandler

Bases: DataHandler

Class to manage tf.data.Dataset. The aim is to provide a simple interface for working with tf.data.Datasets and manage them without having to use tensorflow syntax.

Source code in oodeel/datasets/tf_data_handler.py
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
class TFDataHandler(DataHandler):
    """
    Class to manage tf.data.Dataset. The aim is to provide a simple interface for
    working with tf.data.Datasets and manage them without having to use
    tensorflow syntax.
    """

    def __init__(self) -> None:
        super().__init__()
        self.backend = "tensorflow"
        self.channel_order = "channels_last"

    @classmethod
    def load_dataset(
        cls,
        dataset_id: Union[tf.data.Dataset, ItemType, str],
        columns: Optional[list] = None,
        load_kwargs: dict = {},
    ) -> tf.data.Dataset:
        """Load dataset from different manners, ensuring to return a dict based
        tf.data.Dataset.

        Args:
            dataset_id (Union[tf.data.Dataset, ItemType, str]): dataset identification.
            Can be the name of a dataset from tensorflow_datasets, a tf.data.Dataset,
            or a tuple/dict of np.ndarrays/tf.Tensors.
            columns (list, optional): Column names. If None, assigned as "input_i"
                for i-th column. Defaults to None.
            load_kwargs (dict, optional): Additional args for loading from
                tensorflow_datasets. Defaults to {}.

        Returns:
            tf.data.Dataset: A dict based tf.data.Dataset
        """
        load_kwargs["as_supervised"] = False

        if isinstance(dataset_id, get_args(ItemType)):
            dataset = cls.load_dataset_from_arrays(dataset_id, columns)
        elif isinstance(dataset_id, tf.data.Dataset):
            dataset = cls.load_custom_dataset(dataset_id, columns)
        elif isinstance(dataset_id, str):
            dataset = cls.load_from_tensorflow_datasets(dataset_id, load_kwargs)
        return dataset

    @staticmethod
    def load_dataset_from_arrays(
        dataset_id: ItemType, columns: Optional[list] = None
    ) -> tf.data.Dataset:
        """Load a tf.data.Dataset from a np.ndarray, a tf.Tensor or a tuple/dict
        of np.ndarrays/tf.Tensors.

        Args:
            dataset_id (ItemType): numpy array(s) to load.
            columns (list, optional): Column names to assign. If None,
                assigned as "input_i" for i-th column. Defaults to None.

        Returns:
            tf.data.Dataset
        """
        # If dataset_id is a numpy array, convert it to a dict
        if isinstance(dataset_id, get_args(TensorType)):
            dataset_dict = {"input": dataset_id}

        # If dataset_id is a tuple, convert it to a dict
        elif isinstance(dataset_id, tuple):
            len_elem = len(dataset_id)
            if columns is None:
                if len_elem == 2:
                    dataset_dict = {"input": dataset_id[0], "label": dataset_id[1]}
                else:
                    dataset_dict = {
                        f"input_{i}": dataset_id[i] for i in range(len_elem - 1)
                    }
                    dataset_dict["label"] = dataset_id[-1]
                print(
                    'Loading tf.data.Dataset with elems as dicts, assigning "input_i" '
                    'key to the i-th tuple dimension and "label" key to the last '
                    "tuple dimension."
                )
            else:
                assert (
                    len(columns) == len_elem
                ), "Number of column names mismatch with the number of columns"
                dataset_dict = {columns[i]: dataset_id[i] for i in range(len_elem)}

        elif isinstance(dataset_id, dict):
            if columns is not None:
                len_elem = len(dataset_id)
                assert (
                    len(columns) == len_elem
                ), "Number of column names mismatch with the number of columns"
                original_columns = list(dataset_id.keys())
                dataset_dict = {
                    columns[i]: dataset_id[original_columns[i]] for i in range(len_elem)
                }

        dataset = tf.data.Dataset.from_tensor_slices(dataset_dict)
        return dataset

    @classmethod
    def load_custom_dataset(
        cls, dataset_id: tf.data.Dataset, columns: Optional[list] = None
    ) -> tf.data.Dataset:
        """Load a custom Dataset by ensuring it has the correct format (dict-based)

        Args:
            dataset_id (tf.data.Dataset): tf.data.Dataset
            columns (list, optional): Column names to use for elements if dataset_id is
                tuple based. If None, assigned as "input_i"
                for i-th column. Defaults to None.

        Returns:
            tf.data.Dataset
        """
        # If dataset_id is a tuple based tf.data.dataset, convert it to a dict
        if not isinstance(dataset_id.element_spec, dict):
            len_elem = len(dataset_id.element_spec)
            if columns is None:
                print(
                    "Column name not found, assigning 'input_i' "
                    "key to the i-th tensor and 'label' key to the last"
                )
                if len_elem == 2:
                    columns = ["input", "label"]
                else:
                    columns = [f"input_{i}" for i in range(len_elem)]
                    columns[-1] = "label"
            else:
                assert (
                    len(columns) == len_elem
                ), "Number of column names mismatch with the number of columns"

            dataset_id = cls.tuple_to_dict(dataset_id, columns)

        dataset = dataset_id
        return dataset

    @staticmethod
    def load_from_tensorflow_datasets(
        dataset_id: str,
        load_kwargs: dict = {},
    ) -> tf.data.Dataset:
        """Load a tf.data.Dataset from the tensorflow_datasets catalog

        Args:
            dataset_id (str): Identifier of the dataset
            load_kwargs (dict, optional): Loading kwargs to add to tfds.load().
                Defaults to {}.

        Returns:
            tf.data.Dataset
        """
        assert (
            dataset_id in tfds.list_builders()
        ), "Dataset not available on tensorflow datasets catalog"
        dataset = tfds.load(dataset_id, **load_kwargs)
        return dataset

    @staticmethod
    @dict_only_ds
    def dict_to_tuple(
        dataset: tf.data.Dataset, columns: Optional[list] = None
    ) -> tf.data.Dataset:
        """Turn a dict based tf.data.Dataset to a tuple based tf.data.Dataset

        Args:
            dataset (tf.data.Dataset): Dict based tf.data.Dataset
            columns (list, optional): Columns to use for the tuples based
                tf.data.Dataset. If None, takes all the columns. Defaults to None.

        Returns:
            tf.data.Dataset
        """
        if columns is None:
            columns = list(dataset.element_spec.keys())
        dataset = dataset.map(lambda x: tuple(x[k] for k in columns))
        return dataset

    @staticmethod
    def tuple_to_dict(dataset: tf.data.Dataset, columns: list) -> tf.data.Dataset:
        """Turn a tuple based tf.data.Dataset to a dict based tf.data.Dataset

        Args:
            dataset (tf.data.Dataset): Tuple based tf.data.Dataset
            columns (list): Column names to use for the dict based tf.data.Dataset

        Returns:
            tf.data.Dataset
        """
        assert isinstance(
            dataset.element_spec, tuple
        ), "dataset elements must be tuples"
        len_elem = len(dataset.element_spec)
        assert len_elem == len(
            columns
        ), "The number of columns must be equal to the number of tuple elements"

        def tuple_to_dict(*inputs):
            return {columns[i]: inputs[i] for i in range(len_elem)}

        dataset = dataset.map(tuple_to_dict)
        return dataset

    @staticmethod
    @dict_only_ds
    def get_ds_column_names(dataset: tf.data.Dataset) -> list:
        """Get the column names of a tf.data.Dataset

        Args:
            dataset (tf.data.Dataset): tf.data.Dataset to get the column names from

        Returns:
            list: List of column names
        """
        return list(dataset.element_spec.keys())

    @staticmethod
    def map_ds(
        dataset: tf.data.Dataset,
        map_fn: Callable,
        num_parallel_calls: Optional[int] = None,
    ) -> tf.data.Dataset:
        """Map a function to a tf.data.Dataset

        Args:
            dataset (tf.data.Dataset): tf.data.Dataset to map the function to
            map_fn (Callable): Function to map
            num_parallel_calls (Optional[int], optional): Number of parallel processes
                to use. Defaults to None.

        Returns:
            tf.data.Dataset: Maped dataset
        """
        if num_parallel_calls is None:
            num_parallel_calls = tf.data.experimental.AUTOTUNE
        dataset = dataset.map(map_fn, num_parallel_calls=num_parallel_calls)
        return dataset

    @staticmethod
    @dict_only_ds
    def filter_by_value(
        dataset: tf.data.Dataset,
        column_name: str,
        values: list,
        excluded: bool = False,
    ) -> tf.data.Dataset:
        """Filter a tf.data.Dataset by checking if the value of a column is in 'values'

        Args:
            dataset (tf.data.Dataset): tf.data.Dataset to filter
            column_name (str): Column to filter the dataset with
            values (list): Column values to keep (if excluded is False)
                or to exclude
            excluded (bool, optional): To keep (False) or exclude (True) the samples
                with Column values included in Values. Defaults to False.

        Returns:
            tf.data.Dataset: Filtered dataset
        """
        # If the labels are one-hot encoded, prepare a function to get the label as int
        if len(dataset.element_spec[column_name].shape) > 0:

            def get_label_int(elem):
                return int(tf.argmax(elem[column_name]))

        else:

            def get_label_int(elem):
                return elem[column_name]

        def filter_fn(elem):
            value = get_label_int(elem)
            if excluded:
                return not tf.reduce_any(tf.equal(value, values))
            else:
                return tf.reduce_any(tf.equal(value, values))

        dataset_to_filter = dataset
        dataset_to_filter = dataset_to_filter.filter(filter_fn)
        return dataset_to_filter

    @classmethod
    def prepare(
        cls,
        dataset: tf.data.Dataset,
        batch_size: int,
        preprocess_fn: Optional[Callable] = None,
        augment_fn: Optional[Callable] = None,
        columns: Optional[list] = None,
        shuffle: bool = False,
        dict_based_fns: bool = True,
        return_tuple: bool = True,
        shuffle_buffer_size: Optional[int] = None,
        prefetch_buffer_size: Optional[int] = None,
        drop_remainder: Optional[bool] = False,
    ) -> tf.data.Dataset:
        """Prepare a tf.data.Dataset for training

        Args:
            dataset (tf.data.Dataset): tf.data.Dataset to prepare
            batch_size (int): Batch size
            preprocess_fn (Callable, optional): Preprocessing function to apply to
                the dataset. Defaults to None.
            augment_fn (Callable, optional): Augment function to be used (when the
                returned dataset is to be used for training). Defaults to None.
            columns (list, optional): List of column names corresponding to the columns
                that will be returned. Keep all columns if None. Defaults to None.
            shuffle (bool, optional): To shuffle the returned dataset or not.
                Defaults to False.
            dict_based_fns (bool): Whether to use preprocess and DA functions as dict
                based (if True) or as tuple based (if False). Defaults to True.
            return_tuple (bool, optional): Whether to return each dataset item
                as a tuple. Defaults to True.
            shuffle_buffer_size (int, optional): Size of the shuffle buffer. If None,
                taken as the number of samples in the dataset. Defaults to None.
            prefetch_buffer_size (Optional[int], optional): Buffer size for prefetch.
                If None, automatically chose using tf.data.experimental.AUTOTUNE.
                Defaults to None.
            drop_remainder (Optional[bool], optional): To drop the last batch when
                its size is lower than batch_size. Defaults to False.

        Returns:
            tf.data.Dataset: Prepared dataset
        """
        # dict based to tuple based
        columns = columns or cls.get_ds_column_names(dataset)
        if not dict_based_fns:
            dataset = cls.dict_to_tuple(dataset, columns)

        # preprocess + DA
        if preprocess_fn is not None:
            dataset = cls.map_ds(dataset, preprocess_fn)
        if augment_fn is not None:
            dataset = cls.map_ds(dataset, augment_fn)

        if dict_based_fns and return_tuple:
            dataset = cls.dict_to_tuple(dataset, columns)

        dataset = dataset.cache()

        # shuffle
        if shuffle:
            num_samples = cls.get_dataset_length(dataset)
            shuffle_buffer_size = (
                num_samples if shuffle_buffer_size is None else shuffle_buffer_size
            )
            dataset = dataset.shuffle(shuffle_buffer_size)
        # batch
        dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
        # prefetch
        if prefetch_buffer_size is not None:
            prefetch_buffer_size = tf.data.experimental.AUTOTUNE
        dataset = dataset.prefetch(prefetch_buffer_size)
        return dataset

    @staticmethod
    def make_channel_first(input_key: str, dataset: tf.data.Dataset) -> tf.data.Dataset:
        """Make a tf.data.Dataset channel first. Make sure that the dataset is not
            already Channel first. If so, the tensor will have the format
            (batch_size, x_size, channel, y_size).

        Args:
            input_key (str): input key of the dict-based tf.data.Dataset
            dataset (tf.data.Dataset): tf.data.Dataset to make channel first

        Returns:
            tf.data.Dataset: Channel first dataset
        """

        def channel_first(x):
            x[input_key] = tf.transpose(x[input_key], perm=[2, 0, 1])
            return x

        dataset = dataset.map(channel_first)
        return dataset

    @classmethod
    def merge(
        cls,
        id_dataset: tf.data.Dataset,
        ood_dataset: tf.data.Dataset,
        resize: Optional[bool] = False,
        shape: Optional[Tuple[int]] = None,
        channel_order: Optional[str] = "channels_last",
    ) -> tf.data.Dataset:
        """Merge two tf.data.Datasets

        Args:
            id_dataset (tf.data.Dataset): dataset of in-distribution data
            ood_dataset (tf.data.Dataset): dataset of out-of-distribution data
            resize (Optional[bool], optional): toggles if input tensors of the
                datasets have to be resized to have the same shape. Defaults to True.
            shape (Optional[Tuple[int]], optional): shape to use for resizing input
                tensors. If None, the tensors are resized with the shape of the
                id_dataset input tensors. Defaults to None.
            channel_order (Optional[str], optional): channel order of the input

        Returns:
            tf.data.Dataset: merged dataset
        """
        len_elem_id = cls.get_item_length(id_dataset)
        len_elem_ood = cls.get_item_length(ood_dataset)
        assert (
            len_elem_id == len_elem_ood
        ), "incompatible dataset elements (different elem dict length)"

        # If a desired shape is given, triggers the resize
        if shape is not None:
            resize = True

        id_elem_spec = id_dataset.element_spec
        ood_elem_spec = ood_dataset.element_spec
        assert isinstance(id_elem_spec, dict), "dataset elements must be dicts"
        assert isinstance(ood_elem_spec, dict), "dataset elements must be dicts"

        input_key_id = list(id_elem_spec.keys())[0]
        input_key_ood = list(ood_elem_spec.keys())[0]
        shape_id = id_dataset.element_spec[input_key_id].shape
        shape_ood = ood_dataset.element_spec[input_key_ood].shape

        # If the shape of the two datasets are different, triggers the resize
        if shape_id != shape_ood:
            resize = True

            if shape is None:
                print(
                    "Resizing the first item of elem (usually the image)",
                    " with the shape of id_dataset",
                )
                if channel_order == "channels_first":
                    shape = shape_id[1:]
                else:
                    shape = shape_id[:2]

        if resize:

            def reshape_im_id(elem):
                elem[input_key_id] = tf.image.resize(elem[input_key_id], shape)
                return elem

            def reshape_im_ood(elem):
                elem[input_key_ood] = tf.image.resize(elem[input_key_ood], shape)
                return elem

            id_dataset = id_dataset.map(reshape_im_id)
            ood_dataset = ood_dataset.map(reshape_im_ood)

        merged_dataset = id_dataset.concatenate(ood_dataset)
        return merged_dataset

    @staticmethod
    def get_item_length(dataset: tf.data.Dataset) -> int:
        """Get the length of a dataset element. If an element is a tensor, the length is
        one and if it is a sequence (list or tuple), it is len(elem).

        Args:
            dataset (tf.data.Dataset): Dataset to process

        Returns:
            int: length of the dataset elems
        """
        if isinstance(dataset.element_spec, (tuple, list, dict)):
            return len(dataset.element_spec)
        return 1

    @staticmethod
    def get_dataset_length(dataset: tf.data.Dataset) -> int:
        """Get the length of a dataset. Try to access it with len(), and if not
        available, with a reduce op.

        Args:
            dataset (tf.data.Dataset): Dataset to process

        Returns:
            int: _description_
        """
        try:
            return len(dataset)
        except TypeError:
            cardinality = dataset.reduce(0, lambda x, _: x + 1)
            return int(cardinality)

    @staticmethod
    def get_column_elements_shape(
        dataset: tf.data.Dataset, column_name: Union[str, int]
    ) -> tuple:
        """Get the shape of the elements of a column of dataset identified by
        column_name

        Args:
            dataset (tf.data.Dataset): a tf.data.dataset
            column_name (Union[str, int]): The column name to get
                the element shape from.

        Returns:
            tuple: the shape of an element from column_name
        """
        return tuple(dataset.element_spec[column_name].shape)

    @staticmethod
    def get_input_from_dataset_item(elem: ItemType) -> TensorType:
        """Get the tensor that is to be feed as input to a model from a dataset element.

        Args:
            elem (ItemType): dataset element to extract input from

        Returns:
            TensorType: Input tensor
        """
        if isinstance(elem, (tuple, list)):
            tensor = elem[0]
        elif isinstance(elem, dict):
            tensor = elem[list(elem.keys())[0]]
        else:
            tensor = elem
        return tensor

    @staticmethod
    def get_label_from_dataset_item(item: ItemType):
        """Retrieve label tensor from item as a tuple/list. Label must be at index 1
        in the item tuple. If one-hot encoded, labels are converted to single value.

        Args:
            elem (ItemType): dataset element to extract label from

        Returns:
            Any: Label tensor
        """
        label = item[1]  # labels must be at index 1 in the item tuple
        # If labels are one-hot encoded, take the argmax
        if tf.rank(label) > 1 and label.shape[1] > 1:
            label = tf.reshape(label, shape=[label.shape[0], -1])
            label = tf.argmax(label, axis=1)
        # If labels are in two dimensions, squeeze them
        if len(label.shape) > 1:
            label = tf.reshape(label, [label.shape[0]])
        return label

dict_to_tuple(dataset, columns=None) staticmethod

Turn a dict based tf.data.Dataset to a tuple based tf.data.Dataset

Parameters:

Name Type Description Default
dataset Dataset

Dict based tf.data.Dataset

required
columns list

Columns to use for the tuples based tf.data.Dataset. If None, takes all the columns. Defaults to None.

None

Returns:

Type Description
Dataset

tf.data.Dataset

Source code in oodeel/datasets/tf_data_handler.py
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
@staticmethod
@dict_only_ds
def dict_to_tuple(
    dataset: tf.data.Dataset, columns: Optional[list] = None
) -> tf.data.Dataset:
    """Turn a dict based tf.data.Dataset to a tuple based tf.data.Dataset

    Args:
        dataset (tf.data.Dataset): Dict based tf.data.Dataset
        columns (list, optional): Columns to use for the tuples based
            tf.data.Dataset. If None, takes all the columns. Defaults to None.

    Returns:
        tf.data.Dataset
    """
    if columns is None:
        columns = list(dataset.element_spec.keys())
    dataset = dataset.map(lambda x: tuple(x[k] for k in columns))
    return dataset

filter_by_value(dataset, column_name, values, excluded=False) staticmethod

Filter a tf.data.Dataset by checking if the value of a column is in 'values'

Parameters:

Name Type Description Default
dataset Dataset

tf.data.Dataset to filter

required
column_name str

Column to filter the dataset with

required
values list

Column values to keep (if excluded is False) or to exclude

required
excluded bool

To keep (False) or exclude (True) the samples with Column values included in Values. Defaults to False.

False

Returns:

Type Description
Dataset

tf.data.Dataset: Filtered dataset

Source code in oodeel/datasets/tf_data_handler.py
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
@staticmethod
@dict_only_ds
def filter_by_value(
    dataset: tf.data.Dataset,
    column_name: str,
    values: list,
    excluded: bool = False,
) -> tf.data.Dataset:
    """Filter a tf.data.Dataset by checking if the value of a column is in 'values'

    Args:
        dataset (tf.data.Dataset): tf.data.Dataset to filter
        column_name (str): Column to filter the dataset with
        values (list): Column values to keep (if excluded is False)
            or to exclude
        excluded (bool, optional): To keep (False) or exclude (True) the samples
            with Column values included in Values. Defaults to False.

    Returns:
        tf.data.Dataset: Filtered dataset
    """
    # If the labels are one-hot encoded, prepare a function to get the label as int
    if len(dataset.element_spec[column_name].shape) > 0:

        def get_label_int(elem):
            return int(tf.argmax(elem[column_name]))

    else:

        def get_label_int(elem):
            return elem[column_name]

    def filter_fn(elem):
        value = get_label_int(elem)
        if excluded:
            return not tf.reduce_any(tf.equal(value, values))
        else:
            return tf.reduce_any(tf.equal(value, values))

    dataset_to_filter = dataset
    dataset_to_filter = dataset_to_filter.filter(filter_fn)
    return dataset_to_filter

get_column_elements_shape(dataset, column_name) staticmethod

Get the shape of the elements of a column of dataset identified by column_name

Parameters:

Name Type Description Default
dataset Dataset

a tf.data.dataset

required
column_name Union[str, int]

The column name to get the element shape from.

required

Returns:

Name Type Description
tuple tuple

the shape of an element from column_name

Source code in oodeel/datasets/tf_data_handler.py
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
@staticmethod
def get_column_elements_shape(
    dataset: tf.data.Dataset, column_name: Union[str, int]
) -> tuple:
    """Get the shape of the elements of a column of dataset identified by
    column_name

    Args:
        dataset (tf.data.Dataset): a tf.data.dataset
        column_name (Union[str, int]): The column name to get
            the element shape from.

    Returns:
        tuple: the shape of an element from column_name
    """
    return tuple(dataset.element_spec[column_name].shape)

get_dataset_length(dataset) staticmethod

Get the length of a dataset. Try to access it with len(), and if not available, with a reduce op.

Parameters:

Name Type Description Default
dataset Dataset

Dataset to process

required

Returns:

Name Type Description
int int

description

Source code in oodeel/datasets/tf_data_handler.py
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
@staticmethod
def get_dataset_length(dataset: tf.data.Dataset) -> int:
    """Get the length of a dataset. Try to access it with len(), and if not
    available, with a reduce op.

    Args:
        dataset (tf.data.Dataset): Dataset to process

    Returns:
        int: _description_
    """
    try:
        return len(dataset)
    except TypeError:
        cardinality = dataset.reduce(0, lambda x, _: x + 1)
        return int(cardinality)

get_ds_column_names(dataset) staticmethod

Get the column names of a tf.data.Dataset

Parameters:

Name Type Description Default
dataset Dataset

tf.data.Dataset to get the column names from

required

Returns:

Name Type Description
list list

List of column names

Source code in oodeel/datasets/tf_data_handler.py
276
277
278
279
280
281
282
283
284
285
286
287
@staticmethod
@dict_only_ds
def get_ds_column_names(dataset: tf.data.Dataset) -> list:
    """Get the column names of a tf.data.Dataset

    Args:
        dataset (tf.data.Dataset): tf.data.Dataset to get the column names from

    Returns:
        list: List of column names
    """
    return list(dataset.element_spec.keys())

get_input_from_dataset_item(elem) staticmethod

Get the tensor that is to be feed as input to a model from a dataset element.

Parameters:

Name Type Description Default
elem ItemType

dataset element to extract input from

required

Returns:

Name Type Description
TensorType TensorType

Input tensor

Source code in oodeel/datasets/tf_data_handler.py
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
@staticmethod
def get_input_from_dataset_item(elem: ItemType) -> TensorType:
    """Get the tensor that is to be feed as input to a model from a dataset element.

    Args:
        elem (ItemType): dataset element to extract input from

    Returns:
        TensorType: Input tensor
    """
    if isinstance(elem, (tuple, list)):
        tensor = elem[0]
    elif isinstance(elem, dict):
        tensor = elem[list(elem.keys())[0]]
    else:
        tensor = elem
    return tensor

get_item_length(dataset) staticmethod

Get the length of a dataset element. If an element is a tensor, the length is one and if it is a sequence (list or tuple), it is len(elem).

Parameters:

Name Type Description Default
dataset Dataset

Dataset to process

required

Returns:

Name Type Description
int int

length of the dataset elems

Source code in oodeel/datasets/tf_data_handler.py
523
524
525
526
527
528
529
530
531
532
533
534
535
536
@staticmethod
def get_item_length(dataset: tf.data.Dataset) -> int:
    """Get the length of a dataset element. If an element is a tensor, the length is
    one and if it is a sequence (list or tuple), it is len(elem).

    Args:
        dataset (tf.data.Dataset): Dataset to process

    Returns:
        int: length of the dataset elems
    """
    if isinstance(dataset.element_spec, (tuple, list, dict)):
        return len(dataset.element_spec)
    return 1

get_label_from_dataset_item(item) staticmethod

Retrieve label tensor from item as a tuple/list. Label must be at index 1 in the item tuple. If one-hot encoded, labels are converted to single value.

Parameters:

Name Type Description Default
elem ItemType

dataset element to extract label from

required

Returns:

Name Type Description
Any

Label tensor

Source code in oodeel/datasets/tf_data_handler.py
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
@staticmethod
def get_label_from_dataset_item(item: ItemType):
    """Retrieve label tensor from item as a tuple/list. Label must be at index 1
    in the item tuple. If one-hot encoded, labels are converted to single value.

    Args:
        elem (ItemType): dataset element to extract label from

    Returns:
        Any: Label tensor
    """
    label = item[1]  # labels must be at index 1 in the item tuple
    # If labels are one-hot encoded, take the argmax
    if tf.rank(label) > 1 and label.shape[1] > 1:
        label = tf.reshape(label, shape=[label.shape[0], -1])
        label = tf.argmax(label, axis=1)
    # If labels are in two dimensions, squeeze them
    if len(label.shape) > 1:
        label = tf.reshape(label, [label.shape[0]])
    return label

load_custom_dataset(dataset_id, columns=None) classmethod

Load a custom Dataset by ensuring it has the correct format (dict-based)

Parameters:

Name Type Description Default
dataset_id Dataset

tf.data.Dataset

required
columns list

Column names to use for elements if dataset_id is tuple based. If None, assigned as "input_i" for i-th column. Defaults to None.

None

Returns:

Type Description
Dataset

tf.data.Dataset

Source code in oodeel/datasets/tf_data_handler.py
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
@classmethod
def load_custom_dataset(
    cls, dataset_id: tf.data.Dataset, columns: Optional[list] = None
) -> tf.data.Dataset:
    """Load a custom Dataset by ensuring it has the correct format (dict-based)

    Args:
        dataset_id (tf.data.Dataset): tf.data.Dataset
        columns (list, optional): Column names to use for elements if dataset_id is
            tuple based. If None, assigned as "input_i"
            for i-th column. Defaults to None.

    Returns:
        tf.data.Dataset
    """
    # If dataset_id is a tuple based tf.data.dataset, convert it to a dict
    if not isinstance(dataset_id.element_spec, dict):
        len_elem = len(dataset_id.element_spec)
        if columns is None:
            print(
                "Column name not found, assigning 'input_i' "
                "key to the i-th tensor and 'label' key to the last"
            )
            if len_elem == 2:
                columns = ["input", "label"]
            else:
                columns = [f"input_{i}" for i in range(len_elem)]
                columns[-1] = "label"
        else:
            assert (
                len(columns) == len_elem
            ), "Number of column names mismatch with the number of columns"

        dataset_id = cls.tuple_to_dict(dataset_id, columns)

    dataset = dataset_id
    return dataset

load_dataset(dataset_id, columns=None, load_kwargs={}) classmethod

Load dataset from different manners, ensuring to return a dict based tf.data.Dataset.

Parameters:

Name Type Description Default
dataset_id Union[Dataset, ItemType, str]

dataset identification.

required
columns list

Column names. If None, assigned as "input_i" for i-th column. Defaults to None.

None
load_kwargs dict

Additional args for loading from tensorflow_datasets. Defaults to {}.

{}

Returns:

Type Description
Dataset

tf.data.Dataset: A dict based tf.data.Dataset

Source code in oodeel/datasets/tf_data_handler.py
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
@classmethod
def load_dataset(
    cls,
    dataset_id: Union[tf.data.Dataset, ItemType, str],
    columns: Optional[list] = None,
    load_kwargs: dict = {},
) -> tf.data.Dataset:
    """Load dataset from different manners, ensuring to return a dict based
    tf.data.Dataset.

    Args:
        dataset_id (Union[tf.data.Dataset, ItemType, str]): dataset identification.
        Can be the name of a dataset from tensorflow_datasets, a tf.data.Dataset,
        or a tuple/dict of np.ndarrays/tf.Tensors.
        columns (list, optional): Column names. If None, assigned as "input_i"
            for i-th column. Defaults to None.
        load_kwargs (dict, optional): Additional args for loading from
            tensorflow_datasets. Defaults to {}.

    Returns:
        tf.data.Dataset: A dict based tf.data.Dataset
    """
    load_kwargs["as_supervised"] = False

    if isinstance(dataset_id, get_args(ItemType)):
        dataset = cls.load_dataset_from_arrays(dataset_id, columns)
    elif isinstance(dataset_id, tf.data.Dataset):
        dataset = cls.load_custom_dataset(dataset_id, columns)
    elif isinstance(dataset_id, str):
        dataset = cls.load_from_tensorflow_datasets(dataset_id, load_kwargs)
    return dataset

load_dataset_from_arrays(dataset_id, columns=None) staticmethod

Load a tf.data.Dataset from a np.ndarray, a tf.Tensor or a tuple/dict of np.ndarrays/tf.Tensors.

Parameters:

Name Type Description Default
dataset_id ItemType

numpy array(s) to load.

required
columns list

Column names to assign. If None, assigned as "input_i" for i-th column. Defaults to None.

None

Returns:

Type Description
Dataset

tf.data.Dataset

Source code in oodeel/datasets/tf_data_handler.py
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
@staticmethod
def load_dataset_from_arrays(
    dataset_id: ItemType, columns: Optional[list] = None
) -> tf.data.Dataset:
    """Load a tf.data.Dataset from a np.ndarray, a tf.Tensor or a tuple/dict
    of np.ndarrays/tf.Tensors.

    Args:
        dataset_id (ItemType): numpy array(s) to load.
        columns (list, optional): Column names to assign. If None,
            assigned as "input_i" for i-th column. Defaults to None.

    Returns:
        tf.data.Dataset
    """
    # If dataset_id is a numpy array, convert it to a dict
    if isinstance(dataset_id, get_args(TensorType)):
        dataset_dict = {"input": dataset_id}

    # If dataset_id is a tuple, convert it to a dict
    elif isinstance(dataset_id, tuple):
        len_elem = len(dataset_id)
        if columns is None:
            if len_elem == 2:
                dataset_dict = {"input": dataset_id[0], "label": dataset_id[1]}
            else:
                dataset_dict = {
                    f"input_{i}": dataset_id[i] for i in range(len_elem - 1)
                }
                dataset_dict["label"] = dataset_id[-1]
            print(
                'Loading tf.data.Dataset with elems as dicts, assigning "input_i" '
                'key to the i-th tuple dimension and "label" key to the last '
                "tuple dimension."
            )
        else:
            assert (
                len(columns) == len_elem
            ), "Number of column names mismatch with the number of columns"
            dataset_dict = {columns[i]: dataset_id[i] for i in range(len_elem)}

    elif isinstance(dataset_id, dict):
        if columns is not None:
            len_elem = len(dataset_id)
            assert (
                len(columns) == len_elem
            ), "Number of column names mismatch with the number of columns"
            original_columns = list(dataset_id.keys())
            dataset_dict = {
                columns[i]: dataset_id[original_columns[i]] for i in range(len_elem)
            }

    dataset = tf.data.Dataset.from_tensor_slices(dataset_dict)
    return dataset

load_from_tensorflow_datasets(dataset_id, load_kwargs={}) staticmethod

Load a tf.data.Dataset from the tensorflow_datasets catalog

Parameters:

Name Type Description Default
dataset_id str

Identifier of the dataset

required
load_kwargs dict

Loading kwargs to add to tfds.load(). Defaults to {}.

{}

Returns:

Type Description
Dataset

tf.data.Dataset

Source code in oodeel/datasets/tf_data_handler.py
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
@staticmethod
def load_from_tensorflow_datasets(
    dataset_id: str,
    load_kwargs: dict = {},
) -> tf.data.Dataset:
    """Load a tf.data.Dataset from the tensorflow_datasets catalog

    Args:
        dataset_id (str): Identifier of the dataset
        load_kwargs (dict, optional): Loading kwargs to add to tfds.load().
            Defaults to {}.

    Returns:
        tf.data.Dataset
    """
    assert (
        dataset_id in tfds.list_builders()
    ), "Dataset not available on tensorflow datasets catalog"
    dataset = tfds.load(dataset_id, **load_kwargs)
    return dataset

make_channel_first(input_key, dataset) staticmethod

Make a tf.data.Dataset channel first. Make sure that the dataset is not already Channel first. If so, the tensor will have the format (batch_size, x_size, channel, y_size).

Parameters:

Name Type Description Default
input_key str

input key of the dict-based tf.data.Dataset

required
dataset Dataset

tf.data.Dataset to make channel first

required

Returns:

Type Description
Dataset

tf.data.Dataset: Channel first dataset

Source code in oodeel/datasets/tf_data_handler.py
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
@staticmethod
def make_channel_first(input_key: str, dataset: tf.data.Dataset) -> tf.data.Dataset:
    """Make a tf.data.Dataset channel first. Make sure that the dataset is not
        already Channel first. If so, the tensor will have the format
        (batch_size, x_size, channel, y_size).

    Args:
        input_key (str): input key of the dict-based tf.data.Dataset
        dataset (tf.data.Dataset): tf.data.Dataset to make channel first

    Returns:
        tf.data.Dataset: Channel first dataset
    """

    def channel_first(x):
        x[input_key] = tf.transpose(x[input_key], perm=[2, 0, 1])
        return x

    dataset = dataset.map(channel_first)
    return dataset

map_ds(dataset, map_fn, num_parallel_calls=None) staticmethod

Map a function to a tf.data.Dataset

Parameters:

Name Type Description Default
dataset Dataset

tf.data.Dataset to map the function to

required
map_fn Callable

Function to map

required
num_parallel_calls Optional[int]

Number of parallel processes to use. Defaults to None.

None

Returns:

Type Description
Dataset

tf.data.Dataset: Maped dataset

Source code in oodeel/datasets/tf_data_handler.py
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
@staticmethod
def map_ds(
    dataset: tf.data.Dataset,
    map_fn: Callable,
    num_parallel_calls: Optional[int] = None,
) -> tf.data.Dataset:
    """Map a function to a tf.data.Dataset

    Args:
        dataset (tf.data.Dataset): tf.data.Dataset to map the function to
        map_fn (Callable): Function to map
        num_parallel_calls (Optional[int], optional): Number of parallel processes
            to use. Defaults to None.

    Returns:
        tf.data.Dataset: Maped dataset
    """
    if num_parallel_calls is None:
        num_parallel_calls = tf.data.experimental.AUTOTUNE
    dataset = dataset.map(map_fn, num_parallel_calls=num_parallel_calls)
    return dataset

merge(id_dataset, ood_dataset, resize=False, shape=None, channel_order='channels_last') classmethod

Merge two tf.data.Datasets

Parameters:

Name Type Description Default
id_dataset Dataset

dataset of in-distribution data

required
ood_dataset Dataset

dataset of out-of-distribution data

required
resize Optional[bool]

toggles if input tensors of the datasets have to be resized to have the same shape. Defaults to True.

False
shape Optional[Tuple[int]]

shape to use for resizing input tensors. If None, the tensors are resized with the shape of the id_dataset input tensors. Defaults to None.

None
channel_order Optional[str]

channel order of the input

'channels_last'

Returns:

Type Description
Dataset

tf.data.Dataset: merged dataset

Source code in oodeel/datasets/tf_data_handler.py
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
@classmethod
def merge(
    cls,
    id_dataset: tf.data.Dataset,
    ood_dataset: tf.data.Dataset,
    resize: Optional[bool] = False,
    shape: Optional[Tuple[int]] = None,
    channel_order: Optional[str] = "channels_last",
) -> tf.data.Dataset:
    """Merge two tf.data.Datasets

    Args:
        id_dataset (tf.data.Dataset): dataset of in-distribution data
        ood_dataset (tf.data.Dataset): dataset of out-of-distribution data
        resize (Optional[bool], optional): toggles if input tensors of the
            datasets have to be resized to have the same shape. Defaults to True.
        shape (Optional[Tuple[int]], optional): shape to use for resizing input
            tensors. If None, the tensors are resized with the shape of the
            id_dataset input tensors. Defaults to None.
        channel_order (Optional[str], optional): channel order of the input

    Returns:
        tf.data.Dataset: merged dataset
    """
    len_elem_id = cls.get_item_length(id_dataset)
    len_elem_ood = cls.get_item_length(ood_dataset)
    assert (
        len_elem_id == len_elem_ood
    ), "incompatible dataset elements (different elem dict length)"

    # If a desired shape is given, triggers the resize
    if shape is not None:
        resize = True

    id_elem_spec = id_dataset.element_spec
    ood_elem_spec = ood_dataset.element_spec
    assert isinstance(id_elem_spec, dict), "dataset elements must be dicts"
    assert isinstance(ood_elem_spec, dict), "dataset elements must be dicts"

    input_key_id = list(id_elem_spec.keys())[0]
    input_key_ood = list(ood_elem_spec.keys())[0]
    shape_id = id_dataset.element_spec[input_key_id].shape
    shape_ood = ood_dataset.element_spec[input_key_ood].shape

    # If the shape of the two datasets are different, triggers the resize
    if shape_id != shape_ood:
        resize = True

        if shape is None:
            print(
                "Resizing the first item of elem (usually the image)",
                " with the shape of id_dataset",
            )
            if channel_order == "channels_first":
                shape = shape_id[1:]
            else:
                shape = shape_id[:2]

    if resize:

        def reshape_im_id(elem):
            elem[input_key_id] = tf.image.resize(elem[input_key_id], shape)
            return elem

        def reshape_im_ood(elem):
            elem[input_key_ood] = tf.image.resize(elem[input_key_ood], shape)
            return elem

        id_dataset = id_dataset.map(reshape_im_id)
        ood_dataset = ood_dataset.map(reshape_im_ood)

    merged_dataset = id_dataset.concatenate(ood_dataset)
    return merged_dataset

prepare(dataset, batch_size, preprocess_fn=None, augment_fn=None, columns=None, shuffle=False, dict_based_fns=True, return_tuple=True, shuffle_buffer_size=None, prefetch_buffer_size=None, drop_remainder=False) classmethod

Prepare a tf.data.Dataset for training

Parameters:

Name Type Description Default
dataset Dataset

tf.data.Dataset to prepare

required
batch_size int

Batch size

required
preprocess_fn Callable

Preprocessing function to apply to the dataset. Defaults to None.

None
augment_fn Callable

Augment function to be used (when the returned dataset is to be used for training). Defaults to None.

None
columns list

List of column names corresponding to the columns that will be returned. Keep all columns if None. Defaults to None.

None
shuffle bool

To shuffle the returned dataset or not. Defaults to False.

False
dict_based_fns bool

Whether to use preprocess and DA functions as dict based (if True) or as tuple based (if False). Defaults to True.

True
return_tuple bool

Whether to return each dataset item as a tuple. Defaults to True.

True
shuffle_buffer_size int

Size of the shuffle buffer. If None, taken as the number of samples in the dataset. Defaults to None.

None
prefetch_buffer_size Optional[int]

Buffer size for prefetch. If None, automatically chose using tf.data.experimental.AUTOTUNE. Defaults to None.

None
drop_remainder Optional[bool]

To drop the last batch when its size is lower than batch_size. Defaults to False.

False

Returns:

Type Description
Dataset

tf.data.Dataset: Prepared dataset

Source code in oodeel/datasets/tf_data_handler.py
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
@classmethod
def prepare(
    cls,
    dataset: tf.data.Dataset,
    batch_size: int,
    preprocess_fn: Optional[Callable] = None,
    augment_fn: Optional[Callable] = None,
    columns: Optional[list] = None,
    shuffle: bool = False,
    dict_based_fns: bool = True,
    return_tuple: bool = True,
    shuffle_buffer_size: Optional[int] = None,
    prefetch_buffer_size: Optional[int] = None,
    drop_remainder: Optional[bool] = False,
) -> tf.data.Dataset:
    """Prepare a tf.data.Dataset for training

    Args:
        dataset (tf.data.Dataset): tf.data.Dataset to prepare
        batch_size (int): Batch size
        preprocess_fn (Callable, optional): Preprocessing function to apply to
            the dataset. Defaults to None.
        augment_fn (Callable, optional): Augment function to be used (when the
            returned dataset is to be used for training). Defaults to None.
        columns (list, optional): List of column names corresponding to the columns
            that will be returned. Keep all columns if None. Defaults to None.
        shuffle (bool, optional): To shuffle the returned dataset or not.
            Defaults to False.
        dict_based_fns (bool): Whether to use preprocess and DA functions as dict
            based (if True) or as tuple based (if False). Defaults to True.
        return_tuple (bool, optional): Whether to return each dataset item
            as a tuple. Defaults to True.
        shuffle_buffer_size (int, optional): Size of the shuffle buffer. If None,
            taken as the number of samples in the dataset. Defaults to None.
        prefetch_buffer_size (Optional[int], optional): Buffer size for prefetch.
            If None, automatically chose using tf.data.experimental.AUTOTUNE.
            Defaults to None.
        drop_remainder (Optional[bool], optional): To drop the last batch when
            its size is lower than batch_size. Defaults to False.

    Returns:
        tf.data.Dataset: Prepared dataset
    """
    # dict based to tuple based
    columns = columns or cls.get_ds_column_names(dataset)
    if not dict_based_fns:
        dataset = cls.dict_to_tuple(dataset, columns)

    # preprocess + DA
    if preprocess_fn is not None:
        dataset = cls.map_ds(dataset, preprocess_fn)
    if augment_fn is not None:
        dataset = cls.map_ds(dataset, augment_fn)

    if dict_based_fns and return_tuple:
        dataset = cls.dict_to_tuple(dataset, columns)

    dataset = dataset.cache()

    # shuffle
    if shuffle:
        num_samples = cls.get_dataset_length(dataset)
        shuffle_buffer_size = (
            num_samples if shuffle_buffer_size is None else shuffle_buffer_size
        )
        dataset = dataset.shuffle(shuffle_buffer_size)
    # batch
    dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
    # prefetch
    if prefetch_buffer_size is not None:
        prefetch_buffer_size = tf.data.experimental.AUTOTUNE
    dataset = dataset.prefetch(prefetch_buffer_size)
    return dataset

split_by_class(dataset, in_labels=None, out_labels=None)

Filter the dataset by assigning ood labels depending on labels value (typically, class id).

Parameters:

Name Type Description Default
in_labels Optional[Union[ndarray, list]]

set of labels to be considered as in-distribution. Defaults to None.

None
out_labels Optional[Union[ndarray, list]]

set of labels to be considered as out-of-distribution. Defaults to None.

None

Returns:

Type Description
Optional[Tuple[DatasetType]]

Optional[Tuple[OODDataset]]: Tuple of in-distribution and out-of-distribution OODDatasets

Source code in oodeel/datasets/data_handler.py
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def split_by_class(
    self,
    dataset: DatasetType,
    in_labels: Optional[Union[np.ndarray, list]] = None,
    out_labels: Optional[Union[np.ndarray, list]] = None,
) -> Optional[Tuple[DatasetType]]:
    """Filter the dataset by assigning ood labels depending on labels
    value (typically, class id).

    Args:
        in_labels (Optional[Union[np.ndarray, list]], optional): set of labels
            to be considered as in-distribution. Defaults to None.
        out_labels (Optional[Union[np.ndarray, list]], optional): set of labels
            to be considered as out-of-distribution. Defaults to None.

    Returns:
        Optional[Tuple[OODDataset]]: Tuple of in-distribution and
            out-of-distribution OODDatasets
    """
    # Make sure the dataset has labels
    assert (in_labels is not None) or (
        out_labels is not None
    ), "specify labels to filter with"
    assert self.get_item_length(dataset) >= 2, "the dataset has no labels"

    # Filter the dataset depending on in_labels and out_labels given
    if (out_labels is not None) and (in_labels is not None):
        in_data = self.filter_by_value(dataset, "label", in_labels)
        out_data = self.filter_by_value(dataset, "label", out_labels)

    if out_labels is None:
        in_data = self.filter_by_value(dataset, "label", in_labels)
        out_data = self.filter_by_value(dataset, "label", in_labels, excluded=True)

    elif in_labels is None:
        in_data = self.filter_by_value(dataset, "label", out_labels, excluded=True)
        out_data = self.filter_by_value(dataset, "label", out_labels)

    # Return the filtered OODDatasets
    return in_data, out_data

tuple_to_dict(dataset, columns) staticmethod

Turn a tuple based tf.data.Dataset to a dict based tf.data.Dataset

Parameters:

Name Type Description Default
dataset Dataset

Tuple based tf.data.Dataset

required
columns list

Column names to use for the dict based tf.data.Dataset

required

Returns:

Type Description
Dataset

tf.data.Dataset

Source code in oodeel/datasets/tf_data_handler.py
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
@staticmethod
def tuple_to_dict(dataset: tf.data.Dataset, columns: list) -> tf.data.Dataset:
    """Turn a tuple based tf.data.Dataset to a dict based tf.data.Dataset

    Args:
        dataset (tf.data.Dataset): Tuple based tf.data.Dataset
        columns (list): Column names to use for the dict based tf.data.Dataset

    Returns:
        tf.data.Dataset
    """
    assert isinstance(
        dataset.element_spec, tuple
    ), "dataset elements must be tuples"
    len_elem = len(dataset.element_spec)
    assert len_elem == len(
        columns
    ), "The number of columns must be equal to the number of tuple elements"

    def tuple_to_dict(*inputs):
        return {columns[i]: inputs[i] for i in range(len_elem)}

    dataset = dataset.map(tuple_to_dict)
    return dataset

dict_only_ds(ds_handling_method)

Decorator to ensure that the dataset is a dict dataset and that the column_name given as argument matches one of the column names. matches one of the column names. The signature of decorated functions must be function(dataset, args, *kwargs) with column_name either in kwargs or args[0] when relevant.

Parameters:

Name Type Description Default
ds_handling_method Callable

method to decorate

required

Returns:

Type Description
Callable

decorated method

Source code in oodeel/datasets/tf_data_handler.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def dict_only_ds(ds_handling_method: Callable) -> Callable:
    """Decorator to ensure that the dataset is a dict dataset and that the column_name
    given as argument matches one of the column names.
    matches one of the column names. The signature of decorated functions
    must be function(dataset, *args, **kwargs) with column_name either in kwargs or
    args[0] when relevant.


    Args:
        ds_handling_method: method to decorate

    Returns:
        decorated method
    """

    def wrapper(dataset: tf.data.Dataset, *args, **kwargs):
        assert isinstance(dataset.element_spec, dict), "dataset elements must be dicts"

        if "column_name" in kwargs.keys():
            column_name = kwargs["column_name"]
        elif len(args) > 0:
            column_name = args[0]

        # If column_name is provided, check that it is in the dataset column names
        if (len(args) > 0) or ("column_name" in kwargs):
            if isinstance(column_name, str):
                column_name = [column_name]
            for name in column_name:
                assert (
                    name in dataset.element_spec.keys()
                ), f"The input dataset has no column named {name}"
        return ds_handling_method(dataset, *args, **kwargs)

    return wrapper