Skip to content

losses

CosineLoss

Bases: Module

Source code in orthogonium\losses.py
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
class CosineLoss(nn.Module):
    def __init__(self):
        """
        A class that implements the Cosine Loss for measuring the cosine similarity
        between predictions and targets. Designed for use in scenarios involving
        angle-based loss calculations or similarity measurements.

        Attributes:
            None

        """
        super(CosineLoss, self).__init__()

    def forward(self, yp, yt):
        return -torch.nn.functional.cosine_similarity(
            yp, torch.nn.functional.one_hot(yt, yp.shape[1])
        ).mean()

__init__()

A class that implements the Cosine Loss for measuring the cosine similarity between predictions and targets. Designed for use in scenarios involving angle-based loss calculations or similarity measurements.

Source code in orthogonium\losses.py
158
159
160
161
162
163
164
165
166
167
168
def __init__(self):
    """
    A class that implements the Cosine Loss for measuring the cosine similarity
    between predictions and targets. Designed for use in scenarios involving
    angle-based loss calculations or similarity measurements.

    Attributes:
        None

    """
    super(CosineLoss, self).__init__()

LossXent

Bases: Module

Source code in orthogonium\losses.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
class LossXent(nn.Module):
    def __init__(self, n_classes, offset=2.12132, temperature=0.25):
        """
        A custom loss function class for cross-entropy calculation.

        This class initializes a cross-entropy loss criterion along with additional
        parameters, such as an offset and a temperature factor, to allow a finer control over
        the accuracy/robustness tradeoff during training.

        Attributes:
            criterion (nn.CrossEntropyLoss): The PyTorch cross-entropy loss criterion.
            n_classes (int): The number of classes present in the dataset.
            offset (float): An offset value for customizing the loss computation.
            temperature (float): A temperature factor for scaling logits during loss calculation.

        Parameters:
            n_classes (int): The number of classes in the dataset.
            offset (float, optional): The offset value for loss computation. Default is 2.12132.
            temperature (float, optional): The temperature scaling factor. Default is 0.25.
        """
        super(LossXent, self).__init__()
        self.criterion = nn.CrossEntropyLoss()
        self.n_classes = n_classes
        self.offset = offset
        self.temperature = temperature

    def __call__(self, outputs, labels):
        one_hot_labels = torch.nn.functional.one_hot(labels, num_classes=self.n_classes)
        offset_outputs = outputs - self.offset * one_hot_labels
        offset_outputs /= self.temperature
        loss = self.criterion(offset_outputs, labels) * self.temperature
        return loss

__init__(n_classes, offset=2.12132, temperature=0.25)

A custom loss function class for cross-entropy calculation.

This class initializes a cross-entropy loss criterion along with additional parameters, such as an offset and a temperature factor, to allow a finer control over the accuracy/robustness tradeoff during training.

Attributes:

Name Type Description
criterion CrossEntropyLoss

The PyTorch cross-entropy loss criterion.

n_classes int

The number of classes present in the dataset.

offset float

An offset value for customizing the loss computation.

temperature float

A temperature factor for scaling logits during loss calculation.

Parameters:

Name Type Description Default
n_classes int

The number of classes in the dataset.

required
offset float

The offset value for loss computation. Default is 2.12132.

2.12132
temperature float

The temperature scaling factor. Default is 0.25.

0.25
Source code in orthogonium\losses.py
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
def __init__(self, n_classes, offset=2.12132, temperature=0.25):
    """
    A custom loss function class for cross-entropy calculation.

    This class initializes a cross-entropy loss criterion along with additional
    parameters, such as an offset and a temperature factor, to allow a finer control over
    the accuracy/robustness tradeoff during training.

    Attributes:
        criterion (nn.CrossEntropyLoss): The PyTorch cross-entropy loss criterion.
        n_classes (int): The number of classes present in the dataset.
        offset (float): An offset value for customizing the loss computation.
        temperature (float): A temperature factor for scaling logits during loss calculation.

    Parameters:
        n_classes (int): The number of classes in the dataset.
        offset (float, optional): The offset value for loss computation. Default is 2.12132.
        temperature (float, optional): The temperature scaling factor. Default is 0.25.
    """
    super(LossXent, self).__init__()
    self.criterion = nn.CrossEntropyLoss()
    self.n_classes = n_classes
    self.offset = offset
    self.temperature = temperature

SoftHKRMulticlassLoss

Bases: Module

