Skip to content

Progressive Growing of GANS (ProGAN)

View colab tutorial | View source | 📰 Paper

Progressive Growing GAN also known as ProGAN is an extension of the GAN training process that allows training generating models with stability that can produce large-high-quality images.

It involves training by starting with a very small image and then layer blocks are added gradually so that the output size of the generator model increases and the input size of the discriminator model increases until the desired image size is obtained. This approach has proven to be very effective in creating highly realistic, high-quality synthetic images.

It basically includes 4 steps: - Progressive growing (of model and layers)

  • Minibatch std on Discriminator

  • Normalization with PixelNorm

  • Equalized Learning Rate

Simplified view of ProGAN (Image source)

Here we can see in the above figure that Progressive Growing GAN involves using a generator and discriminator model with the traditional GAN structure and its starts with very small images, such as 4×4 pixels.

During training, it systematically adds new convolutional blocks to both the generator model and the discriminator model. This gradual addition of convolutional layers allows models to effectively learn coarse-level details early on and then learn even finer details, both on the generator and discriminator.

ProGAN goals: - Produce high-quality, high-resolution images. - Greater diversity of images in the output. - Improve stability in GANs. - Increase variation in the generated images

NETWORK ARCHITECTURE : ProGAN

GENERATOR NETWORK

A generator to incrementally size the output by starting with a very small image, then the blocks of layers added incrementally and increasing the input size of the discriminant model until the desired image size is obtained.

DISCRIMINATOR NETWORK

A discriminator for discerning real from generated images.

LOSS FUNCTIONS

ProGAN use one of the common loss functions in GANs, the Wasserstein loss function, also known as WGAN-GP from the paper Improved Training of Wasserstein GANs.

