Skip to content

Cycle GAN

View colab tutorial | View source | 📰 Paper

Have you ever had the dark secret of turning a horse into a zebra? CycleGAN was developed to do just that. Learn how to turn a horse into a zebra and much more.

Back to the story of the dark secret, how are we going to do that?

We'll climb over the neighbor's fence in the middle of the night to paint a horse with stripes. Snap a snap of a horse before we start our fucking act. Then paint the horses quickly before their owners notice. Take a photo of the zebra-striped horse's output, then run to the fence before the homeowner's pit bulls spot you. And you'll have to keep doing that over and over again until you have enough example items in your database to train your neural network.

Forget it because you have CycleGAN. You will build a generator like the Pix2Pix architecture, which the GAN will train to be a Generator to turn a horse into a zebra. And then you build a Generator (again based on the Pix2Pix architecture) for a second inverse GAN that is supposed to take a picture of a zebra and turn it into an image of a horse.

Let's look at the image below, which shows the first half of CycleGAN trying to create a fake zebra from a horse. The second half of CycleGAN tries to create a fake horse from a zebra. Both halves include loss of cyclic consistency trying to make the output of the inverting generator match the input of the non-inverting generator.

Simplified view of CycleGAN architecture (Image source)

NETWORK ARCHITECTURE : CycleGAN

GENERATOR NETWORK

The CycleGAN Generator has 3 components:

  1. A downsampling network: It is composed of 3 convolutional layers (together with the regular padding, normalization and activation layers).
  2. A chain of residual networks built using the Residual Block. You can try to vary the ResidualBlock parameter and see the results.
  3. A upsampling network: It is composed of 3 transposed convolutional layers.

In CycleGAN Generator, we shall be using Instance Norm instead of Batch Norm and finally swap the Zero Padding of the Convolutional Layer with Reflection Padding.

DISCRIMINATOR NETWORK

The CycleGAN Discriminator is like the standard DCGAN Discriminator. The only difference is the Instance Normalization used.

LOSS FUNCTIONS

The Generator Loss is composed of 3 parts. They are described below:

  1. GAN Loss: It is the standard generator loss of the Least Squares GAN. We use the functional forms of the losses to implement this part. \(\(L_{GAN} = \frac{1}{2} \times ((D_A(G_{B2A}(Image_B)) - 1)^2 + (D_B(G_{A2B}(Image_A)) - 1)^2)\)\)
  2. Identity Loss: It computes the similarity of a real image of type B and a fake image B generated from image A and vice versa. The similarity is measured using the \(L_1\) Loss. \(\(L_{identity} = \frac{1}{2} \times (||G_{B2A}(Image_B) - Image_A||_1 + ||G_{A2B}(Image_A) - Image_B||_1)\)\)
  3. Cycle Consistency Loss: This loss computes the similarity of the original image and the image generated by a composition of the 2 generators. This allows cyclegan to deak with unpaired images. We reconstruct the original image and try to minimize the \(L_1\) norm between the original images and this reconstructed image. \(\(L_{cycle\_consistency} = \frac{1}{2} \times (||G_{B2A}(G_{A2B}(Image_A)) - Image_A||_1 + ||G_{A2B}(G_{B2A}(Image_B)) - Image_B||_1)\)\)

The Discriminator as mentioned before is same as the normal DCGAN Discriminator. As such even the loss function for that is same as that of the standard GAN:

\[L_{GAN} = \frac{1}{2} \times (((D_A(Image_A) - 1)^2 - (D_A(G_{B2A}(Image_B))^2) + ((D_B(Image_B) - 1)^2 - (D_B(G_{A2B}(Image_A))^2))\]

Example

# Augmentare Imports
import augmentare
from augmentare.methods.gan import *

# Create GAN Generator
net_gen = CYCLEGANGenerator()

# Create GAN Discriminator
net_dis = CYCLEGANDiscriminator()

# Optimizers and Loss functions
optimizer_gen = Adam(net_gen.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_dis = Adam(net_dis.parameters(), lr=0.0002, betas=(0.5, 0.999))
loss_fn_gen =  nn.L1Loss()
loss_fn_dis =  nn.L1Loss()

# Create GAN network
gan = CYCLEGAN(
    net_gen,
    net_dis,
    optimizer_gen,
    optimizer_dis,
    loss_fn_gen,
    loss_fn_dis,
    device,
    latent_size=None
)

# Training the CycleGAN network
gen_losses, dis_losses = gan.train(
    subset_a=data["A"],
    num_epochs=30,
    num_decay_epochs=15,
    num_classes = None,
    batch_size = None,
    subset_b = data["B"]
)

# Sample images from the Generator
real_a = data["A"][:36]
real_b = data["B"][:36]

fake_image_a, fake_image_b= gan.generate_samples(
    nb_samples = None,
    num_classes = None,
    real_image_a = real_a,
    real_image_b = real_b
)

Notebooks

CYCLEGAN

A basic CycleGAN class for training of image-to-image translation model without paired examples.

__init__(self,
         generator: augmentare.methods.gan.cyclegan.CYCLEGANGenerator,
         discriminator: augmentare.methods.gan.cyclegan.CYCLEGANDiscriminator,
         optimizer_gen: torch.optim.optimizer.Optimizer,
         optimizer_dis: torch.optim.optimizer.Optimizer,
         loss_fn_gen: Callable,
         loss_fn_dis: Callable,
         device,
         latent_size=typing.Optional[int])

generate_samples(self,
                 nb_samples=None,
                 num_classes=None,
                 real_image_a=typing.Optional[],
                 real_image_b=typing.Optional[])

Sample images from the generator.

Return

  • fake_image_a

    • A list of generated images A

  • fake_image_b

    • A list of generated images B


train(self,
      subset_a: Union[, torch.utils.data.dataset.Dataset],
      num_epochs: int,
      num_decay_epochs=typing.Optional[int],
      num_classes=typing.Optional[int],
      batch_size=typing.Optional[int],
      subset_b=typing.Union[, torch.utils.data.dataset.Dataset, None])

The corresponding training function

Parameters

  • subset_a : Union[, torch.utils.data.dataset.Dataset]

    • Torch.tensor or Dataset

  • num_epochs : int

    • The number of epochs you want to train your CycleGAN

  • num_decay_epochs : num_decay_epochs=typing.Optional[int]

    • The number of epochs to start linearly decaying the learning rate to 0

  • subset_b : subset_b=typing.Union[, torch.utils.data.dataset.Dataset, None]

    • The second Torch.tensor or Dataset

Return

  • gen_losses, dis_losses

    • The losses of both the discriminator and generator


train_discriminator(self,
                    real_samples_a,
                    real_samples_b)

Train the discriminator one step and return the loss.

Parameters

  • real_samples_a : real_samples_a

    • True samples of your dataset A

  • real_samples_b : real_samples_b

    • True samples of your dataset B

Return

  • dis_loss

    • The loss of the discriminator


train_generator(self,
                real_samples_a,
                real_samples_b)

Train the generator one step and return the loss.

Parameters

  • real_samples_a : real_samples_a

    • True samples of your dataset A

  • real_samples_b : real_samples_b

    • True samples of your dataset B

Return

  • gen_loss

    • The loss of the generator


CYCLEGANGenerator

A conditional generator for synthesizing an image given an input image.

__init__(self)

add_module(self,
           name: str,
           module: Optional[ForwardRef('Module')]) -> None

Adds a child module to the current module.


apply(self: ~T,
      fn: Callable[[ForwardRef('Module')], None]) -> ~T

Applies fn recursively to every submodule (as returned by .children()) as well as self. Typical use includes initializing the parameters of a model (see also :ref:nn-init-doc).


bfloat16(self: ~T) -> ~T

Casts all floating point parameters and buffers to bfloat16 datatype.


buffers(self,
        recurse: bool = True) -> Iterator[torch.Tensor]

Returns an iterator over module buffers.


children(self) -> Iterator[ForwardRef('Module')]

Returns an iterator over immediate children modules.


compile(self,
        args,
       
*kwargs)

Compile this Module's forward using :func:torch.compile.


cpu(self: ~T) -> ~T

Moves all model parameters and buffers to the CPU.


cuda(self: ~T,
     device: Union[int, torch.device, None] = None) -> ~T

Moves all model parameters and buffers to the GPU.


double(self: ~T) -> ~T

Casts all floating point parameters and buffers to double datatype.


eval(self: ~T) -> ~T

Sets the module in evaluation mode.


extra_repr(self) -> str

Set the extra representation of the module


float(self: ~T) -> ~T

Casts all floating point parameters and buffers to float datatype.


forward(self,
        noise,
        labels=typing.Optional[])

A forward function CYCLEGANGenerator.


get_buffer(self,
           target: str) -> 'Tensor'

Returns the buffer given by target if it exists, otherwise throws an error.


get_extra_state(self) -> Any

Returns any extra state to include in the module's state_dict. Implement this and a corresponding :func:set_extra_state for your module if you need to store extra state. This function is called when building the module's state_dict().


get_parameter(self,
              target: str) -> 'Parameter'

Returns the parameter given by target if it exists, otherwise throws an error.


get_submodule(self,
              target: str) -> 'Module'

Returns the submodule given by target if it exists, otherwise throws an error.


half(self: ~T) -> ~T

Casts all floating point parameters and buffers to half datatype.


ipu(self: ~T,
    device: Union[int, torch.device, None] = None) -> ~T

Moves all model parameters and buffers to the IPU.


load_state_dict(self,
                state_dict: Mapping[str, Any],
                strict: bool = True,
                assign: bool = False)

Copies parameters and buffers from :attr:state_dict into this module and its descendants. If :attr:strict is True, then the keys of :attr:state_dict must exactly match the keys returned by this module's :meth:~torch.nn.Module.state_dict function.


modules(self) -> Iterator[ForwardRef('Module')]

Returns an iterator over all modules in the network.


named_buffers(self,
              prefix: str = '',
              recurse: bool = True,
              remove_duplicate: bool = True) -> Iterator[Tuple[str, torch.Tensor]]

Returns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.


named_children(self) -> Iterator[Tuple[str, ForwardRef('Module')]]

Returns an iterator over immediate children modules, yielding both the name of the module as well as the module itself.


named_modules(self,
              memo: Optional[Set[ForwardRef('Module')]] = None,
              prefix: str = '',
              remove_duplicate: bool = True)

Returns an iterator over all modules in the network, yielding both the name of the module as well as the module itself.


named_parameters(self,
                 prefix: str = '',
                 recurse: bool = True,
                 remove_duplicate: bool = True) -> Iterator[Tuple[str, torch.nn.parameter.Parameter]]

Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.


parameters(self,
           recurse: bool = True) -> Iterator[torch.nn.parameter.Parameter]

Returns an iterator over module parameters.


register_backward_hook(self,
                       hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor],
                       Union[Tuple[torch.Tensor, ...], torch.Tensor]],
                       Union[None, Tuple[torch.Tensor, ...], torch.Tensor]]) -> torch.utils.hooks.RemovableHandle