Source code in orthogonium\losses.py
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
class SoftHKRMulticlassLoss(torch.nn.Module):
    def __init__(
        self,
        alpha=10.0,
        min_margin=1.0,
        alpha_mean=0.99,
        temperature=1.0,
    ):
        """
        The multiclass version of HKR with softmax. This is done by computing
        the HKR term over each class and averaging the results.

        Note that `y_true` could be either one-hot encoded, +/-1 values.


        Args:
            alpha (float): regularization factor (0 <= alpha <= 1),
                0 for KR only, 1 for hinge only
            min_margin (float): margin to enforce.
            temperature (float): factor for softmax  temperature
                (higher value increases the weight of the highest non y_true logits)
            alpha_mean (float): geometric mean factor
            one_hot_ytrue (bool): set to True when y_true are one hot encoded (0 or 1),
                and False when y_true already signed bases (for instance +/-1)
            reduction: passed to tf.keras.Loss constructor
            name (str): passed to tf.keras.Loss constructor

        """
        assert (alpha >= 0) and (alpha <= 1), "alpha must in [0,1]"
        self.alpha = torch.tensor(alpha, dtype=torch.float32)
        self.min_margin_v = min_margin
        self.alpha_mean = alpha_mean

        self.current_mean = torch.tensor((self.min_margin_v,), dtype=torch.float32)
        """    constraint=lambda x: torch.clamp(x, 0.005, 1000),
            name="current_mean",
        )"""

        self.temperature = temperature * self.min_margin_v
        if alpha == 1.0:  # alpha = 1.0 => hinge only
            self.fct = self.multiclass_hinge_soft
        else:
            if alpha == 0.0:  # alpha = 0.0 => KR only
                self.fct = self.kr_soft
            else:
                self.fct = self.hkr

        super(SoftHKRMulticlassLoss, self).__init__()

    def clamp_current_mean(self, x):
        return torch.clamp(x, 0.005, 1000)

    def _update_mean(self, y_pred):
        self.current_mean = self.current_mean.to(y_pred.device)
        current_global_mean = torch.mean(torch.abs(y_pred)).to(
            dtype=self.current_mean.dtype
        )
        current_global_mean = (
            self.alpha_mean * self.current_mean
            + (1 - self.alpha_mean) * current_global_mean
        )
        self.current_mean = self.clamp_current_mean(current_global_mean).detach()
        total_mean = current_global_mean
        total_mean = torch.clamp(total_mean, self.min_margin_v, 20000)
        return total_mean

    def computeTemperatureSoftMax(self, y_true, y_pred):
        total_mean = self._update_mean(y_pred)
        current_temperature = (
            torch.clamp(self.temperature / total_mean, 0.005, 250)
            .to(dtype=y_pred.dtype)
            .detach()
        )
        min_value = torch.tensor(torch.finfo(torch.float32).min, dtype=y_pred.dtype).to(
            device=y_pred.device
        )
        opposite_values = torch.where(
            y_true > 0, min_value, current_temperature * y_pred
        )
        F_soft_KR = torch.softmax(opposite_values, dim=-1)
        one_value = torch.tensor(1.0, dtype=F_soft_KR.dtype).to(device=y_pred.device)
        F_soft_KR = torch.where(y_true > 0, one_value, F_soft_KR)
        return F_soft_KR

    def signed_y_pred(self, y_true, y_pred):
        """Return for each item sign(y_true)*y_pred."""
        sign_y_true = torch.where(y_true > 0, 1, -1)  # switch to +/-1
        sign_y_true = sign_y_true.to(dtype=y_pred.dtype)
        return y_pred * sign_y_true

    def multiclass_hinge_preproc(self, signed_y_pred, min_margin):
        """From multiclass_hinge(y_true, y_pred, min_margin)
        simplified to use precalculated signed_y_pred"""
        # compute the elementwise hinge term
        hinge = torch.nn.functional.relu(min_margin / 2.0 - signed_y_pred)
        return hinge

    def multiclass_hinge_soft_preproc(self, signed_y_pred, F_soft_KR):
        hinge = self.multiclass_hinge_preproc(signed_y_pred, self.min_margin_v)
        b = hinge * F_soft_KR
        b = torch.sum(b, axis=-1)
        return b

    def multiclass_hinge_soft(self, y_true, y_pred):
        F_soft_KR = self.computeTemperatureSoftMax(y_true, y_pred)
        signed_y_pred = self.signed_y_pred(y_true, y_pred)
        return self.multiclass_hinge_soft_preproc(signed_y_pred, F_soft_KR)

    def kr_soft_preproc(self, signed_y_pred, F_soft_KR):
        kr = -signed_y_pred
        a = kr * F_soft_KR
        a = torch.sum(a, axis=-1)
        return a

    def kr_soft(self, y_true, y_pred):
        F_soft_KR = self.computeTemperatureSoftMax(y_true, y_pred)
        signed_y_pred = self.signed_y_pred(y_true, y_pred)
        return self.kr_soft_preproc(signed_y_pred, F_soft_KR)

    def hkr(self, y_true, y_pred):
        F_soft_KR = self.computeTemperatureSoftMax(y_true, y_pred)
        signed_y_pred = self.signed_y_pred(y_true, y_pred)

        loss_softkr = self.kr_soft_preproc(signed_y_pred, F_soft_KR)

        loss_softhinge = self.multiclass_hinge_soft_preproc(signed_y_pred, F_soft_KR)
        return (1 - self.alpha) * loss_softkr + self.alpha * loss_softhinge

    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        target = torch.nn.functional.one_hot(target, num_classes=input.shape[1])
        if not (isinstance(input, torch.Tensor)):  # required for dtype.max
            input = torch.Tensor(input, dtype=input.dtype)
        if not (isinstance(target, torch.Tensor)):
            target = torch.Tensor(target, dtype=input.dtype)
        loss_batch = self.fct(target, input)
        return torch.mean(loss_batch)

