Skip to content

deel.lip.normalizers

This module contains computation function, for Bjorck and spectral normalization. This is done for internal use only.

bjorck_normalization

bjorck_normalization(
    w,
    eps=DEFAULT_EPS_BJORCK,
    beta=DEFAULT_BETA_BJORCK,
    maxiter=DEFAULT_MAXITER_BJORCK,
)

apply Bjorck normalization on w.

PARAMETER DESCRIPTION
w

weight to normalize, in order to work properly, we must have max_eigenval(w) ~= 1

TYPE: Tensor

eps

epsilon stopping criterion: norm(wt - wt-1) must be less than eps

TYPE: float DEFAULT: DEFAULT_EPS_BJORCK

beta

beta used in each iteration, must be in the interval ]0, 0.5]

TYPE: float DEFAULT: DEFAULT_BETA_BJORCK

maxiter

maximum number of iterations for the algorithm

TYPE: int DEFAULT: DEFAULT_MAXITER_BJORCK

RETURNS DESCRIPTION

tf.Tensor: the orthonormal weights

Source code in deel/lip/normalizers.py
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
def bjorck_normalization(
    w, eps=DEFAULT_EPS_BJORCK, beta=DEFAULT_BETA_BJORCK, maxiter=DEFAULT_MAXITER_BJORCK
):
    """
    apply Bjorck normalization on w.

    Args:
        w (tf.Tensor): weight to normalize, in order to work properly, we must have
            max_eigenval(w) ~= 1
        eps (float): epsilon stopping criterion: norm(wt - wt-1) must be less than eps
        beta (float): beta used in each iteration, must be in the interval ]0, 0.5]
        maxiter (int): maximum number of iterations for the algorithm

    Returns:
        tf.Tensor: the orthonormal weights

    """
    # create a fake old_w that does'nt pass the loop condition
    # it won't affect computation as the first action done in the loop overwrite it.
    old_w = 10 * w
    # define the loop condition

    def cond(w, old_w):
        return tf.linalg.norm(w - old_w) >= eps

    # define the loop body
    def body(w, old_w):
        old_w = w
        w = (1 + beta) * w - beta * _wwtw(w)
        return w, old_w

    # apply the loop
    w, old_w = tf.while_loop(
        cond,
        body,
        (w, old_w),
        parallel_iterations=30,
        maximum_iterations=maxiter,
        swap_memory=SWAP_MEMORY,
    )
    return w

get_conv_operators

get_conv_operators(
    kernel,
    u_shape,
    stride=1.0,
    conv_first=True,
    pad_func=None,
)

Return two functions corresponding to the linear convolution operator and its adjoint.

PARAMETER DESCRIPTION
kernel

the convolution kernel to normalize

TYPE: Tensor

u_shape

shape of a singular vector (as a 4D tensor).

TYPE: tuple

stride

stride parameter of convolutions. Defaults to 1.

TYPE: int DEFAULT: 1.0

conv_first

RO or CO case , should be True in CO case (stride^2*C<M). Defaults to True.

TYPE: bool DEFAULT: True

pad_func

function for applying padding (None is padding same). Defaults to None.

TYPE: Callable DEFAULT: None

RETURNS DESCRIPTION
tuple

two functions for the linear convolution operator and its adjoint operator.

