Skip to content

TorchOperator

TorchOperator

Bases: Operator

Class to handle torch operations with a unified API

Source code in oodeel/utils/torch_operator.py
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
class TorchOperator(Operator):
    """Class to handle torch operations with a unified API"""

    def __init__(self, model: Optional[torch.nn.Module] = None):
        if model is not None:
            self._device = next(model.parameters()).device
        else:
            self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    @staticmethod
    def softmax(tensor: TensorType) -> torch.Tensor:
        """Softmax function along the last dimension"""
        return torch.nn.functional.softmax(tensor, dim=-1)

    @staticmethod
    def argmax(tensor: TensorType, dim: Optional[int] = None) -> torch.Tensor:
        """Argmax function"""
        return torch.argmax(tensor, dim=dim)

    @staticmethod
    def max(
        tensor: TensorType, dim: Optional[int] = None, keepdim: Optional[bool] = False
    ) -> torch.Tensor:
        """Max function"""
        if dim is None:
            return torch.max(tensor)
        else:
            return torch.max(tensor, dim, keepdim=keepdim)[0]

    @staticmethod
    def min(
        tensor: TensorType, dim: Optional[int] = None, keepdim: bool = False
    ) -> torch.Tensor:
        """Min function"""
        if dim is None:
            return torch.min(tensor)
        else:
            return torch.min(tensor, dim, keepdim=keepdim)[0]

    @staticmethod
    def one_hot(tensor: TensorType, num_classes: int) -> torch.Tensor:
        """One hot function"""
        return torch.nn.functional.one_hot(tensor, num_classes)

    @staticmethod
    def sign(tensor: TensorType) -> torch.Tensor:
        """Sign function"""
        return torch.sign(tensor)

    @staticmethod
    def CrossEntropyLoss(reduction: str = "mean"):
        """Cross Entropy Loss from logits"""

        def sanitized_ce_loss(inputs, targets):
            return torch.nn.CrossEntropyLoss(reduction=reduction)(inputs, targets)

        return sanitized_ce_loss

    @staticmethod
    def norm(tensor: TensorType, dim: Optional[int] = None) -> torch.Tensor:
        """Tensor Norm"""
        return torch.norm(tensor, dim=dim)

    @staticmethod
    def matmul(tensor_1: TensorType, tensor_2: TensorType) -> torch.Tensor:
        """Matmul operation"""
        return torch.matmul(tensor_1, tensor_2)

    @staticmethod
    def convert_from_tensorflow(tensor: TensorType) -> torch.Tensor:
        """Convert a tensorflow tensor into a torch tensor

        Used when using a pytorch model on a dataset loaded from tensorflow datasets
        """
        return torch.Tensor(tensor.numpy())

    @staticmethod
    def convert_to_numpy(tensor: TensorType) -> np.ndarray:
        """Convert tensor into a np.ndarray"""
        if tensor.device != "cpu":
            tensor = tensor.to("cpu")
        return tensor.detach().numpy()

    @staticmethod
    def gradient(func: Callable, inputs: torch.Tensor, *args, **kwargs) -> torch.Tensor:
        """Compute gradients for a batch of samples.

        Args:
            func (Callable): Function used for computing gradient. Must be built with
                torch differentiable operations only, and return a scalar.
            inputs (torch.Tensor): Input tensor wrt which the gradients are computed
            *args: Additional Args for func.
            **kwargs: Additional Kwargs for func.

        Returns:
            torch.Tensor: Gradients computed, with the same shape as the inputs.
        """
        inputs.requires_grad_(True)
        outputs = func(inputs, *args, **kwargs)
        gradients = torch.autograd.grad(outputs, inputs)
        inputs.requires_grad_(False)
        return gradients[0]

    @staticmethod
    def stack(tensors: List[TensorType], dim: int = 0) -> torch.Tensor:
        "Stack tensors along a new dimension"
        return torch.stack(tensors, dim)

    @staticmethod
    def cat(tensors: List[TensorType], dim: int = 0) -> torch.Tensor:
        "Concatenate tensors in a given dimension"
        return torch.cat(tensors, dim)

    @staticmethod
    def mean(tensor: TensorType, dim: Optional[int] = None) -> torch.Tensor:
        "Mean function"
        if dim is None:
            return torch.mean(tensor)
        else:
            return torch.mean(tensor, dim)

    @staticmethod
    def flatten(tensor: TensorType) -> torch.Tensor:
        "Flatten function"
        # Flatten the features to 2D (n_batch, n_features)
        return tensor.view(tensor.size(0), -1)

    def from_numpy(self, arr: np.ndarray) -> torch.Tensor:
        "Convert a NumPy array to a tensor"
        # TODO change dtype
        return torch.tensor(arr).to(self._device)

    @staticmethod
    def t(tensor: TensorType) -> torch.Tensor:
        "Transpose function for tensor of rank 2"
        return tensor.t()

    @staticmethod
    def permute(tensor: TensorType, dims) -> torch.Tensor:
        "Transpose function for tensor of rank 2"
        return torch.permute(tensor, dims)

    @staticmethod
    def diag(tensor: TensorType) -> torch.Tensor:
        "Diagonal function: return the diagonal of a 2D tensor"
        return tensor.diag()

    @staticmethod
    def reshape(tensor: TensorType, shape: List[int]) -> torch.Tensor:
        "Reshape function"
        return tensor.view(*shape)

    @staticmethod
    def equal(tensor: TensorType, other: Union[TensorType, int, float]) -> torch.Tensor:
        "Computes element-wise equality"
        return torch.eq(tensor, other)

    @staticmethod
    def pinv(tensor: TensorType) -> torch.Tensor:
        "Computes the pseudoinverse (Moore-Penrose inverse) of a matrix."
        return torch.linalg.pinv(tensor)

    @staticmethod
    def eigh(tensor: TensorType) -> torch.Tensor:
        "Computes the eigen decomposition of a self-adjoint matrix."
        eigval, eigvec = torch.linalg.eigh(tensor)
        return eigval, eigvec

    @staticmethod
    def quantile(tensor: TensorType, q: float, dim: int = None) -> torch.Tensor:
        "Computes the quantile of a tensor's components. q in (0,1)"
        if dim is None:
            # keep the 16 millions first elements (see torch.quantile issue:
            # https://github.com/pytorch/pytorch/issues/64947)
            tensor_flatten = tensor.view(-1)[:16_000_000]
            return torch.quantile(tensor_flatten, q).item()
        else:
            return torch.quantile(tensor, q, dim)

    @staticmethod
    def relu(tensor: TensorType) -> torch.Tensor:
        "Apply relu to a tensor"
        return torch.nn.functional.relu(tensor)

    @staticmethod
    def einsum(equation: str, *tensors: TensorType) -> torch.Tensor:
        "Computes the einsum between tensors following equation"
        return torch.einsum(equation, *tensors)

    @staticmethod
    def tril(tensor: TensorType, diagonal: int = 0) -> torch.Tensor:
        "Set the upper triangle of the matrix formed by the last two dimensions of"
        "tensor to zero"
        return torch.tril(tensor, diagonal)

    @staticmethod
    def sum(tensor: TensorType, dim: Union[tuple, list, int] = None) -> torch.Tensor:
        "sum along dim"
        return torch.sum(tensor, dim)

    @staticmethod
    def unsqueeze(tensor: TensorType, dim: int) -> torch.Tensor:
        "unsqueeze along dim"
        return torch.unsqueeze(tensor, dim)

    @staticmethod
    def abs(tensor: TensorType) -> torch.Tensor:
        "compute absolute value"
        return torch.abs(tensor)

    @staticmethod
    def where(
        condition: TensorType,
        input: Union[TensorType, float],
        other: Union[TensorType, float],
    ) -> torch.Tensor:
        "Applies where function , to condition"
        return torch.where(condition, input, other)

