Skip to content

reparametrizers

BJORCK_PASS_THROUGH_ORTHO_PARAMS = OrthoParams(spectral_normalizer=ClassParam(BatchedPowerIteration, power_it_niter=3, eps=0.0001), orthogonalizer=ClassParam(BatchedBjorckOrthogonalization, beta=0.5, niters=12, pass_through=True)) module-attribute

Orthogonalization parameters that use the Bjorck orthogonalization method with a pass-through optimization. This configuration greatly reduces the consumed memory but at the cost of a slower convergence and worst perfomances.

CHOLESKY_ORTHO_PARAMS = OrthoParams(spectral_normalizer=BatchedIdentity, orthogonalizer=ClassParam(BatchedCholeskyOrthogonalization)) module-attribute

Setting that use the Cholesky orthogonalization method. This method is memory and time efficient but cannot converge to the exact orthogonal matrix (tests passing with epsilon=5e-5 meaning the layer may be 1.05 lipschitz).

CHOLESKY_STABLE_ORTHO_PARAMS = OrthoParams(spectral_normalizer=BatchedIdentity, orthogonalizer=ClassParam(BatchedCholeskyOrthogonalization, stable=True)) module-attribute

Setting that use the Cholesky orthogonalization method and stores some values for backward to ensure numerical stability.

DEFAULT_ORTHO_PARAMS = OrthoParams() module-attribute

The default orthogonalization parameters used by our library. Suitable for most applications and includes: - A BatchedPowerIteration for spectral normalization - A BatchedBjorckOrthogonalization for orthogonalization

DEFAULT_TEST_ORTHO_PARAMS = OrthoParams(spectral_normalizer=ClassParam(BatchedPowerIteration, power_it_niter=4, eps=0.0001), orthogonalizer=ClassParam(BatchedBjorckOrthogonalization, beta=0.5, niters=25)) module-attribute

Setting with more iterations to ensure that test passes with epsilon=1e-4.

EXP_ORTHO_PARAMS = OrthoParams(spectral_normalizer=ClassParam(BatchedPowerIteration, power_it_niter=3, eps=1e-06), orthogonalizer=ClassParam(BatchedExponentialOrthogonalization, niters=12)) module-attribute

Setting that use the exponential orthogonalization method with 12 iterations. The matrix is pre-conditionned with the power iteration method.

QR_ORTHO_PARAMS = OrthoParams(spectral_normalizer=ClassParam(BatchedPowerIteration, power_it_niter=3, eps=0.001), orthogonalizer=ClassParam(BatchedQROrthogonalization)) module-attribute

Setting that use the QR orthogonalization method. The matrix is pre-conditionned with the power iteration method.

BatchedBjorckOrthogonalization

Bases: Module

Source code in orthogonium\reparametrizers.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
class BatchedBjorckOrthogonalization(nn.Module):
    def __init__(self, weight_shape, beta=0.5, niters=12, pass_through=False):
        """
        Initialize the BatchedBjorckOrthogonalization module.

        This module implements the Björck orthogonalization method, which iteratively refines
        a weight matrix towards orthogonality. The method is especially effective when the
        weight matrix columns are nearly orthonormal. It balances computational efficiency
        with convergence speed through a user-defined `beta` parameter and iteration count.

        Args:
            weight_shape (tuple): The shape of the weight matrix to be orthogonalized.
            beta (float): Coefficient controlling the convergence of the orthogonalization process.
                Default is 0.5.
            niters (int): Number of iterations for the orthogonalization algorithm. Default is 12.
            pass_through (bool): If True, most iterations are performed without gradient computation,
                which can improve efficiency.
        """
        self.weight_shape = weight_shape
        self.beta = beta
        self.niters = niters
        self.pass_through = pass_through
        if weight_shape[-2] < weight_shape[-1]:
            self.wwtw_op = BatchedBjorckOrthogonalization.wwt_w_op
        else:
            self.wwtw_op = BatchedBjorckOrthogonalization.w_wtw_op
        super(BatchedBjorckOrthogonalization, self).__init__()

    @staticmethod
    def w_wtw_op(w):
        return w @ (w.transpose(-1, -2) @ w)

    @staticmethod
    def wwt_w_op(w):
        return (w @ w.transpose(-1, -2)) @ w

    def forward(self, w):
        """
        Apply the Björck orthogonalization process to the weight matrix.

        The algorithm adjusts the input matrix to approximate the closest orthogonal matrix
        by iteratively applying transformations based on the Björck algorithm.

        Args:
            w (torch.Tensor): The weight matrix to be orthogonalized.

        Returns:
            torch.Tensor: The orthogonalized weight matrix.
        """
        if self.pass_through:
            with torch.no_grad():
                for _ in range(self.niters):
                    w = (1 + self.beta) * w - self.beta * self.wwtw_op(w)
            # Final iteration without no_grad, using parameters:
            w = (1 + self.beta) * w - self.beta * self.wwtw_op(w)
        else:
            for _ in range(self.niters):
                w = (1 + self.beta) * w - self.beta * self.wwtw_op(w)
        return w

    def right_inverse(self, w):
        return w

