• Docs >
  • deel.torchlip.functional
Shortcuts

deel.torchlip.functional

Non-linear activation functions

invertible down/up sample

deel.torchlip.functional.invertible_downsample(input: torch.Tensor, kernel_size: Union[int, Tuple[int, ...]]) torch.Tensor[source]

Downsamples the input in an invertible way.

The number of elements in the output tensor is the same as the number of elements in the input tensor.

Parameters
  • input – A tensor of shape (N,C,W)(N, C, W), (N,C,W,H)(N, C, W, H) or (N,C,D,W,H)(N, C, D, W, H) to downsample.

  • kernel_size – The downsample scale. If a single-value is passed, the same value will be used alongside all dimensions, otherwise the length of kernel_size must match the number of dimensions of the input (1, 2 or 3).

Raises

ValueError – If there is a mismatch between kernel_size and the input shape.

Examples

>>> x = torch.rand(16, 16, 32, 32)
>>> x.shape
(16, 16, 32, 32)
>>> y = invertible_downsample(x, (2, 4))
>>> y.shape
(16, 128, 16, 8)
deel.torchlip.functional.invertible_upsample(input: torch.Tensor, kernel_size: Union[int, Tuple[int, ...]]) torch.Tensor[source]

Upsamples the input in an invertible way. The number of elements in the output tensor is the same as the number of elements in the input tensor.

The number of input channels must be a multiple of the product of the kernel sizes, i.e.

C0mod(k1kl)C \equiv 0 \mod (k_1 * \ldots{} * k_l)

where CC is the number of inputs channels and kik_i the kernel size for dimension ii and ll the number of dimensions.

Parameters
  • input – A tensor of shape (N,C,W)(N, C, W), (N,C,W,H)(N, C, W, H) or (N,C,D,W,H)(N, C, D, W, H) to upsample.

  • kernel_size – The upsample scale. If a single-value is passed, the same value will be used alongside all dimensions, otherwise the length of kernel_size must match the number of dimensions of the input (1, 2 or 3).

Raises

ValueError – If there is a mismatch between kernel_size and the input shape.

Examples

>>> x = torch.rand(16, 128, 16, 8)
>>> x.shape
(16, 128, 16, 8)
>>> y = invertible_upsample(x, (2, 4))
>>> y.shape
(16, 16, 32, 32)

max_min

deel.torchlip.functional.max_min(input: torch.Tensor, dim: Optional[int] = None) torch.Tensor[source]

Applies max-min activation on the given tensor.

If input is a tensor of shape (N,C)(N, C) and dim is None, the output can be described as:

out(Ni,C2j)=max(input(Ni,Cj),0)out(Ni,C2j+1)=max(input(Ni,Cj),0)\text{out}(N_i, C_{2j}) = \max(\text{input}(N_i, C_j), 0)\\ \text{out}(N_i, C_{2j + 1}) = \max(-\text{input}(N_i, C_j), 0)

where NN is the batch size and CC is the size of the tensor.

Parameters
  • input – A tensor of arbitrary shape.

  • dim – The dimension to apply max-min. If None, will apply to the 0th dimension if the shape of input is (C)(C) or to the first if its (N,C,)(N, C, *).

Returns

A tensor of shape (2C)(2C) or (N,2C,)(N, 2C, *) depending on the shape of the input.

Note

M. Blot, M. Cord, et N. Thome, « Max-min convolutional neural networks for image classification », in 2016 IEEE International Conference on Image Processing (ICIP), Phoenix, AZ, USA, 2016, p. 3678‑3682.

group_sort

deel.torchlip.functional.group_sort(input: torch.Tensor, group_size: Optional[int] = None) torch.Tensor[source]

Applies GroupSort activation on the given tensor.

deel.torchlip.functional.group_sort_2(input: torch.Tensor) torch.Tensor[source]

Applies GroupSort-2 activation on the given tensor. This function is equivalent to group_sort(input, 2).

See also

group_sort()

deel.torchlip.functional.full_sort(input: torch.Tensor) torch.Tensor[source]

Applies FullSort activation on the given tensor. This function is equivalent to group_sort(input, None).

See also

group_sort()

others

deel.torchlip.functional.lipschitz_prelu(input: torch.Tensor, weight: torch.Tensor, k_coef_lip: float = 1.0) torch.Tensor[source]

Applies k-Lipschitz version of PReLU by clamping the weights