CrossEntropyLoss(reduction='mean') staticmethod

Cross Entropy Loss from logits

Source code in oodeel/utils/torch_operator.py
101
102
103
104
105
106
107
108
@staticmethod
def CrossEntropyLoss(reduction: str = "mean"):
    """Cross Entropy Loss from logits"""

    def sanitized_ce_loss(inputs, targets):
        return torch.nn.CrossEntropyLoss(reduction=reduction)(inputs, targets)

    return sanitized_ce_loss

abs(tensor) staticmethod

compute absolute value

Source code in oodeel/utils/torch_operator.py
257
258
259
260
@staticmethod
def abs(tensor: TensorType) -> torch.Tensor:
    "compute absolute value"
    return torch.abs(tensor)

argmax(tensor, dim=None) staticmethod

Argmax function

Source code in oodeel/utils/torch_operator.py
66
67
68
69
@staticmethod
def argmax(tensor: TensorType, dim: Optional[int] = None) -> torch.Tensor:
    """Argmax function"""
    return torch.argmax(tensor, dim=dim)

cat(tensors, dim=0) staticmethod

Concatenate tensors in a given dimension

Source code in oodeel/utils/torch_operator.py
160
161
162
163
@staticmethod
def cat(tensors: List[TensorType], dim: int = 0) -> torch.Tensor:
    "Concatenate tensors in a given dimension"
    return torch.cat(tensors, dim)

