Skip to content

TFOperator

TFOperator

Bases: Operator

Class to handle tensorflow operations with a unified API

Source code in oodeel/utils/tf_operator.py
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
class TFOperator(Operator):
    """Class to handle tensorflow operations with a unified API"""

    @staticmethod
    def softmax(tensor: TensorType) -> tf.Tensor:
        """Softmax function along the last dimension"""
        return tf.keras.activations.softmax(tensor, axis=-1)

    @staticmethod
    def argmax(tensor: TensorType, dim: Optional[int] = None) -> tf.Tensor:
        """Argmax function"""
        if dim is None:
            return tf.argmax(tf.reshape(tensor, [-1]))
        return tf.argmax(tensor, axis=dim)

    @staticmethod
    def max(
        tensor: TensorType, dim: Optional[int] = None, keepdim: bool = False
    ) -> tf.Tensor:
        """Max function"""
        return tf.reduce_max(tensor, axis=dim, keepdims=keepdim)

    @staticmethod
    def min(
        tensor: TensorType, dim: Optional[int] = None, keepdim: bool = False
    ) -> tf.Tensor:
        """Min function"""
        return tf.reduce_min(tensor, axis=dim, keepdims=keepdim)

    @staticmethod
    def one_hot(tensor: TensorType, num_classes: int) -> tf.Tensor:
        """One hot function"""
        return tf.one_hot(tensor, num_classes)

    @staticmethod
    def sign(tensor: TensorType) -> tf.Tensor:
        """Sign function"""
        return tf.sign(tensor)

    @staticmethod
    def CrossEntropyLoss(reduction: str = "mean"):
        """Cross Entropy Loss from logits"""

        tf_reduction = {"mean": "sum_over_batch_size", "sum": "sum"}[reduction]

        def sanitized_ce_loss(inputs, targets):
            return tf.keras.losses.SparseCategoricalCrossentropy(
                from_logits=True, reduction=tf_reduction
            )(targets, inputs)

        return sanitized_ce_loss

    @staticmethod
    def norm(tensor: TensorType, dim: Optional[int] = None) -> tf.Tensor:
        """Tensor Norm"""
        return tf.norm(tensor, axis=dim)

    @staticmethod
    @tf.function
    def matmul(tensor_1: TensorType, tensor_2: TensorType) -> tf.Tensor:
        """Matmul operation"""
        return tf.matmul(tensor_1, tensor_2)

    @staticmethod
    def convert_to_numpy(tensor: TensorType) -> np.ndarray:
        """Convert tensor into a np.ndarray"""
        return tensor.numpy()

    @staticmethod
    def gradient(func: Callable, inputs: tf.Tensor, *args, **kwargs) -> tf.Tensor:
        """Compute gradients for a batch of samples.

        Args:
            func (Callable): Function used for computing gradient. Must be built with
                tensorflow differentiable operations only, and return a scalar.
            inputs (tf.Tensor): Input tensor wrt which the gradients are computed
            *args: Additional Args for func.
            **kwargs: Additional Kwargs for func.

        Returns:
            tf.Tensor: Gradients computed, with the same shape as the inputs.
        """
        with tf.GradientTape(watch_accessed_variables=False) as tape:
            tape.watch(inputs)
            outputs = func(inputs, *args, **kwargs)
        return tape.gradient(outputs, inputs)

    @staticmethod
    def stack(tensors: List[TensorType], dim: int = 0) -> tf.Tensor:
        "Stack tensors along a new dimension"
        return tf.stack(tensors, dim)

    @staticmethod
    def cat(tensors: List[TensorType], dim: int = 0) -> tf.Tensor:
        "Concatenate tensors in a given dimension"
        return tf.concat(tensors, dim)

    @staticmethod
    def mean(tensor: TensorType, dim: Optional[int] = None) -> tf.Tensor:
        "Mean function"
        return tf.reduce_mean(tensor, dim)

    @staticmethod
    def flatten(tensor: TensorType) -> tf.Tensor:
        "Flatten to 2D tensor of shape (tensor.shape[0], -1)"
        # Flatten the features to 2D (n_batch, n_features)
        return tf.reshape(tensor, shape=[tf.shape(tensor)[0], -1])

    @staticmethod
    def from_numpy(arr: np.ndarray) -> tf.Tensor:
        "Convert a NumPy array to a tensor"
        # TODO change dtype
        return tf.convert_to_tensor(arr)

    @staticmethod
    def t(tensor: TensorType) -> tf.Tensor:
        "Transpose function for tensor of rank 2"
        return tf.transpose(tensor)

    @staticmethod
    def permute(tensor: TensorType, dims) -> tf.Tensor:
        "Transpose function for tensor of rank 2"
        return tf.transpose(tensor, dims)

    @staticmethod
    def diag(tensor: TensorType) -> tf.Tensor:
        "Diagonal function: return the diagonal of a 2D tensor"
        return tf.linalg.diag_part(tensor)

    @staticmethod
    def reshape(tensor: TensorType, shape: List[int]) -> tf.Tensor:
        "Reshape function"
        return tf.reshape(tensor, shape)

    @staticmethod
    def equal(tensor: TensorType, other: Union[TensorType, int, float]) -> tf.Tensor:
        "Computes element-wise equality"
        return tf.math.equal(tensor, other)

    @staticmethod
    def pinv(tensor: TensorType) -> tf.Tensor:
        "Computes the pseudoinverse (Moore-Penrose inverse) of a matrix."
        return tf.linalg.pinv(tensor)

    @staticmethod
    def eigh(tensor: TensorType) -> tf.Tensor:
        "Computes the eigen decomposition of a self-adjoint matrix."
        eigval, eigvec = tf.linalg.eigh(tensor)
        return eigval, eigvec

    @staticmethod
    def quantile(tensor: TensorType, q: float, dim: int = None) -> tf.Tensor:
        "Computes the quantile of a tensor's components. q in (0,1)"
        q = tfp.stats.percentile(tensor, q * 100, axis=dim)
        return float(q) if dim is None else q

    @staticmethod
    def relu(tensor: TensorType) -> tf.Tensor:
        "Apply relu to a tensor"
        return tf.nn.relu(tensor)

    @staticmethod
    def einsum(equation: str, *tensors: TensorType) -> tf.Tensor:
        "Computes the einsum between tensors following equation"
        return tf.einsum(equation, *tensors)

    @staticmethod
    def tril(tensor: TensorType, diagonal: int = 0) -> tf.Tensor:
        "Set the upper triangle of the matrix formed by the last two dimensions of"
        "tensor to zero"
        return tf.experimental.numpy.tril(tensor, k=diagonal)

    @staticmethod
    def sum(tensor: TensorType, dim: Union[tuple, list, int] = None) -> tf.Tensor:
        "sum along dim"
        return tf.reduce_sum(tensor, axis=dim)

    @staticmethod
    def unsqueeze(tensor: TensorType, dim: int) -> tf.Tensor:
        "expand_dim along dim"
        return tf.expand_dims(tensor, dim)

    @staticmethod
    def abs(tensor: TensorType) -> tf.Tensor:
        "compute absolute value"
        return tf.abs(tensor)

    @staticmethod
    def where(
        condition: TensorType,
        input: Union[TensorType, float],
        other: Union[TensorType, float],
    ) -> tf.Tensor:
        "Applies where function to condition"
        return tf.where(condition, input, other)

    @staticmethod
    def percentile(x, q):
        return tfp.stats.percentile(x, q)