Registers a backward hook on the module.


register_buffer(self,
                name: str,
                tensor: Optional[torch.Tensor],
                persistent: bool = True) -> None

Adds a buffer to the module.


register_forward_hook(self,
                      hook: Union[Callable[[~T, Tuple[Any, ...], Any],
                      Optional[Any]],
                      Callable[[~T, Tuple[Any, ...],
                      Dict[str, Any], Any],
                      Optional[Any]]],
                      *,
                      prepend: bool = False,
                      with_kwargs: bool = False,
                      always_call: bool = False) -> torch.utils.hooks.RemovableHandle

Registers a forward hook on the module.


register_forward_pre_hook(self,
                          hook: Union[Callable[[~T, Tuple[Any, ...]],
                          Optional[Any]],
                          Callable[[~T, Tuple[Any, ...],
                          Dict[str, Any]],
                          Optional[Tuple[Any, Dict[str, Any]]]]],
                          *,
                          prepend: bool = False,
                          with_kwargs: bool = False) -> torch.utils.hooks.RemovableHandle

Registers a forward pre-hook on the module.


register_full_backward_hook(self,
                            hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor],
                            Union[Tuple[torch.Tensor, ...], torch.Tensor]],
                            Union[None, Tuple[torch.Tensor, ...], torch.Tensor]],
                            prepend: bool = False) -> torch.utils.hooks.RemovableHandle

Registers a backward hook on the module.


register_full_backward_pre_hook(self,
                                hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor]],
                                Union[None, Tuple[torch.Tensor, ...], torch.Tensor]],
                                prepend: bool = False) -> torch.utils.hooks.RemovableHandle

Registers a backward pre-hook on the module.


register_load_state_dict_post_hook(self,
                                   hook)

Registers a post hook to be run after module's load_state_dict is called.


register_module(self,
                name: str,
                module: Optional[ForwardRef('Module')]) -> None

Alias for :func:add_module.


register_parameter(self,
                   name: str,
                   param: Optional[torch.nn.parameter.Parameter]) -> None

Adds a parameter to the module.


register_state_dict_pre_hook(self,
                             hook)

These hooks will be called with arguments: self, prefix, and keep_vars before calling state_dict on self. The registered hooks can be used to perform pre-processing before the state_dict call is made.