convert_from_tensorflow(tensor) staticmethod

Convert a tensorflow tensor into a torch tensor

Used when using a pytorch model on a dataset loaded from tensorflow datasets

Source code in oodeel/utils/torch_operator.py
120
121
122
123
124
125
126
@staticmethod
def convert_from_tensorflow(tensor: TensorType) -> torch.Tensor:
    """Convert a tensorflow tensor into a torch tensor

    Used when using a pytorch model on a dataset loaded from tensorflow datasets
    """
    return torch.Tensor(tensor.numpy())

convert_to_numpy(tensor) staticmethod

Convert tensor into a np.ndarray

Source code in oodeel/utils/torch_operator.py
128
129
130
131
132
133
@staticmethod
def convert_to_numpy(tensor: TensorType) -> np.ndarray:
    """Convert tensor into a np.ndarray"""
    if tensor.device != "cpu":
        tensor = tensor.to("cpu")
    return tensor.detach().numpy()

diag(tensor) staticmethod

Diagonal function: return the diagonal of a 2D tensor

Source code in oodeel/utils/torch_operator.py
194
195
196
197
@staticmethod
def diag(tensor: TensorType) -> torch.Tensor:
    "Diagonal function: return the diagonal of a 2D tensor"
    return tensor.diag()

eigh(tensor) staticmethod

Computes the eigen decomposition of a self-adjoint matrix.

Source code in oodeel/utils/torch_operator.py
214
215
216
217
218
@staticmethod
def eigh(tensor: TensorType) -> torch.Tensor:
    "Computes the eigen decomposition of a self-adjoint matrix."
    eigval, eigvec = torch.linalg.eigh(tensor)
    return eigval, eigvec

einsum(equation, *tensors) staticmethod

Computes the einsum between tensors following equation

Source code in oodeel/utils/torch_operator.py
236
237
238
239
@staticmethod
def einsum(equation: str, *tensors: TensorType) -> torch.Tensor:
    "Computes the einsum between tensors following equation"
    return torch.einsum(equation, *tensors)

equal(tensor, other) staticmethod

Computes element-wise equality

Source code in oodeel/utils/torch_operator.py
204
205
206
207
@staticmethod
def equal(tensor: TensorType, other: Union[TensorType, int, float]) -> torch.Tensor:
    "Computes element-wise equality"
    return torch.eq(tensor, other)

flatten(tensor) staticmethod

Flatten function

Source code in oodeel/utils/torch_operator.py
173
174
175
176
177
@staticmethod
def flatten(tensor: TensorType) -> torch.Tensor:
    "Flatten function"
    # Flatten the features to 2D (n_batch, n_features)
    return tensor.view(tensor.size(0), -1)

from_numpy(arr)

Convert a NumPy array to a tensor

Source code in oodeel/utils/torch_operator.py
179
180
181
182
def from_numpy(self, arr: np.ndarray) -> torch.Tensor:
    "Convert a NumPy array to a tensor"
    # TODO change dtype
    return torch.tensor(arr).to(self._device)

gradient(func, inputs, *args, **kwargs) staticmethod

Compute gradients for a batch of samples.

Parameters:

Name Type Description Default
func Callable

Function used for computing gradient. Must be built with torch differentiable operations only, and return a scalar.

required
inputs Tensor

Input tensor wrt which the gradients are computed

required
*args

Additional Args for func.

()
**kwargs

Additional Kwargs for func.

