Skip to content

Conditional GAN (CGAN)

View colab tutorial | View source | 📰 Paper

Conditional GAN Architecture

NETWORK ARCHITECTURE : CGAN

GENERATOR NETWORK

The CGAN Generator is parameterized to learn and produce realistic samples for each label in the training dataset. It receives an input noise vector of size \(batch\ size \times latent\ size\). It outputs a tensor of \(batch\ size \times channel \times height \times width\) corresponding to a batch of generated image samples.

The intermediate layers use the LeakyReLU activation function to kill gradients and slow down convergence. We can also use any other activation to ensure a good gradation flow. The last layer uses the Tanh activation to constrain the pixel values ​​to the range of \((- 1 \to 1)\).

DISCRIMINATOR NETWORK

The CGAN Discriminator learns to distinguish fake and real samples, given the label information. It has a symmetric architecture to the generator. It maps the image with a confidence score to classify whether the image is real (i.e. comes from the dataset) or fake (i.e. sampled by the generator)

We use the LeakyReLU activation for Discriminator. We also use Dropout activation, it's an effective technique for regularization and preventing the co-adaptation of neurons.

The last layer of CGAN's Discriminator has a Sigmoid layer that makes the confidence score between \((0 \to 1)\) and allows the confidence score to be easily interpreted in terms of the probability that the image is real. However, this interpretation is restricted only to the Minimax Loss proposed in the original GAN paper, and losses such as the Wasserstein Loss require no such interpretation. However, if required, one can easily set last layer activation to Sigmoid by passing it as a parameter during initialization time.

Example

# Augmentare Imports
import augmentare
from augmentare.methods.gan import *

# Create GAN Generator
net_gen = CGANGenerator(
    latent_size=100,
    num_classes=10,
    image_shape=image_shape
)

# Create GAN Discriminator
net_dis = CGANDiscriminator(
    num_classes=10,
    image_shape=image_shape
)

# Optimizers and Loss functions
optimizer_gen = Adam(net_gen.parameters(), lr=0.0001)
optimizer_dis = Adam(net_dis.parameters(), lr=0.0001)
loss_fn_gen =  nn.BCELoss()
loss_fn_dis =  nn.BCELoss()

# Create GAN network
gan = CGAN(
    net_gen,
    net_dis,
    optimizer_gen,
    optimizer_dis,
    loss_fn_gen,
    loss_fn_dis,
    device,
    latent_size=100,
    init_weights=False
)

# Training the CGAN network
gen_losses, dis_losses = gan.train(
    subset_a=dataloader,
    num_epochs=20,
    num_decay_epochs = None,
    num_classes=10,
    batch_size=32,
    subset_b = None
)

# Sample images from the Generator
img_list = gan.generate_samples(
    nb_samples=32,
    num_classes=8,
    real_image_a = None,
    real_image_b = None
)

Notebooks

CGAN

A basic CGAN class for generating images with the condition.

__init__(self,
         generator: augmentare.methods.gan.base.BaseGenerator,
         discriminator: augmentare.methods.gan.base.BaseDiscriminator,
         optimizer_gen: torch.optim.optimizer.Optimizer,
         optimizer_dis: torch.optim.optimizer.Optimizer,
         loss_fn_gen: Callable,
         loss_fn_dis: Callable,
         device,
         latent_size: Optional[int] = None,
         init_weights: bool = True)

generate_samples(self,
                 nb_samples: int,
                 num_classes=typing.Optional[int],
                 real_image_a=None,
                 real_image_b=None)

Sample images from the generator.

Return

  • img_list

    • A list of generated images (one image of one class is successively a list)


train(self,
      subset_a: Union[, torch.utils.data.dataset.Dataset],
      num_epochs: int,
      num_decay_epochs=typing.Optional[int],
      num_classes=typing.Optional[int],
      batch_size=typing.Optional[int],
      subset_b=typing.Union[, torch.utils.data.dataset.Dataset, None])

Train both networks and return the losses.