current_mean = torch.tensor((self.min_margin_v), dtype=torch.float32) instance-attribute

constraint=lambda x: torch.clamp(x, 0.005, 1000), name="current_mean", )

__init__(alpha=10.0, min_margin=1.0, alpha_mean=0.99, temperature=1.0)

The multiclass version of HKR with softmax. This is done by computing the HKR term over each class and averaging the results.

Note that y_true could be either one-hot encoded, +/-1 values.

Parameters:

Name Type Description Default
alpha float

regularization factor (0 <= alpha <= 1), 0 for KR only, 1 for hinge only

10.0
min_margin float

margin to enforce.

1.0
temperature float

factor for softmax temperature (higher value increases the weight of the highest non y_true logits)

1.0
alpha_mean float

geometric mean factor

0.99
one_hot_ytrue bool

set to True when y_true are one hot encoded (0 or 1), and False when y_true already signed bases (for instance +/-1)

required
reduction

passed to tf.keras.Loss constructor

required
name str

passed to tf.keras.Loss constructor

required
Source code in orthogonium\losses.py
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
def __init__(
    self,
    alpha=10.0,
    min_margin=1.0,
    alpha_mean=0.99,
    temperature=1.0,
):
    """
    The multiclass version of HKR with softmax. This is done by computing
    the HKR term over each class and averaging the results.

    Note that `y_true` could be either one-hot encoded, +/-1 values.


    Args:
        alpha (float): regularization factor (0 <= alpha <= 1),
            0 for KR only, 1 for hinge only
        min_margin (float): margin to enforce.
        temperature (float): factor for softmax  temperature
            (higher value increases the weight of the highest non y_true logits)
        alpha_mean (float): geometric mean factor
        one_hot_ytrue (bool): set to True when y_true are one hot encoded (0 or 1),
            and False when y_true already signed bases (for instance +/-1)
        reduction: passed to tf.keras.Loss constructor
        name (str): passed to tf.keras.Loss constructor

    """
    assert (alpha >= 0) and (alpha <= 1), "alpha must in [0,1]"
    self.alpha = torch.tensor(alpha, dtype=torch.float32)
    self.min_margin_v = min_margin
    self.alpha_mean = alpha_mean

    self.current_mean = torch.tensor((self.min_margin_v,), dtype=torch.float32)
    """    constraint=lambda x: torch.clamp(x, 0.005, 1000),
        name="current_mean",
    )"""

    self.temperature = temperature * self.min_margin_v
    if alpha == 1.0:  # alpha = 1.0 => hinge only
        self.fct = self.multiclass_hinge_soft
    else:
        if alpha == 0.0:  # alpha = 0.0 => KR only
            self.fct = self.kr_soft
        else:
            self.fct = self.hkr

    super(SoftHKRMulticlassLoss, self).__init__()

multiclass_hinge_preproc(signed_y_pred, min_margin)

From multiclass_hinge(y_true, y_pred, min_margin) simplified to use precalculated signed_y_pred

Source code in orthogonium\losses.py
305
306
307
308
309
310
def multiclass_hinge_preproc(self, signed_y_pred, min_margin):
    """From multiclass_hinge(y_true, y_pred, min_margin)
    simplified to use precalculated signed_y_pred"""
    # compute the elementwise hinge term
    hinge = torch.nn.functional.relu(min_margin / 2.0 - signed_y_pred)
    return hinge

signed_y_pred(y_true, y_pred)

Return for each item sign(y_true)*y_pred.

Source code in orthogonium\losses.py
299
300
301
302
303
def signed_y_pred(self, y_true, y_pred):
    """Return for each item sign(y_true)*y_pred."""
    sign_y_true = torch.where(y_true > 0, 1, -1)  # switch to +/-1
    sign_y_true = sign_y_true.to(dtype=y_pred.dtype)
    return y_pred * sign_y_true

