Cycle GAN¶
View colab tutorial | View source | 📰 Paper
Have you ever had the dark secret of turning a horse into a zebra? CycleGAN was developed to do just that. Learn how to turn a horse into a zebra and much more.
Back to the story of the dark secret, how are we going to do that?
We'll climb over the neighbor's fence in the middle of the night to paint a horse with stripes. Snap a snap of a horse before we start our fucking act. Then paint the horses quickly before their owners notice. Take a photo of the zebra-striped horse's output, then run to the fence before the homeowner's pit bulls spot you. And you'll have to keep doing that over and over again until you have enough example items in your database to train your neural network.
Forget it because you have CycleGAN. You will build a generator like the Pix2Pix architecture, which the GAN will train to be a Generator to turn a horse into a zebra. And then you build a Generator (again based on the Pix2Pix architecture) for a second inverse GAN that is supposed to take a picture of a zebra and turn it into an image of a horse.
Let's look at the image below, which shows the first half of CycleGAN trying to create a fake zebra from a horse. The second half of CycleGAN tries to create a fake horse from a zebra. Both halves include loss of cyclic consistency trying to make the output of the inverting generator match the input of the non-inverting generator.
Simplified view of CycleGAN architecture (Image source)
NETWORK ARCHITECTURE : CycleGAN¶
GENERATOR NETWORK¶
The CycleGAN Generator has 3 components:
- A downsampling network: It is composed of 3 convolutional layers (together with the regular padding, normalization and activation layers).
- A chain of residual networks built using the Residual Block. You can try to vary the
ResidualBlock
parameter and see the results. - A upsampling network: It is composed of 3 transposed convolutional layers.
In CycleGAN Generator, we shall be using Instance Norm instead of Batch Norm and finally swap the Zero Padding of the Convolutional Layer with Reflection Padding.
DISCRIMINATOR NETWORK¶
The CycleGAN Discriminator is like the standard DCGAN Discriminator. The only difference is the Instance Normalization used.
LOSS FUNCTIONS¶
The Generator Loss is composed of 3 parts. They are described below:
- GAN Loss: It is the standard generator loss of the Least Squares GAN. We use the functional forms of the losses to implement this part. \(\(L_{GAN} = \frac{1}{2} \times ((D_A(G_{B2A}(Image_B)) - 1)^2 + (D_B(G_{A2B}(Image_A)) - 1)^2)\)\)
- Identity Loss: It computes the similarity of a real image of type B and a fake image B generated from image A and vice versa. The similarity is measured using the \(L_1\) Loss. \(\(L_{identity} = \frac{1}{2} \times (||G_{B2A}(Image_B) - Image_A||_1 + ||G_{A2B}(Image_A) - Image_B||_1)\)\)
- Cycle Consistency Loss: This loss computes the similarity of the original image and the image generated by a composition of the 2 generators. This allows cyclegan to deak with unpaired images. We reconstruct the original image and try to minimize the \(L_1\) norm between the original images and this reconstructed image. \(\(L_{cycle\_consistency} = \frac{1}{2} \times (||G_{B2A}(G_{A2B}(Image_A)) - Image_A||_1 + ||G_{A2B}(G_{B2A}(Image_B)) - Image_B||_1)\)\)
The Discriminator as mentioned before is same as the normal DCGAN Discriminator. As such even the loss function for that is same as that of the standard GAN:
Example¶
# Augmentare Imports
import augmentare
from augmentare.methods.gan import *
# Create GAN Generator
net_gen = CYCLEGANGenerator()
# Create GAN Discriminator
net_dis = CYCLEGANDiscriminator()
# Optimizers and Loss functions
optimizer_gen = Adam(net_gen.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_dis = Adam(net_dis.parameters(), lr=0.0002, betas=(0.5, 0.999))
loss_fn_gen = nn.L1Loss()
loss_fn_dis = nn.L1Loss()
# Create GAN network
gan = CYCLEGAN(
net_gen,
net_dis,
optimizer_gen,
optimizer_dis,
loss_fn_gen,
loss_fn_dis,
device,
latent_size=None
)
# Training the CycleGAN network
gen_losses, dis_losses = gan.train(
subset_a=data["A"],
num_epochs=30,
num_decay_epochs=15,
num_classes = None,
batch_size = None,
subset_b = data["B"]
)
# Sample images from the Generator
real_a = data["A"][:36]
real_b = data["B"][:36]
fake_image_a, fake_image_b= gan.generate_samples(
nb_samples = None,
num_classes = None,
real_image_a = real_a,
real_image_b = real_b
)
Notebooks¶
CYCLEGAN
¶
A basic CycleGAN class for training of image-to-image translation model without paired examples.
__init__(self,
generator: augmentare.methods.gan.cyclegan.CYCLEGANGenerator,
discriminator: augmentare.methods.gan.cyclegan.CYCLEGANDiscriminator,
optimizer_gen: torch.optim.optimizer.Optimizer,
optimizer_dis: torch.optim.optimizer.Optimizer,
loss_fn_gen: Callable,
loss_fn_dis: Callable,
device,
latent_size=typing.Optional[int])
¶
generator: augmentare.methods.gan.cyclegan.CYCLEGANGenerator,
discriminator: augmentare.methods.gan.cyclegan.CYCLEGANDiscriminator,
optimizer_gen: torch.optim.optimizer.Optimizer,
optimizer_dis: torch.optim.optimizer.Optimizer,
loss_fn_gen: Callable,
loss_fn_dis: Callable,
device,
latent_size=typing.Optional[int])
generate_samples(self,
nb_samples=None,
num_classes=None,
real_image_a=typing.Optional[],
real_image_b=typing.Optional[])
¶
nb_samples=None,
num_classes=None,
real_image_a=typing.Optional[
real_image_b=typing.Optional[
Sample images from the generator.
Return
-
fake_image_a
A list of generated images A
-
fake_image_b
A list of generated images B
train(self,
subset_a: Union[, torch.utils.data.dataset.Dataset] ,
num_epochs: int,
num_decay_epochs=typing.Optional[int],
num_classes=typing.Optional[int],
batch_size=typing.Optional[int],
subset_b=typing.Union[, torch.utils.data.dataset.Dataset, None])
¶
subset_a: Union[
num_epochs: int,
num_decay_epochs=typing.Optional[int],
num_classes=typing.Optional[int],
batch_size=typing.Optional[int],
subset_b=typing.Union[
The corresponding training function
Parameters
-
subset_a : Union[
, torch.utils.data.dataset.Dataset] Torch.tensor or Dataset
-
num_epochs : int
The number of epochs you want to train your CycleGAN
-
num_decay_epochs : num_decay_epochs=typing.Optional[int]
The number of epochs to start linearly decaying the learning rate to 0
-
subset_b : subset_b=typing.Union[
, torch.utils.data.dataset.Dataset, None] The second Torch.tensor or Dataset
Return
-
gen_losses, dis_losses
The losses of both the discriminator and generator
train_discriminator(self,
real_samples_a,
real_samples_b)
¶
real_samples_a,
real_samples_b)
Train the discriminator one step and return the loss.
Parameters
-
real_samples_a : real_samples_a
True samples of your dataset A
-
real_samples_b : real_samples_b
True samples of your dataset B
Return
-
dis_loss
The loss of the discriminator
train_generator(self,
real_samples_a,
real_samples_b)
¶
real_samples_a,
real_samples_b)
Train the generator one step and return the loss.
Parameters
-
real_samples_a : real_samples_a
True samples of your dataset A
-
real_samples_b : real_samples_b
True samples of your dataset B
Return
-
gen_loss
The loss of the generator
CYCLEGANGenerator
¶
A conditional generator for synthesizing an image given an input image.
__init__(self)
¶
add_module(self,
name: str,
module: Optional[ForwardRef('Module')]) -> None
¶
name: str,
module: Optional[ForwardRef('Module')]) -> None
Adds a child module to the current module.
apply(self: ~T,
fn: Callable[[ForwardRef('Module')], None]) -> ~T
¶
fn: Callable[[ForwardRef('Module')], None]) -> ~T
Applies fn
recursively to every submodule (as returned by .children()
)
as well as self. Typical use includes initializing the parameters of a model
(see also :ref:nn-init-doc
).
bfloat16(self: ~T) -> ~T
¶
Casts all floating point parameters and buffers to bfloat16
datatype.
buffers(self,
recurse: bool = True) -> Iterator[torch.Tensor]
¶
recurse: bool = True) -> Iterator[torch.Tensor]
Returns an iterator over module buffers.
children(self) -> Iterator[ForwardRef('Module')]
¶
Returns an iterator over immediate children modules.
compile(self,
args,
*kwargs)
¶
args,
*kwargs)
Compile this Module's forward using :func:torch.compile
.
cpu(self: ~T) -> ~T
¶
Moves all model parameters and buffers to the CPU.
cuda(self: ~T,
device: Union[int, torch.device, None] = None) -> ~T
¶
device: Union[int, torch.device, None] = None) -> ~T
Moves all model parameters and buffers to the GPU.
double(self: ~T) -> ~T
¶
Casts all floating point parameters and buffers to double
datatype.
eval(self: ~T) -> ~T
¶
Sets the module in evaluation mode.
extra_repr(self) -> str
¶
Set the extra representation of the module
float(self: ~T) -> ~T
¶
Casts all floating point parameters and buffers to float
datatype.
forward(self,
noise,
labels=typing.Optional[])
¶
noise,
labels=typing.Optional[
A forward function CYCLEGANGenerator.
get_buffer(self,
target: str) -> 'Tensor'
¶
target: str) -> 'Tensor'
Returns the buffer given by target
if it exists,
otherwise throws an error.
get_extra_state(self) -> Any
¶
Returns any extra state to include in the module's state_dict.
Implement this and a corresponding :func:set_extra_state
for your module
if you need to store extra state. This function is called when building the
module's state_dict()
.
get_parameter(self,
target: str) -> 'Parameter'
¶
target: str) -> 'Parameter'
Returns the parameter given by target
if it exists,
otherwise throws an error.
get_submodule(self,
target: str) -> 'Module'
¶
target: str) -> 'Module'
Returns the submodule given by target
if it exists,
otherwise throws an error.
half(self: ~T) -> ~T
¶
Casts all floating point parameters and buffers to half
datatype.
ipu(self: ~T,
device: Union[int, torch.device, None] = None) -> ~T
¶
device: Union[int, torch.device, None] = None) -> ~T
Moves all model parameters and buffers to the IPU.
load_state_dict(self,
state_dict: Mapping[str, Any],
strict: bool = True,
assign: bool = False)
¶
state_dict: Mapping[str, Any],
strict: bool = True,
assign: bool = False)
Copies parameters and buffers from :attr:state_dict
into
this module and its descendants. If :attr:strict
is True
, then
the keys of :attr:state_dict
must exactly match the keys returned
by this module's :meth:~torch.nn.Module.state_dict
function.
modules(self) -> Iterator[ForwardRef('Module')]
¶
Returns an iterator over all modules in the network.
named_buffers(self,
prefix: str = '',
recurse: bool = True,
remove_duplicate: bool = True) -> Iterator[Tuple[str, torch.Tensor]]
¶
prefix: str = '',
recurse: bool = True,
remove_duplicate: bool = True) -> Iterator[Tuple[str, torch.Tensor]]
Returns an iterator over module buffers, yielding both the
name of the buffer as well as the buffer itself.
named_children(self) -> Iterator[Tuple[str, ForwardRef('Module')]]
¶
Returns an iterator over immediate children modules, yielding both
the name of the module as well as the module itself.
named_modules(self,
memo: Optional[Set[ForwardRef('Module')]] = None,
prefix: str = '',
remove_duplicate: bool = True)
¶
memo: Optional[Set[ForwardRef('Module')]] = None,
prefix: str = '',
remove_duplicate: bool = True)
Returns an iterator over all modules in the network, yielding
both the name of the module as well as the module itself.
named_parameters(self,
prefix: str = '',
recurse: bool = True,
remove_duplicate: bool = True) -> Iterator[Tuple[str, torch.nn.parameter.Parameter]]
¶
prefix: str = '',
recurse: bool = True,
remove_duplicate: bool = True) -> Iterator[Tuple[str, torch.nn.parameter.Parameter]]
Returns an iterator over module parameters, yielding both the
name of the parameter as well as the parameter itself.
parameters(self,
recurse: bool = True) -> Iterator[torch.nn.parameter.Parameter]
¶
recurse: bool = True) -> Iterator[torch.nn.parameter.Parameter]
Returns an iterator over module parameters.
register_backward_hook(self,
hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor],
Union[Tuple[torch.Tensor, ...], torch.Tensor]],
Union[None, Tuple[torch.Tensor, ...], torch.Tensor]]) -> torch.utils.hooks.RemovableHandle
¶
hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor],
Union[Tuple[torch.Tensor, ...], torch.Tensor]],
Union[None, Tuple[torch.Tensor, ...], torch.Tensor]]) -> torch.utils.hooks.RemovableHandle
Registers a backward hook on the module.
register_buffer(self,
name: str,
tensor: Optional[torch.Tensor],
persistent: bool = True) -> None
¶
name: str,
tensor: Optional[torch.Tensor],
persistent: bool = True) -> None
Adds a buffer to the module.
register_forward_hook(self,
hook: Union[Callable[[~T, Tuple[Any, ...], Any],
Optional[Any]],
Callable[[~T, Tuple[Any, ...],
Dict[str, Any], Any],
Optional[Any]]],
*,
prepend: bool = False,
with_kwargs: bool = False,
always_call: bool = False) -> torch.utils.hooks.RemovableHandle
¶
hook: Union[Callable[[~T, Tuple[Any, ...], Any],
Optional[Any]],
Callable[[~T, Tuple[Any, ...],
Dict[str, Any], Any],
Optional[Any]]],
*,
prepend: bool = False,
with_kwargs: bool = False,
always_call: bool = False) -> torch.utils.hooks.RemovableHandle
Registers a forward hook on the module.
register_forward_pre_hook(self,
hook: Union[Callable[[~T, Tuple[Any, ...]],
Optional[Any]],
Callable[[~T, Tuple[Any, ...],
Dict[str, Any]],
Optional[Tuple[Any, Dict[str, Any]]]]],
*,
prepend: bool = False,
with_kwargs: bool = False) -> torch.utils.hooks.RemovableHandle
¶
hook: Union[Callable[[~T, Tuple[Any, ...]],
Optional[Any]],
Callable[[~T, Tuple[Any, ...],
Dict[str, Any]],
Optional[Tuple[Any, Dict[str, Any]]]]],
*,
prepend: bool = False,
with_kwargs: bool = False) -> torch.utils.hooks.RemovableHandle
Registers a forward pre-hook on the module.
register_full_backward_hook(self,
hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor],
Union[Tuple[torch.Tensor, ...], torch.Tensor]],
Union[None, Tuple[torch.Tensor, ...], torch.Tensor]],
prepend: bool = False) -> torch.utils.hooks.RemovableHandle
¶
hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor],
Union[Tuple[torch.Tensor, ...], torch.Tensor]],
Union[None, Tuple[torch.Tensor, ...], torch.Tensor]],
prepend: bool = False) -> torch.utils.hooks.RemovableHandle
Registers a backward hook on the module.
register_full_backward_pre_hook(self,
hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor]],
Union[None, Tuple[torch.Tensor, ...], torch.Tensor]],
prepend: bool = False) -> torch.utils.hooks.RemovableHandle
¶
hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor]],
Union[None, Tuple[torch.Tensor, ...], torch.Tensor]],
prepend: bool = False) -> torch.utils.hooks.RemovableHandle
Registers a backward pre-hook on the module.
register_load_state_dict_post_hook(self,
hook)
¶
hook)
Registers a post hook to be run after module's load_state_dict
is called.
register_module(self,
name: str,
module: Optional[ForwardRef('Module')]) -> None
¶
name: str,
module: Optional[ForwardRef('Module')]) -> None
Alias for :func:add_module
.
register_parameter(self,
name: str,
param: Optional[torch.nn.parameter.Parameter]) -> None
¶
name: str,
param: Optional[torch.nn.parameter.Parameter]) -> None
Adds a parameter to the module.
register_state_dict_pre_hook(self,
hook)
¶
hook)
These hooks will be called with arguments: self
, prefix
,
and keep_vars
before calling state_dict
on self
. The registered
hooks can be used to perform pre-processing before the state_dict
call is made.
requires_grad_(self: ~T,
requires_grad: bool = True) -> ~T
¶
requires_grad: bool = True) -> ~T
Change if autograd should record operations on parameters in this
module.
set_extra_state(self,
state: Any)
¶
state: Any)
This function is called from :func:load_state_dict
to handle any extra state
found within the state_dict
. Implement this function and a corresponding
:func:get_extra_state
for your module if you need to store extra state within its
state_dict
.
share_memory(self: ~T) -> ~T
¶
See :meth:torch.Tensor.share_memory_
state_dict(self,
*args,
destination=None,
prefix='',
keep_vars=False)
¶
*args,
destination=None,
prefix='',
keep_vars=False)
Returns a dictionary containing references to the whole state of the module.
to(self,
args,
*kwargs)
¶
args,
*kwargs)
Moves and/or casts the parameters and buffers.
to_empty(self: ~T,
*,
device: Union[str, torch.device],
recurse: bool = True) -> ~T
¶
*,
device: Union[str, torch.device],
recurse: bool = True) -> ~T
Moves the parameters and buffers to the specified device without copying storage.
train(self: ~T,
mode: bool = True) -> ~T
¶
mode: bool = True) -> ~T
Sets the module in training mode.
type(self: ~T,
dst_type: Union[torch.dtype, str]) -> ~T
¶
dst_type: Union[torch.dtype, str]) -> ~T
Casts all parameters and buffers to :attr:dst_type
.
xpu(self: ~T,
device: Union[int, torch.device, None] = None) -> ~T
¶
device: Union[int, torch.device, None] = None) -> ~T
Moves all model parameters and buffers to the XPU.
zero_grad(self,
set_to_none: bool = True) -> None
¶
set_to_none: bool = True) -> None
Resets gradients of all model parameters. See similar function
under :class:torch.optim.Optimizer
for more context.
CYCLEGANDiscriminator
¶
A discriminator for predicting how likely the generated image is to
have come from the target image collection.
__init__(self)
¶
add_module(self,
name: str,
module: Optional[ForwardRef('Module')]) -> None
¶
name: str,
module: Optional[ForwardRef('Module')]) -> None
Adds a child module to the current module.
apply(self: ~T,
fn: Callable[[ForwardRef('Module')], None]) -> ~T
¶
fn: Callable[[ForwardRef('Module')], None]) -> ~T
Applies fn
recursively to every submodule (as returned by .children()
)
as well as self. Typical use includes initializing the parameters of a model
(see also :ref:nn-init-doc
).
bfloat16(self: ~T) -> ~T
¶
Casts all floating point parameters and buffers to bfloat16
datatype.
buffers(self,
recurse: bool = True) -> Iterator[torch.Tensor]
¶
recurse: bool = True) -> Iterator[torch.Tensor]
Returns an iterator over module buffers.
children(self) -> Iterator[ForwardRef('Module')]
¶
Returns an iterator over immediate children modules.
compile(self,
args,
*kwargs)
¶
args,
*kwargs)
Compile this Module's forward using :func:torch.compile
.
cpu(self: ~T) -> ~T
¶
Moves all model parameters and buffers to the CPU.
cuda(self: ~T,
device: Union[int, torch.device, None] = None) -> ~T
¶
device: Union[int, torch.device, None] = None) -> ~T
Moves all model parameters and buffers to the GPU.
double(self: ~T) -> ~T
¶
Casts all floating point parameters and buffers to double
datatype.
eval(self: ~T) -> ~T
¶
Sets the module in evaluation mode.
extra_repr(self) -> str
¶
Set the extra representation of the module
float(self: ~T) -> ~T
¶
Casts all floating point parameters and buffers to float
datatype.
forward(self,
noise,
labels=typing.Optional[])
¶
noise,
labels=typing.Optional[
A forward function CYCLEGANDiscriminator.
get_buffer(self,
target: str) -> 'Tensor'
¶
target: str) -> 'Tensor'
Returns the buffer given by target
if it exists,
otherwise throws an error.
get_extra_state(self) -> Any
¶
Returns any extra state to include in the module's state_dict.
Implement this and a corresponding :func:set_extra_state
for your module
if you need to store extra state. This function is called when building the
module's state_dict()
.
get_parameter(self,
target: str) -> 'Parameter'
¶
target: str) -> 'Parameter'
Returns the parameter given by target
if it exists,
otherwise throws an error.
get_submodule(self,
target: str) -> 'Module'
¶
target: str) -> 'Module'
Returns the submodule given by target
if it exists,
otherwise throws an error.
half(self: ~T) -> ~T
¶
Casts all floating point parameters and buffers to half
datatype.
ipu(self: ~T,
device: Union[int, torch.device, None] = None) -> ~T
¶
device: Union[int, torch.device, None] = None) -> ~T
Moves all model parameters and buffers to the IPU.
load_state_dict(self,
state_dict: Mapping[str, Any],
strict: bool = True,
assign: bool = False)
¶
state_dict: Mapping[str, Any],
strict: bool = True,
assign: bool = False)
Copies parameters and buffers from :attr:state_dict
into
this module and its descendants. If :attr:strict
is True
, then
the keys of :attr:state_dict
must exactly match the keys returned
by this module's :meth:~torch.nn.Module.state_dict
function.
modules(self) -> Iterator[ForwardRef('Module')]
¶
Returns an iterator over all modules in the network.
named_buffers(self,
prefix: str = '',
recurse: bool = True,
remove_duplicate: bool = True) -> Iterator[Tuple[str, torch.Tensor]]
¶
prefix: str = '',
recurse: bool = True,
remove_duplicate: bool = True) -> Iterator[Tuple[str, torch.Tensor]]
Returns an iterator over module buffers, yielding both the
name of the buffer as well as the buffer itself.
named_children(self) -> Iterator[Tuple[str, ForwardRef('Module')]]
¶
Returns an iterator over immediate children modules, yielding both
the name of the module as well as the module itself.
named_modules(self,
memo: Optional[Set[ForwardRef('Module')]] = None,
prefix: str = '',
remove_duplicate: bool = True)
¶
memo: Optional[Set[ForwardRef('Module')]] = None,
prefix: str = '',
remove_duplicate: bool = True)
Returns an iterator over all modules in the network, yielding
both the name of the module as well as the module itself.
named_parameters(self,
prefix: str = '',
recurse: bool = True,
remove_duplicate: bool = True) -> Iterator[Tuple[str, torch.nn.parameter.Parameter]]
¶
prefix: str = '',
recurse: bool = True,
remove_duplicate: bool = True) -> Iterator[Tuple[str, torch.nn.parameter.Parameter]]
Returns an iterator over module parameters, yielding both the
name of the parameter as well as the parameter itself.
parameters(self,
recurse: bool = True) -> Iterator[torch.nn.parameter.Parameter]
¶
recurse: bool = True) -> Iterator[torch.nn.parameter.Parameter]
Returns an iterator over module parameters.
register_backward_hook(self,
hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor],
Union[Tuple[torch.Tensor, ...], torch.Tensor]],
Union[None, Tuple[torch.Tensor, ...], torch.Tensor]]) -> torch.utils.hooks.RemovableHandle
¶
hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor],
Union[Tuple[torch.Tensor, ...], torch.Tensor]],
Union[None, Tuple[torch.Tensor, ...], torch.Tensor]]) -> torch.utils.hooks.RemovableHandle
Registers a backward hook on the module.
register_buffer(self,
name: str,
tensor: Optional[torch.Tensor],
persistent: bool = True) -> None
¶
name: str,
tensor: Optional[torch.Tensor],
persistent: bool = True) -> None
Adds a buffer to the module.
register_forward_hook(self,
hook: Union[Callable[[~T, Tuple[Any, ...], Any],
Optional[Any]],
Callable[[~T, Tuple[Any, ...],
Dict[str, Any], Any],
Optional[Any]]],
*,
prepend: bool = False,
with_kwargs: bool = False,
always_call: bool = False) -> torch.utils.hooks.RemovableHandle
¶
hook: Union[Callable[[~T, Tuple[Any, ...], Any],
Optional[Any]],
Callable[[~T, Tuple[Any, ...],
Dict[str, Any], Any],
Optional[Any]]],
*,
prepend: bool = False,
with_kwargs: bool = False,
always_call: bool = False) -> torch.utils.hooks.RemovableHandle
Registers a forward hook on the module.
register_forward_pre_hook(self,
hook: Union[Callable[[~T, Tuple[Any, ...]],
Optional[Any]],
Callable[[~T, Tuple[Any, ...],
Dict[str, Any]],
Optional[Tuple[Any, Dict[str, Any]]]]],
*,
prepend: bool = False,
with_kwargs: bool = False) -> torch.utils.hooks.RemovableHandle
¶
hook: Union[Callable[[~T, Tuple[Any, ...]],
Optional[Any]],
Callable[[~T, Tuple[Any, ...],
Dict[str, Any]],
Optional[Tuple[Any, Dict[str, Any]]]]],
*,
prepend: bool = False,
with_kwargs: bool = False) -> torch.utils.hooks.RemovableHandle
Registers a forward pre-hook on the module.
register_full_backward_hook(self,
hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor],
Union[Tuple[torch.Tensor, ...], torch.Tensor]],
Union[None, Tuple[torch.Tensor, ...], torch.Tensor]],
prepend: bool = False) -> torch.utils.hooks.RemovableHandle
¶
hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor],
Union[Tuple[torch.Tensor, ...], torch.Tensor]],
Union[None, Tuple[torch.Tensor, ...], torch.Tensor]],
prepend: bool = False) -> torch.utils.hooks.RemovableHandle
Registers a backward hook on the module.
register_full_backward_pre_hook(self,
hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor]],
Union[None, Tuple[torch.Tensor, ...], torch.Tensor]],
prepend: bool = False) -> torch.utils.hooks.RemovableHandle
¶
hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor]],
Union[None, Tuple[torch.Tensor, ...], torch.Tensor]],
prepend: bool = False) -> torch.utils.hooks.RemovableHandle
Registers a backward pre-hook on the module.
register_load_state_dict_post_hook(self,
hook)
¶
hook)
Registers a post hook to be run after module's load_state_dict
is called.
register_module(self,
name: str,
module: Optional[ForwardRef('Module')]) -> None
¶
name: str,
module: Optional[ForwardRef('Module')]) -> None
Alias for :func:add_module
.
register_parameter(self,
name: str,
param: Optional[torch.nn.parameter.Parameter]) -> None
¶
name: str,
param: Optional[torch.nn.parameter.Parameter]) -> None
Adds a parameter to the module.
register_state_dict_pre_hook(self,
hook)
¶
hook)
These hooks will be called with arguments: self
, prefix
,
and keep_vars
before calling state_dict
on self
. The registered
hooks can be used to perform pre-processing before the state_dict
call is made.
requires_grad_(self: ~T,
requires_grad: bool = True) -> ~T
¶
requires_grad: bool = True) -> ~T
Change if autograd should record operations on parameters in this
module.
set_extra_state(self,
state: Any)
¶
state: Any)
This function is called from :func:load_state_dict
to handle any extra state
found within the state_dict
. Implement this function and a corresponding
:func:get_extra_state
for your module if you need to store extra state within its
state_dict
.
share_memory(self: ~T) -> ~T
¶
See :meth:torch.Tensor.share_memory_
state_dict(self,
*args,
destination=None,
prefix='',
keep_vars=False)
¶
*args,
destination=None,
prefix='',
keep_vars=False)
Returns a dictionary containing references to the whole state of the module.
to(self,
args,
*kwargs)
¶
args,
*kwargs)
Moves and/or casts the parameters and buffers.
to_empty(self: ~T,
*,
device: Union[str, torch.device],
recurse: bool = True) -> ~T
¶
*,
device: Union[str, torch.device],
recurse: bool = True) -> ~T
Moves the parameters and buffers to the specified device without copying storage.
train(self: ~T,
mode: bool = True) -> ~T
¶
mode: bool = True) -> ~T
Sets the module in training mode.
type(self: ~T,
dst_type: Union[torch.dtype, str]) -> ~T
¶
dst_type: Union[torch.dtype, str]) -> ~T
Casts all parameters and buffers to :attr:dst_type
.
xpu(self: ~T,
device: Union[int, torch.device, None] = None) -> ~T
¶
device: Union[int, torch.device, None] = None) -> ~T
Moves all model parameters and buffers to the XPU.
zero_grad(self,
set_to_none: bool = True) -> None
¶
set_to_none: bool = True) -> None
Resets gradients of all model parameters. See similar function
under :class:torch.optim.Optimizer
for more context.
Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks by Jun-Yan Zhu & al (2017).