__init__(weight_shape, beta=0.5, niters=12, pass_through=False)

Initialize the BatchedBjorckOrthogonalization module.

This module implements the Björck orthogonalization method, which iteratively refines a weight matrix towards orthogonality. The method is especially effective when the weight matrix columns are nearly orthonormal. It balances computational efficiency with convergence speed through a user-defined beta parameter and iteration count.

Parameters:

Name Type Description Default
weight_shape tuple

The shape of the weight matrix to be orthogonalized.

required
beta float

Coefficient controlling the convergence of the orthogonalization process. Default is 0.5.

0.5
niters int

Number of iterations for the orthogonalization algorithm. Default is 12.

12
pass_through bool

If True, most iterations are performed without gradient computation, which can improve efficiency.

False
Source code in orthogonium\reparametrizers.py
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
def __init__(self, weight_shape, beta=0.5, niters=12, pass_through=False):
    """
    Initialize the BatchedBjorckOrthogonalization module.

    This module implements the Björck orthogonalization method, which iteratively refines
    a weight matrix towards orthogonality. The method is especially effective when the
    weight matrix columns are nearly orthonormal. It balances computational efficiency
    with convergence speed through a user-defined `beta` parameter and iteration count.

    Args:
        weight_shape (tuple): The shape of the weight matrix to be orthogonalized.
        beta (float): Coefficient controlling the convergence of the orthogonalization process.
            Default is 0.5.
        niters (int): Number of iterations for the orthogonalization algorithm. Default is 12.
        pass_through (bool): If True, most iterations are performed without gradient computation,
            which can improve efficiency.
    """
    self.weight_shape = weight_shape
    self.beta = beta
    self.niters = niters
    self.pass_through = pass_through
    if weight_shape[-2] < weight_shape[-1]:
        self.wwtw_op = BatchedBjorckOrthogonalization.wwt_w_op
    else:
        self.wwtw_op = BatchedBjorckOrthogonalization.w_wtw_op
    super(BatchedBjorckOrthogonalization, self).__init__()

forward(w)

Apply the Björck orthogonalization process to the weight matrix.

The algorithm adjusts the input matrix to approximate the closest orthogonal matrix by iteratively applying transformations based on the Björck algorithm.

Parameters:

Name Type Description Default
w Tensor

The weight matrix to be orthogonalized.

required

Returns:

Type Description

torch.Tensor: The orthogonalized weight matrix.

Source code in orthogonium\reparametrizers.py
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
def forward(self, w):
    """
    Apply the Björck orthogonalization process to the weight matrix.

    The algorithm adjusts the input matrix to approximate the closest orthogonal matrix
    by iteratively applying transformations based on the Björck algorithm.

    Args:
        w (torch.Tensor): The weight matrix to be orthogonalized.

    Returns:
        torch.Tensor: The orthogonalized weight matrix.
    """
    if self.pass_through:
        with torch.no_grad():
            for _ in range(self.niters):
                w = (1 + self.beta) * w - self.beta * self.wwtw_op(w)
        # Final iteration without no_grad, using parameters:
        w = (1 + self.beta) * w - self.beta * self.wwtw_op(w)
    else:
        for _ in range(self.niters):
            w = (1 + self.beta) * w - self.beta * self.wwtw_op(w)
    return w

BatchedCholeskyOrthogonalization

Bases: Module

