Skip to content

linear layers

OrthoLinear

Bases: Linear

Source code in orthogonium\layers\linear\ortho_linear.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
class OrthoLinear(nn.Linear):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        bias: bool = True,
        ortho_params: OrthoParams = OrthoParams(),
    ):
        """
        Initializes an orthogonal linear layer with customizable orthogonalization parameters.

        Attributes:
            in_features : int
                Number of input features.
            out_features : int
                Number of output features.
            bias : bool
                Whether to include a bias term in the layer. Default is True.
            ortho_params : OrthoParams
                Parameters for orthogonalization and spectral normalization. Default is the
                default instance of OrthoParams.

        Parameters:
            in_features : int
                The size of each input sample.
            out_features : int
                The size of each output sample.
            bias : bool
                Indicates if the layer should include a learnable bias parameter.
            ortho_params : OrthoParams
                An object containing orthogonalization and normalization configurations.

        Notes
        -----
        The layer is initialized with orthogonal weights using `torch.nn.init.orthogonal_`.
        Weight parameters are further parametrized for both spectral normalization and
        orthogonal constraints using the provided `OrthoParams` object.
        """
        super(OrthoLinear, self).__init__(in_features, out_features, bias=bias)
        torch.nn.init.orthogonal_(self.weight)
        parametrize.register_parametrization(
            self,
            "weight",
            ortho_params.spectral_normalizer(
                weight_shape=(self.out_features, self.in_features)
            ),
        )
        parametrize.register_parametrization(
            self, "weight", ortho_params.orthogonalizer(weight_shape=self.weight.shape)
        )

    def singular_values(self):
        svs = np.linalg.svd(
            self.weight.detach().cpu().numpy(), full_matrices=False, compute_uv=False
        )
        stable_rank = np.sum((np.mean(svs) ** 2)) / (svs.max() ** 2)
        return svs.min(), svs.max(), stable_rank

__init__(in_features, out_features, bias=True, ortho_params=OrthoParams())

Initializes an orthogonal linear layer with customizable orthogonalization parameters.

Attributes:

Name Type Description
in_features

int Number of input features.

out_features

int Number of output features.

bias

bool Whether to include a bias term in the layer. Default is True.

ortho_params

OrthoParams Parameters for orthogonalization and spectral normalization. Default is the default instance of OrthoParams.

Parameters:

Name Type Description Default
in_features

int The size of each input sample.

required
out_features

int The size of each output sample.

required
bias

bool Indicates if the layer should include a learnable bias parameter.

True
ortho_params

OrthoParams An object containing orthogonalization and normalization configurations.

OrthoParams()
Notes

The layer is initialized with orthogonal weights using torch.nn.init.orthogonal_. Weight parameters are further parametrized for both spectral normalization and orthogonal constraints using the provided OrthoParams object.

Source code in orthogonium\layers\linear\ortho_linear.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def __init__(
    self,
    in_features: int,
    out_features: int,
    bias: bool = True,
    ortho_params: OrthoParams = OrthoParams(),
):
    """
    Initializes an orthogonal linear layer with customizable orthogonalization parameters.

    Attributes:
        in_features : int
            Number of input features.
        out_features : int
            Number of output features.
        bias : bool
            Whether to include a bias term in the layer. Default is True.
        ortho_params : OrthoParams
            Parameters for orthogonalization and spectral normalization. Default is the
            default instance of OrthoParams.

    Parameters:
        in_features : int
            The size of each input sample.
        out_features : int
            The size of each output sample.
        bias : bool
            Indicates if the layer should include a learnable bias parameter.
        ortho_params : OrthoParams
            An object containing orthogonalization and normalization configurations.

    Notes
    -----
    The layer is initialized with orthogonal weights using `torch.nn.init.orthogonal_`.
    Weight parameters are further parametrized for both spectral normalization and
    orthogonal constraints using the provided `OrthoParams` object.
    """
    super(OrthoLinear, self).__init__(in_features, out_features, bias=bias)
    torch.nn.init.orthogonal_(self.weight)
    parametrize.register_parametrization(
        self,
        "weight",
        ortho_params.spectral_normalizer(
            weight_shape=(self.out_features, self.in_features)
        ),
    )
    parametrize.register_parametrization(
        self, "weight", ortho_params.orthogonalizer(weight_shape=self.weight.shape)
    )

UnitNormLinear

Bases: Linear

Source code in orthogonium\layers\linear\ortho_linear.py
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
class UnitNormLinear(nn.Linear):
    def __init__(
        self,
        *args,
        **kwargs,
    ):
        """
        A custom PyTorch Linear layer that ensures weights are normalized to unit norm along a specified dimension.

        This class extends the torch.nn.Linear module and modifies the weight
        matrix to maintain orthogonal initialization and unit norm
        normalization during training. In this specific case, each output can be viewed as the result of a 1-Lipschitz
        function. This means that the whole function in more than 1-Lipschitz but that each output taken independently
        is 1-Lipschitz.

        Attributes:
            weight: The learnable weight tensor with orthogonal initialization
                and enforced unit norm parametrization.

        Args:
            *args: Variable length positional arguments passed to the base
                Linear class.
            **kwargs: Variable length keyword arguments passed to the base
                Linear class.
        """
        super(UnitNormLinear, self).__init__(*args, **kwargs)
        torch.nn.init.orthogonal_(self.weight)
        parametrize.register_parametrization(
            self,
            "weight",
            L2Normalize(dtype=self.weight.dtype, dim=1),
        )

    def singular_values(self):
        svs = np.linalg.svd(
            self.weight.detach().cpu().numpy(), full_matrices=False, compute_uv=False
        )
        stable_rank = np.sum(np.mean(svs) ** 2) / (svs.max() ** 2)
        return svs.min(), svs.max(), stable_rank

__init__(*args, **kwargs)

A custom PyTorch Linear layer that ensures weights are normalized to unit norm along a specified dimension.

This class extends the torch.nn.Linear module and modifies the weight matrix to maintain orthogonal initialization and unit norm normalization during training. In this specific case, each output can be viewed as the result of a 1-Lipschitz function. This means that the whole function in more than 1-Lipschitz but that each output taken independently is 1-Lipschitz.

Attributes:

Name Type Description
weight

The learnable weight tensor with orthogonal initialization and enforced unit norm parametrization.

Parameters:

Name Type Description Default
*args

Variable length positional arguments passed to the base Linear class.

()
**kwargs

Variable length keyword arguments passed to the base Linear class.

{}
Source code in orthogonium\layers\linear\ortho_linear.py
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def __init__(
    self,
    *args,
    **kwargs,
):
    """
    A custom PyTorch Linear layer that ensures weights are normalized to unit norm along a specified dimension.

    This class extends the torch.nn.Linear module and modifies the weight
    matrix to maintain orthogonal initialization and unit norm
    normalization during training. In this specific case, each output can be viewed as the result of a 1-Lipschitz
    function. This means that the whole function in more than 1-Lipschitz but that each output taken independently
    is 1-Lipschitz.

    Attributes:
        weight: The learnable weight tensor with orthogonal initialization
            and enforced unit norm parametrization.

    Args:
        *args: Variable length positional arguments passed to the base
            Linear class.
        **kwargs: Variable length keyword arguments passed to the base
            Linear class.
    """
    super(UnitNormLinear, self).__init__(*args, **kwargs)
    torch.nn.init.orthogonal_(self.weight)
    parametrize.register_parametrization(
        self,
        "weight",
        L2Normalize(dtype=self.weight.dtype, dim=1),
    )