\[Loss_{G} = -D(x')$$ $$GP = (||\nabla D(ax' + (1-a)x))||_2 - 1)^2$$ $$Loss_{D} = -D(x) + D(x') + \lambda * GP\]

Where: - x' is the generated image. - x is an image from the training set. - D is the discriminator. - GP is a gradient penalty that helps stabilize training. - The a term in the gradient penalty refers to a tensor of random numbers between 0 and 1, chosen uniformly at random. - The parameter λ is common to set to 10.

Example

# Augmentare Imports
import augmentare
from augmentare.methods.gan import *

# Create GAN Generator
net_gen = PROGANGenerator(
    latent_size=128,
    in_channels=128,
    img_channels=3,
    alpha=1e-5,
    steps=4
)

# Create GAN Discriminator
net_dis = PROGANDiscriminator(
    in_channels=128,
    img_channels=3,
    alpha=1e-5,
    steps=4
)

# Optimizers and Loss functions
optimizer_gen = Adam(net_gen.parameters(), lr=1e-3, betas=(0.0, 0.999))
optimizer_dis = Adam(net_dis.parameters(), lr=1e-3, betas=(0.0, 0.999))
loss_fn_gen =  torch.cuda.amp.GradScaler()
loss_fn_dis =  torch.cuda.amp.GradScaler()

# Create GAN network
gan = PROGAN(
    net_gen,
    net_dis,
    optimizer_gen,
    optimizer_dis,
    loss_fn_gen,
    loss_fn_dis,
    device,
    latent_size=128
)

# Training the ProGAN network
gen_losses, dis_losses = gan.train(
    subset_a=dataloader,
    num_epochs=5,
    num_decay_epochs=None,
    num_classes = None,
    batch_size = [32, 32, 32, 16, 16, 16, 16, 8, 4],
    subset_b = None
)

# Sample images from the Generator
img_list = gan.generate_samples(
    nb_samples = 36,
    num_classes = None,
    real_image_a = None,
    real_image_b = None
)

Notebooks

PROGAN

A basic ProGAN class for synthesizing high resolution and high quality images via the incremental growing of the discriminator and the generator networks during the training process.

__init__(self,
         generator: augmentare.methods.gan.base.BaseGenerator,
         discriminator: augmentare.methods.gan.base.BaseDiscriminator,
         optimizer_gen: torch.optim.optimizer.Optimizer,
         optimizer_dis: torch.optim.optimizer.Optimizer,
         loss_fn_gen: Callable,
         loss_fn_dis: Callable,
         device,
         latent_size: Optional[int] = None)

generate_samples(self,
                 nb_samples: int,
                 num_classes=None,
                 real_image_a=None,
                 real_image_b=None)

Sample images from the generator.

Return

  • img_list

    • A list of generated images


train(self,
      subset_a: Union[, torch.utils.data.dataset.Dataset],
      num_epochs: int,
      num_decay_epochs=typing.Optional[int],
      num_classes=typing.Optional[int],
      batch_size=typing.Optional[int],
      subset_b=typing.Union[, torch.utils.data.dataset.Dataset, None])

Train both networks and return the losses.

Parameters

  • subset_a : Union[, torch.utils.data.dataset.Dataset]

    • Torch.tensor or Dataset

  • num_epochs : int

    • The number of epochs you want to train your ProGAN

  • batch_size : batch_size=typing.Optional[int]

    • Training batch size

Return

  • gen_losses, dis_losses

    • The losses of both the discriminator and generator


train_discriminator(self,
                    real_samples,
                    noise)

Train the discriminator one step and return the loss.

Parameters

  • real_samples : real_samples

    • True samples of your dataset

  • noise : noise

    • Noise for train discriminator

Return

  • dis_loss

    • The loss of the discriminator


train_generator(self,
                noise)

Train the generator one step and return the loss.

Parameters

  • noise : noise

    • Noise for train generator

Return

  • gen_loss

    • The loss of the generator


PROGANGenerator

A generator to incrementally size the output by starting with a very small image, then the blocks of layers added incrementally and increasing the input size of the discriminant model until the desired image size is obtained.

__init__(self,
         latent_size,
         in_channels,
         img_channels=3,
         alpha=1e-05,
         steps=4)

add_module(self,
           name: str,
           module: Optional[ForwardRef('Module')]) -> None

Adds a child module to the current module.


apply(self: ~T,
      fn: Callable[[ForwardRef('Module')], None]) -> ~T

Applies fn recursively to every submodule (as returned by .children()) as well as self. Typical use includes initializing the parameters of a model (see also :ref:nn-init-doc).


bfloat16(self: ~T) -> ~T

Casts all floating point parameters and buffers to bfloat16 datatype.


buffers(self,
        recurse: bool = True) -> Iterator[torch.Tensor]

Returns an iterator over module buffers.


children(self) -> Iterator[ForwardRef('Module')]

Returns an iterator over immediate children modules.


compile(self,
        args,
       
*kwargs)

Compile this Module's forward using :func:torch.compile.


cpu(self: ~T) -> ~T

Moves all model parameters and buffers to the CPU.


cuda(self: ~T,
     device: Union[int, torch.device, None] = None) -> ~T

Moves all model parameters and buffers to the GPU.


double(self: ~T) -> ~T

Casts all floating point parameters and buffers to double datatype.


eval(self: ~T) -> ~T

Sets the module in evaluation mode.


extra_repr(self) -> str

Set the extra representation of the module


float(self: ~T) -> ~T

Casts all floating point parameters and buffers to float datatype.


forward(self,
        noise,
        labels=None)

A forward function PROGANGenerator.


get_buffer(self,
           target: str) -> 'Tensor'

Returns the buffer given by target if it exists, otherwise throws an error.


get_extra_state(self) -> Any

Returns any extra state to include in the module's state_dict. Implement this and a corresponding :func:set_extra_state for your module if you need to store extra state. This function is called when building the module's state_dict().


get_parameter(self,
              target: str) -> 'Parameter'

Returns the parameter given by target if it exists, otherwise throws an error.


get_submodule(self,
              target: str) -> 'Module'

Returns the submodule given by target if it exists, otherwise throws an error.


half(self: ~T) -> ~T

Casts all floating point parameters and buffers to half datatype.


ipu(self: ~T,
    device: Union[int, torch.device, None] = None) -> ~T

Moves all model parameters and buffers to the IPU.


load_state_dict(self,
                state_dict: Mapping[str, Any],
                strict: bool = True,
                assign: bool = False)

Copies parameters and buffers from :attr:state_dict into this module and its descendants. If :attr:strict is True, then the keys of :attr:state_dict must exactly match the keys returned by this module's :meth:~torch.nn.Module.state_dict function.


modules(self) -> Iterator[ForwardRef('Module')]

Returns an iterator over all modules in the network.


named_buffers(self,
              prefix: str = '',
              recurse: bool = True,
              remove_duplicate: bool = True) -> Iterator[Tuple[str, torch.Tensor]]

Returns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.


named_children(self) -> Iterator[Tuple[str, ForwardRef('Module')]]

Returns an iterator over immediate children modules, yielding both the name of the module as well as the module itself.


named_modules(self,
              memo: Optional[Set[ForwardRef('Module')]] = None,
              prefix: str = '',
              remove_duplicate: bool = True)

Returns an iterator over all modules in the network, yielding both the name of the module as well as the module itself.


named_parameters(self,
                 prefix: str = '',
                 recurse: bool = True,
                 remove_duplicate: bool = True) -> Iterator[Tuple[str, torch.nn.parameter.Parameter]]

Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.


parameters(self,
           recurse: bool = True) -> Iterator[torch.nn.parameter.Parameter]

Returns an iterator over module parameters.


register_backward_hook(self,
                       hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor],
                       Union[Tuple[torch.Tensor, ...], torch.Tensor]],
                       Union[None, Tuple[torch.Tensor, ...], torch.Tensor]]) -> torch.utils.hooks.RemovableHandle