Source code in orthogonium\reparametrizers.py
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
class BatchedCholeskyOrthogonalization(nn.Module):
    def __init__(self, weight_shape, stable=False):
        """
        Initialize the BatchedCholeskyOrthogonalization module.

        This module orthogonalizes a weight matrix using the Cholesky decomposition method.
        It first computes the positive definite matrix \( V V^T \), then performs a Cholesky
        decomposition to obtain a lower triangular matrix. Solving the resulting triangular
        system yields an orthogonal matrix. This method is efficient and numerically stable,
        making it suitable for a wide range of applications.

        Args:
            weight_shape (tuple): The shape of the weight matrix.
            stable (bool): Whether to use the stable version of the Cholesky-based orthogonalization
                function, which adds a small positive diagonal element to ensure numerical stability.
                Default is False.
        """
        self.weight_shape = weight_shape
        super(BatchedCholeskyOrthogonalization, self).__init__()
        if stable:
            self.orth = BatchedCholeskyOrthogonalization.CholeskyOrthfn_stable.apply
        else:
            self.orth = BatchedCholeskyOrthogonalization.CholeskyOrthfn.apply

    # @staticmethod
    # def orth(X):
    #     S = X @ X.mT
    #     eps = S.diagonal(dim1=1, dim2=2).mean(1).mul(1e-3).detach()
    #     eye = torch.eye(S.size(-1), dtype=S.dtype, device=S.device)
    #     S = S + eps.view(-1, 1, 1) * eye.unsqueeze(0)
    #     L = torch.linalg.cholesky(S)
    #     W = torch.linalg.solve_triangular(L, X, upper=False)
    #     return W

    class CholeskyOrthfn(torch.autograd.Function):
        @staticmethod
        # def forward(ctx, X):
        #     S = X @ X.mT
        #     eps = S.diagonal(dim1=1, dim2=2).mean(1).mul(1e-3)
        #     eye = torch.eye(S.size(-1), dtype=S.dtype, device=S.device)
        #     S = S + eps.view(-1, 1, 1) * eye.unsqueeze(0)
        #     L = torch.linalg.cholesky(S)
        #     W = torch.linalg.solve_triangular(L, X, upper=False)
        #     ctx.save_for_backward(W, L)
        #     return W
        def forward(ctx, X):
            S = X @ X.mT
            eps = 1e-5  # A common stable choice
            S = S + eps * torch.eye(
                S.size(-1), dtype=S.dtype, device=S.device
            ).unsqueeze(0)
            L = torch.linalg.cholesky(S)
            W = torch.linalg.solve_triangular(L, X, upper=False)
            ctx.save_for_backward(W, L)
            return W

        @staticmethod
        def backward(ctx, grad_output):
            W, L = ctx.saved_tensors
            LmT = L.mT.contiguous()
            gB = torch.linalg.solve_triangular(LmT, grad_output, upper=True)
            gA = (-gB @ W.mT).tril()
            gS = (LmT @ gA).tril()
            gS = gS + gS.tril(-1).mT
            gS = torch.linalg.solve_triangular(LmT, gS, upper=True)
            gX = gS @ W + gB
            return gX

    class CholeskyOrthfn_stable(torch.autograd.Function):
        @staticmethod
        def forward(ctx, X):
            S = X @ X.mT
            eps = 1e-5  # A common stable choice
            S = S + eps * torch.eye(
                S.size(-1), dtype=S.dtype, device=S.device
            ).unsqueeze(0)
            L = torch.linalg.cholesky(S)
            W = torch.linalg.solve_triangular(L, X, upper=False)
            ctx.save_for_backward(X, W, L)
            return W

        @staticmethod
        def backward(ctx, grad_output):
            X, W, L = ctx.saved_tensors
            gB = torch.linalg.solve_triangular(L.mT, grad_output, upper=True)
            gA = (-gB @ W.mT).tril()
            gS = (L.mT @ gA).tril()
            gS = gS + gS.tril(-1).mT
            gS = torch.linalg.solve_triangular(L.mT, gS, upper=True)
            gS = torch.linalg.solve_triangular(L, gS, upper=False, left=False)
            gX = gS @ X + gB
            return gX

    def forward(self, w):
        """
        Apply Cholesky-based orthogonalization to the weight matrix.

        This method constructs a symmetric positive definite matrix from the input weight
        matrix, performs Cholesky decomposition, and solves the triangular system to produce
        an orthogonal matrix. It mimics the results of the Gram-Schmidt process but with
        improved numerical stability.

        Args:
            w (torch.Tensor): The weight matrix to be orthogonalized.

        Returns:
            torch.Tensor: The orthogonalized weight matrix.
        """
        return self.orth(w).view(*self.weight_shape)

    def right_inverse(self, w):
        return w

__init__(weight_shape, stable=False)

Initialize the BatchedCholeskyOrthogonalization module.