CrossEntropyLoss(reduction='mean') staticmethod

Cross Entropy Loss from logits

Source code in oodeel/utils/tf_operator.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
@staticmethod
def CrossEntropyLoss(reduction: str = "mean"):
    """Cross Entropy Loss from logits"""

    tf_reduction = {"mean": "sum_over_batch_size", "sum": "sum"}[reduction]

    def sanitized_ce_loss(inputs, targets):
        return tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True, reduction=tf_reduction
        )(targets, inputs)

    return sanitized_ce_loss

abs(tensor) staticmethod

compute absolute value

Source code in oodeel/utils/tf_operator.py
234
235
236
237
@staticmethod
def abs(tensor: TensorType) -> tf.Tensor:
    "compute absolute value"
    return tf.abs(tensor)

argmax(tensor, dim=None) staticmethod

Argmax function

Source code in oodeel/utils/tf_operator.py
60
61
62
63
64
65
@staticmethod
def argmax(tensor: TensorType, dim: Optional[int] = None) -> tf.Tensor:
    """Argmax function"""
    if dim is None:
        return tf.argmax(tf.reshape(tensor, [-1]))
    return tf.argmax(tensor, axis=dim)

cat(tensors, dim=0) staticmethod

Concatenate tensors in a given dimension

Source code in oodeel/utils/tf_operator.py
144
145
146
147
@staticmethod
def cat(tensors: List[TensorType], dim: int = 0) -> tf.Tensor:
    "Concatenate tensors in a given dimension"
    return tf.concat(tensors, dim)