Parameters

  • subset_a : Union[, torch.utils.data.dataset.Dataset]

    • Torch.tensor or Dataset

  • num_epochs : int

    • The number of epochs you want to train your CGan

  • num_classes : num_classes=typing.Optional[int]

    • Number of classes in dataset

  • batch_size : batch_size=typing.Optional[int]

    • Training batch size

Return

  • gen_losses, dis_losses

    • The losses of both the generator and discriminator


train_discriminator(self,
                    real_samples,
                    num_classes,
                    batch_size)

Train the discriminator one step and return the loss.

Parameters

  • real_samples : real_samples

    • True samples of your dataset

  • num_classes : num_classes

    • Number of classes in dataset

  • batch_size : batch_size

    • Training batch size

Return

  • dis_loss

    • The loss of the discriminator


train_generator(self,
                num_classes,
                batch_size)

Train the generator one step and return the loss.

Parameters

  • num_classes : num_classes

    • Number of classes in dataset

  • batch_size : batch_size

    • Training batch size

Return

  • gen_loss

    • The loss of the generator


CGANGenerator

A generator is parameterized to learn and produce realistic samples for each label in the training dataset.

__init__(self,
         latent_size,
         num_classes,
         image_shape)

add_module(self,
           name: str,
           module: Optional[ForwardRef('Module')]) -> None

Adds a child module to the current module.


apply(self: ~T,
      fn: Callable[[ForwardRef('Module')], None]) -> ~T

Applies fn recursively to every submodule (as returned by .children()) as well as self. Typical use includes initializing the parameters of a model (see also :ref:nn-init-doc).


bfloat16(self: ~T) -> ~T

Casts all floating point parameters and buffers to bfloat16 datatype.


buffers(self,
        recurse: bool = True) -> Iterator[torch.Tensor]

Returns an iterator over module buffers.


children(self) -> Iterator[ForwardRef('Module')]

Returns an iterator over immediate children modules.


compile(self,
        args,
       
*kwargs)

Compile this Module's forward using :func:torch.compile.


cpu(self: ~T) -> ~T

Moves all model parameters and buffers to the CPU.


cuda(self: ~T,
     device: Union[int, torch.device, None] = None) -> ~T

Moves all model parameters and buffers to the GPU.


double(self: ~T) -> ~T

Casts all floating point parameters and buffers to double datatype.


eval(self: ~T) -> ~T

Sets the module in evaluation mode.


extra_repr(self) -> str

Set the extra representation of the module


float(self: ~T) -> ~T

Casts all floating point parameters and buffers to float datatype.


forward(self,
        noise,
        labels=typing.Optional[])

A forward function CGANGenerator.


get_buffer(self,
           target: str) -> 'Tensor'

Returns the buffer given by target if it exists, otherwise throws an error.


get_extra_state(self) -> Any

Returns any extra state to include in the module's state_dict. Implement this and a corresponding :func:set_extra_state for your module if you need to store extra state. This function is called when building the module's state_dict().


get_parameter(self,
              target: str) -> 'Parameter'

Returns the parameter given by target if it exists, otherwise throws an error.


get_submodule(self,
              target: str) -> 'Module'

Returns the submodule given by target if it exists, otherwise throws an error.


half(self: ~T) -> ~T

Casts all floating point parameters and buffers to half datatype.


ipu(self: ~T,
    device: Union[int, torch.device, None] = None) -> ~T

Moves all model parameters and buffers to the IPU.


load_state_dict(self,
                state_dict: Mapping[str, Any],
                strict: bool = True,
                assign: bool = False)

Copies parameters and buffers from :attr:state_dict into this module and its descendants. If :attr:strict is True, then the keys of :attr:state_dict must exactly match the keys returned by this module's :meth:~torch.nn.Module.state_dict function.


modules(self) -> Iterator[ForwardRef('Module')]

Returns an iterator over all modules in the network.


named_buffers(self,
              prefix: str = '',
              recurse: bool = True,
              remove_duplicate: bool = True) -> Iterator[Tuple[str, torch.Tensor]]

Returns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.


named_children(self) -> Iterator[Tuple[str, ForwardRef('Module')]]

Returns an iterator over immediate children modules, yielding both the name of the module as well as the module itself.


