Shortcuts

deel.torchlip

Containers

class deel.torchlip.LipschitzModule(coefficient_lip: float = 1.0)[source]

This class allow to set lipschitz factor of a layer. Lipschitz layer must inherit this class to allow user to set the lipschitz factor.

Warning

This class only regroup useful functions when developing new Lipschitz layers. But it does not ensure any property about the layer. This means that inheriting from this class won’t ensure anything about the lipschitz constant.

abstract vanilla_export()[source]

Convert this layer to a corresponding vanilla torch layer (when possible).

Returns

A vanilla torch version of this layer.

class deel.torchlip.Sequential(*args: Any, k_coef_lip: float = 1.0)[source]

Equivalent of torch.Sequential but allow to set k-lip factor globally. Also support condensation and vanilla exportation. For now constant repartition is implemented (each layer get n_sqrt(k_lip_factor), where n is the number of layers) But in the future other repartition function may be implemented.

Parameters
  • layers – list of layers to add to the model.

  • name – name of the model, can be None

  • k_coef_lip – the Lipschitz coefficient to ensure globally on the model.

Convolution Layers

class deel.torchlip.SpectralConv2d(in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int]], stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, bias: bool = True, padding_mode: str = 'zeros', k_coef_lip: float = 1.0, niter_spectral: int = 3, niter_bjorck: int = 15)[source]

This class is a Conv2d Layer constrained such that all singular of it’s kernel are 1. The computation based on BjorckNormalizer algorithm. As this is not enough to ensure 1-Lipschitz a coercive coefficient is applied on the output. The computation is done in three steps:

  1. reduce the largest singular value to 1, using iterated power method.

  2. increase other singular values to 1, using BjorckNormalizer algorithm.

  3. divide the output by the Lipschitz bound to ensure k-Lipschitz.

Parameters
  • in_channels (int) – Number of channels in the input image

  • out_channels (int) – Number of channels produced by the convolution

  • kernel_size (int or tuple) – Size of the convolving kernel

  • stride (int or tuple, optional) – Stride of the convolution.

  • padding (int or tuple, optional) – Zero-padding added to both sides of the input.

  • padding_mode (string, optional) – 'zeros', 'reflect', 'replicate' or 'circular'. Default: 'zeros'

  • dilation (int or tuple, optional) – Spacing between kernel elements.

  • groups (int, optional) – Number of blocked connections from input channels to output channels.

  • bias (bool, optional) – If True, adds a learnable bias to the output.

  • k_coef_lip – Lipschitz constant to ensure.

  • niter_spectral – Number of iteration to find the maximum singular value.

  • niter_bjorck – Number of iteration with BjorckNormalizer algorithm.

This documentation reuse the body of the original torch.nn.Conv2D doc.

class deel.torchlip.FrobeniusConv2d(in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int]], stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, bias: bool = True, padding_mode: str = 'zeros', k_coef_lip: float = 1.0)[source]

Same as SpectralConv2d but in the case of a single output.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Pooling Layers

class deel.torchlip.ScaledAdaptiveAvgPool2d(output_size: Union[int, Tuple[int, int]], k_coef_lip: float = 1.0)[source]

Applies a 2D adaptive max pooling over an input signal composed of several input planes.

The output is of size H x W, for any input size. The number of output features is equal to the number of input planes.

Parameters
  • output_size – The target output size of the image of the form H x W. Can be a tuple (H, W) or a single H for a square image H x H. H and W can be either a int, or None which means the size will be the same as that of the input.

  • k_coef_lip – The Lipschitz factor to ensure. The output will be scaled by this factor.

This documentation reuse the body of the original nn.AdaptiveAvgPool2d doc.

class deel.torchlip.ScaledAvgPool2d(kernel_size: Union[int, Tuple[int, int]], stride: Optional[Union[int, Tuple[int, int]]] = None, padding: Union[int, Tuple[int, int]] = 0, ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[bool] = None, k_coef_lip: float = 1.0)[source]

Average pooling operation for spatial data, but with a lipschitz bound.

Parameters
  • kernel_size – The size of the window.

  • stride – The stride of the window. Must be None or equal to kernel_size. Default value is kernel_size.

  • padding – Implicit zero-padding to be added on both sides. Must be zero.

  • ceil_mode – When True, will use ceil instead of floor to compute the output shape.

  • count_include_pad – When True, will include the zero-padding in the averaging calculation.

  • divisor_override – If specified, it will be used as divisor, otherwise kernel_size will be used.

  • k_coef_lip – The Lipschitz factor to ensure. The output will be scaled by this factor.

