Skip to content

Utils

import_backend_specific_stuff(model)

Get backend specific data handler, operator and feature extractor class.

Parameters:

Name Type Description Default
model Callable

a model (Keras or PyTorch) used to identify the backend.

required

Returns:

Name Type Description
str

backend name

DataHandler

torch or tf data handler

Operator

torch or tf operator

FeatureExtractor

torch or tf feature extractor class

Source code in oodeel/utils/general_utils.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def import_backend_specific_stuff(model: Callable):
    """Get backend specific data handler, operator and feature extractor class.

    Args:
        model (Callable): a model (Keras or PyTorch) used to identify the backend.

    Returns:
        str: backend name
        DataHandler: torch or tf data handler
        Operator: torch or tf operator
        FeatureExtractor: torch or tf feature extractor class
    """
    if is_from(model, "keras"):
        from ..extractor.keras_feature_extractor import KerasFeatureExtractor
        from ..datasets.tf_data_handler import TFDataHandler
        from ..utils import TFOperator

        backend = "tensorflow"
        data_handler = TFDataHandler()
        op = TFOperator()
        FeatureExtractorClass = KerasFeatureExtractor

    elif is_from(model, "torch"):
        from ..extractor.torch_feature_extractor import TorchFeatureExtractor
        from ..datasets.torch_data_handler import TorchDataHandler
        from ..utils import TorchOperator

        backend = "torch"
        data_handler = TorchDataHandler()
        op = TorchOperator(model)
        FeatureExtractorClass = TorchFeatureExtractor

    else:
        raise NotImplementedError()

    return backend, data_handler, op, FeatureExtractorClass

is_from(model_or_tensor, framework)

Check whether a model or tensor belongs to a specific framework

Parameters:

Name Type Description Default
model_or_tensor Any

Neural network or Tensor

required
framework str

Model or tensor framework ("torch" | "keras" | "tensorflow")

required

Returns:

Name Type Description
bool bool

Whether the model belongs to specified framework or not

Source code in oodeel/utils/general_utils.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def is_from(model_or_tensor: Any, framework: str) -> bool:
    """Check whether a model or tensor belongs to a specific framework

    Args:
        model_or_tensor (Any): Neural network or Tensor
        framework (str):  Model or tensor framework ("torch" | "keras" | "tensorflow")

    Returns:
        bool: Whether the model belongs to specified framework or not
    """
    keywords_list = []
    class_parents = list(model_or_tensor.__class__.__mro__)
    for class_id in class_parents:
        class_list = str(class_id).split("'")[1].split(".")
        for keyword in class_list:
            keywords_list.append(keyword)
    return framework in keywords_list