Skip to content

Adaptive Instance Normalization (AdaIN)

View colab tutorial | View source | 📰 Paper

NETWORK ARCHITECTURE : AdaIN

They use the first few layers of a fixed VGG-19 network to encode the content and style images. An AdaIN layer is used to perform style transfer in the feature space. A decoder is learned to invert the AdaIN output to the image spaces. They use the same VGG encoder to compute a content loss Lc and a style loss Ls.

Picture

Example

# Augmentare Imports
import augmentare
from augmentare.methods.style_transfer import *

# Create AdaIN network
model = ADAIN(device)

# Optimizers
optimizer = Adam(model.parameters(), lr=1e-4)

# Training the AdaIN network
loss_train = model.train_network(
            num_epochs=49,
            train_loader= train_loader,
            optimizer= optimizer
        )

# Styled image by AdaIN
gen_image = model.adain_generate(content_tensor, style_tensor, alpha=1.0)

Notebooks

ADAIN

Adaptive Instance Normalization (AdaIN) that aligns the mean and variance of the content features with those of the style features. It achieves speed comparable to the fastest existing approach, without the restriction to a pre-defined set of styles. In addition, this approach allows flexible user controls such as content-style trade-off, style interpolation, color & spatial controls, all using a single feed-forward neural network.

__init__(self,
         device)

adain_generate(self,
               content_image,
               style_image,
               alpha=1.0)

A function that generates one image after training by AdaIn method.


add_module(self,
           name: str,
           module: Optional[ForwardRef('Module')]) -> None

Adds a child module to the current module.


apply(self: ~T,
      fn: Callable[[ForwardRef('Module')], None]) -> ~T

Applies fn recursively to every submodule (as returned by .children()) as well as self. Typical use includes initializing the parameters of a model (see also :ref:nn-init-doc).


bfloat16(self: ~T) -> ~T

Casts all floating point parameters and buffers to bfloat16 datatype.


buffers(self,
        recurse: bool = True) -> Iterator[torch.Tensor]

Returns an iterator over module buffers.


children(self) -> Iterator[ForwardRef('Module')]

Returns an iterator over immediate children modules.


compile(self,
        args,
       
*kwargs)

Compile this Module's forward using :func:torch.compile.


cpu(self: ~T) -> ~T

Moves all model parameters and buffers to the CPU.


cuda(self: ~T,
     device: Union[int, torch.device, None] = None) -> ~T

Moves all model parameters and buffers to the GPU.


double(self: ~T) -> ~T

Casts all floating point parameters and buffers to double datatype.


eval(self: ~T) -> ~T

Sets the module in evaluation mode.


extra_repr(self) -> str

Set the extra representation of the module


float(self: ~T) -> ~T

Casts all floating point parameters and buffers to float datatype.


_forward_unimplemented(self,
                       *input: Any) -> None

Defines the computation performed at every call.


get_buffer(self,
           target: str) -> 'Tensor'

Returns the buffer given by target if it exists, otherwise throws an error.


get_extra_state(self) -> Any

Returns any extra state to include in the module's state_dict. Implement this and a corresponding :func:set_extra_state for your module if you need to store extra state. This function is called when building the module's state_dict().


get_parameter(self,
              target: str) -> 'Parameter'

Returns the parameter given by target if it exists, otherwise throws an error.


get_submodule(self,
              target: str) -> 'Module'

Returns the submodule given by target if it exists, otherwise throws an error.


half(self: ~T) -> ~T

Casts all floating point parameters and buffers to half datatype.


ipu(self: ~T,
    device: Union[int, torch.device, None] = None) -> ~T

Moves all model parameters and buffers to the IPU.


load_state_dict(self,
                state_dict: Mapping[str, Any],
                strict: bool = True,
                assign: bool = False)

Copies parameters and buffers from :attr:state_dict into this module and its descendants. If :attr:strict is True, then the keys of :attr:state_dict must exactly match the keys returned by this module's :meth:~torch.nn.Module.state_dict function.


modules(self) -> Iterator[ForwardRef('Module')]

Returns an iterator over all modules in the network.


named_buffers(self,
              prefix: str = '',
              recurse: bool = True,
              remove_duplicate: bool = True) -> Iterator[Tuple[str, torch.Tensor]]

Returns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.


named_children(self) -> Iterator[Tuple[str, ForwardRef('Module')]]

Returns an iterator over immediate children modules, yielding both the name of the module as well as the module itself.


named_modules(self,
              memo: Optional[Set[ForwardRef('Module')]] = None,
              prefix: str = '',
              remove_duplicate: bool = True)