This module orthogonalizes a weight matrix using the Cholesky decomposition method. It first computes the positive definite matrix \( V V^T \), then performs a Cholesky decomposition to obtain a lower triangular matrix. Solving the resulting triangular system yields an orthogonal matrix. This method is efficient and numerically stable, making it suitable for a wide range of applications.

Parameters:

Name Type Description Default
weight_shape tuple

The shape of the weight matrix.

required
stable bool

Whether to use the stable version of the Cholesky-based orthogonalization function, which adds a small positive diagonal element to ensure numerical stability. Default is False.

False
Source code in orthogonium\reparametrizers.py
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
def __init__(self, weight_shape, stable=False):
    """
    Initialize the BatchedCholeskyOrthogonalization module.

    This module orthogonalizes a weight matrix using the Cholesky decomposition method.
    It first computes the positive definite matrix \( V V^T \), then performs a Cholesky
    decomposition to obtain a lower triangular matrix. Solving the resulting triangular
    system yields an orthogonal matrix. This method is efficient and numerically stable,
    making it suitable for a wide range of applications.

    Args:
        weight_shape (tuple): The shape of the weight matrix.
        stable (bool): Whether to use the stable version of the Cholesky-based orthogonalization
            function, which adds a small positive diagonal element to ensure numerical stability.
            Default is False.
    """
    self.weight_shape = weight_shape
    super(BatchedCholeskyOrthogonalization, self).__init__()
    if stable:
        self.orth = BatchedCholeskyOrthogonalization.CholeskyOrthfn_stable.apply
    else:
        self.orth = BatchedCholeskyOrthogonalization.CholeskyOrthfn.apply

forward(w)

Apply Cholesky-based orthogonalization to the weight matrix.

This method constructs a symmetric positive definite matrix from the input weight matrix, performs Cholesky decomposition, and solves the triangular system to produce an orthogonal matrix. It mimics the results of the Gram-Schmidt process but with improved numerical stability.

Parameters:

Name Type Description Default
w Tensor

The weight matrix to be orthogonalized.

required

Returns:

Type Description

torch.Tensor: The orthogonalized weight matrix.

Source code in orthogonium\reparametrizers.py
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
def forward(self, w):
    """
    Apply Cholesky-based orthogonalization to the weight matrix.

    This method constructs a symmetric positive definite matrix from the input weight
    matrix, performs Cholesky decomposition, and solves the triangular system to produce
    an orthogonal matrix. It mimics the results of the Gram-Schmidt process but with
    improved numerical stability.

    Args:
        w (torch.Tensor): The weight matrix to be orthogonalized.

    Returns:
        torch.Tensor: The orthogonalized weight matrix.
    """
    return self.orth(w).view(*self.weight_shape)

BatchedExponentialOrthogonalization

Bases: Module

Source code in orthogonium\reparametrizers.py
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
class BatchedExponentialOrthogonalization(nn.Module):
    def __init__(self, weight_shape, niters=7):
        """
        Initialize the BatchedExponentialOrthogonalization module.

        This module orthogonalizes a weight matrix using the exponential map of a skew-symmetric
        matrix. By converting the matrix into a skew-symmetric form and applying the matrix
        exponential, it produces an orthogonal matrix. This approach is particularly useful
        in contexts where smooth transitions between matrices are required.

        Non-square matrices are padded to the largest dimension to ensure that the matrix can
        be converted to a skew-symmetric matrix. The resulting matrix is cropped to the original
        dimension.

        Args:
            weight_shape (tuple): The shape of the weight matrix.
            niters (int): Number of iterations for the series expansion approximation of the
                matrix exponential. Default is 7.
        """
        self.weight_shape = weight_shape
        self.max_dim = max(weight_shape[-2:])
        self.niters = niters
        super(BatchedExponentialOrthogonalization, self).__init__()

    def forward(self, w):
        # fill w with zero to have a square matrix over the last two dimensions
        # if ((self.max_dim - w.shape[-1]) != 0) and ((self.max_dim - w.shape[-2]) != 0):
        w = torch.nn.functional.pad(
            w, (0, self.max_dim - w.shape[-1], 0, self.max_dim - w.shape[-2])
        )
        # makes w skew symmetric
        w = (w - w.transpose(-1, -2)) / 2
        acc = w
        res = torch.eye(acc.shape[-2], acc.shape[-1], device=w.device) + acc
        for i in range(2, self.niters):
            acc = torch.einsum("...ij,...jk->...ik", acc, w) / i
            res = res + acc
        # if transpose:
        #     res = res.transpose(-1, -2)
        res = res[..., : self.weight_shape[-2], : self.weight_shape[-1]]
        return res.contiguous()

    def right_inverse(self, w):
        return w