Registers a backward hook on the module.


register_buffer(self,
                name: str,
                tensor: Optional[torch.Tensor],
                persistent: bool = True) -> None

Adds a buffer to the module.


register_forward_hook(self,
                      hook: Union[Callable[[~T, Tuple[Any, ...], Any],
                      Optional[Any]],
                      Callable[[~T, Tuple[Any, ...],
                      Dict[str, Any], Any],
                      Optional[Any]]],
                      *,
                      prepend: bool = False,
                      with_kwargs: bool = False,
                      always_call: bool = False) -> torch.utils.hooks.RemovableHandle

Registers a forward hook on the module.


register_forward_pre_hook(self,
                          hook: Union[Callable[[~T, Tuple[Any, ...]],
                          Optional[Any]],
                          Callable[[~T, Tuple[Any, ...],
                          Dict[str, Any]],
                          Optional[Tuple[Any, Dict[str, Any]]]]],
                          *,
                          prepend: bool = False,
                          with_kwargs: bool = False) -> torch.utils.hooks.RemovableHandle

Registers a forward pre-hook on the module.


register_full_backward_hook(self,
                            hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor],
                            Union[Tuple[torch.Tensor, ...], torch.Tensor]],
                            Union[None, Tuple[torch.Tensor, ...], torch.Tensor]],
                            prepend: bool = False) -> torch.utils.hooks.RemovableHandle

Registers a backward hook on the module.


register_full_backward_pre_hook(self,
                                hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor]],
                                Union[None, Tuple[torch.Tensor, ...], torch.Tensor]],
                                prepend: bool = False) -> torch.utils.hooks.RemovableHandle

Registers a backward pre-hook on the module.


register_load_state_dict_post_hook(self,
                                   hook)

Registers a post hook to be run after module's load_state_dict is called.


register_module(self,
                name: str,
                module: Optional[ForwardRef('Module')]) -> None

Alias for :func:add_module.


register_parameter(self,
                   name: str,
                   param: Optional[torch.nn.parameter.Parameter]) -> None

Adds a parameter to the module.


register_state_dict_pre_hook(self,
                             hook)

These hooks will be called with arguments: self, prefix, and keep_vars before calling state_dict on self. The registered hooks can be used to perform pre-processing before the state_dict call is made.


requires_grad_(self: ~T,
               requires_grad: bool = True) -> ~T

Change if autograd should record operations on parameters in this module.


set_extra_state(self,
                state: Any)

This function is called from :func:load_state_dict to handle any extra state found within the state_dict. Implement this function and a corresponding :func:get_extra_state for your module if you need to store extra state within its state_dict.


share_memory(self: ~T) -> ~T

See :meth:torch.Tensor.share_memory_


state_dict(self,
           *args,
           destination=None,
           prefix='',
           keep_vars=False)

Returns a dictionary containing references to the whole state of the module.


to(self,
   args,
  
*kwargs)

Moves and/or casts the parameters and buffers.