Source code in deel/lip/normalizers.py
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
def get_conv_operators(kernel, u_shape, stride=1.0, conv_first=True, pad_func=None):
    """
    Return two functions corresponding to the linear convolution operator and its
    adjoint.

    Args:
        kernel (tf.Tensor): the convolution kernel to normalize
        u_shape (tuple): shape of a singular vector (as a 4D tensor).
        stride (int, optional): stride parameter of convolutions. Defaults to 1.
        conv_first (bool, optional): RO or CO case , should be True in CO case
            (stride^2*C<M). Defaults to True.
        pad_func (Callable, optional): function for applying padding (None is padding
            same). Defaults to None.

    Returns:
        tuple: two functions for the linear convolution operator and its adjoint
            operator.
    """

    def identity(x):
        return x

    # If pad_func is None, standard convolution with SAME padding
    # Else, pad_func padding function (externally defined)
    #       + standard convolution with VALID padding.
    if pad_func is None:
        pad_type = "SAME"
        _pad_func = identity
    else:
        pad_type = "VALID"
        _pad_func = pad_func

    def _conv(u, w, stride):
        u_pad = _pad_func(u)
        return tf.nn.conv2d(u_pad, w, stride, pad_type)

    def _conv_transpose(u, w, output_shape, stride):
        if pad_func is None:
            return tf.nn.conv2d_transpose(u, w, output_shape, stride, pad_type)
        else:
            u_upscale = _zero_upscale2D(u, (stride, stride))
            w_adj = _maybe_transpose_kernel(w, True)
            return _conv(u_upscale, w_adj, stride=1)

    if conv_first:

        def linear_op(u):
            return _conv(u, kernel, stride)

        def adjoint_op(v):
            return _conv_transpose(v, kernel, u_shape, stride)

    else:
        v_shape = (
            (u_shape[0],)
            + (u_shape[1] * stride, u_shape[2] * stride)
            + (kernel.shape[-2],)
        )

        def linear_op(u):
            return _conv_transpose(u, kernel, v_shape, stride)

        def adjoint_op(v):
            return _conv(v, kernel, stride)

    return linear_op, adjoint_op

reshaped_kernel_orthogonalization

reshaped_kernel_orthogonalization(
    kernel,
    u,
    adjustment_coef,
    eps_spectral=DEFAULT_EPS_SPECTRAL,
    eps_bjorck=DEFAULT_EPS_BJORCK,
    beta=DEFAULT_BETA_BJORCK,
    maxiter_spectral=DEFAULT_MAXITER_SPECTRAL,
    maxiter_bjorck=DEFAULT_MAXITER_BJORCK,
)

Perform reshaped kernel orthogonalization (RKO) to the kernel given as input. It apply the power method to find the largest singular value and apply the Bjorck algorithm to the rescaled kernel. This greatly improve the stability and and speed convergence of the bjorck algorithm.

PARAMETER DESCRIPTION
kernel

the kernel to orthogonalize

TYPE: Tensor

u

the vector used to do the power iteration method

TYPE: Tensor

adjustment_coef

the adjustment coefficient as used in convolution

TYPE: float

eps_spectral

stopping criterion in spectral algorithm

TYPE: float DEFAULT: DEFAULT_EPS_SPECTRAL

eps_bjorck

stopping criterion in bjorck algorithm

TYPE: float DEFAULT: DEFAULT_EPS_BJORCK

beta

the beta used in the bjorck algorithm

TYPE: float DEFAULT: DEFAULT_BETA_BJORCK

maxiter_spectral

maximum number of iterations for the power iteration

TYPE: int DEFAULT: DEFAULT_MAXITER_SPECTRAL

maxiter_bjorck

maximum number of iterations for bjorck algorithm

TYPE: int DEFAULT: DEFAULT_MAXITER_BJORCK

RETURNS DESCRIPTION

tf.Tensor: the orthogonalized kernel, the new u, and sigma which is the largest singular value

Source code in deel/lip/normalizers.py
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
def reshaped_kernel_orthogonalization(
    kernel,
    u,
    adjustment_coef,
    eps_spectral=DEFAULT_EPS_SPECTRAL,
    eps_bjorck=DEFAULT_EPS_BJORCK,
    beta=DEFAULT_BETA_BJORCK,
    maxiter_spectral=DEFAULT_MAXITER_SPECTRAL,
    maxiter_bjorck=DEFAULT_MAXITER_BJORCK,
):
    """
    Perform reshaped kernel orthogonalization (RKO) to the kernel given as input. It
    apply the power method to find the largest singular value and apply the Bjorck
    algorithm to the rescaled kernel. This greatly improve the stability and and
    speed convergence of the bjorck algorithm.

    Args:
        kernel (tf.Tensor): the kernel to orthogonalize
        u (tf.Tensor): the vector used to do the power iteration method
        adjustment_coef (float): the adjustment coefficient as used in convolution
        eps_spectral (float): stopping criterion in spectral algorithm
        eps_bjorck (float): stopping criterion in bjorck algorithm
        beta (float): the beta used in the bjorck algorithm
        maxiter_spectral (int): maximum number of iterations for the power iteration
        maxiter_bjorck (int): maximum number of iterations for bjorck algorithm

    Returns:
        tf.Tensor: the orthogonalized kernel, the new u, and sigma which is the largest
            singular value

    """
    W_shape = kernel.shape
    # Flatten the Tensor
    W_reshaped = tf.reshape(kernel, [-1, W_shape[-1]])
    W_bar, u, sigma = spectral_normalization(
        W_reshaped, u, eps=eps_spectral, maxiter=maxiter_spectral
    )
    if (eps_bjorck is not None) and (beta is not None):
        W_bar = bjorck_normalization(
            W_bar, eps=eps_bjorck, beta=beta, maxiter=maxiter_bjorck
        )
    W_bar = W_bar * adjustment_coef
    W_bar = K.reshape(W_bar, kernel.shape)
    return W_bar, u, sigma