{}

Returns:

Type Description
Tensor

torch.Tensor: Gradients computed, with the same shape as the inputs.

Source code in oodeel/utils/torch_operator.py
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
@staticmethod
def gradient(func: Callable, inputs: torch.Tensor, *args, **kwargs) -> torch.Tensor:
    """Compute gradients for a batch of samples.

    Args:
        func (Callable): Function used for computing gradient. Must be built with
            torch differentiable operations only, and return a scalar.
        inputs (torch.Tensor): Input tensor wrt which the gradients are computed
        *args: Additional Args for func.
        **kwargs: Additional Kwargs for func.

    Returns:
        torch.Tensor: Gradients computed, with the same shape as the inputs.
    """
    inputs.requires_grad_(True)
    outputs = func(inputs, *args, **kwargs)
    gradients = torch.autograd.grad(outputs, inputs)
    inputs.requires_grad_(False)
    return gradients[0]

matmul(tensor_1, tensor_2) staticmethod

Matmul operation

Source code in oodeel/utils/torch_operator.py
115
116
117
118
@staticmethod
def matmul(tensor_1: TensorType, tensor_2: TensorType) -> torch.Tensor:
    """Matmul operation"""
    return torch.matmul(tensor_1, tensor_2)

max(tensor, dim=None, keepdim=False) staticmethod

Max function

Source code in oodeel/utils/torch_operator.py
71
72
73
74
75
76
77
78
79
@staticmethod
def max(
    tensor: TensorType, dim: Optional[int] = None, keepdim: Optional[bool] = False
) -> torch.Tensor:
    """Max function"""
    if dim is None:
        return torch.max(tensor)
    else:
        return torch.max(tensor, dim, keepdim=keepdim)[0]

mean(tensor, dim=None) staticmethod

Mean function

Source code in oodeel/utils/torch_operator.py
165
166
167
168
169
170
171
@staticmethod
def mean(tensor: TensorType, dim: Optional[int] = None) -> torch.Tensor:
    "Mean function"
    if dim is None:
        return torch.mean(tensor)
    else:
        return torch.mean(tensor, dim)

min(tensor, dim=None, keepdim=False) staticmethod

Min function

Source code in oodeel/utils/torch_operator.py
81
82
83
84
85
86
87
88
89
@staticmethod
def min(
    tensor: TensorType, dim: Optional[int] = None, keepdim: bool = False
) -> torch.Tensor:
    """Min function"""
    if dim is None:
        return torch.min(tensor)
    else:
        return torch.min(tensor, dim, keepdim=keepdim)[0]

norm(tensor, dim=None) staticmethod

Tensor Norm

Source code in oodeel/utils/torch_operator.py
110
111
112
113
@staticmethod
def norm(tensor: TensorType, dim: Optional[int] = None) -> torch.Tensor:
    """Tensor Norm"""
    return torch.norm(tensor, dim=dim)

one_hot(tensor, num_classes) staticmethod

One hot function

Source code in oodeel/utils/torch_operator.py
91
92
93
94
@staticmethod
def one_hot(tensor: TensorType, num_classes: int) -> torch.Tensor:
    """One hot function"""
    return torch.nn.functional.one_hot(tensor, num_classes)

permute(tensor, dims) staticmethod

Transpose function for tensor of rank 2

Source code in oodeel/utils/torch_operator.py
189
190
191
192
@staticmethod
def permute(tensor: TensorType, dims) -> torch.Tensor:
    "Transpose function for tensor of rank 2"
    return torch.permute(tensor, dims)

pinv(tensor) staticmethod

Computes the pseudoinverse (Moore-Penrose inverse) of a matrix.

Source code in oodeel/utils/torch_operator.py
209
210
211
212
@staticmethod
def pinv(tensor: TensorType) -> torch.Tensor:
    "Computes the pseudoinverse (Moore-Penrose inverse) of a matrix."
    return torch.linalg.pinv(tensor)

quantile(tensor, q, dim=None) staticmethod

Computes the quantile of a tensor's components. q in (0,1)