named_modules(self,
              memo: Optional[Set[ForwardRef('Module')]] = None,
              prefix: str = '',
              remove_duplicate: bool = True)

Returns an iterator over all modules in the network, yielding both the name of the module as well as the module itself.


named_parameters(self,
                 prefix: str = '',
                 recurse: bool = True,
                 remove_duplicate: bool = True) -> Iterator[Tuple[str, torch.nn.parameter.Parameter]]

Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.


parameters(self,
           recurse: bool = True) -> Iterator[torch.nn.parameter.Parameter]

Returns an iterator over module parameters.


register_backward_hook(self,
                       hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor],
                       Union[Tuple[torch.Tensor, ...], torch.Tensor]],
                       Union[None, Tuple[torch.Tensor, ...], torch.Tensor]]) -> torch.utils.hooks.RemovableHandle

Registers a backward hook on the module.


register_buffer(self,
                name: str,
                tensor: Optional[torch.Tensor],
                persistent: bool = True) -> None

Adds a buffer to the module.


register_forward_hook(self,
                      hook: Union[Callable[[~T, Tuple[Any, ...], Any],
                      Optional[Any]],
                      Callable[[~T, Tuple[Any, ...],
                      Dict[str, Any], Any],
                      Optional[Any]]],
                      *,
                      prepend: bool = False,
                      with_kwargs: bool = False,
                      always_call: bool = False) -> torch.utils.hooks.RemovableHandle

Registers a forward hook on the module.


register_forward_pre_hook(self,
                          hook: Union[Callable[[~T, Tuple[Any, ...]],
                          Optional[Any]],
                          Callable[[~T, Tuple[Any, ...],
                          Dict[str, Any]],
                          Optional[Tuple[Any, Dict[str, Any]]]]],
                          *,
                          prepend: bool = False,
                          with_kwargs: bool = False) -> torch.utils.hooks.RemovableHandle

Registers a forward pre-hook on the module.


register_full_backward_hook(self,
                            hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor],
                            Union[Tuple[torch.Tensor, ...], torch.Tensor]],
                            Union[None, Tuple[torch.Tensor, ...], torch.Tensor]],
                            prepend: bool = False) -> torch.utils.hooks.RemovableHandle

Registers a backward hook on the module.


register_full_backward_pre_hook(self,
                                hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor]],
                                Union[None, Tuple[torch.Tensor, ...], torch.Tensor]],
                                prepend: bool = False) -> torch.utils.hooks.RemovableHandle

Registers a backward pre-hook on the module.


register_load_state_dict_post_hook(self,
                                   hook)

Registers a post hook to be run after module's load_state_dict is called.


register_module(self,
                name: str,
                module: Optional[ForwardRef('Module')]) -> None

Alias for :func:add_module.


register_parameter(self,
                   name: str,
                   param: Optional[torch.nn.parameter.Parameter]) -> None

Adds a parameter to the module.


register_state_dict_pre_hook(self,
                             hook)

These hooks will be called with arguments: self, prefix, and keep_vars before calling state_dict on self. The registered hooks can be used to perform pre-processing before the state_dict call is made.


requires_grad_(self: ~T,
               requires_grad: bool = True) -> ~T

Change if autograd should record operations on parameters in this module.


set_extra_state(self,
                state: Any)

This function is called from :func:load_state_dict to handle any extra state found within the state_dict. Implement this function and a corresponding :func:get_extra_state for your module if you need to store extra state within its state_dict.


share_memory(self: ~T) -> ~T

See :meth:torch.Tensor.share_memory_


state_dict(self,
           *args,
           destination=None,
           prefix='',
           keep_vars=False)

Returns a dictionary containing references to the whole state of the module.


to(self,
   args,
  
*kwargs)

Moves and/or casts the parameters and buffers.


to_empty(self: ~T,
         *,
         device: Union[str, torch.device],
         recurse: bool = True) -> ~T

Moves the parameters and buffers to the specified device without copying storage.


train(self: ~T,
      mode: bool = True) -> ~T

Sets the module in training mode.