LPReLU(x)={x, if x0min(max(a,k),k)x, otherwise \text{LPReLU}(x) = \begin{cases} x, & \text{ if } x \geq 0 \\ \min(\max(a, -k), k) * x, & \text{ otherwise } \end{cases}

Loss functions

Binary losses

deel.torchlip.functional.kr_loss(input: torch.Tensor, target: torch.Tensor, true_values: Tuple[int, int] = (0, 1)) torch.Tensor[source]

Loss to estimate the Wasserstein-1 distance using Kantorovich-Rubinstein duality, as per

W(μ,ν)=supfLip1(Ω)Exμ[f(x)]Exν[f(x)]\mathcal{W}(\mu, \nu) = \sup\limits_{f\in{}Lip_1(\Omega)} \underset{\mathbf{x}\sim{}\mu}{\mathbb{E}}[f(\mathbf{x})] - \underset{\mathbf{x}\sim{}\nu}{\mathbb{E}}[f(\mathbf{x})]

where μ\mu and ν\nu are the distributions corresponding to the two possible labels as specific by true_values.

Parameters
  • input – Tensor of arbitrary shape.

  • target – Tensor of the same shape as input.

  • true_values – Tuple containing the two label for the predicted class.

Returns

The Wasserstein-1 loss between input and target.

deel.torchlip.functional.neg_kr_loss(input: torch.Tensor, target: torch.Tensor, true_values: Tuple[int, int] = (0, 1)) torch.Tensor[source]

Loss to estimate the negative wasserstein-1 distance using Kantorovich-Rubinstein duality.

Parameters
  • input – Tensor of arbitrary shape.

  • target – Tensor of the same shape as input.

  • true_values – Tuple containing the two label for the predicted classes.

Returns

The negative Wasserstein-1 loss between input and target.

See also

kr_loss()

deel.torchlip.functional.hinge_margin_loss(input: torch.Tensor, target: torch.Tensor, min_margin: float = 1) torch.Tensor[source]

Compute the hinge margin loss as per

Ex[max(0,1yf(x))]\underset{\mathbf{x}}{\mathbb{E}} [\max(0, 1 - \mathbf{y} f(\mathbf{x}))]
Parameters
  • input – Tensor of arbitrary shape.

  • target – Tensor of the same shape as input containing target labels (-1 and +1).

  • min_margin – The minimal margin to enforce.

Returns

The hinge margin loss.

deel.torchlip.functional.hkr_loss(input: torch.Tensor, target: torch.Tensor, alpha: float, min_margin: float = 1.0, true_values: Tuple[int, int] = (- 1, 1)) torch.Tensor[source]

Loss to estimate the wasserstein-1 distance with a hinge regularization using Kantorovich-Rubinstein duality.

Parameters
  • input – Tensor of arbitrary shape.

  • target – Tensor of the same shape as input.

  • alpha – Regularization factor between the hinge and the KR loss.

  • min_margin – Minimal margin for the hinge loss.

  • true_values – tuple containing the two label for each predicted class.

Returns

The regularized Wasserstein-1 loss.

multiclass losses

deel.torchlip.functional.kr_multiclass_loss(input: torch.Tensor, target: torch.Tensor) torch.Tensor[source]

Loss to estimate average of W1 distance using Kantorovich-Rubinstein duality over outputs. In this multiclass setup thr KR term is computed for each class and then averaged.

Parameters
  • input – Tensor of arbitrary shape.

  • target – Tensor of the same shape as input. target has to be one hot encoded (labels being 1s and 0s ).

Returns

The Wasserstein multiclass loss between input and target.

deel.torchlip.functional.hinge_multiclass_loss(input: torch.Tensor, target: torch.Tensor, min_margin: float = 1) torch.Tensor[source]

Loss to estimate the Hinge loss in a multiclass setup. It compute the elementwise hinge term. Note that this formulation differs from the one commonly found in tensorflow/pytorch (with marximise the difference between the two largest logits). This formulation is consistent with the binary classification loss used in a multiclass fashion.

Parameters
  • input – Tensor of arbitrary shape.

  • target – Tensor of the same shape as input containing one hot encoding target labels (0 and +1).

  • min_margin – The minimal margin to enforce.

Note

target should be one hot encoded. labels in (1,0)

Returns

The hinge margin multiclass loss.

deel.torchlip.functional.hkr_multiclass_loss(input: torch.Tensor, target: torch.Tensor, alpha: float = 0.0, min_margin: float = 1.0) torch.Tensor[source]

Loss to estimate the wasserstein-1 distance with a hinge regularization using Kantorovich-Rubinstein duality.

Parameters
  • input – Tensor of arbitrary shape.

  • target – Tensor of the same shape as input.

  • alpha – Regularization factor between the hinge and the KR loss.

  • min_margin – Minimal margin for the hinge loss.

  • true_values – tuple containing the two label for each predicted class.

Returns

The regularized Wasserstein-1 loss.


© Copyright 2020, IRT Antoine de Saint Exupéry - All rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry, CRIAQ and ANITI..

Built with Sphinx using PyTorch's theme provided originally by Read the Docs.