requires_grad_(self: ~T,
               requires_grad: bool = True) -> ~T

Change if autograd should record operations on parameters in this module.


set_extra_state(self,
                state: Any)

This function is called from :func:load_state_dict to handle any extra state found within the state_dict. Implement this function and a corresponding :func:get_extra_state for your module if you need to store extra state within its state_dict.


share_memory(self: ~T) -> ~T

See :meth:torch.Tensor.share_memory_


state_dict(self,
           *args,
           destination=None,
           prefix='',
           keep_vars=False)

Returns a dictionary containing references to the whole state of the module.


to(self,
   args,
  
*kwargs)

Moves and/or casts the parameters and buffers.


to_empty(self: ~T,
         *,
         device: Union[str, torch.device],
         recurse: bool = True) -> ~T

Moves the parameters and buffers to the specified device without copying storage.


train(self: ~T,
      mode: bool = True) -> ~T

Sets the module in training mode.


type(self: ~T,
     dst_type: Union[torch.dtype, str]) -> ~T

Casts all parameters and buffers to :attr:dst_type.


xpu(self: ~T,
    device: Union[int, torch.device, None] = None) -> ~T

Moves all model parameters and buffers to the XPU.


zero_grad(self,
          set_to_none: bool = True) -> None

Resets gradients of all model parameters. See similar function under :class:torch.optim.Optimizer for more context.


CYCLEGANDiscriminator

A discriminator for predicting how likely the generated image is to have come from the target image collection.

__init__(self)

add_module(self,
           name: str,
           module: Optional[ForwardRef('Module')]) -> None

Adds a child module to the current module.


apply(self: ~T,
      fn: Callable[[ForwardRef('Module')], None]) -> ~T

Applies fn recursively to every submodule (as returned by .children()) as well as self. Typical use includes initializing the parameters of a model (see also :ref:nn-init-doc).


bfloat16(self: ~T) -> ~T

Casts all floating point parameters and buffers to bfloat16 datatype.


buffers(self,
        recurse: bool = True) -> Iterator[torch.Tensor]

Returns an iterator over module buffers.


children(self) -> Iterator[ForwardRef('Module')]

Returns an iterator over immediate children modules.


compile(self,
        args,
       
*kwargs)

Compile this Module's forward using :func:torch.compile.


cpu(self: ~T) -> ~T

Moves all model parameters and buffers to the CPU.


cuda(self: ~T,
     device: Union[int, torch.device, None] = None) -> ~T

Moves all model parameters and buffers to the GPU.


double(self: ~T) -> ~T

Casts all floating point parameters and buffers to double datatype.


eval(self: ~T) -> ~T

Sets the module in evaluation mode.


extra_repr(self) -> str

Set the extra representation of the module


float(self: ~T) -> ~T

Casts all floating point parameters and buffers to float datatype.


forward(self,
        noise,
        labels=typing.Optional[])

A forward function CYCLEGANDiscriminator.


get_buffer(self,
           target: str) -> 'Tensor'

Returns the buffer given by target if it exists, otherwise throws an error.


get_extra_state(self) -> Any

Returns any extra state to include in the module's state_dict. Implement this and a corresponding :func:set_extra_state for your module if you need to store extra state. This function is called when building the module's state_dict().


get_parameter(self,
              target: str) -> 'Parameter'

Returns the parameter given by target if it exists, otherwise throws an error.


get_submodule(self,
              target: str) -> 'Module'

Returns the submodule given by target if it exists, otherwise throws an error.


half(self: ~T) -> ~T

Casts all floating point parameters and buffers to half datatype.


ipu(self: ~T,
    device: Union[int, torch.device, None] = None) -> ~T

Moves all model parameters and buffers to the IPU.


load_state_dict(self,
                state_dict: Mapping[str, Any],
                strict: bool = True,
                assign: bool = False)

Copies parameters and buffers from :attr:state_dict into this module and its descendants. If :attr:strict is True, then the keys of :attr:state_dict must exactly match the keys returned by this module's :meth:~torch.nn.Module.state_dict function.


modules(self) -> Iterator[ForwardRef('Module')]

Returns an iterator over all modules in the network.


named_buffers(self,
              prefix: str = '',
              recurse: bool = True,
              remove_duplicate: bool = True) -> Iterator[Tuple[str, torch.Tensor]]

Returns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.


named_children(self) -> Iterator[Tuple[str, ForwardRef('Module')]]

Returns an iterator over immediate children modules, yielding both the name of the module as well as the module itself.


