Shortcuts

deel.torchlip

Containers

class deel.torchlip.LipschitzModule(coefficient_lip: float = 1.0)[source]

This class allow to set lipschitz factor of a layer. Lipschitz layer must inherit this class to allow user to set the lipschitz factor.

Warning

This class only regroup useful functions when developing new Lipschitz layers. But it does not ensure any property about the layer. This means that inheriting from this class won’t ensure anything about the lipschitz constant.

apply_lipschitz_factor()[source]

Multiply the layer weights by a lipschitz factor.

abstract vanilla_export()[source]

Convert this layer to a corresponding vanilla torch layer (when possible).

Returns:

A vanilla torch version of this layer.

class deel.torchlip.Sequential(*args: Any, k_coef_lip: float = 1.0)[source]

Equivalent of torch.Sequential but allow to set k-lip factor globally. Also support condensation and vanilla exportation. For now constant repartition is implemented (each layer get n_sqrt(k_lip_factor), where n is the number of layers) But in the future other repartition function may be implemented.

Parameters:
  • layers – list of layers to add to the model.

  • name – name of the model, can be None

  • k_coef_lip – the Lipschitz coefficient to ensure globally on the model.

Linear Layers

class deel.torchlip.SpectralLinear(in_features: int, out_features: int, bias: bool = True, k_coef_lip: float = 1.0, eps_spectral: int = 0.001, eps_bjorck: int = 0.001)[source]

This class is a Linear Layer constrained such that all singular of it’s kernel are 1. The computation based on BjorckNormalizer algorithm. The computation is done in two steps:

  1. reduce the larget singular value to 1, using iterated power method.

  2. increase other singular values to 1, using BjorckNormalizer algorithm.

Parameters:
  • in_features – Size of each input sample.

  • out_features – Size of each output sample.

  • bias – If False, the layer will not learn an additive bias.

  • k_coef_lip – Lipschitz constant to ensure.

  • eps_spectral – stopping criterion for the iterative power algorithm.

  • eps_bjorck – stopping criterion Bjorck algorithm.

Shape:
  • Input: (N,,Hin)(N, *, H_{in}) where * means any number of additional dimensions and Hin=in_featuresH_{in} = \text{in\_features}

  • Output: (N,,Hout)(N, *, H_{out}) where all but the last dimension are the same shape as the input and Hout=out_featuresH_{out} = \text{out\_features}.

This documentation reuse the body of the original torch.nn.Linear doc.

class deel.torchlip.FrobeniusLinear(in_features: int, out_features: int, bias: bool = True, disjoint_neurons: bool = True, k_coef_lip: float = 1.0)[source]

This class is a Linear Layer constrained such that the Frobenius norm of the weight is 1. In the case of a single output neuron, it is equivalent and faster than the SpectralLinear layer. For multi-neuron case, the “disjoint_neurons” parameter affects the behaviour:

  • if disjoint_neurons is True (default), it corresponds to the stacking of independent 1-Lipschitz neurons.

  • if disjoint_neurons is False, the matrix weight is normalized by its Frobenius norm.

Parameters:
  • in_features – Size of each input sample.

  • out_features – Size of each output sample.

  • bias – If False, the layer will not learn an additive bias.

  • disjoint_neurons – Normalize, independently per neuron or not, the matrix weight.

  • k_coef_lip – Lipschitz constant to ensure.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Convolution Layers

class deel.torchlip.SpectralConv1d(in_channels: int, out_channels: int, kernel_size: int | Tuple[int, int], stride: int | Tuple[int, int] = 1, padding: int | Tuple[int, int] = 0, dilation: int | Tuple[int, int] = 1, groups: int = 1, bias: bool = True, padding_mode: str = 'zeros', k_coef_lip: float = 1.0, eps_spectral: int = 0.001, eps_bjorck: int = 0.001)[source]

This class is a Conv1d Layer constrained such that all singular of it’s kernel are 1. The computation based on BjorckNormalizer algorithm. As this is not enough to ensure 1-Lipschitz a coercive coefficient is applied on the output. The computation is done in three steps:

  1. reduce the largest singular value to 1, using iterated power method.

  2. increase other singular values to 1, using BjorckNormalizer algorithm.

  3. divide the output by the Lipschitz bound to ensure k-Lipschitz.