type(self: ~T,
     dst_type: Union[torch.dtype, str]) -> ~T

Casts all parameters and buffers to :attr:dst_type.


xpu(self: ~T,
    device: Union[int, torch.device, None] = None) -> ~T

Moves all model parameters and buffers to the XPU.


zero_grad(self,
          set_to_none: bool = True) -> None

Resets gradients of all model parameters. See similar function under :class:torch.optim.Optimizer for more context.


CGANDiscriminator

A discriminator learns to distinguish fake and real samples, given the label information. Output activation is Sigmoid.

__init__(self,
         num_classes,
         image_shape)

add_module(self,
           name: str,
           module: Optional[ForwardRef('Module')]) -> None

Adds a child module to the current module.


apply(self: ~T,
      fn: Callable[[ForwardRef('Module')], None]) -> ~T

Applies fn recursively to every submodule (as returned by .children()) as well as self. Typical use includes initializing the parameters of a model (see also :ref:nn-init-doc).


bfloat16(self: ~T) -> ~T

Casts all floating point parameters and buffers to bfloat16 datatype.


buffers(self,
        recurse: bool = True) -> Iterator[torch.Tensor]

Returns an iterator over module buffers.


children(self) -> Iterator[ForwardRef('Module')]

Returns an iterator over immediate children modules.


compile(self,
        args,
       
*kwargs)

Compile this Module's forward using :func:torch.compile.


cpu(self: ~T) -> ~T

Moves all model parameters and buffers to the CPU.


cuda(self: ~T,
     device: Union[int, torch.device, None] = None) -> ~T

Moves all model parameters and buffers to the GPU.


double(self: ~T) -> ~T

Casts all floating point parameters and buffers to double datatype.


eval(self: ~T) -> ~T

Sets the module in evaluation mode.


extra_repr(self) -> str

Set the extra representation of the module


float(self: ~T) -> ~T

Casts all floating point parameters and buffers to float datatype.


forward(self,
        noise,
        labels=typing.Optional[])

A forward function CGANDiscriminator.


get_buffer(self,
           target: str) -> 'Tensor'

Returns the buffer given by target if it exists, otherwise throws an error.


get_extra_state(self) -> Any

Returns any extra state to include in the module's state_dict. Implement this and a corresponding :func:set_extra_state for your module if you need to store extra state. This function is called when building the module's state_dict().


get_parameter(self,
              target: str) -> 'Parameter'

Returns the parameter given by target if it exists, otherwise throws an error.


get_submodule(self,
              target: str) -> 'Module'

Returns the submodule given by target if it exists, otherwise throws an error.


half(self: ~T) -> ~T

Casts all floating point parameters and buffers to half datatype.


ipu(self: ~T,
    device: Union[int, torch.device, None] = None) -> ~T

Moves all model parameters and buffers to the IPU.


load_state_dict(self,
                state_dict: Mapping[str, Any],
                strict: bool = True,
                assign: bool = False)

Copies parameters and buffers from :attr:state_dict into this module and its descendants. If :attr:strict is True, then the keys of :attr:state_dict must exactly match the keys returned by this module's :meth:~torch.nn.Module.state_dict function.


modules(self) -> Iterator[ForwardRef('Module')]

Returns an iterator over all modules in the network.


named_buffers(self,
              prefix: str = '',
              recurse: bool = True,
              remove_duplicate: bool = True) -> Iterator[Tuple[str, torch.Tensor]]

Returns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.


named_children(self) -> Iterator[Tuple[str, ForwardRef('Module')]]

Returns an iterator over immediate children modules, yielding both the name of the module as well as the module itself.


named_modules(self,
              memo: Optional[Set[ForwardRef('Module')]] = None,
              prefix: str = '',
              remove_duplicate: bool = True)

Returns an iterator over all modules in the network, yielding both the name of the module as well as the module itself.


named_parameters(self,
                 prefix: str = '',
                 recurse: bool = True,
                 remove_duplicate: bool = True) -> Iterator[Tuple[str, torch.nn.parameter.Parameter]]

Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.