This documentation reuse the body of the original torch.nn.AveragePooling2D doc.

class deel.torchlip.ScaledL2NormPool2d(kernel_size: Union[int, Tuple[int, int]], stride: Optional[Union[int, Tuple[int, int]]] = None, padding: Union[int, Tuple[int, int]] = 0, ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[bool] = None, k_coef_lip: float = 1.0, eps_grad_sqrt: float = 1e-06)[source]

Average pooling operation for spatial data, with a lipschitz bound. This pooling operation is norm preserving (gradient=1 almost everywhere).

[1] Y.-L.Boureau, J.Ponce, et Y.LeCun, « A Theoretical Analysis of Feature Pooling in Visual Recognition »,p.8.

Parameters
  • kernel_size – The size of the window.

  • stride – The stride of the window. Must be None or equal to kernel_size. Default value is kernel_size.

  • padding – Implicit zero-padding to be added on both sides. Must be zero.

  • ceil_mode – When True, will use ceil instead of floor to compute the output shape.

  • count_include_pad – When True, will include the zero-padding in the averaging calculation.

  • divisor_override – If specified, it will be used as divisor, otherwise kernel_size will be used.

  • k_coef_lip – The lipschitz factor to ensure. The output will be scaled by this factor.

  • eps_grad_sqrt – Epsilon value to avoid numerical instability due to non-defined gradient at 0 in the sqrt function

Non-linear Activations

class deel.torchlip.InvertibleDownSampling(kernel_size: Tuple[int, int], k_coef_lip: float = 1.0)[source]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

class deel.torchlip.InvertibleUpSampling(kernel_size: Union[int, Tuple[int, ...]], k_coef_lip: float = 1.0)[source]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

class deel.torchlip.MaxMin(dim: Optional[int] = None, k_coef_lip: float = 1.0)[source]

Applies max-min activation.

If input is a tensor of shape (N,C)(N, C) and dim is None, the output can be described as:

out(Ni,C2j)=max(input(Ni,Cj),0)out(Ni,C2j+1)=max(input(Ni,Cj),0)\text{out}(N_i, C_{2j}) = \max(\text{input}(N_i, C_j), 0)\\ \text{out}(N_i, C_{2j + 1}) = \max(-\text{input}(N_i, C_j), 0)

where NN is the batch size and CC is the size of the tensor.

See also functional.max_min().

Parameters
  • dim – The dimension to apply max-min. If None, will apply to the 0th dimension if the shape of input is (C)(C) or to the first if its (N,C,)(N, C, *).

  • k_coef_lip – The lipschitz coefficient to enforce.

Shape:
  • Input: (C)(C) or (N,C,)(N, C, *) where * means any number of additional dimensions.

  • Output: (2C)(2C) is the input shape was (C)(C), or (N,2C,)(N, 2C, *) if dim is None, otherwise (N,,2Cdim,)(N, *, 2C_{dim}, *) where CdimC_{dim} is the dimension corresponding to the dim parameter.

Note

M. Blot, M. Cord, et N. Thome, « Max-min convolutional neural networks for image classification », in 2016 IEEE International Conference on Image Processing (ICIP), Phoenix, AZ, USA, 2016, p. 3678‑3682.

class deel.torchlip.GroupSort(group_size: Optional[int] = None, k_coef_lip: float = 1.0)[source]

Applies group-sort activation.