to_empty(self: ~T,
         *,
         device: Union[str, torch.device],
         recurse: bool = True) -> ~T

Moves the parameters and buffers to the specified device without copying storage.


train(self: ~T,
      mode: bool = True) -> ~T

Sets the module in training mode.


type(self: ~T,
     dst_type: Union[torch.dtype, str]) -> ~T

Casts all parameters and buffers to :attr:dst_type.


xpu(self: ~T,
    device: Union[int, torch.device, None] = None) -> ~T

Moves all model parameters and buffers to the XPU.


zero_grad(self,
          set_to_none: bool = True) -> None

Resets gradients of all model parameters. See similar function under :class:torch.optim.Optimizer for more context.


PROGANDiscriminator

A discriminator for discerning real from generated images.

__init__(self,
         in_channels,
         img_channels=3,
         alpha=1e-05,
         steps=4)

add_module(self,
           name: str,
           module: Optional[ForwardRef('Module')]) -> None

Adds a child module to the current module.


apply(self: ~T,
      fn: Callable[[ForwardRef('Module')], None]) -> ~T

Applies fn recursively to every submodule (as returned by .children()) as well as self. Typical use includes initializing the parameters of a model (see also :ref:nn-init-doc).


bfloat16(self: ~T) -> ~T

Casts all floating point parameters and buffers to bfloat16 datatype.


buffers(self,
        recurse: bool = True) -> Iterator[torch.Tensor]

Returns an iterator over module buffers.


children(self) -> Iterator[ForwardRef('Module')]

Returns an iterator over immediate children modules.


compile(self,
        args,
       
*kwargs)

Compile this Module's forward using :func:torch.compile.


cpu(self: ~T) -> ~T

Moves all model parameters and buffers to the CPU.


cuda(self: ~T,
     device: Union[int, torch.device, None] = None) -> ~T

Moves all model parameters and buffers to the GPU.


double(self: ~T) -> ~T

Casts all floating point parameters and buffers to double datatype.


eval(self: ~T) -> ~T

Sets the module in evaluation mode.


extra_repr(self) -> str

Set the extra representation of the module


float(self: ~T) -> ~T

Casts all floating point parameters and buffers to float datatype.


forward(self,
        noise,
        labels=None)

A forward function PROGANDiscriminator.


get_buffer(self,
           target: str) -> 'Tensor'

Returns the buffer given by target if it exists, otherwise throws an error.


get_extra_state(self) -> Any

Returns any extra state to include in the module's state_dict. Implement this and a corresponding :func:set_extra_state for your module if you need to store extra state. This function is called when building the module's state_dict().


get_parameter(self,
              target: str) -> 'Parameter'

Returns the parameter given by target if it exists, otherwise throws an error.


get_submodule(self,
              target: str) -> 'Module'

Returns the submodule given by target if it exists, otherwise throws an error.


half(self: ~T) -> ~T

Casts all floating point parameters and buffers to half datatype.


ipu(self: ~T,
    device: Union[int, torch.device, None] = None) -> ~T

Moves all model parameters and buffers to the IPU.


load_state_dict(self,
                state_dict: Mapping[str, Any],
                strict: bool = True,
                assign: bool = False)

Copies parameters and buffers from :attr:state_dict into this module and its descendants. If :attr:strict is True, then the keys of :attr:state_dict must exactly match the keys returned by this module's :meth:~torch.nn.Module.state_dict function.


modules(self) -> Iterator[ForwardRef('Module')]

Returns an iterator over all modules in the network.


named_buffers(self,
              prefix: str = '',
              recurse: bool = True,
              remove_duplicate: bool = True) -> Iterator[Tuple[str, torch.Tensor]]

Returns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.


named_children(self) -> Iterator[Tuple[str, ForwardRef('Module')]]

Returns an iterator over immediate children modules, yielding both the name of the module as well as the module itself.


named_modules(self,
              memo: Optional[Set[ForwardRef('Module')]] = None,
              prefix: str = '',
              remove_duplicate: bool = True)

Returns an iterator over all modules in the network, yielding both the name of the module as well as the module itself.


named_parameters(self,
                 prefix: str = '',
                 recurse: bool = True,
                 remove_duplicate: bool = True) -> Iterator[Tuple[str, torch.nn.parameter.Parameter]]

Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.


