reparametrizers
BJORCK_PASS_THROUGH_ORTHO_PARAMS = OrthoParams(spectral_normalizer=ClassParam(BatchedPowerIteration, power_it_niter=3, eps=0.0001), orthogonalizer=ClassParam(BatchedBjorckOrthogonalization, beta=0.5, niters=12, pass_through=True))
module-attribute
¶
Orthogonalization parameters that use the Bjorck orthogonalization method with a pass-through optimization. This configuration greatly reduces the consumed memory but at the cost of a slower convergence and worst perfomances.
CHOLESKY_ORTHO_PARAMS = OrthoParams(spectral_normalizer=BatchedIdentity, orthogonalizer=ClassParam(BatchedCholeskyOrthogonalization))
module-attribute
¶
Setting that use the Cholesky orthogonalization method. This method is memory and time efficient but cannot converge to the exact orthogonal matrix (tests passing with epsilon=5e-5 meaning the layer may be 1.05 lipschitz).
CHOLESKY_STABLE_ORTHO_PARAMS = OrthoParams(spectral_normalizer=BatchedIdentity, orthogonalizer=ClassParam(BatchedCholeskyOrthogonalization, stable=True))
module-attribute
¶
Setting that use the Cholesky orthogonalization method and stores some values for backward to ensure numerical stability.
DEFAULT_ORTHO_PARAMS = OrthoParams()
module-attribute
¶
The default orthogonalization parameters used by our library.
Suitable for most applications and includes:
- A BatchedPowerIteration
for spectral normalization
- A BatchedBjorckOrthogonalization
for orthogonalization
DEFAULT_TEST_ORTHO_PARAMS = OrthoParams(spectral_normalizer=ClassParam(BatchedPowerIteration, power_it_niter=4, eps=0.0001), orthogonalizer=ClassParam(BatchedBjorckOrthogonalization, beta=0.5, niters=25))
module-attribute
¶
Setting with more iterations to ensure that test passes with epsilon=1e-4.
EXP_ORTHO_PARAMS = OrthoParams(spectral_normalizer=ClassParam(BatchedPowerIteration, power_it_niter=3, eps=1e-06), orthogonalizer=ClassParam(BatchedExponentialOrthogonalization, niters=12))
module-attribute
¶
Setting that use the exponential orthogonalization method with 12 iterations. The matrix is pre-conditionned with the power iteration method.
QR_ORTHO_PARAMS = OrthoParams(spectral_normalizer=ClassParam(BatchedPowerIteration, power_it_niter=3, eps=0.001), orthogonalizer=ClassParam(BatchedQROrthogonalization))
module-attribute
¶
Setting that use the QR orthogonalization method. The matrix is pre-conditionned with the power iteration method.
BatchedBjorckOrthogonalization
¶
Bases: Module
Source code in orthogonium\reparametrizers.py
123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
|
__init__(weight_shape, beta=0.5, niters=12, pass_through=False)
¶
Initialize the BatchedBjorckOrthogonalization module.
This module implements the Björck orthogonalization method, which iteratively refines
a weight matrix towards orthogonality. The method is especially effective when the
weight matrix columns are nearly orthonormal. It balances computational efficiency
with convergence speed through a user-defined beta
parameter and iteration count.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
weight_shape
|
tuple
|
The shape of the weight matrix to be orthogonalized. |
required |
beta
|
float
|
Coefficient controlling the convergence of the orthogonalization process. Default is 0.5. |
0.5
|
niters
|
int
|
Number of iterations for the orthogonalization algorithm. Default is 12. |
12
|
pass_through
|
bool
|
If True, most iterations are performed without gradient computation, which can improve efficiency. |
False
|
Source code in orthogonium\reparametrizers.py
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
|
forward(w)
¶
Apply the Björck orthogonalization process to the weight matrix.
The algorithm adjusts the input matrix to approximate the closest orthogonal matrix by iteratively applying transformations based on the Björck algorithm.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
w
|
Tensor
|
The weight matrix to be orthogonalized. |
required |
Returns:
Type | Description |
---|---|
torch.Tensor: The orthogonalized weight matrix. |
Source code in orthogonium\reparametrizers.py
159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
|
BatchedCholeskyOrthogonalization
¶
Bases: Module
Source code in orthogonium\reparametrizers.py
187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 |
|
__init__(weight_shape, stable=False)
¶
Initialize the BatchedCholeskyOrthogonalization module.
This module orthogonalizes a weight matrix using the Cholesky decomposition method. It first computes the positive definite matrix \( V V^T \), then performs a Cholesky decomposition to obtain a lower triangular matrix. Solving the resulting triangular system yields an orthogonal matrix. This method is efficient and numerically stable, making it suitable for a wide range of applications.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
weight_shape
|
tuple
|
The shape of the weight matrix. |
required |
stable
|
bool
|
Whether to use the stable version of the Cholesky-based orthogonalization function, which adds a small positive diagonal element to ensure numerical stability. Default is False. |
False
|
Source code in orthogonium\reparametrizers.py
188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 |
|
forward(w)
¶
Apply Cholesky-based orthogonalization to the weight matrix.
This method constructs a symmetric positive definite matrix from the input weight matrix, performs Cholesky decomposition, and solves the triangular system to produce an orthogonal matrix. It mimics the results of the Gram-Schmidt process but with improved numerical stability.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
w
|
Tensor
|
The weight matrix to be orthogonalized. |
required |
Returns:
Type | Description |
---|---|
torch.Tensor: The orthogonalized weight matrix. |
Source code in orthogonium\reparametrizers.py
280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 |
|
BatchedExponentialOrthogonalization
¶
Bases: Module
Source code in orthogonium\reparametrizers.py
301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 |
|
__init__(weight_shape, niters=7)
¶
Initialize the BatchedExponentialOrthogonalization module.
This module orthogonalizes a weight matrix using the exponential map of a skew-symmetric matrix. By converting the matrix into a skew-symmetric form and applying the matrix exponential, it produces an orthogonal matrix. This approach is particularly useful in contexts where smooth transitions between matrices are required.
Non-square matrices are padded to the largest dimension to ensure that the matrix can be converted to a skew-symmetric matrix. The resulting matrix is cropped to the original dimension.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
weight_shape
|
tuple
|
The shape of the weight matrix. |
required |
niters
|
int
|
Number of iterations for the series expansion approximation of the matrix exponential. Default is 7. |
7
|
Source code in orthogonium\reparametrizers.py
302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 |
|
BatchedIdentity
¶
Bases: Module
Source code in orthogonium\reparametrizers.py
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
|
__init__(weight_shape)
¶
Class representing a batched identity matrix with a specific weight shape. The matrix is initialized based on the provided shape of the weights. It is a convenient utility for applications where identity-like operations are required in a batched manner.
Attributes:
Name | Type | Description |
---|---|---|
weight_shape |
Tuple[int, int]
|
A tuple representing the shape of the |
Parameters:
Name | Type | Description | Default |
---|---|---|---|
weight_shape
|
A tuple specifying the shape of the individual weight matrix. |
required |
Source code in orthogonium\reparametrizers.py
100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
|
BatchedPowerIteration
¶
Bases: Module
Source code in orthogonium\reparametrizers.py
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
|
__init__(weight_shape, power_it_niter=3, eps=1e-12)
¶
BatchedPowerIteration is a class that performs spectral normalization on weights using the power iteration method in a batched manner. It initializes singular vectors 'u' and 'v', which are used to approximate the largest singular value of the associated weight matrix during training. The L2 normalization is applied to stabilize these singular vector parameters.
Attributes:
Name | Type | Description |
---|---|---|
weight_shape |
tuple Shape of the weight tensor. Normalization is applied to the last two dimensions. |
|
power_it_niter |
int Number of iterations to perform for the power iteration method. |
|
eps |
float A small constant to ensure numerical stability during calculations. Used in the power iteration method to avoid dividing by zero. |
Source code in orthogonium\reparametrizers.py
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
|
BatchedQROrthogonalization
¶
Bases: Module
Source code in orthogonium\reparametrizers.py
347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 |
|
__init__(weight_shape)
¶
Initialize the BatchedQROrthogonalization module.
This module uses QR decomposition to orthogonalize a weight matrix in a batched manner.
It computes the orthogonal component (Q
) from the decomposition, ensuring that the
output satisfies orthogonality constraints.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
weight_shape
|
tuple
|
The shape of the weight matrix to be orthogonalized. |
required |
Source code in orthogonium\reparametrizers.py
348 349 350 351 352 353 354 355 356 357 358 359 |
|
forward(w)
¶
Perform QR decomposition to compute the orthogonalized weight matrix.
The QR decomposition splits the input matrix into an orthogonal matrix (Q
) and
an upper triangular matrix (R
). This module returns the orthogonal component.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
w
|
Tensor
|
The weight matrix to be orthogonalized. |
required |
Returns:
Type | Description |
---|---|
torch.Tensor: The orthogonalized weight matrix ( |
Source code in orthogonium\reparametrizers.py
361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 |
|
L2Normalize
¶
Bases: Module
Source code in orthogonium\reparametrizers.py
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
|
__init__(dtype, dim=None)
¶
A class that performs L2 normalization for the given input tensor.
L2 normalization is a process that normalizes the input over a specified dimension such that the sum of squares of the elements along that dimension equals 1. It ensures that the resulting tensor has a unit norm. This operation is widely used in machine learning and deep learning applications to standardize feature representations.
Attributes:
Name | Type | Description |
---|---|---|
dim |
Optional[int]
|
The specific dimension along which normalization is performed. If None, normalization is done over all dimensions. |
dtype |
Any
|
The data type of the tensor to be normalized. |
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dtype
|
The data type of the tensor to be normalized. |
required | |
dim
|
An optional integer specifying the dimension along which to normalize. If not provided, the input will be normalized globally across all dimensions. |
None
|
Source code in orthogonium\reparametrizers.py
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
|
OrthoParams
dataclass
¶
Represents the parameters and configurations used for orthogonalization and spectral normalization.
This class encapsulates the necessary modules and settings required for performing spectral normalization and orthogonalization of tensors in a parameterized way. It accommodates various implementations of normalizers and orthogonalization techniques to provide flexibility in their application. This way we can easily switch between different normalization techniques inside our layer despite that each normalization have different parameters.
Attributes:
Name | Type | Description |
---|---|---|
spectral_normalizer |
Callable[Tuple[int, ...], Module]
|
A callable
that produces a module for spectral normalization. Default is
configured to use BatchedPowerIteration with specific parameters.
This callable can be provided either as a |
orthogonalizer |
Callable[Tuple[int, ...], Module]
|
A callable
that produces a module for orthogonalization. Default is
configured to use BatchedBjorckOrthogonalization with specific
parameters. This callable can be provided either as a |
Source code in orthogonium\reparametrizers.py
390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 |
|