Parameters:
  • in_channels (int) – Number of channels in the input image

  • out_channels (int) – Number of channels produced by the convolution

  • kernel_size (int or tuple) – Size of the convolving kernel

  • stride (int or tuple, optional) – Stride of the convolution.

  • padding (int or tuple, optional) – Zero-padding added to both sides of the input.

  • padding_mode (string, optional) – 'zeros', 'reflect', 'replicate' or 'circular'. Default: 'zeros'

  • dilation (int or tuple, optional) – Spacing between kernel elements.

  • groups (int, optional) – Number of blocked connections from input channels to output channels.

  • bias (bool, optional) – If True, adds a learnable bias to the output.

  • k_coef_lip – Lipschitz constant to ensure.

  • eps_spectral – stopping criterion for the iterative power algorithm.

  • eps_bjorck – stopping criterion Bjorck algorithm.

This documentation reuse the body of the original torch.nn.Conv1D doc.

class deel.torchlip.SpectralConv2d(in_channels: int, out_channels: int, kernel_size: int | Tuple[int, int], stride: int | Tuple[int, int] = 1, padding: int | Tuple[int, int] = 0, dilation: int | Tuple[int, int] = 1, groups: int = 1, bias: bool = True, padding_mode: str = 'zeros', k_coef_lip: float = 1.0, eps_spectral: int = 0.001, eps_bjorck: int = 0.001)[source]

This class is a Conv2d Layer constrained such that all singular of it’s kernel are 1. The computation based on BjorckNormalizer algorithm. As this is not enough to ensure 1-Lipschitz a coercive coefficient is applied on the output. The computation is done in three steps:

  1. reduce the largest singular value to 1, using iterated power method.

  2. increase other singular values to 1, using BjorckNormalizer algorithm.

  3. divide the output by the Lipschitz bound to ensure k-Lipschitz.

Parameters:
  • in_channels (int) – Number of channels in the input image

  • out_channels (int) – Number of channels produced by the convolution

  • kernel_size (int or tuple) – Size of the convolving kernel

  • stride (int or tuple, optional) – Stride of the convolution.

  • padding (int or tuple, optional) – Zero-padding added to both sides of the input.

  • padding_mode (string, optional) – 'zeros', 'reflect', 'replicate', 'symmetric' or 'circular'. Default: 'zeros'

  • dilation (int or tuple, optional) – Spacing between kernel elements. Has to be one

  • groups (int, optional) – Number of blocked connections from input channels to output channels. Has to be one

  • bias (bool, optional) – If True, adds a learnable bias to the output.

  • k_coef_lip – Lipschitz constant to ensure.

  • eps_spectral – stopping criterion for the iterative power algorithm.

  • eps_bjorck – stopping criterion Bjorck algorithm.

This documentation reuse the body of the original torch.nn.Conv2D doc.

class deel.torchlip.FrobeniusConv2d(in_channels: int, out_channels: int, kernel_size: int | Tuple[int, int], stride: int | Tuple[int, int] = 1, padding: int | Tuple[int, int] = 0, dilation: int | Tuple[int, int] = 1, groups: int = 1, bias: bool = True, padding_mode: str = 'zeros', k_coef_lip: float = 1.0)[source]

Same as SpectralConv2d but in the case of a single output.

This class is a Conv2d Layer with additional padding modes

Parameters:
  • in_channels (int) – Number of channels in the input image

  • out_channels (int) – Number of channels produced by the convolution

  • kernel_size (int or tuple) – Size of the convolving kernel

  • stride (int or tuple, optional) – Stride of the convolution.

  • padding (int or tuple, optional) – Zero-padding added to both sides of the input.

  • padding_mode (string, optional) – 'zeros', 'reflect', 'replicate', 'symmetric' or 'circular'. Default: 'zeros'

  • dilation (int or tuple, optional) – Spacing between kernel elements. Has to be one

  • groups (int, optional) – Number of blocked connections from input channels to output channels. Has to be one

  • bias (bool, optional) – If True, adds a learnable bias to the output.

This documentation reuse the body of the original torch.nn.Conv2D doc.

class deel.torchlip.SpectralConvTranspose2d(in_channels: int, out_channels: int, kernel_size: int | Tuple[int, int], stride: int | Tuple[int, int] = 1, padding: int | Tuple[int, int] = 0, output_padding: int | Tuple[int, int] = 0, groups: int = 1, bias: bool = True, dilation: int | Tuple[int, int] = 1, padding_mode: str = 'zeros', device=None, dtype=None, k_coef_lip: float = 1.0, eps_spectral: int = 0.001, eps_bjorck: int = 0.001)[source]