parameters(self,
           recurse: bool = True) -> Iterator[torch.nn.parameter.Parameter]

Returns an iterator over module parameters.


register_backward_hook(self,
                       hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor],
                       Union[Tuple[torch.Tensor, ...], torch.Tensor]],
                       Union[None, Tuple[torch.Tensor, ...], torch.Tensor]]) -> torch.utils.hooks.RemovableHandle

Registers a backward hook on the module.


register_buffer(self,
                name: str,
                tensor: Optional[torch.Tensor],
                persistent: bool = True) -> None

Adds a buffer to the module.


register_forward_hook(self,
                      hook: Union[Callable[[~T, Tuple[Any, ...], Any],
                      Optional[Any]],
                      Callable[[~T, Tuple[Any, ...],
                      Dict[str, Any], Any],
                      Optional[Any]]],
                      *,
                      prepend: bool = False,
                      with_kwargs: bool = False,
                      always_call: bool = False) -> torch.utils.hooks.RemovableHandle

Registers a forward hook on the module.


register_forward_pre_hook(self,
                          hook: Union[Callable[[~T, Tuple[Any, ...]],
                          Optional[Any]],
                          Callable[[~T, Tuple[Any, ...],
                          Dict[str, Any]],
                          Optional[Tuple[Any, Dict[str, Any]]]]],
                          *,
                          prepend: bool = False,
                          with_kwargs: bool = False) -> torch.utils.hooks.RemovableHandle

Registers a forward pre-hook on the module.


register_full_backward_hook(self,
                            hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor],
                            Union[Tuple[torch.Tensor, ...], torch.Tensor]],
                            Union[None, Tuple[torch.Tensor, ...], torch.Tensor]],
                            prepend: bool = False) -> torch.utils.hooks.RemovableHandle

Registers a backward hook on the module.


register_full_backward_pre_hook(self,
                                hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor]],
                                Union[None, Tuple[torch.Tensor, ...], torch.Tensor]],
                                prepend: bool = False) -> torch.utils.hooks.RemovableHandle

Registers a backward pre-hook on the module.


register_load_state_dict_post_hook(self,
                                   hook)

Registers a post hook to be run after module's load_state_dict is called.


register_module(self,
                name: str,
                module: Optional[ForwardRef('Module')]) -> None

Alias for :func:add_module.


register_parameter(self,
                   name: str,
                   param: Optional[torch.nn.parameter.Parameter]) -> None

Adds a parameter to the module.


register_state_dict_pre_hook(self,
                             hook)

These hooks will be called with arguments: self, prefix, and keep_vars before calling state_dict on self. The registered hooks can be used to perform pre-processing before the state_dict call is made.


requires_grad_(self: ~T,
               requires_grad: bool = True) -> ~T

Change if autograd should record operations on parameters in this module.


set_extra_state(self,
                state: Any)

This function is called from :func:load_state_dict to handle any extra state found within the state_dict. Implement this function and a corresponding :func:get_extra_state for your module if you need to store extra state within its state_dict.


share_memory(self: ~T) -> ~T

See :meth:torch.Tensor.share_memory_


state_dict(self,
           *args,
           destination=None,
           prefix='',
           keep_vars=False)

Returns a dictionary containing references to the whole state of the module.


to(self,
   args,
  
*kwargs)

Moves and/or casts the parameters and buffers.


to_empty(self: ~T,
         *,
         device: Union[str, torch.device],
         recurse: bool = True) -> ~T

Moves the parameters and buffers to the specified device without copying storage.


train(self: ~T,
      mode: bool = True) -> ~T

Sets the module in training mode.


type(self: ~T,
     dst_type: Union[torch.dtype, str]) -> ~T

Casts all parameters and buffers to :attr:dst_type.


xpu(self: ~T,
    device: Union[int, torch.device, None] = None) -> ~T

Moves all model parameters and buffers to the XPU.


zero_grad(self,
          set_to_none: bool = True) -> None

Resets gradients of all model parameters. See similar function under :class:torch.optim.Optimizer for more context.


Conditional Generative Adversarial Nets by Mehdi Mirza & Simon Osindero (2014).