convert_to_numpy(tensor) staticmethod

Convert tensor into a np.ndarray

Source code in oodeel/utils/tf_operator.py
115
116
117
118
@staticmethod
def convert_to_numpy(tensor: TensorType) -> np.ndarray:
    """Convert tensor into a np.ndarray"""
    return tensor.numpy()

diag(tensor) staticmethod

Diagonal function: return the diagonal of a 2D tensor

Source code in oodeel/utils/tf_operator.py
176
177
178
179
@staticmethod
def diag(tensor: TensorType) -> tf.Tensor:
    "Diagonal function: return the diagonal of a 2D tensor"
    return tf.linalg.diag_part(tensor)

eigh(tensor) staticmethod

Computes the eigen decomposition of a self-adjoint matrix.

Source code in oodeel/utils/tf_operator.py
196
197
198
199
200
@staticmethod
def eigh(tensor: TensorType) -> tf.Tensor:
    "Computes the eigen decomposition of a self-adjoint matrix."
    eigval, eigvec = tf.linalg.eigh(tensor)
    return eigval, eigvec

einsum(equation, *tensors) staticmethod

Computes the einsum between tensors following equation

Source code in oodeel/utils/tf_operator.py
213
214
215
216
@staticmethod
def einsum(equation: str, *tensors: TensorType) -> tf.Tensor:
    "Computes the einsum between tensors following equation"
    return tf.einsum(equation, *tensors)

equal(tensor, other) staticmethod

Computes element-wise equality

Source code in oodeel/utils/tf_operator.py
186
187
188
189
@staticmethod
def equal(tensor: TensorType, other: Union[TensorType, int, float]) -> tf.Tensor:
    "Computes element-wise equality"
    return tf.math.equal(tensor, other)

flatten(tensor) staticmethod

Flatten to 2D tensor of shape (tensor.shape[0], -1)

Source code in oodeel/utils/tf_operator.py
154
155
156
157
158
@staticmethod
def flatten(tensor: TensorType) -> tf.Tensor:
    "Flatten to 2D tensor of shape (tensor.shape[0], -1)"
    # Flatten the features to 2D (n_batch, n_features)
    return tf.reshape(tensor, shape=[tf.shape(tensor)[0], -1])

from_numpy(arr) staticmethod

Convert a NumPy array to a tensor

Source code in oodeel/utils/tf_operator.py
160
161
162
163
164
@staticmethod
def from_numpy(arr: np.ndarray) -> tf.Tensor:
    "Convert a NumPy array to a tensor"
    # TODO change dtype
    return tf.convert_to_tensor(arr)

gradient(func, inputs, *args, **kwargs) staticmethod

Compute gradients for a batch of samples.

Parameters:

Name Type Description Default
func Callable

Function used for computing gradient. Must be built with tensorflow differentiable operations only, and return a scalar.

required
inputs Tensor

Input tensor wrt which the gradients are computed

required
*args

Additional Args for func.

()
**kwargs

Additional Kwargs for func.

{}

Returns:

Type Description
Tensor

tf.Tensor: Gradients computed, with the same shape as the inputs.

Source code in oodeel/utils/tf_operator.py
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
@staticmethod
def gradient(func: Callable, inputs: tf.Tensor, *args, **kwargs) -> tf.Tensor:
    """Compute gradients for a batch of samples.

    Args:
        func (Callable): Function used for computing gradient. Must be built with
            tensorflow differentiable operations only, and return a scalar.
        inputs (tf.Tensor): Input tensor wrt which the gradients are computed
        *args: Additional Args for func.
        **kwargs: Additional Kwargs for func.

    Returns:
        tf.Tensor: Gradients computed, with the same shape as the inputs.
    """
    with tf.GradientTape(watch_accessed_variables=False) as tape:
        tape.watch(inputs)
        outputs = func(inputs, *args, **kwargs)
    return tape.gradient(outputs, inputs)

matmul(tensor_1, tensor_2) staticmethod

Matmul operation

Source code in oodeel/utils/tf_operator.py
109
110
111
112
113
@staticmethod
@tf.function
def matmul(tensor_1: TensorType, tensor_2: TensorType) -> tf.Tensor:
    """Matmul operation"""
    return tf.matmul(tensor_1, tensor_2)

max(tensor, dim=None, keepdim=False) staticmethod

Max function

Source code in oodeel/utils/tf_operator.py
67
68
69
70
71
72
@staticmethod
def max(
    tensor: TensorType, dim: Optional[int] = None, keepdim: bool = False
) -> tf.Tensor:
    """Max function"""
    return tf.reduce_max(tensor, axis=dim, keepdims=keepdim)

mean(tensor, dim=None) staticmethod

Mean function

Source code in oodeel/utils/tf_operator.py
149
150
151
152
@staticmethod
def mean(tensor: TensorType, dim: Optional[int] = None) -> tf.Tensor:
    "Mean function"
    return tf.reduce_mean(tensor, dim)

min(tensor, dim=None, keepdim=False) staticmethod

Min function

Source code in oodeel/utils/tf_operator.py
74
75
76
77
78
79
@staticmethod
def min(
    tensor: TensorType, dim: Optional[int] = None, keepdim: bool = False
) -> tf.Tensor:
    """Min function"""
    return tf.reduce_min(tensor, axis=dim, keepdims=keepdim)

norm(tensor, dim=None) staticmethod

Tensor Norm

Source code in oodeel/utils/tf_operator.py
104
105
106
107
@staticmethod
def norm(tensor: TensorType, dim: Optional[int] = None) -> tf.Tensor:
    """Tensor Norm"""
    return tf.norm(tensor, axis=dim)

one_hot(tensor, num_classes) staticmethod

One hot function

Source code in oodeel/utils/tf_operator.py
81
82
83
84
@staticmethod
def one_hot(tensor: TensorType, num_classes: int) -> tf.Tensor:
    """One hot function"""
    return tf.one_hot(tensor, num_classes)

permute(tensor, dims) staticmethod

Transpose function for tensor of rank 2

Source code in oodeel/utils/tf_operator.py
171
172
173
174
@staticmethod
def permute(tensor: TensorType, dims) -> tf.Tensor:
    "Transpose function for tensor of rank 2"
    return tf.transpose(tensor, dims)

pinv(tensor) staticmethod

Computes the pseudoinverse (Moore-Penrose inverse) of a matrix.

Source code in oodeel/utils/tf_operator.py
191
192
193
194
@staticmethod
def pinv(tensor: TensorType) -> tf.Tensor:
    "Computes the pseudoinverse (Moore-Penrose inverse) of a matrix."
    return tf.linalg.pinv(tensor)

quantile(tensor, q, dim=None) staticmethod

Computes the quantile of a tensor's components. q in (0,1)

Source code in oodeel/utils/tf_operator.py
202
203
204
205
206
@staticmethod
def quantile(tensor: TensorType, q: float, dim: int = None) -> tf.Tensor:
    "Computes the quantile of a tensor's components. q in (0,1)"
    q = tfp.stats.percentile(tensor, q * 100, axis=dim)
    return float(q) if dim is None else q