Source code in oodeel/utils/torch_operator.py
220
221
222
223
224
225
226
227
228
229
@staticmethod
def quantile(tensor: TensorType, q: float, dim: int = None) -> torch.Tensor:
    "Computes the quantile of a tensor's components. q in (0,1)"
    if dim is None:
        # keep the 16 millions first elements (see torch.quantile issue:
        # https://github.com/pytorch/pytorch/issues/64947)
        tensor_flatten = tensor.view(-1)[:16_000_000]
        return torch.quantile(tensor_flatten, q).item()
    else:
        return torch.quantile(tensor, q, dim)

relu(tensor) staticmethod

Apply relu to a tensor

Source code in oodeel/utils/torch_operator.py
231
232
233
234
@staticmethod
def relu(tensor: TensorType) -> torch.Tensor:
    "Apply relu to a tensor"
    return torch.nn.functional.relu(tensor)

reshape(tensor, shape) staticmethod

Reshape function

Source code in oodeel/utils/torch_operator.py
199
200
201
202
@staticmethod
def reshape(tensor: TensorType, shape: List[int]) -> torch.Tensor:
    "Reshape function"
    return tensor.view(*shape)

sign(tensor) staticmethod

Sign function

Source code in oodeel/utils/torch_operator.py
96
97
98
99
@staticmethod
def sign(tensor: TensorType) -> torch.Tensor:
    """Sign function"""
    return torch.sign(tensor)

softmax(tensor) staticmethod

Softmax function along the last dimension

Source code in oodeel/utils/torch_operator.py
61
62
63
64
@staticmethod
def softmax(tensor: TensorType) -> torch.Tensor:
    """Softmax function along the last dimension"""
    return torch.nn.functional.softmax(tensor, dim=-1)

stack(tensors, dim=0) staticmethod

Stack tensors along a new dimension

Source code in oodeel/utils/torch_operator.py
155
156
157
158
@staticmethod
def stack(tensors: List[TensorType], dim: int = 0) -> torch.Tensor:
    "Stack tensors along a new dimension"
    return torch.stack(tensors, dim)

sum(tensor, dim=None) staticmethod

sum along dim

Source code in oodeel/utils/torch_operator.py
247
248
249
250
@staticmethod
def sum(tensor: TensorType, dim: Union[tuple, list, int] = None) -> torch.Tensor:
    "sum along dim"
    return torch.sum(tensor, dim)

t(tensor) staticmethod

Transpose function for tensor of rank 2

Source code in oodeel/utils/torch_operator.py
184
185
186
187
@staticmethod
def t(tensor: TensorType) -> torch.Tensor:
    "Transpose function for tensor of rank 2"
    return tensor.t()

tril(tensor, diagonal=0) staticmethod

Set the upper triangle of the matrix formed by the last two dimensions of

Source code in oodeel/utils/torch_operator.py
241
242
243
244
245
@staticmethod
def tril(tensor: TensorType, diagonal: int = 0) -> torch.Tensor:
    "Set the upper triangle of the matrix formed by the last two dimensions of"
    "tensor to zero"
    return torch.tril(tensor, diagonal)

unsqueeze(tensor, dim) staticmethod

unsqueeze along dim

Source code in oodeel/utils/torch_operator.py
252
253
254
255
@staticmethod
def unsqueeze(tensor: TensorType, dim: int) -> torch.Tensor:
    "unsqueeze along dim"
    return torch.unsqueeze(tensor, dim)

where(condition, input, other) staticmethod

Applies where function , to condition

Source code in oodeel/utils/torch_operator.py
262
263
264
265
266
267
268
269
@staticmethod
def where(
    condition: TensorType,
    input: Union[TensorType, float],
    other: Union[TensorType, float],
) -> torch.Tensor:
    "Applies where function , to condition"
    return torch.where(condition, input, other)

sanitize_input(tensor_arg_func)

ensures the decorated function receives a torch.Tensor

Source code in oodeel/utils/torch_operator.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def sanitize_input(tensor_arg_func: Callable):
    """ensures the decorated function receives a torch.Tensor"""

    def wrapper(obj, tensor, *args, **kwargs):
        if isinstance(tensor, torch.Tensor):
            pass
        elif is_from(tensor, "tensorflow"):
            tensor = torch.Tensor(tensor.numpy())
        else:
            tensor = torch.Tensor(tensor)

        return tensor_arg_func(obj, tensor, *args, **kwargs)

    return wrapper