__init__(weight_shape, niters=7)

Initialize the BatchedExponentialOrthogonalization module.

This module orthogonalizes a weight matrix using the exponential map of a skew-symmetric matrix. By converting the matrix into a skew-symmetric form and applying the matrix exponential, it produces an orthogonal matrix. This approach is particularly useful in contexts where smooth transitions between matrices are required.

Non-square matrices are padded to the largest dimension to ensure that the matrix can be converted to a skew-symmetric matrix. The resulting matrix is cropped to the original dimension.

Parameters:

Name Type Description Default
weight_shape tuple

The shape of the weight matrix.

required
niters int

Number of iterations for the series expansion approximation of the matrix exponential. Default is 7.

7
Source code in orthogonium\reparametrizers.py
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
def __init__(self, weight_shape, niters=7):
    """
    Initialize the BatchedExponentialOrthogonalization module.

    This module orthogonalizes a weight matrix using the exponential map of a skew-symmetric
    matrix. By converting the matrix into a skew-symmetric form and applying the matrix
    exponential, it produces an orthogonal matrix. This approach is particularly useful
    in contexts where smooth transitions between matrices are required.

    Non-square matrices are padded to the largest dimension to ensure that the matrix can
    be converted to a skew-symmetric matrix. The resulting matrix is cropped to the original
    dimension.

    Args:
        weight_shape (tuple): The shape of the weight matrix.
        niters (int): Number of iterations for the series expansion approximation of the
            matrix exponential. Default is 7.
    """
    self.weight_shape = weight_shape
    self.max_dim = max(weight_shape[-2:])
    self.niters = niters
    super(BatchedExponentialOrthogonalization, self).__init__()

BatchedIdentity

Bases: Module

Source code in orthogonium\reparametrizers.py
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
class BatchedIdentity(nn.Module):
    def __init__(self, weight_shape):
        """
        Class representing a batched identity matrix with a specific weight shape. The
        matrix is initialized based on the provided shape of the weights. It is a
        convenient utility for applications where identity-like operations are
        required in a batched manner.

        Attributes:
            weight_shape (Tuple[int, int]): A tuple representing the shape of the
            weight matrix for each batch. (unused)

        Args:
            weight_shape: A tuple specifying the shape of the individual weight matrix.
        """
        super(BatchedIdentity, self).__init__()

    def forward(self, w):
        return w

    def right_inverse(self, w):
        return w

__init__(weight_shape)

Class representing a batched identity matrix with a specific weight shape. The matrix is initialized based on the provided shape of the weights. It is a convenient utility for applications where identity-like operations are required in a batched manner.

Attributes:

Name Type Description
weight_shape Tuple[int, int]

A tuple representing the shape of the

Parameters:

Name Type Description Default
weight_shape

A tuple specifying the shape of the individual weight matrix.

required
Source code in orthogonium\reparametrizers.py
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def __init__(self, weight_shape):
    """
    Class representing a batched identity matrix with a specific weight shape. The
    matrix is initialized based on the provided shape of the weights. It is a
    convenient utility for applications where identity-like operations are
    required in a batched manner.

    Attributes:
        weight_shape (Tuple[int, int]): A tuple representing the shape of the
        weight matrix for each batch. (unused)

    Args:
        weight_shape: A tuple specifying the shape of the individual weight matrix.
    """
    super(BatchedIdentity, self).__init__()

BatchedPowerIteration

Bases: Module