relu(tensor) staticmethod

Apply relu to a tensor

Source code in oodeel/utils/tf_operator.py
208
209
210
211
@staticmethod
def relu(tensor: TensorType) -> tf.Tensor:
    "Apply relu to a tensor"
    return tf.nn.relu(tensor)

reshape(tensor, shape) staticmethod

Reshape function

Source code in oodeel/utils/tf_operator.py
181
182
183
184
@staticmethod
def reshape(tensor: TensorType, shape: List[int]) -> tf.Tensor:
    "Reshape function"
    return tf.reshape(tensor, shape)

sign(tensor) staticmethod

Sign function

Source code in oodeel/utils/tf_operator.py
86
87
88
89
@staticmethod
def sign(tensor: TensorType) -> tf.Tensor:
    """Sign function"""
    return tf.sign(tensor)

softmax(tensor) staticmethod

Softmax function along the last dimension

Source code in oodeel/utils/tf_operator.py
55
56
57
58
@staticmethod
def softmax(tensor: TensorType) -> tf.Tensor:
    """Softmax function along the last dimension"""
    return tf.keras.activations.softmax(tensor, axis=-1)

stack(tensors, dim=0) staticmethod

Stack tensors along a new dimension

Source code in oodeel/utils/tf_operator.py
139
140
141
142
@staticmethod
def stack(tensors: List[TensorType], dim: int = 0) -> tf.Tensor:
    "Stack tensors along a new dimension"
    return tf.stack(tensors, dim)

sum(tensor, dim=None) staticmethod

sum along dim

Source code in oodeel/utils/tf_operator.py
224
225
226
227
@staticmethod
def sum(tensor: TensorType, dim: Union[tuple, list, int] = None) -> tf.Tensor:
    "sum along dim"
    return tf.reduce_sum(tensor, axis=dim)

t(tensor) staticmethod

Transpose function for tensor of rank 2

Source code in oodeel/utils/tf_operator.py
166
167
168
169
@staticmethod
def t(tensor: TensorType) -> tf.Tensor:
    "Transpose function for tensor of rank 2"
    return tf.transpose(tensor)

tril(tensor, diagonal=0) staticmethod

Set the upper triangle of the matrix formed by the last two dimensions of

Source code in oodeel/utils/tf_operator.py
218
219
220
221
222
@staticmethod
def tril(tensor: TensorType, diagonal: int = 0) -> tf.Tensor:
    "Set the upper triangle of the matrix formed by the last two dimensions of"
    "tensor to zero"
    return tf.experimental.numpy.tril(tensor, k=diagonal)

unsqueeze(tensor, dim) staticmethod

expand_dim along dim

Source code in oodeel/utils/tf_operator.py
229
230
231
232
@staticmethod
def unsqueeze(tensor: TensorType, dim: int) -> tf.Tensor:
    "expand_dim along dim"
    return tf.expand_dims(tensor, dim)

where(condition, input, other) staticmethod

Applies where function to condition

Source code in oodeel/utils/tf_operator.py
239
240
241
242
243
244
245
246
@staticmethod
def where(
    condition: TensorType,
    input: Union[TensorType, float],
    other: Union[TensorType, float],
) -> tf.Tensor:
    "Applies where function to condition"
    return tf.where(condition, input, other)

sanitize_input(tensor_arg_func)

ensures the decorated function receives a tf.Tensor

Source code in oodeel/utils/tf_operator.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def sanitize_input(tensor_arg_func: Callable):
    """ensures the decorated function receives a tf.Tensor"""

    def wrapper(obj, tensor, *args, **kwargs):
        if isinstance(tensor, tf.Tensor):
            pass
        elif is_from(tensor, "torch"):
            tensor = tf.convert_to_tensor(tensor.numpy())
        else:
            tensor = tf.convert_to_tensor(tensor)

        return tensor_arg_func(obj, tensor, *args, **kwargs)

    return wrapper