set_stop_grad_spectral

set_stop_grad_spectral(value)

Set the global STOP_GRAD_SPECTRAL to values. This function must be called before constructing the model (first call of reshaped_kernel_orthogonalization) in order to be accounted.

PARAMETER DESCRIPTION
value

boolean, when set to True, disable back-propagation through the power iteration algorithm. The back-propagation will account how updates affects the maximum singular value but not how it affects the largest singular vector. When set to False, back-propagate through the while loop.

TYPE: bool

Source code in deel/lip/normalizers.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def set_stop_grad_spectral(value: bool):
    """
    Set the global STOP_GRAD_SPECTRAL to values. This function must be called before
    constructing the model (first call of `reshaped_kernel_orthogonalization`) in
    order to be accounted.

    Args:
        value: boolean, when set to True, disable back-propagation through the power
            iteration algorithm. The back-propagation will account how updates affects
            the maximum singular value but not how it affects the largest singular
            vector. When set to False, back-propagate through the while loop.

    """
    global STOP_GRAD_SPECTRAL
    STOP_GRAD_SPECTRAL = value

set_swap_memory

set_swap_memory(value)

Set the global SWAP_MEMORY to values. This function must be called before constructing the model (first call of reshaped_kernel_orthogonalization) in order to be accounted.

PARAMETER DESCRIPTION
value

boolean that will be used as the swap_memory parameter in while loops in spectral and bjorck algorithms.

TYPE: bool

Source code in deel/lip/normalizers.py
23
24
25
26
27
28
29
30
31
32
33
34
35
def set_swap_memory(value: bool):
    """
    Set the global SWAP_MEMORY to values. This function must be called before
    constructing the model (first call of `reshaped_kernel_orthogonalization`) in
    order to be accounted.

    Args:
        value: boolean that will be used as the swap_memory parameter in while loops
            in spectral and bjorck algorithms.

    """
    global SWAP_MEMORY
    SWAP_MEMORY = value

spectral_normalization

spectral_normalization(
    kernel,
    u,
    eps=DEFAULT_EPS_SPECTRAL,
    maxiter=DEFAULT_MAXITER_SPECTRAL,
)

Normalize the kernel to have its maximum singular value equal to 1.

PARAMETER DESCRIPTION
kernel

the kernel to normalize, assuming a 2D kernel.

TYPE: Tensor

u

initialization of the maximum singular vector.

TYPE: Tensor

eps

stopping criterion of the algorithm, when norm(u[t] - u[t-1]) is less than eps. Defaults to DEFAULT_EPS_SPECTRAL.

TYPE: float DEFAULT: DEFAULT_EPS_SPECTRAL

maxiter

maximum number of iterations for the algorithm. Defaults to DEFAULT_MAXITER_SPECTRAL.

TYPE: int DEFAULT: DEFAULT_MAXITER_SPECTRAL

RETURNS DESCRIPTION

the normalized kernel, the maximum singular vector, and the maximum singular value.