Applies a 2D transposed convolution operator over an input image such that all singular of it’s kernel are 1. The computation are the same as for SpectralConv2d layer

Parameters:
  • in_channels (int) – Number of channels in the input image

  • out_channels (int) – Number of channels produced by the convolution

  • kernel_size (int or tuple) – Size of the convolving kernel

  • stride (int or tuple, optional) – Stride of the convolution.

  • padding (int or tuple, optional) – Zero-padding added to both sides of the input.

  • output_padding – only 0 or none are supported

  • padding_mode (string, optional) – 'zeros', 'reflect', 'replicate' or 'circular'. Default: 'zeros'

  • dilation (int or tuple, optional) – Spacing between kernel elements. Has to be one.

  • groups (int, optional) – Number of blocked connections from input channels to output channels. Has to be one.

  • bias (bool, optional) – If True, adds a learnable bias to the output.

  • k_coef_lip – Lipschitz constant to ensure.

  • eps_spectral – stopping criterion for the iterative power algorithm.

  • eps_bjorck – stopping criterion Bjorck algorithm.

This documentation reuse the body of the original torch.nn.ConvTranspose2d doc.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Pooling Layers

class deel.torchlip.ScaledAvgPool2d(kernel_size: int | Tuple[int, int], stride: int | Tuple[int, int] | None = None, padding: int | Tuple[int, int] = 0, ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: bool = None, k_coef_lip: float = 1.0)[source]

Average pooling operation for spatial data, but with a lipschitz bound.

Parameters:
  • kernel_size – The size of the window.

  • stride – The stride of the window. Must be None or equal to kernel_size. Default value is kernel_size.

  • padding – Implicit zero-padding to be added on both sides. Must be zero.

  • ceil_mode – When True, will use ceil instead of floor to compute the output shape.

  • count_include_pad – When True, will include the zero-padding in the averaging calculation.

  • divisor_override – If specified, it will be used as divisor, otherwise kernel_size will be used.

  • k_coef_lip – The Lipschitz factor to ensure. The output will be scaled by this factor.

This documentation reuse the body of the original torch.nn.AveragePooling2D doc.

class deel.torchlip.ScaledL2NormPool2d(kernel_size: int | Tuple[int, int], stride: int | Tuple[int, int] | None = None, ceil_mode: bool = False, k_coef_lip: float = 1.0)[source]

Average pooling operation for spatial data, with a lipschitz bound. This pooling operation is norm preserving (gradient=1 almost everywhere).

[1] Y.-L.Boureau, J.Ponce, et Y.LeCun, « A Theoretical Analysis of Feature Pooling in Visual Recognition »,p.8.

Parameters:
  • kernel_size – The size of the window.

  • stride – The stride of the window. Must be None or equal to kernel_size. Default value is kernel_size.

  • ceil_mode – When True, will use ceil instead of floor to compute the output shape.

  • k_coef_lip – The lipschitz factor to ensure. The output will be scaled by this factor.

class deel.torchlip.ScaledAdaptiveAvgPool2d(output_size: int | Tuple[int, int], k_coef_lip: float = 1.0)[source]

Applies a 2D adaptive average pooling over an input signal composed of several input planes.

The output is of size H x W, for any input size. The number of output features is equal to the number of input planes.

Parameters:
  • output_size – The target output size of the image of the form H x W. Can be a tuple (H, W) or a single H for a square image H x H. H and W can be either a int, or None which means the size will be the same as that of the input.

  • k_coef_lip – The Lipschitz factor to ensure. The output will be scaled by this factor.

This documentation reuse the body of the original nn.AdaptiveAvgPool2d doc.

class deel.torchlip.ScaledAdaptativeL2NormPool2d(output_size: int | Tuple[int, int] = (1, 1), k_coef_lip: float = 1.0)[source]

Average pooling operation for spatial data, with a lipschitz bound. This pooling operation is norm preserving (aka gradient=1 almost everywhere).

[1]Y.-L.Boureau, J.Ponce, et Y.LeCun, « A Theoretical Analysis of Feature Pooling in Visual Recognition »,p.8.

Parameters:
  • output_size – the target output size has to be (1,1)

  • k_coef_lip – the lipschitz factor to ensure

Input shape:

4D tensor with shape (batch_size, channels, rows, cols).

Output shape:

4D tensor with shape (batch_size, channels, 1, 1).