Source code in orthogonium\reparametrizers.py
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
class BatchedPowerIteration(nn.Module):
    def __init__(self, weight_shape, power_it_niter=3, eps=1e-12):
        """
        BatchedPowerIteration is a class that performs spectral normalization on weights
        using the power iteration method in a batched manner. It initializes singular
        vectors 'u' and 'v', which are used to approximate the largest singular value
        of the associated weight matrix during training. The L2 normalization is applied
        to stabilize these singular vector parameters.

        Attributes:
            weight_shape: tuple
                Shape of the weight tensor. Normalization is applied to the last two dimensions.
            power_it_niter: int
                Number of iterations to perform for the power iteration method.
            eps: float
                A small constant to ensure numerical stability during calculations. Used in the power iteration
                method to avoid dividing by zero.
        """
        super(BatchedPowerIteration, self).__init__()
        self.weight_shape = weight_shape
        self.power_it_niter = power_it_niter
        self.eps = eps
        # init u
        # u will be weight_shape[:-2] + (weight_shape[:-2], 1)
        # v will be weight_shape[:-2] + (weight_shape[:-1], 1,)
        self.u = nn.Parameter(
            torch.Tensor(torch.randn(*weight_shape[:-2], weight_shape[-2], 1)),
            requires_grad=False,
        )
        self.v = nn.Parameter(
            torch.Tensor(torch.randn(*weight_shape[:-2], weight_shape[-1], 1)),
            requires_grad=False,
        )
        parametrize.register_parametrization(
            self, "u", L2Normalize(dtype=self.u.dtype, dim=(-2))
        )
        parametrize.register_parametrization(
            self, "v", L2Normalize(dtype=self.v.dtype, dim=(-2))
        )

    def forward(self, X, init_u=None, n_iters=3, return_uv=True):
        for _ in range(n_iters):
            self.v = X.transpose(-1, -2) @ self.u
            self.u = X @ self.v
        # stop gradient on u and v
        u = self.u.detach()
        v = self.v.detach()
        # but keep gradient on s
        s = u.transpose(-1, -2) @ X @ v
        return X / (s + self.eps)

    def right_inverse(self, normalized_kernel):
        # we assume that the kernel is normalized
        return normalized_kernel.to(self.u.dtype)

__init__(weight_shape, power_it_niter=3, eps=1e-12)

BatchedPowerIteration is a class that performs spectral normalization on weights using the power iteration method in a batched manner. It initializes singular vectors 'u' and 'v', which are used to approximate the largest singular value of the associated weight matrix during training. The L2 normalization is applied to stabilize these singular vector parameters.

Attributes:

Name Type Description
weight_shape

tuple Shape of the weight tensor. Normalization is applied to the last two dimensions.

power_it_niter

int Number of iterations to perform for the power iteration method.

eps

float A small constant to ensure numerical stability during calculations. Used in the power iteration method to avoid dividing by zero.

Source code in orthogonium\reparametrizers.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def __init__(self, weight_shape, power_it_niter=3, eps=1e-12):
    """
    BatchedPowerIteration is a class that performs spectral normalization on weights
    using the power iteration method in a batched manner. It initializes singular
    vectors 'u' and 'v', which are used to approximate the largest singular value
    of the associated weight matrix during training. The L2 normalization is applied
    to stabilize these singular vector parameters.

    Attributes:
        weight_shape: tuple
            Shape of the weight tensor. Normalization is applied to the last two dimensions.
        power_it_niter: int
            Number of iterations to perform for the power iteration method.
        eps: float
            A small constant to ensure numerical stability during calculations. Used in the power iteration
            method to avoid dividing by zero.
    """
    super(BatchedPowerIteration, self).__init__()
    self.weight_shape = weight_shape
    self.power_it_niter = power_it_niter
    self.eps = eps
    # init u
    # u will be weight_shape[:-2] + (weight_shape[:-2], 1)
    # v will be weight_shape[:-2] + (weight_shape[:-1], 1,)
    self.u = nn.Parameter(
        torch.Tensor(torch.randn(*weight_shape[:-2], weight_shape[-2], 1)),
        requires_grad=False,
    )
    self.v = nn.Parameter(
        torch.Tensor(torch.randn(*weight_shape[:-2], weight_shape[-1], 1)),
        requires_grad=False,
    )
    parametrize.register_parametrization(
        self, "u", L2Normalize(dtype=self.u.dtype, dim=(-2))
    )
    parametrize.register_parametrization(
        self, "v", L2Normalize(dtype=self.v.dtype, dim=(-2))
    )

BatchedQROrthogonalization

Bases: Module