named_modules(self,
              memo: Optional[Set[ForwardRef('Module')]] = None,
              prefix: str = '',
              remove_duplicate: bool = True)

Returns an iterator over all modules in the network, yielding both the name of the module as well as the module itself.


named_parameters(self,
                 prefix: str = '',
                 recurse: bool = True,
                 remove_duplicate: bool = True) -> Iterator[Tuple[str, torch.nn.parameter.Parameter]]

Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.


parameters(self,
           recurse: bool = True) -> Iterator[torch.nn.parameter.Parameter]

Returns an iterator over module parameters.


register_backward_hook(self,
                       hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor],
                       Union[Tuple[torch.Tensor, ...], torch.Tensor]],
                       Union[None, Tuple[torch.Tensor, ...], torch.Tensor]]) -> torch.utils.hooks.RemovableHandle

Registers a backward hook on the module.


register_buffer(self,
                name: str,
                tensor: Optional[torch.Tensor],
                persistent: bool = True) -> None

Adds a buffer to the module.


register_forward_hook(self,
                      hook: Union[Callable[[~T, Tuple[Any, ...], Any],
                      Optional[Any]],
                      Callable[[~T, Tuple[Any, ...],
                      Dict[str, Any], Any],
                      Optional[Any]]],
                      *,
                      prepend: bool = False,
                      with_kwargs: bool = False,
                      always_call: bool = False) -> torch.utils.hooks.RemovableHandle

Registers a forward hook on the module.


register_forward_pre_hook(self,
                          hook: Union[Callable[[~T, Tuple[Any, ...]],
                          Optional[Any]],
                          Callable[[~T, Tuple[Any, ...],
                          Dict[str, Any]],
                          Optional[Tuple[Any, Dict[str, Any]]]]],
                          *,
                          prepend: bool = False,
                          with_kwargs: bool = False) -> torch.utils.hooks.RemovableHandle

Registers a forward pre-hook on the module.


register_full_backward_hook(self,
                            hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor],
                            Union[Tuple[torch.Tensor, ...], torch.Tensor]],
                            Union[None, Tuple[torch.Tensor, ...], torch.Tensor]],
                            prepend: bool = False) -> torch.utils.hooks.RemovableHandle

Registers a backward hook on the module.


register_full_backward_pre_hook(self,
                                hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor]],
                                Union[None, Tuple[torch.Tensor, ...], torch.Tensor]],
                                prepend: bool = False) -> torch.utils.hooks.RemovableHandle

Registers a backward pre-hook on the module.


register_load_state_dict_post_hook(self,
                                   hook)

Registers a post hook to be run after module's load_state_dict is called.


register_module(self,
                name: str,
                module: Optional[ForwardRef('Module')]) -> None

Alias for :func:add_module.


register_parameter(self,
                   name: str,
                   param: Optional[torch.nn.parameter.Parameter]) -> None

Adds a parameter to the module.


register_state_dict_pre_hook(self,
                             hook)

These hooks will be called with arguments: self, prefix, and keep_vars before calling state_dict on self. The registered hooks can be used to perform pre-processing before the state_dict call is made.


requires_grad_(self: ~T,
               requires_grad: bool = True) -> ~T

Change if autograd should record operations on parameters in this module.


set_extra_state(self,
                state: Any)

This function is called from :func:load_state_dict to handle any extra state found within the state_dict. Implement this function and a corresponding :func:get_extra_state for your module if you need to store extra state within its state_dict.


share_memory(self: ~T) -> ~T

See :meth:torch.Tensor.share_memory_


state_dict(self,
           *args,
           destination=None,
           prefix='',
           keep_vars=False)

Returns a dictionary containing references to the whole state of the module.


to(self,
   args,
  
*kwargs)

Moves and/or casts the parameters and buffers.


to_empty(self: ~T,
         *,
         device: Union[str, torch.device],
         recurse: bool = True) -> ~T

Moves the parameters and buffers to the specified device without copying storage.


train(self: ~T,
      mode: bool = True) -> ~T

Sets the module in training mode.


type(self: ~T,
     dst_type: Union[torch.dtype, str]) -> ~T

Casts all parameters and buffers to :attr:dst_type.


xpu(self: ~T,
    device: Union[int, torch.device, None] = None) -> ~T

Moves all model parameters and buffers to the XPU.


zero_grad(self,
          set_to_none: bool = True) -> None

Resets gradients of all model parameters. See similar function under :class:torch.optim.Optimizer for more context.


Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks by Jun-Yan Zhu & al (2017).