class deel.torchlip.InvertibleDownSampling(kernel_size: int, k_coef_lip: float = 1.0)[source]

A combination of torch.nn.PixelUnshuffle and LipschitzModule. This module is used to downsample the input tensor by a factor of kernel_size. The resulting output tensor has kernel_size^2 times more channels than the input tensor.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

class deel.torchlip.InvertibleUpSampling(kernel_size: int, k_coef_lip: float = 1.0)[source]

A combination of torch.nn.PixelShuffle and LipschitzModule. This module is used to upsample the input tensor by a factor of kernel_size. The resulting output tensor has kernel_size^2 times less channels than the input tensor.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Non-linear Activations

class deel.torchlip.MaxMin(dim: int | None = None, k_coef_lip: float = 1.0)[source]

Applies max-min activation.

If input is a tensor of shape (N,C)(N, C) and dim is None, the output can be described as:

out(Ni,C2j)=max(input(Ni,Cj),0)out(Ni,C2j+1)=max(input(Ni,Cj),0)\text{out}(N_i, C_{2j}) = \max(\text{input}(N_i, C_j), 0)\\ \text{out}(N_i, C_{2j + 1}) = \max(-\text{input}(N_i, C_j), 0)

where NN is the batch size and CC is the size of the tensor.

See also functional.max_min().

Parameters:
  • dim – The dimension to apply max-min. If None, will apply to the 0th dimension if the shape of input is (C)(C) or to the first if its (N,C,)(N, C, *).

  • k_coef_lip – The lipschitz coefficient to enforce.

Shape:
  • Input: (C)(C) or (N,C,)(N, C, *) where * means any number of additional dimensions.

  • Output: (2C)(2C) is the input shape was (C)(C), or (N,2C,)(N, 2C, *) if dim is None, otherwise (N,,2Cdim,)(N, *, 2C_{dim}, *) where CdimC_{dim} is the dimension corresponding to the dim parameter.

Note

M. Blot, M. Cord, et N. Thome, « Max-min convolutional neural networks for image classification », in 2016 IEEE International Conference on Image Processing (ICIP), Phoenix, AZ, USA, 2016, p. 3678‑3682.

class deel.torchlip.GroupSort(group_size: int | None = None, k_coef_lip: float = 1.0)[source]

Applies group-sort activation.