Source code in orthogonium\reparametrizers.py
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
class BatchedQROrthogonalization(nn.Module):
    def __init__(self, weight_shape):
        """
        Initialize the BatchedQROrthogonalization module.

        This module uses QR decomposition to orthogonalize a weight matrix in a batched manner.
        It computes the orthogonal component (`Q`) from the decomposition, ensuring that the
        output satisfies orthogonality constraints.

        Args:
            weight_shape (tuple): The shape of the weight matrix to be orthogonalized.
        """
        super(BatchedQROrthogonalization, self).__init__()

    def forward(self, w):
        """
        Perform QR decomposition to compute the orthogonalized weight matrix.

        The QR decomposition splits the input matrix into an orthogonal matrix (`Q`) and
        an upper triangular matrix (`R`). This module returns the orthogonal component.

        Args:
            w (torch.Tensor): The weight matrix to be orthogonalized.

        Returns:
            torch.Tensor: The orthogonalized weight matrix (`Q` from the QR decomposition).
        """
        transpose = w.shape[-2] < w.shape[-1]
        if transpose:
            w = w.transpose(-1, -2)
        q, r = torch.linalg.qr(w, mode="reduced")
        # compute the sign of the diagonal of d
        diag_sign = torch.sign(torch.diagonal(r, dim1=-2, dim2=-1)).unsqueeze(-2)
        # multiply the sign with the diagonal of r
        q = q * diag_sign
        if transpose:
            q = q.transpose(-1, -2)
        return q.contiguous()

    def right_inverse(self, w):
        return w

__init__(weight_shape)

Initialize the BatchedQROrthogonalization module.

This module uses QR decomposition to orthogonalize a weight matrix in a batched manner. It computes the orthogonal component (Q) from the decomposition, ensuring that the output satisfies orthogonality constraints.

Parameters:

Name Type Description Default
weight_shape tuple

The shape of the weight matrix to be orthogonalized.

required
Source code in orthogonium\reparametrizers.py
348
349
350
351
352
353
354
355
356
357
358
359
def __init__(self, weight_shape):
    """
    Initialize the BatchedQROrthogonalization module.

    This module uses QR decomposition to orthogonalize a weight matrix in a batched manner.
    It computes the orthogonal component (`Q`) from the decomposition, ensuring that the
    output satisfies orthogonality constraints.

    Args:
        weight_shape (tuple): The shape of the weight matrix to be orthogonalized.
    """
    super(BatchedQROrthogonalization, self).__init__()

forward(w)

Perform QR decomposition to compute the orthogonalized weight matrix.

The QR decomposition splits the input matrix into an orthogonal matrix (Q) and an upper triangular matrix (R). This module returns the orthogonal component.

Parameters:

Name Type Description Default
w Tensor

The weight matrix to be orthogonalized.

required

Returns:

Type Description

torch.Tensor: The orthogonalized weight matrix (Q from the QR decomposition).

Source code in orthogonium\reparametrizers.py
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
def forward(self, w):
    """
    Perform QR decomposition to compute the orthogonalized weight matrix.

    The QR decomposition splits the input matrix into an orthogonal matrix (`Q`) and
    an upper triangular matrix (`R`). This module returns the orthogonal component.

    Args:
        w (torch.Tensor): The weight matrix to be orthogonalized.

    Returns:
        torch.Tensor: The orthogonalized weight matrix (`Q` from the QR decomposition).
    """
    transpose = w.shape[-2] < w.shape[-1]
    if transpose:
        w = w.transpose(-1, -2)
    q, r = torch.linalg.qr(w, mode="reduced")
    # compute the sign of the diagonal of d
    diag_sign = torch.sign(torch.diagonal(r, dim1=-2, dim2=-1)).unsqueeze(-2)
    # multiply the sign with the diagonal of r
    q = q * diag_sign
    if transpose:
        q = q.transpose(-1, -2)
    return q.contiguous()

L2Normalize

Bases: Module

Source code in orthogonium\reparametrizers.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class L2Normalize(nn.Module):
    def __init__(self, dtype, dim=None):
        """
        A class that performs L2 normalization for the given input tensor.

        L2 normalization is a process that normalizes the input over a specified
        dimension such that the sum of squares of the elements along that
        dimension equals 1. It ensures that the resulting tensor has a unit norm.
        This operation is widely used in machine learning and deep learning
        applications to standardize feature representations.

        Attributes:
            dim (Optional[int]): The specific dimension along which normalization
                is performed. If None, normalization is done over all dimensions.
            dtype (Any): The data type of the tensor to be normalized.

        Parameters:
            dtype: The data type of the tensor to be normalized.
            dim: An optional integer specifying the dimension along which to
                normalize. If not provided, the input will be normalized globally
                across all dimensions.
        """
        super(L2Normalize, self).__init__()
        self.dim = dim
        self.dtype = dtype

    def forward(self, x):
        return x / (torch.norm(x, dim=self.dim, keepdim=True, dtype=self.dtype) + 1e-8)

    def right_inverse(self, x):
        return x / (torch.norm(x, dim=self.dim, keepdim=True, dtype=self.dtype) + 1e-8)

__init__(dtype, dim=None)

