Conditional Deep Convolutional GAN (CDCGAN)¶
View colab tutorial | View source | 📰 Paper
Conditional Deep Convolutional GAN
is a conditional GAN that use the same convolution layers as DCGAN
that is described previously. CDCGAN
generate more realistic images than CGAN
thanks to convolutional layers.
NETWORK ARCHITECTURE : CDCGAN¶
GENERATOR NETWORK¶
The CDCGAN Generator is parameterized to learn and produce realistic samples for each label in the training dataset. It receives an input noise vector of size \(batch\ size \times latent\ size\). It outputs a tensor of \(batch\ size \times channel \times height \times width\) corresponding to a batch of generated image samples.
The intermediate layers use the ReLU activation function to kill gradients and slow down convergence. We can also use any other activation to ensure a good gradation flow. The last layer uses the Tanh activation to constrain the pixel values to the range of \((- 1 \to 1)\).
DISCRIMINATOR NETWORK¶
The CDCGAN Discriminator learns to distinguish fake and real samples, given the label information. It has a symmetric architecture to the generator. It maps the image with a confidence score to classify whether the image is real (i.e. comes from the dataset) or fake (i.e. sampled by the generator)
We use the LeakyReLU activation for Discriminator.
The last layer of CDCGAN's Discriminator has a Sigmoid layer that makes the confidence score between \((0 \to 1)\) and allows the confidence score to be easily interpreted in terms of the probability that the image is real. However, this interpretation is restricted only to the Minimax Loss proposed in the original GAN paper, and losses such as the Wasserstein Loss require no such interpretation. However, if required, one can easily set last layer activation to Sigmoid by passing it as a parameter during initialization time.
Example¶
# Augmentare Imports
import augmentare
from augmentare.methods.gan import *
# Create GAN Generator
net_gen = CDCGANGenerator(
num_classes=10,
latent_size=10,
label_embed_size=5,
channels=3,
conv_dim=64
)
# Create GAN Discriminator
net_dis = CDCGANDiscriminator(
num_classes=10,
channels=3,
conv_dim=64,
image_size=image_size
)
# Optimizers and Loss functions
optimizer_gen = Adam(net_gen.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_dis = Adam(net_dis.parameters(), lr=0.0002, betas=(0.5, 0.999))
loss_fn_gen = nn.BCELoss()
loss_fn_dis = nn.BCELoss()
# Create GAN network
gan = CDCGAN(
net_gen,
net_dis,
optimizer_gen,
optimizer_dis,
loss_fn_gen,
loss_fn_dis,
device,
latent_size=10,
init_weights=False
)
# Training the CDCGAN network
gen_losses, dis_losses = gan.train(
subset_a=dataloader,
num_epochs=200,
num_decay_epochs = None,
num_classes = None,
batch_size=256,
subset_b = None
)
# Sample images from the Generator
img_list = gan.generate_samples(
nb_samples=32,
num_classes=8,
real_image_a = None,
real_image_b = None
)
Notebooks¶
CDCGAN
¶
A basic CDCGAN class for generating images.
__init__(self,
generator: augmentare.methods.gan.base.BaseGenerator,
discriminator: augmentare.methods.gan.base.BaseDiscriminator,
optimizer_gen: torch.optim.optimizer.Optimizer,
optimizer_dis: torch.optim.optimizer.Optimizer,
loss_fn_gen: Callable,
loss_fn_dis: Callable,
device,
latent_size: Optional[int] = None,
init_weights: bool = True)
¶
generator: augmentare.methods.gan.base.BaseGenerator,
discriminator: augmentare.methods.gan.base.BaseDiscriminator,
optimizer_gen: torch.optim.optimizer.Optimizer,
optimizer_dis: torch.optim.optimizer.Optimizer,
loss_fn_gen: Callable,
loss_fn_dis: Callable,
device,
latent_size: Optional[int] = None,
init_weights: bool = True)
generate_samples(self,
nb_samples: int,
num_classes=typing.Optional[int],
real_image_a=None,
real_image_b=None)
¶
nb_samples: int,
num_classes=typing.Optional[int],
real_image_a=None,
real_image_b=None)
Sample images from the generator.
Return
-
img_list
A list of generated images
train(self,
subset_a: Union[, torch.utils.data.dataset.Dataset] ,
num_epochs: int,
num_decay_epochs=typing.Optional[int],
num_classes=typing.Optional[int],
batch_size=typing.Optional[int],
subset_b=typing.Union[, torch.utils.data.dataset.Dataset, None])
¶
subset_a: Union[
num_epochs: int,
num_decay_epochs=typing.Optional[int],
num_classes=typing.Optional[int],
batch_size=typing.Optional[int],
subset_b=typing.Union[
Train both networks and return the losses.
Parameters
-
subset_a : Union[
, torch.utils.data.dataset.Dataset] Torch.tensor or Dataset
-
num_epochs : int
The number of epochs you want to train your CDCGan
-
batch_size : batch_size=typing.Optional[int]
Training batch size
Return
-
gen_losses, dis_losses
The losses of both the discriminator and generator
train_discriminator(self,
real_samples,
real_labels,
batch_size)
¶
real_samples,
real_labels,
batch_size)
Train the discriminator one step and return the loss.
Parameters
-
real_samples : real_samples
True samples of your dataset
-
real_labels : real_labels
True labels of real samples
-
batch_size : batch_size
Batch size
Return
-
dis_loss
The loss of the discriminator
train_generator(self,
real_samples,
real_labels,
batch_size)
¶
real_samples,
real_labels,
batch_size)
Train the generator one step and return the loss.
Parameters
-
real_samples : real_samples
True samples of your dataset
-
real_labels : real_labels
True labels of real samples
-
batch_size : batch_size
Batch size
Return
-
gen_loss
The loss of the generator
CDCGANGenerator
¶
A generator for mapping a latent space to a sample space.
__init__(self,
num_classes,
latent_size,
label_embed_size,
channels,
conv_dim)
¶
num_classes,
latent_size,
label_embed_size,
channels,
conv_dim)
add_module(self,
name: str,
module: Optional[ForwardRef('Module')]) -> None
¶
name: str,
module: Optional[ForwardRef('Module')]) -> None
Adds a child module to the current module.
apply(self: ~T,
fn: Callable[[ForwardRef('Module')], None]) -> ~T
¶
fn: Callable[[ForwardRef('Module')], None]) -> ~T
Applies fn
recursively to every submodule (as returned by .children()
)
as well as self. Typical use includes initializing the parameters of a model
(see also :ref:nn-init-doc
).
bfloat16(self: ~T) -> ~T
¶
Casts all floating point parameters and buffers to bfloat16
datatype.
buffers(self,
recurse: bool = True) -> Iterator[torch.Tensor]
¶
recurse: bool = True) -> Iterator[torch.Tensor]
Returns an iterator over module buffers.
children(self) -> Iterator[ForwardRef('Module')]
¶
Returns an iterator over immediate children modules.
compile(self,
args,
*kwargs)
¶
args,
*kwargs)
Compile this Module's forward using :func:torch.compile
.
cpu(self: ~T) -> ~T
¶
Moves all model parameters and buffers to the CPU.
cuda(self: ~T,
device: Union[int, torch.device, None] = None) -> ~T
¶
device: Union[int, torch.device, None] = None) -> ~T
Moves all model parameters and buffers to the GPU.
double(self: ~T) -> ~T
¶
Casts all floating point parameters and buffers to double
datatype.
eval(self: ~T) -> ~T
¶
Sets the module in evaluation mode.
extra_repr(self) -> str
¶
Set the extra representation of the module
float(self: ~T) -> ~T
¶
Casts all floating point parameters and buffers to float
datatype.
forward(self,
noise,
labels=typing.Optional[])
¶
noise,
labels=typing.Optional[
A forward function CDCGANGenerator.
get_buffer(self,
target: str) -> 'Tensor'
¶
target: str) -> 'Tensor'
Returns the buffer given by target
if it exists,
otherwise throws an error.
get_extra_state(self) -> Any
¶
Returns any extra state to include in the module's state_dict.
Implement this and a corresponding :func:set_extra_state
for your module
if you need to store extra state. This function is called when building the
module's state_dict()
.
get_parameter(self,
target: str) -> 'Parameter'
¶
target: str) -> 'Parameter'
Returns the parameter given by target
if it exists,
otherwise throws an error.
get_submodule(self,
target: str) -> 'Module'
¶
target: str) -> 'Module'
Returns the submodule given by target
if it exists,
otherwise throws an error.
half(self: ~T) -> ~T
¶
Casts all floating point parameters and buffers to half
datatype.
ipu(self: ~T,
device: Union[int, torch.device, None] = None) -> ~T
¶
device: Union[int, torch.device, None] = None) -> ~T
Moves all model parameters and buffers to the IPU.
load_state_dict(self,
state_dict: Mapping[str, Any],
strict: bool = True,
assign: bool = False)
¶
state_dict: Mapping[str, Any],
strict: bool = True,
assign: bool = False)
Copies parameters and buffers from :attr:state_dict
into
this module and its descendants. If :attr:strict
is True
, then
the keys of :attr:state_dict
must exactly match the keys returned
by this module's :meth:~torch.nn.Module.state_dict
function.
modules(self) -> Iterator[ForwardRef('Module')]
¶
Returns an iterator over all modules in the network.
named_buffers(self,
prefix: str = '',
recurse: bool = True,
remove_duplicate: bool = True) -> Iterator[Tuple[str, torch.Tensor]]
¶
prefix: str = '',
recurse: bool = True,
remove_duplicate: bool = True) -> Iterator[Tuple[str, torch.Tensor]]
Returns an iterator over module buffers, yielding both the
name of the buffer as well as the buffer itself.
named_children(self) -> Iterator[Tuple[str, ForwardRef('Module')]]
¶
Returns an iterator over immediate children modules, yielding both
the name of the module as well as the module itself.
named_modules(self,
memo: Optional[Set[ForwardRef('Module')]] = None,
prefix: str = '',
remove_duplicate: bool = True)
¶
memo: Optional[Set[ForwardRef('Module')]] = None,
prefix: str = '',
remove_duplicate: bool = True)
Returns an iterator over all modules in the network, yielding
both the name of the module as well as the module itself.
named_parameters(self,
prefix: str = '',
recurse: bool = True,
remove_duplicate: bool = True) -> Iterator[Tuple[str, torch.nn.parameter.Parameter]]
¶
prefix: str = '',
recurse: bool = True,
remove_duplicate: bool = True) -> Iterator[Tuple[str, torch.nn.parameter.Parameter]]
Returns an iterator over module parameters, yielding both the
name of the parameter as well as the parameter itself.
parameters(self,
recurse: bool = True) -> Iterator[torch.nn.parameter.Parameter]
¶
recurse: bool = True) -> Iterator[torch.nn.parameter.Parameter]
Returns an iterator over module parameters.
register_backward_hook(self,
hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor],
Union[Tuple[torch.Tensor, ...], torch.Tensor]],
Union[None, Tuple[torch.Tensor, ...], torch.Tensor]]) -> torch.utils.hooks.RemovableHandle
¶
hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor],
Union[Tuple[torch.Tensor, ...], torch.Tensor]],
Union[None, Tuple[torch.Tensor, ...], torch.Tensor]]) -> torch.utils.hooks.RemovableHandle
Registers a backward hook on the module.
register_buffer(self,
name: str,
tensor: Optional[torch.Tensor],
persistent: bool = True) -> None
¶
name: str,
tensor: Optional[torch.Tensor],
persistent: bool = True) -> None
Adds a buffer to the module.
register_forward_hook(self,
hook: Union[Callable[[~T, Tuple[Any, ...], Any],
Optional[Any]],
Callable[[~T, Tuple[Any, ...],
Dict[str, Any], Any],
Optional[Any]]],
*,
prepend: bool = False,
with_kwargs: bool = False,
always_call: bool = False) -> torch.utils.hooks.RemovableHandle
¶
hook: Union[Callable[[~T, Tuple[Any, ...], Any],
Optional[Any]],
Callable[[~T, Tuple[Any, ...],
Dict[str, Any], Any],
Optional[Any]]],
*,
prepend: bool = False,
with_kwargs: bool = False,
always_call: bool = False) -> torch.utils.hooks.RemovableHandle
Registers a forward hook on the module.
register_forward_pre_hook(self,
hook: Union[Callable[[~T, Tuple[Any, ...]],
Optional[Any]],
Callable[[~T, Tuple[Any, ...],
Dict[str, Any]],
Optional[Tuple[Any, Dict[str, Any]]]]],
*,
prepend: bool = False,
with_kwargs: bool = False) -> torch.utils.hooks.RemovableHandle
¶
hook: Union[Callable[[~T, Tuple[Any, ...]],
Optional[Any]],
Callable[[~T, Tuple[Any, ...],
Dict[str, Any]],
Optional[Tuple[Any, Dict[str, Any]]]]],
*,
prepend: bool = False,
with_kwargs: bool = False) -> torch.utils.hooks.RemovableHandle
Registers a forward pre-hook on the module.
register_full_backward_hook(self,
hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor],
Union[Tuple[torch.Tensor, ...], torch.Tensor]],
Union[None, Tuple[torch.Tensor, ...], torch.Tensor]],
prepend: bool = False) -> torch.utils.hooks.RemovableHandle
¶
hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor],
Union[Tuple[torch.Tensor, ...], torch.Tensor]],
Union[None, Tuple[torch.Tensor, ...], torch.Tensor]],
prepend: bool = False) -> torch.utils.hooks.RemovableHandle
Registers a backward hook on the module.
register_full_backward_pre_hook(self,
hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor]],
Union[None, Tuple[torch.Tensor, ...], torch.Tensor]],
prepend: bool = False) -> torch.utils.hooks.RemovableHandle
¶
hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor]],
Union[None, Tuple[torch.Tensor, ...], torch.Tensor]],
prepend: bool = False) -> torch.utils.hooks.RemovableHandle
Registers a backward pre-hook on the module.
register_load_state_dict_post_hook(self,
hook)
¶
hook)
Registers a post hook to be run after module's load_state_dict
is called.
register_module(self,
name: str,
module: Optional[ForwardRef('Module')]) -> None
¶
name: str,
module: Optional[ForwardRef('Module')]) -> None
Alias for :func:add_module
.
register_parameter(self,
name: str,
param: Optional[torch.nn.parameter.Parameter]) -> None
¶
name: str,
param: Optional[torch.nn.parameter.Parameter]) -> None
Adds a parameter to the module.
register_state_dict_pre_hook(self,
hook)
¶
hook)
These hooks will be called with arguments: self
, prefix
,
and keep_vars
before calling state_dict
on self
. The registered
hooks can be used to perform pre-processing before the state_dict
call is made.
requires_grad_(self: ~T,
requires_grad: bool = True) -> ~T
¶
requires_grad: bool = True) -> ~T
Change if autograd should record operations on parameters in this
module.
set_extra_state(self,
state: Any)
¶
state: Any)
This function is called from :func:load_state_dict
to handle any extra state
found within the state_dict
. Implement this function and a corresponding
:func:get_extra_state
for your module if you need to store extra state within its
state_dict
.
share_memory(self: ~T) -> ~T
¶
See :meth:torch.Tensor.share_memory_
state_dict(self,
*args,
destination=None,
prefix='',
keep_vars=False)
¶
*args,
destination=None,
prefix='',
keep_vars=False)
Returns a dictionary containing references to the whole state of the module.
to(self,
args,
*kwargs)
¶
args,
*kwargs)
Moves and/or casts the parameters and buffers.
to_empty(self: ~T,
*,
device: Union[str, torch.device],
recurse: bool = True) -> ~T
¶
*,
device: Union[str, torch.device],
recurse: bool = True) -> ~T
Moves the parameters and buffers to the specified device without copying storage.
train(self: ~T,
mode: bool = True) -> ~T
¶
mode: bool = True) -> ~T
Sets the module in training mode.
type(self: ~T,
dst_type: Union[torch.dtype, str]) -> ~T
¶
dst_type: Union[torch.dtype, str]) -> ~T
Casts all parameters and buffers to :attr:dst_type
.
xpu(self: ~T,
device: Union[int, torch.device, None] = None) -> ~T
¶
device: Union[int, torch.device, None] = None) -> ~T
Moves all model parameters and buffers to the XPU.
zero_grad(self,
set_to_none: bool = True) -> None
¶
set_to_none: bool = True) -> None
Resets gradients of all model parameters. See similar function
under :class:torch.optim.Optimizer
for more context.
CDCGANDiscriminator
¶
A discriminator for discerning real from generated images.
Output activation is Sigmoid.
__init__(self,
num_classes,
channels,
conv_dim,
image_size)
¶
num_classes,
channels,
conv_dim,
image_size)
add_module(self,
name: str,
module: Optional[ForwardRef('Module')]) -> None
¶
name: str,
module: Optional[ForwardRef('Module')]) -> None
Adds a child module to the current module.
apply(self: ~T,
fn: Callable[[ForwardRef('Module')], None]) -> ~T
¶
fn: Callable[[ForwardRef('Module')], None]) -> ~T
Applies fn
recursively to every submodule (as returned by .children()
)
as well as self. Typical use includes initializing the parameters of a model
(see also :ref:nn-init-doc
).
bfloat16(self: ~T) -> ~T
¶
Casts all floating point parameters and buffers to bfloat16
datatype.
buffers(self,
recurse: bool = True) -> Iterator[torch.Tensor]
¶
recurse: bool = True) -> Iterator[torch.Tensor]
Returns an iterator over module buffers.
children(self) -> Iterator[ForwardRef('Module')]
¶
Returns an iterator over immediate children modules.
compile(self,
args,
*kwargs)
¶
args,
*kwargs)
Compile this Module's forward using :func:torch.compile
.
cpu(self: ~T) -> ~T
¶
Moves all model parameters and buffers to the CPU.
cuda(self: ~T,
device: Union[int, torch.device, None] = None) -> ~T
¶
device: Union[int, torch.device, None] = None) -> ~T
Moves all model parameters and buffers to the GPU.
double(self: ~T) -> ~T
¶
Casts all floating point parameters and buffers to double
datatype.
eval(self: ~T) -> ~T
¶
Sets the module in evaluation mode.
extra_repr(self) -> str
¶
Set the extra representation of the module
float(self: ~T) -> ~T
¶
Casts all floating point parameters and buffers to float
datatype.
forward(self,
noise,
labels=typing.Optional[])
¶
noise,
labels=typing.Optional[
A forward function CDCGANDiscriminator.
get_buffer(self,
target: str) -> 'Tensor'
¶
target: str) -> 'Tensor'
Returns the buffer given by target
if it exists,
otherwise throws an error.
get_extra_state(self) -> Any
¶
Returns any extra state to include in the module's state_dict.
Implement this and a corresponding :func:set_extra_state
for your module
if you need to store extra state. This function is called when building the
module's state_dict()
.
get_parameter(self,
target: str) -> 'Parameter'
¶
target: str) -> 'Parameter'
Returns the parameter given by target
if it exists,
otherwise throws an error.
get_submodule(self,
target: str) -> 'Module'
¶
target: str) -> 'Module'
Returns the submodule given by target
if it exists,
otherwise throws an error.
half(self: ~T) -> ~T
¶
Casts all floating point parameters and buffers to half
datatype.
ipu(self: ~T,
device: Union[int, torch.device, None] = None) -> ~T
¶
device: Union[int, torch.device, None] = None) -> ~T
Moves all model parameters and buffers to the IPU.
load_state_dict(self,
state_dict: Mapping[str, Any],
strict: bool = True,
assign: bool = False)
¶
state_dict: Mapping[str, Any],
strict: bool = True,
assign: bool = False)
Copies parameters and buffers from :attr:state_dict
into
this module and its descendants. If :attr:strict
is True
, then
the keys of :attr:state_dict
must exactly match the keys returned
by this module's :meth:~torch.nn.Module.state_dict
function.
modules(self) -> Iterator[ForwardRef('Module')]
¶
Returns an iterator over all modules in the network.
named_buffers(self,
prefix: str = '',
recurse: bool = True,
remove_duplicate: bool = True) -> Iterator[Tuple[str, torch.Tensor]]
¶
prefix: str = '',
recurse: bool = True,
remove_duplicate: bool = True) -> Iterator[Tuple[str, torch.Tensor]]
Returns an iterator over module buffers, yielding both the
name of the buffer as well as the buffer itself.
named_children(self) -> Iterator[Tuple[str, ForwardRef('Module')]]
¶
Returns an iterator over immediate children modules, yielding both
the name of the module as well as the module itself.
named_modules(self,
memo: Optional[Set[ForwardRef('Module')]] = None,
prefix: str = '',
remove_duplicate: bool = True)
¶
memo: Optional[Set[ForwardRef('Module')]] = None,
prefix: str = '',
remove_duplicate: bool = True)
Returns an iterator over all modules in the network, yielding
both the name of the module as well as the module itself.
named_parameters(self,
prefix: str = '',
recurse: bool = True,
remove_duplicate: bool = True) -> Iterator[Tuple[str, torch.nn.parameter.Parameter]]
¶
prefix: str = '',
recurse: bool = True,
remove_duplicate: bool = True) -> Iterator[Tuple[str, torch.nn.parameter.Parameter]]
Returns an iterator over module parameters, yielding both the
name of the parameter as well as the parameter itself.
parameters(self,
recurse: bool = True) -> Iterator[torch.nn.parameter.Parameter]
¶
recurse: bool = True) -> Iterator[torch.nn.parameter.Parameter]
Returns an iterator over module parameters.
register_backward_hook(self,
hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor],
Union[Tuple[torch.Tensor, ...], torch.Tensor]],
Union[None, Tuple[torch.Tensor, ...], torch.Tensor]]) -> torch.utils.hooks.RemovableHandle
¶
hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor],
Union[Tuple[torch.Tensor, ...], torch.Tensor]],
Union[None, Tuple[torch.Tensor, ...], torch.Tensor]]) -> torch.utils.hooks.RemovableHandle
Registers a backward hook on the module.
register_buffer(self,
name: str,
tensor: Optional[torch.Tensor],
persistent: bool = True) -> None
¶
name: str,
tensor: Optional[torch.Tensor],
persistent: bool = True) -> None
Adds a buffer to the module.
register_forward_hook(self,
hook: Union[Callable[[~T, Tuple[Any, ...], Any],
Optional[Any]],
Callable[[~T, Tuple[Any, ...],
Dict[str, Any], Any],
Optional[Any]]],
*,
prepend: bool = False,
with_kwargs: bool = False,
always_call: bool = False) -> torch.utils.hooks.RemovableHandle
¶
hook: Union[Callable[[~T, Tuple[Any, ...], Any],
Optional[Any]],
Callable[[~T, Tuple[Any, ...],
Dict[str, Any], Any],
Optional[Any]]],
*,
prepend: bool = False,
with_kwargs: bool = False,
always_call: bool = False) -> torch.utils.hooks.RemovableHandle
Registers a forward hook on the module.
register_forward_pre_hook(self,
hook: Union[Callable[[~T, Tuple[Any, ...]],
Optional[Any]],
Callable[[~T, Tuple[Any, ...],
Dict[str, Any]],
Optional[Tuple[Any, Dict[str, Any]]]]],
*,
prepend: bool = False,
with_kwargs: bool = False) -> torch.utils.hooks.RemovableHandle
¶
hook: Union[Callable[[~T, Tuple[Any, ...]],
Optional[Any]],
Callable[[~T, Tuple[Any, ...],
Dict[str, Any]],
Optional[Tuple[Any, Dict[str, Any]]]]],
*,
prepend: bool = False,
with_kwargs: bool = False) -> torch.utils.hooks.RemovableHandle
Registers a forward pre-hook on the module.
register_full_backward_hook(self,
hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor],
Union[Tuple[torch.Tensor, ...], torch.Tensor]],
Union[None, Tuple[torch.Tensor, ...], torch.Tensor]],
prepend: bool = False) -> torch.utils.hooks.RemovableHandle
¶
hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor],
Union[Tuple[torch.Tensor, ...], torch.Tensor]],
Union[None, Tuple[torch.Tensor, ...], torch.Tensor]],
prepend: bool = False) -> torch.utils.hooks.RemovableHandle
Registers a backward hook on the module.
register_full_backward_pre_hook(self,
hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor]],
Union[None, Tuple[torch.Tensor, ...], torch.Tensor]],
prepend: bool = False) -> torch.utils.hooks.RemovableHandle
¶
hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor]],
Union[None, Tuple[torch.Tensor, ...], torch.Tensor]],
prepend: bool = False) -> torch.utils.hooks.RemovableHandle
Registers a backward pre-hook on the module.
register_load_state_dict_post_hook(self,
hook)
¶
hook)
Registers a post hook to be run after module's load_state_dict
is called.
register_module(self,
name: str,
module: Optional[ForwardRef('Module')]) -> None
¶
name: str,
module: Optional[ForwardRef('Module')]) -> None
Alias for :func:add_module
.
register_parameter(self,
name: str,
param: Optional[torch.nn.parameter.Parameter]) -> None
¶
name: str,
param: Optional[torch.nn.parameter.Parameter]) -> None
Adds a parameter to the module.
register_state_dict_pre_hook(self,
hook)
¶
hook)
These hooks will be called with arguments: self
, prefix
,
and keep_vars
before calling state_dict
on self
. The registered
hooks can be used to perform pre-processing before the state_dict
call is made.
requires_grad_(self: ~T,
requires_grad: bool = True) -> ~T
¶
requires_grad: bool = True) -> ~T
Change if autograd should record operations on parameters in this
module.
set_extra_state(self,
state: Any)
¶
state: Any)
This function is called from :func:load_state_dict
to handle any extra state
found within the state_dict
. Implement this function and a corresponding
:func:get_extra_state
for your module if you need to store extra state within its
state_dict
.
share_memory(self: ~T) -> ~T
¶
See :meth:torch.Tensor.share_memory_
state_dict(self,
*args,
destination=None,
prefix='',
keep_vars=False)
¶
*args,
destination=None,
prefix='',
keep_vars=False)
Returns a dictionary containing references to the whole state of the module.
to(self,
args,
*kwargs)
¶
args,
*kwargs)
Moves and/or casts the parameters and buffers.
to_empty(self: ~T,
*,
device: Union[str, torch.device],
recurse: bool = True) -> ~T
¶
*,
device: Union[str, torch.device],
recurse: bool = True) -> ~T
Moves the parameters and buffers to the specified device without copying storage.
train(self: ~T,
mode: bool = True) -> ~T
¶
mode: bool = True) -> ~T
Sets the module in training mode.
type(self: ~T,
dst_type: Union[torch.dtype, str]) -> ~T
¶
dst_type: Union[torch.dtype, str]) -> ~T
Casts all parameters and buffers to :attr:dst_type
.
xpu(self: ~T,
device: Union[int, torch.device, None] = None) -> ~T
¶
device: Union[int, torch.device, None] = None) -> ~T
Moves all model parameters and buffers to the XPU.
zero_grad(self,
set_to_none: bool = True) -> None
¶
set_to_none: bool = True) -> None
Resets gradients of all model parameters. See similar function
under :class:torch.optim.Optimizer
for more context.
Unsupervised Representation Learning With Deep Convolutional Generative Aversarial Networks by Radford & al (2015).