The activation works by first reshaping the input to a tensor of shape (N,G)(N', G) where GG is the group size and NN' the number of groups, then sorting each group of size GG and then reshaping to the original input shape.

See also functional.group_sort().

Parameters
  • group_size – group size used when sorting. When None group size

  • size (is set to input) –

  • data_format – either channels_first or channels_last

  • k_coef_lip – The lipschitz coefficient to enforce.

Shape:
  • Input: (N,)(N,∗) where * means, any number

    of additional dimensions

  • Output: (N,)(N,*), same shape as the input.

Example

>>> m = torch.randn(2, 4)
tensor([[ 0.2805, -2.0528,  0.6478,  0.5745],
        [-1.4075,  0.0435, -1.2408,  0.2945]])
>>> torchlip.GroupSort(4)(m)
tensor([[-2.0528,  0.2805,  0.5745,  0.6478],
        [-1.4075, -1.2408,  0.0435,  0.2945]])
class deel.torchlip.GroupSort2(k_coef_lip: float = 1.0)[source]

Applies group-sort activation with a group size of 2.

See GroupSort for details.

See also functional.group_sort_2().

Parameters

k_coef_lip – The lipschitz coefficient to enforce.

class deel.torchlip.FullSort(k_coef_lip: float = 1.0)[source]

Applies full-sort activation. This is equivalent to group-sort with a group-size equals to the size of the input.

See GroupSort for details.

See also functional.full_sort().

Parameters

k_coef_lip – The lipschitz coefficient to enforce.

class deel.torchlip.LPReLU(num_parameters: int = 1, init: float = 0.25, k_coef_lip: float = 1.0)[source]

Applies element-wise PReLU activation with Lipschitz constraint:

LPReLU(x)=max(0,x)+amin(0,x)LPReLU(x) = \max(0, x) + a' * \min(0, x)

or

LPReLU(x)=LipschitzPReLU(x)={x, if x0ax, otherwise LPReLU(x) = \text{LipschitzPReLU}(x) = \begin{cases} x, & \text{ if } x \geq 0 \\ a' * x, & \text{ otherwise } \end{cases}

where a=max(min(a,k),k)a' = \max(\min(a, k), -k), and aa is a learnable parameter.

See also torch.nn.PReLU and functional.lipschitz_prelu().

Parameters
  • num_parameters – Number of aa to learn. Although it

  • input (takes an int as) –

  • legitimate (` there are only two) –

  • values – 1, or the number of channels at input.

  • init – The initial value of aa.

  • k_coef_lip – The lipschitz coefficient to enforce.

Linear Layers

class deel.torchlip.SpectralLinear(in_features: int, out_features: int, bias: bool = True, k_coef_lip: float = 1.0, niter_spectral: int = 3, niter_bjorck: int = 15)[source]

This class is a Linear Layer constrained such that all singular of it’s kernel are 1. The computation based on BjorckNormalizer algorithm. The computation is done in two steps:

  1. reduce the larget singular value to 1, using iterated power method.

  2. increase other singular values to 1, using BjorckNormalizer algorithm.

Parameters
  • in_features – Size of each input sample.

  • out_features – Size of each output sample.

  • bias – If False, the layer will not learn an additive bias.

  • k_coef_lip – Lipschitz constant to ensure.

  • niter_spectral – Number of iteration to find the maximum singular value.

  • niter_bjorck – Number of iteration with BjorckNormalizer algorithm.

Shape:
  • Input: (N,,Hin)(N, *, H_{in}) where * means any number of additional dimensions and Hin=in_featuresH_{in} = \text{in\_features}

  • Output: (N,,Hout)(N, *, H_{out}) where all but the last dimension are the same shape as the input and Hout=out_featuresH_{out} = \text{out\_features}.

This documentation reuse the body of the original torch.nn.Linear doc.

class deel.torchlip.FrobeniusLinear(in_features: int, out_features: int, bias: bool = True, k_coef_lip: float = 1.0)[source]

Same a SpectralLinear, but in the case of a single output.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Loss Functions

class deel.torchlip.KRLoss(true_values: Tuple[int, int] = (0, 1))[source]

Loss that estimates the Wasserstein-1 distance using the Kantorovich-Rubinstein duality.

Parameters

true_values – tuple containing the two label for each predicted class.

class deel.torchlip.NegKRLoss(true_values: Tuple[int, int] = (0, 1))[source]

Loss that estimates the negative of the Wasserstein-1 distance using the Kantorovich-Rubinstein duality.

Parameters

true_values – tuple containing the two label for each predicted class.

class deel.torchlip.HingeMarginLoss(min_margin: float = 1.0)[source]

Hinge margin loss.

Parameters

min_margin – The minimal margin to enforce.

class deel.torchlip.HKRLoss(alpha: float, min_margin: float = 1.0, true_values: Tuple[int, int] = (- 1, 1))[source]

Loss that estimates the Wasserstein-1 distance using the Kantorovich-Rubinstein duality with a hinge regularization.

Parameters
  • alpha – Regularization factor between the hinge and the KR loss.

  • min_margin – Minimal margin for the hinge loss.

  • true_values – tuple containing the two label for each predicted class.


© Copyright 2020, IRT Antoine de Saint Exupéry - All rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry, CRIAQ and ANITI..

Built with Sphinx using PyTorch's theme provided originally by Read the Docs.