The activation works by first reshaping the input to a tensor of shape (N,G)(N', G) where GG is the group size and NN' the number of groups, then sorting each group of size GG and then reshaping to the original input shape.

See also functional.group_sort().

Parameters:
  • group_size – group size used when sorting. When None group size

  • size (is set to input) –

  • data_format – either channels_first or channels_last

  • k_coef_lip – The lipschitz coefficient to enforce.

Shape:
  • Input: (N,)(N,∗) where * means, any number

    of additional dimensions

  • Output: (N,)(N,*), same shape as the input.

Example

>>> m = torch.randn(2, 4)
tensor([[ 0.2805, -2.0528,  0.6478,  0.5745],
        [-1.4075,  0.0435, -1.2408,  0.2945]])
>>> torchlip.GroupSort(4)(m)
tensor([[-2.0528,  0.2805,  0.5745,  0.6478],
        [-1.4075, -1.2408,  0.0435,  0.2945]])
class deel.torchlip.GroupSort2(k_coef_lip: float = 1.0)[source]

Applies group-sort activation with a group size of 2.

See GroupSort for details.

See also functional.group_sort_2().

Parameters:

k_coef_lip – The lipschitz coefficient to enforce.

class deel.torchlip.FullSort(k_coef_lip: float = 1.0)[source]

Applies full-sort activation. This is equivalent to group-sort with a group-size equals to the size of the input.

See GroupSort for details.

See also functional.full_sort().

Parameters:

k_coef_lip – The lipschitz coefficient to enforce.

class deel.torchlip.LPReLU(num_parameters: int = 1, init: float = 0.25, k_coef_lip: float = 1.0)[source]

Applies element-wise PReLU activation with Lipschitz constraint:

LPReLU(x)=max(0,x)+amin(0,x)LPReLU(x) = \max(0, x) + a' * \min(0, x)

or

LPReLU(x)=LipschitzPReLU(x)={x, if x0ax, otherwise LPReLU(x) = \text{LipschitzPReLU}(x) = \begin{cases} x, & \text{ if } x \geq 0 \\ a' * x, & \text{ otherwise } \end{cases}

where a=max(min(a,k),k)a' = \max(\min(a, k), -k), and aa is a learnable parameter.

See also torch.nn.PReLU and functional.lipschitz_prelu().

Parameters:
  • num_parameters – Number of aa to learn. Although it

  • input (takes an int as) –

  • legitimate (` there are only two) –

  • values – 1, or the number of channels at input.

  • init – The initial value of aa.

  • k_coef_lip – The lipschitz coefficient to enforce.

class deel.torchlip.HouseHolder(channels, k_coef_lip: float = 1.0, theta_initializer=None)[source]

Householder activation: [this review](https://openreview.net/pdf?id=tD7eCtaSkR) Adapted from [this repository](https://github.com/singlasahil14/SOC)

Loss Functions

class deel.torchlip.KRLoss(multi_gpu=False, reduction: str = 'mean', true_values=None)[source]

Loss that estimates the Wasserstein-1 distance using the Kantorovich-Rubinstein duality. The Kantorovich-Rubinstein duality is formulated as following:

$$ W_1(mu, nu) = sup_{f in Lip_1(Omega)} underset{textbf{x} sim mu}{mathbb{E}} left[f(textbf{x} )right] - underset{textbf{x} sim nu}{mathbb{E}} left[f(textbf{x} )right] $$

Where mu and nu stands for the two distributions, the distribution where the label is 1 and the rest.

Note that input and target must be of rank 2: (batch_size, 1) or (batch_size, C) for multilabel classification (with C categories). target accepts label values in (0, 1), (-1, 1), or pre-processed with the deel.torchlip.functional.process_labels_for_multi_gpu() function.

Using a multi-GPU/TPU strategy requires to set multi_gpu to True and to pre-process the labels target with the deel.torchlip.functional.process_labels_for_multi_gpu() function.

Parameters:
  • multi_gpu (bool) – set to True when running on multi-GPU/TPU

  • reduction – type of reduction applied to the output. possible values are ‘none’ | ‘mean’ | ‘sum’ | ‘auto’; default is ‘mean’ (‘auto is ‘mean’)

  • true_values – depreciated.

class deel.torchlip.NegKRLoss(multi_gpu=False, reduction: str = 'mean', true_values=None)[source]

Loss that estimates the negative of the Wasserstein-1 distance using the Kantorovich-Rubinstein duality. See KRLoss for more details.

Parameters:
  • multi_gpu (bool) – set to True when running on multi-GPU/TPU

  • reduction – type of reduction applied to the output. possible values are ‘none’ | ‘mean’ | ‘sum’ | ‘auto’; default is ‘mean’ (‘auto is ‘mean’)

  • true_values – depreciated.

class deel.torchlip.HingeMarginLoss(min_margin: float = 1.0, reduction: str = 'mean')[source]

Hinge margin loss.

Parameters:
  • min_margin – The minimal margin to enforce.

  • reduction – type of reduction applied to the output. possible values are ‘none’ | ‘mean’ | ‘sum’ | ‘auto’; default is ‘mean’ (‘auto is ‘mean’)

class deel.torchlip.HKRLoss(alpha: float, min_margin: float = 1.0, multi_gpu=False, reduction: str = 'mean', true_values=None)[source]

Loss that estimates the Wasserstein-1 distance using the Kantorovich-Rubinstein duality with a hinge regularization.

[1] M. Serrurier, F. Mamalet, et al. «Achieving robustness in classification using optimal transport with hinge regularization», 2021.

Note that input and target must be of rank 2: (batch_size, 1) or (batch_size, C) for multilabel classification (with C categories). target accepts label values in (0, 1), (-1, 1), or pre-processed with the deel.torchlip.functional.process_labels_for_multi_gpu() function.

Using a multi-GPU/TPU strategy requires to set multi_gpu to True and to pre-process the labels target with the deel.torchlip.functional.process_labels_for_multi_gpu() function.

the regularization factor alpha is a value between 0 and 1. It controls the trade-off between the hinge and the KR loss. When alpha is 0, the loss is equivalent to the KR loss, and when alpha is 1, the loss is equivalent to the hinge loss.

Parameters:
  • alpha – Regularization factor ([0,1]) between the hinge and the KR loss.

  • min_margin – Minimal margin for the hinge loss.

  • multi_gpu (bool) – set to True when running on multi-GPU/TPU

  • reduction – type of reduction applied to the output. possible values are ‘none’ | ‘mean’ | ‘sum’ | ‘auto’; default is ‘mean’ (‘auto is ‘mean’)

  • true_values – depreciated.

class deel.torchlip.HKRMulticlassLoss(alpha: float, min_margin: float = 1.0, multi_gpu=False, reduction: str = 'mean')[source]

Loss that estimates the Wasserstein-1 distance using the Kantorovich-Rubinstein duality with a hinge regularization.

[1] M. Serrurier, F. Mamalet, et al. «Achieving robustness in classification using optimal transport with hinge regularization», 2021.

Note that`target` should be one-hot encoded or pre-processed with the deel.torchlip.functional.process_labels_for_multi_gpu() function.

Using a multi-GPU/TPU strategy requires to set multi_gpu to True and to pre-process the labels target with the deel.torchlip.functional.process_labels_for_multi_gpu() function.

the regularization factor alpha is a value between 0 and 1. It controls the trade-off between the hinge and the KR loss. When alpha is 0, the loss is equivalent to the KR loss, and when alpha is 1, the loss is equivalent to the hinge loss.

Parameters:
  • alpha – Regularization factor ([0,1]) between the hinge and the KR loss.

  • min_margin – Minimal margin for the hinge loss.

  • multi_gpu (bool) – set to True when running on multi-GPU/TPU

  • reduction – type of reduction applied to the output. possible values are ‘none’ | ‘mean’ | ‘sum’ | ‘auto’; default is ‘mean’ (‘auto is ‘mean’)

class deel.torchlip.SoftHKRMulticlassLoss(alpha=0.9, min_margin=1.0, alpha_mean=0.99, temperature=1.0, reduction: str = 'mean')[source]

The multiclass version of HKR with softmax. This is done by computing the HKR term over each class and averaging the results.

[2] M. Serrurier, F. Mamalet, T. Fel et al. “On the explainable properties of 1-Lipschitz Neural Networks: An Optimal Transport Perspective.”, 2024

Note that`target` should be one-hot encoded, +/-1 values.

the regularization factor alpha is a value between 0 and 1. It controls the trade-off between the hinge and the KR loss. When alpha is 0, the loss is equivalent to the KR loss, and when alpha is 1, the loss is equivalent to the hinge loss.

Parameters:
  • alpha (float) – regularization factor (0 <= alpha <= 1),

  • min_margin (float) – margin to enforce.

  • alpha_mean (float) – geometric mean factor

  • temperature (float) – factor for softmax temperature (higher value increases the weight of the highest non y_true logits)

  • reduction – type of reduction applied to the output. possible values are ‘none’ | ‘mean’ | ‘sum’ | ‘auto’; default is ‘mean’ (‘auto is ‘mean’)

class deel.torchlip.TauCrossEntropyLoss(tau: float, weight: Tensor | None = None, size_average=None, ignore_index: int = -100, reduce=None, reduction: str = 'mean', label_smoothing: float = 0.0)[source]

The loss add a temperature (tau) factor to the CrossEntropyLoss CrossEntropyLoss(tau * input, target).

See CrossEntropyLoss for more details on arguments.

Parameters:

tau (float) – factor for temperature

class deel.torchlip.TauBCEWithLogitsLoss(tau: float, weight: Tensor | None = None, size_average=None, reduce=None, reduction: str = 'mean', pos_weight=None)[source]

The loss add a temperature (tau) factor to the BCEWithLogitsLoss BCEWithLogitsLoss(tau * input, target).

See BCEWithLogitsLoss for more details on arguments.

Parameters:

tau (float) – factor for temperature

class deel.torchlip.CategoricalHingeLoss(min_margin: float = 1.0, reduction: str = 'mean')[source]

This implementation is sligthly different from the pytorch MultiMarginLoss.

target and input must be of shape (batch_size, # classes). Note that target should be one-hot encoded, +/-1 values. ReLU(min_margin(input[target>0]max(input[target<=0])))\text{ReLU}(\text{min\_margin} - (\text{input}[\text{target}>0] - \text{max}(\text{input}[\text{target}<=0]))) is computed element-wise and averaged over the batch.

Parameters:
  • min_margin (float) – margin parameter.

  • reduction – type of reduction applied to the output. possible values are ‘none’ | ‘mean’ | ‘sum’ | ‘auto’; default is ‘mean’ (‘auto is ‘mean’)


© Copyright 2020, IRT Antoine de Saint Exupéry - All rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry, CRIAQ and ANITI..

Built with Sphinx using PyTorch's theme provided originally by Read the Docs.