Returns an iterator over all modules in the network, yielding both the name of the module as well as the module itself.


named_parameters(self,
                 prefix: str = '',
                 recurse: bool = True,
                 remove_duplicate: bool = True) -> Iterator[Tuple[str, torch.nn.parameter.Parameter]]

Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.


parameters(self,
           recurse: bool = True) -> Iterator[torch.nn.parameter.Parameter]

Returns an iterator over module parameters.


register_backward_hook(self,
                       hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor],
                       Union[Tuple[torch.Tensor, ...], torch.Tensor]],
                       Union[None, Tuple[torch.Tensor, ...], torch.Tensor]]) -> torch.utils.hooks.RemovableHandle

Registers a backward hook on the module.


register_buffer(self,
                name: str,
                tensor: Optional[torch.Tensor],
                persistent: bool = True) -> None

Adds a buffer to the module.


register_forward_hook(self,
                      hook: Union[Callable[[~T, Tuple[Any, ...], Any],
                      Optional[Any]],
                      Callable[[~T, Tuple[Any, ...],
                      Dict[str, Any], Any],
                      Optional[Any]]],
                      *,
                      prepend: bool = False,
                      with_kwargs: bool = False,
                      always_call: bool = False) -> torch.utils.hooks.RemovableHandle

Registers a forward hook on the module.


register_forward_pre_hook(self,
                          hook: Union[Callable[[~T, Tuple[Any, ...]],
                          Optional[Any]],
                          Callable[[~T, Tuple[Any, ...],
                          Dict[str, Any]],
                          Optional[Tuple[Any, Dict[str, Any]]]]],
                          *,
                          prepend: bool = False,
                          with_kwargs: bool = False) -> torch.utils.hooks.RemovableHandle

Registers a forward pre-hook on the module.


register_full_backward_hook(self,
                            hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor],
                            Union[Tuple[torch.Tensor, ...], torch.Tensor]],
                            Union[None, Tuple[torch.Tensor, ...], torch.Tensor]],
                            prepend: bool = False) -> torch.utils.hooks.RemovableHandle

Registers a backward hook on the module.


register_full_backward_pre_hook(self,
                                hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor]],
                                Union[None, Tuple[torch.Tensor, ...], torch.Tensor]],
                                prepend: bool = False) -> torch.utils.hooks.RemovableHandle

Registers a backward pre-hook on the module.


register_load_state_dict_post_hook(self,
                                   hook)

Registers a post hook to be run after module's load_state_dict is called.


register_module(self,
                name: str,
                module: Optional[ForwardRef('Module')]) -> None

Alias for :func:add_module.


register_parameter(self,
                   name: str,
                   param: Optional[torch.nn.parameter.Parameter]) -> None

Adds a parameter to the module.


register_state_dict_pre_hook(self,
                             hook)

These hooks will be called with arguments: self, prefix, and keep_vars before calling state_dict on self. The registered hooks can be used to perform pre-processing before the state_dict call is made.


requires_grad_(self: ~T,
               requires_grad: bool = True) -> ~T

Change if autograd should record operations on parameters in this module.


set_extra_state(self,
                state: Any)

This function is called from :func:load_state_dict to handle any extra state found within the state_dict. Implement this function and a corresponding :func:get_extra_state for your module if you need to store extra state within its state_dict.


share_memory(self: ~T) -> ~T

See :meth:torch.Tensor.share_memory_


state_dict(self,
           *args,
           destination=None,
           prefix='',
           keep_vars=False)

Returns a dictionary containing references to the whole state of the module.


to(self,
   args,
  
*kwargs)

Moves and/or casts the parameters and buffers.


to_empty(self: ~T,
         *,
         device: Union[str, torch.device],
         recurse: bool = True) -> ~T

Moves the parameters and buffers to the specified device without copying storage.


train(self: ~T,
      mode: bool = True) -> ~T

Sets the module in training mode.


train_network(self,
              num_epochs,
              train_loader,
              optimizer,
              alpha=1.0,
              lamb=10)

Train the AdaIn network and return the losses.


type(self: ~T,
     dst_type: Union[torch.dtype, str]) -> ~T

Casts all parameters and buffers to :attr:dst_type.


xpu(self: ~T,
    device: Union[int, torch.device, None] = None) -> ~T

Moves all model parameters and buffers to the XPU.


zero_grad(self,
          set_to_none: bool = True) -> None

Resets gradients of all model parameters. See similar function under :class:torch.optim.Optimizer for more context.


Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization by Xuan Huang & Serge Belongie (2017).