parameters(self,
           recurse: bool = True) -> Iterator[torch.nn.parameter.Parameter]

Returns an iterator over module parameters.


register_backward_hook(self,
                       hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor],
                       Union[Tuple[torch.Tensor, ...], torch.Tensor]],
                       Union[None, Tuple[torch.Tensor, ...], torch.Tensor]]) -> torch.utils.hooks.RemovableHandle

Registers a backward hook on the module.


register_buffer(self,
                name: str,
                tensor: Optional[torch.Tensor],
                persistent: bool = True) -> None

Adds a buffer to the module.


register_forward_hook(self,
                      hook: Union[Callable[[~T, Tuple[Any, ...], Any],
                      Optional[Any]],
                      Callable[[~T, Tuple[Any, ...],
                      Dict[str, Any], Any],
                      Optional[Any]]],
                      *,
                      prepend: bool = False,
                      with_kwargs: bool = False,
                      always_call: bool = False) -> torch.utils.hooks.RemovableHandle

Registers a forward hook on the module.


register_forward_pre_hook(self,
                          hook: Union[Callable[[~T, Tuple[Any, ...]],
                          Optional[Any]],
                          Callable[[~T, Tuple[Any, ...],
                          Dict[str, Any]],
                          Optional[Tuple[Any, Dict[str, Any]]]]],
                          *,
                          prepend: bool = False,
                          with_kwargs: bool = False) -> torch.utils.hooks.RemovableHandle

Registers a forward pre-hook on the module.


register_full_backward_hook(self,
                            hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor],
                            Union[Tuple[torch.Tensor, ...], torch.Tensor]],
                            Union[None, Tuple[torch.Tensor, ...], torch.Tensor]],
                            prepend: bool = False) -> torch.utils.hooks.RemovableHandle

Registers a backward hook on the module.


register_full_backward_pre_hook(self,
                                hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor]],
                                Union[None, Tuple[torch.Tensor, ...], torch.Tensor]],
                                prepend: bool = False) -> torch.utils.hooks.RemovableHandle

Registers a backward pre-hook on the module.


register_load_state_dict_post_hook(self,
                                   hook)

Registers a post hook to be run after module's load_state_dict is called.


register_module(self,
                name: str,
                module: Optional[ForwardRef('Module')]) -> None

Alias for :func:add_module.


register_parameter(self,
                   name: str,
                   param: Optional[torch.nn.parameter.Parameter]) -> None

Adds a parameter to the module.


register_state_dict_pre_hook(self,
                             hook)

These hooks will be called with arguments: self, prefix, and keep_vars before calling state_dict on self. The registered hooks can be used to perform pre-processing before the state_dict call is made.


requires_grad_(self: ~T,
               requires_grad: bool = True) -> ~T

Change if autograd should record operations on parameters in this module.


set_extra_state(self,
                state: Any)

This function is called from :func:load_state_dict to handle any extra state found within the state_dict. Implement this function and a corresponding :func:get_extra_state for your module if you need to store extra state within its state_dict.


share_memory(self: ~T) -> ~T

See :meth:torch.Tensor.share_memory_


state_dict(self,
           *args,
           destination=None,
           prefix='',
           keep_vars=False)

Returns a dictionary containing references to the whole state of the module.


to(self,
   args,
  
*kwargs)

Moves and/or casts the parameters and buffers.


to_empty(self: ~T,
         *,
         device: Union[str, torch.device],
         recurse: bool = True) -> ~T

Moves the parameters and buffers to the specified device without copying storage.


train(self: ~T,
      mode: bool = True) -> ~T

Sets the module in training mode.


type(self: ~T,
     dst_type: Union[torch.dtype, str]) -> ~T

Casts all parameters and buffers to :attr:dst_type.


xpu(self: ~T,
    device: Union[int, torch.device, None] = None) -> ~T

Moves all model parameters and buffers to the XPU.


zero_grad(self,
          set_to_none: bool = True) -> None

Resets gradients of all model parameters. See similar function under :class:torch.optim.Optimizer for more context.


Progressive Growing of GANs for Improved Quality, Stability, and Variation by Tero Karras & al (2018).