A class that performs L2 normalization for the given input tensor.

L2 normalization is a process that normalizes the input over a specified dimension such that the sum of squares of the elements along that dimension equals 1. It ensures that the resulting tensor has a unit norm. This operation is widely used in machine learning and deep learning applications to standardize feature representations.

Attributes:

Name Type Description
dim Optional[int]

The specific dimension along which normalization is performed. If None, normalization is done over all dimensions.

dtype Any

The data type of the tensor to be normalized.

Parameters:

Name Type Description Default
dtype

The data type of the tensor to be normalized.

required
dim

An optional integer specifying the dimension along which to normalize. If not provided, the input will be normalized globally across all dimensions.

None
Source code in orthogonium\reparametrizers.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def __init__(self, dtype, dim=None):
    """
    A class that performs L2 normalization for the given input tensor.

    L2 normalization is a process that normalizes the input over a specified
    dimension such that the sum of squares of the elements along that
    dimension equals 1. It ensures that the resulting tensor has a unit norm.
    This operation is widely used in machine learning and deep learning
    applications to standardize feature representations.

    Attributes:
        dim (Optional[int]): The specific dimension along which normalization
            is performed. If None, normalization is done over all dimensions.
        dtype (Any): The data type of the tensor to be normalized.

    Parameters:
        dtype: The data type of the tensor to be normalized.
        dim: An optional integer specifying the dimension along which to
            normalize. If not provided, the input will be normalized globally
            across all dimensions.
    """
    super(L2Normalize, self).__init__()
    self.dim = dim
    self.dtype = dtype

OrthoParams dataclass

Represents the parameters and configurations used for orthogonalization and spectral normalization.

This class encapsulates the necessary modules and settings required for performing spectral normalization and orthogonalization of tensors in a parameterized way. It accommodates various implementations of normalizers and orthogonalization techniques to provide flexibility in their application. This way we can easily switch between different normalization techniques inside our layer despite that each normalization have different parameters.

Attributes:

Name Type Description
spectral_normalizer Callable[Tuple[int, ...], Module]

A callable that produces a module for spectral normalization. Default is configured to use BatchedPowerIteration with specific parameters. This callable can be provided either as a functool.partial or as a orthogonium.ClassParam. It will recieve the shape of the weight tensor as its argument.

orthogonalizer Callable[Tuple[int, ...], Module]

A callable that produces a module for orthogonalization. Default is configured to use BatchedBjorckOrthogonalization with specific parameters. This callable can be provided either as a functool.partial or as a orthogonium.ClassParam. It will recieve the shape of the weight tensor as its argument.

Source code in orthogonium\reparametrizers.py
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
@dataclass
class OrthoParams:
    """
    Represents the parameters and configurations used for orthogonalization
    and spectral normalization.

    This class encapsulates the necessary modules and settings required
    for performing spectral normalization and orthogonalization of tensors
    in a parameterized way. It accommodates various implementations of
    normalizers and orthogonalization techniques to provide flexibility
    in their application. This way we can easily switch between different
    normalization techniques inside our layer despite that each normalization
    have different parameters.

    Attributes:
        spectral_normalizer (Callable[Tuple[int, ...], nn.Module]): A callable
            that produces a module for spectral normalization. Default is
            configured to use BatchedPowerIteration with specific parameters.
            This callable can be provided either as a `functool.partial` or as a
            `orthogonium.ClassParam`. It will recieve the shape of the weight tensor as its
            argument.
        orthogonalizer (Callable[Tuple[int, ...], nn.Module]): A callable
            that produces a module for orthogonalization. Default is
            configured to use BatchedBjorckOrthogonalization with specific
            parameters. This callable can be provided either as a `functool.partial` or as a
            `orthogonium.ClassParam`. It will recieve the shape of the weight tensor as its argument.
    """

    # spectral_normalizer: Callable[Tuple[int, ...], nn.Module] = BatchedIdentity
    spectral_normalizer: Callable[Tuple[int, ...], nn.Module] = ClassParam(  # type: ignore
        BatchedPowerIteration, power_it_niter=3, eps=1e-6
    )
    orthogonalizer: Callable[Tuple[int, ...], nn.Module] = ClassParam(  # type: ignore
        BatchedBjorckOrthogonalization,
        beta=0.5,
        niters=12,
        pass_through=False,
        # ClassParam(BatchedExponentialOrthogonalization, niters=12)
        # BatchedCholeskyOrthogonalization,
        # BatchedQROrthogonalization,
    )