VRA(output, class_indices, last_layer_type='classwise', L=1.0, eps=36 / 255, return_certs=False)

Compute the verified robust accuracy (VRA) of a model's output.

Parameters:

Name Type Description Default
output

torch.Tensor The output of the model.

required
class_indices

torch.Tensor The indices of the correct classes. Should not be one-hot encoded.

required
last_layer_type

str The type of the last layer of the model. Should be either "classwise" (L-lip per class) or "global" (L-lip globally).

'classwise'
L

float The Lipschitz constant of the model.

1.0
eps

float The perturbation size.

36 / 255
return_certs

bool Whether to return the certificates instead of the VRA.

False

Returns:

Name Type Description
vra

torch.Tensor The VRA of the model.

Source code in orthogonium\losses.py
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
def VRA(
    output,
    class_indices,
    last_layer_type="classwise",
    L=1.0,
    eps=36 / 255,
    return_certs=False,
):
    """Compute the verified robust accuracy (VRA) of a model's output.

    Args:
        output : torch.Tensor
            The output of the model.
        class_indices : torch.Tensor
            The indices of the correct classes. Should not be one-hot encoded.
        last_layer_type : str
            The type of the last layer of the model. Should be either "classwise" (L-lip per class) or "global" (L-lip globally).
        L : float
            The Lipschitz constant of the model.
        eps : float
            The perturbation size.
        return_certs : bool
            Whether to return the certificates instead of the VRA.

    Returns:
        vra : torch.Tensor
            The VRA of the model.
    """
    batch_size = output.shape[0]
    batch_indices = torch.arange(batch_size)

    # get the values of the correct class
    output_class_indices = output[batch_indices, class_indices]
    # get the values of the top class that is not the correct class
    # create a mask indicating the correct class
    onehot = torch.zeros_like(output).cuda()
    onehot[torch.arange(output.shape[0]), class_indices] = 1.0
    # subtracting a large number from the correct class to ensure it is not the max
    # doing so will allow us to find the top of the output that is not the correct class
    output_trunc = output - onehot * 1e6
    output_nextmax = torch.max(output_trunc, dim=1)[0]
    # now we can compute the certificates
    output_diff = output_class_indices - output_nextmax
    if last_layer_type == "global":
        den = math.sqrt(2) * L
    elif last_layer_type == "classwise":
        den = 2 * L
    else:
        raise ValueError(
            "[VRA] last_layer_type should be either 'global' or 'classwise'"
        )
    certs = output_diff / den
    # now we can compute the vra
    # vra is percentage of certs > eps
    vra = (certs > eps).float()
    if return_certs:
        return certs
    return vra

check_last_linear_layer_type(model)

Determines the type of the last linear layer in a given model.

This function inspects the architecture of the model and identifies the last linear layer of specific types (nn.Linear, OrthoLinear, UnitNormLinear). It then returns a string indicating the type of the last linear layer based on its class. This allows to determine the parameter to use for computing the VRA of a model's output.

Parameters:

Name Type Description Default
model

The model containing layers to be inspected.

required

Returns:

Name Type Description
str

A string indicating the type of the last linear layer. The possible values are: - "global" if the layer is of type OrthoLinear. - "classwise" if the layer is of type UnitNormLinear. - "unknown" if the layer is of any other type or if no linear layer is found.

Source code in orthogonium\losses.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def check_last_linear_layer_type(model):
    """
    Determines the type of the last linear layer in a given model.

    This function inspects the architecture of the model and identifies the last
    linear layer of specific types (nn.Linear, OrthoLinear, UnitNormLinear). It
    then returns a string indicating the type of the last linear layer based on
    its class. This allows to determine the parameter to use for computing the
    VRA of a model's output.

    Args:
        model: The model containing layers to be inspected.

    Returns:
        str: A string indicating the type of the last linear layer.
             The possible values are:
                 - "global" if the layer is of type OrthoLinear.
                 - "classwise" if the layer is of type UnitNormLinear.
                 - "unknown" if the layer is of any other type or if no
                   linear layer is found.
    """
    # Find the last linear layer in the model
    last_linear_layer = None
    layers = list(model.children())
    for layer in reversed(layers):
        if (
            isinstance(layer, nn.Linear)
            or isinstance(layer, OrthoLinear)
            or isinstance(layer, UnitNormLinear)
        ):
            last_linear_layer = layer
            break

    # Check the type of the last linear layer
    if isinstance(last_linear_layer, OrthoLinear):
        return "global"
    elif isinstance(last_linear_layer, UnitNormLinear):
        return "classwise"
    else:
        return "unknown"