Source code in deel/lip/normalizers.py
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
def spectral_normalization(
    kernel, u, eps=DEFAULT_EPS_SPECTRAL, maxiter=DEFAULT_MAXITER_SPECTRAL
):
    """
    Normalize the kernel to have its maximum singular value equal to 1.

    Args:
        kernel (tf.Tensor): the kernel to normalize, assuming a 2D kernel.
        u (tf.Tensor): initialization of the maximum singular vector.
        eps (float, optional): stopping criterion of the algorithm, when
            norm(u[t] - u[t-1]) is less than eps. Defaults to DEFAULT_EPS_SPECTRAL.
        maxiter (int, optional): maximum number of iterations for the algorithm.
            Defaults to DEFAULT_MAXITER_SPECTRAL.

    Returns:
        the normalized kernel, the maximum singular vector, and the maximum singular
            value.
    """

    if u is None:
        u = tf.random.uniform(
            shape=(1, kernel.shape[-1]), minval=0.0, maxval=1.0, dtype=kernel.dtype
        )

    def linear_op(u):
        return u @ tf.transpose(kernel)

    def adjoint_op(v):
        return v @ kernel

    u = _power_iteration(linear_op, adjoint_op, u, eps, maxiter)

    # Compute the largest singular value and the normalized kernel.
    # We assume that in the worst case we converged to sigma + eps (as u and v are
    # normalized after each iteration)
    # In order to be sure that operator norm of normalized kernel is strictly less than
    # one we use sigma + eps, which ensures stability of Björck algorithm even when
    # beta=0.5
    sigma = tf.reshape(tf.norm(linear_op(u)), (1, 1))
    normalized_kernel = kernel / (sigma + eps)
    return normalized_kernel, u, sigma

spectral_normalization_conv

spectral_normalization_conv(
    kernel,
    u,
    stride=1.0,
    conv_first=True,
    pad_func=None,
    eps=DEFAULT_EPS_SPECTRAL,
    maxiter=DEFAULT_MAXITER_SPECTRAL,
)

Normalize the convolution kernel to have its max eigenvalue == 1.

PARAMETER DESCRIPTION
kernel

the convolution kernel to normalize

TYPE: Tensor

u

initialization for the max eigen vector (as a 4d tensor)

TYPE: Tensor

stride

stride parameter of convolutions

TYPE: int DEFAULT: 1.0

conv_first

RO or CO case , should be True in CO case (stride^2*C<M)

TYPE: bool DEFAULT: True

pad_func

function for applying padding (None is padding same)

TYPE: Callable DEFAULT: None

eps

epsilon stopping criterion: norm(ut - ut-1) must be less than eps

TYPE: float DEFAULT: DEFAULT_EPS_SPECTRAL

maxiter

maximum number of iterations for the power iteration algorithm.

TYPE: int DEFAULT: DEFAULT_MAXITER_SPECTRAL

RETURNS DESCRIPTION

the normalized kernel w_bar, the maximum eigen vector, and the maximum eigen value

Source code in deel/lip/normalizers.py
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
def spectral_normalization_conv(
    kernel,
    u,
    stride=1.0,
    conv_first=True,
    pad_func=None,
    eps=DEFAULT_EPS_SPECTRAL,
    maxiter=DEFAULT_MAXITER_SPECTRAL,
):
    """
    Normalize the convolution kernel to have its max eigenvalue == 1.

    Args:
        kernel (tf.Tensor): the convolution kernel to normalize
        u (tf.Tensor): initialization for the max eigen vector (as a 4d tensor)
        stride (int): stride parameter of convolutions
        conv_first (bool): RO or CO case , should be True in CO case (stride^2*C<M)
        pad_func (Callable): function for applying padding (None is padding same)
        eps (float): epsilon stopping criterion: norm(ut - ut-1) must be less than eps
        maxiter (int): maximum number of iterations for the power iteration algorithm.

    Returns:
        the normalized kernel w_bar, the maximum eigen vector, and the maximum eigen
            value
    """

    if eps < 0:
        return kernel, u, 1.0

    linear_op, adjoint_op = get_conv_operators(
        kernel, u.shape, stride, conv_first, pad_func
    )

    u = tf.math.l2_normalize(u) + tf.random.uniform(u.shape, minval=-eps, maxval=eps)
    u = _power_iteration(linear_op, adjoint_op, u, eps, maxiter)

    # Compute the largest singular value and the normalized kernel
    sigma = tf.norm(linear_op(u))
    normalized_kernel = kernel / (sigma + eps)
    return normalized_kernel, u, sigma