Skip to content

Style Flow

View colab tutorial | View source | 📰 Paper

NETWORK ARCHITECTURE : Style Flow

With the invertible network structure, StyleFlow first projects the input images into the feature space in the forward, while the backward uses the SAN module to perform the fixed feature transformation of the content, and then projects them into image space.

Picture

The blue arrows indicate the forward pass to extract the features, while the red arrows represent the backward pass to reconstruct the images. StyleFlow consists of a series of reversible blocks, where each block has three components: the Squeeze module, the Flow module, and the SAN module. A pre-trained VGG encoder is used for domain feature extraction.

  • Squeeze module: The Squeeze operation serves as an interconnection between blocks for reordering features. It reduces the spatial size of the feature map by first dividing the input feature into small patches along the spatial dimension and then concatenating the patches along the channel dimension.
  • Flow module: The Flow module consists of three reversible transformations: Actnorm Layer, 1x1 Convolution Layer, and Coupling Layer.
  • SAN module: SAN module to perform fixed content feature transformation. Fixed content transfer means that content information before and after transformation should be retained.

Example

# Augmentare Imports
import augmentare
from augmentare.methods.style_transfer import *

# Create StyleFlow method
vgg_path = '/home/vuong.nguyen/vuong/augmentare/augmentare/methods/style_transfer/model/vgg_normalised_flow.pth'
model = STYLEFLOW(in_channel=3, n_flow=15, n_block=2, vgg_path=vgg_path,
                            affine=False, conv_lu=False, keep_ratio=0.8, device=device)

# Training the StyleFlow network
loss_train = model.train_network(train_loader=train_loader,
            content_weight = 0.1, style_weight=1, type_loss="TVLoss"
        )

# Styled image by StyleFlow
gen_image = model.style_flow_generate(
    content_image= content_image,
    style_image= style_image
)

Notebooks

STYLEFLOW

StyleFlow class.

__init__(self,
         in_channel,
         n_flow,
         n_block,
         vgg_path,
         affine=True,
         conv_lu=True,
         keep_ratio=0.8,
         device='cpu')

add_module(self,
           name: str,
           module: Optional[ForwardRef('Module')]) -> None

Adds a child module to the current module.


apply(self: ~T,
      fn: Callable[[ForwardRef('Module')], None]) -> ~T

Applies fn recursively to every submodule (as returned by .children()) as well as self. Typical use includes initializing the parameters of a model (see also :ref:nn-init-doc).


bfloat16(self: ~T) -> ~T

Casts all floating point parameters and buffers to bfloat16 datatype.


buffers(self,
        recurse: bool = True) -> Iterator[torch.Tensor]

Returns an iterator over module buffers.


children(self) -> Iterator[ForwardRef('Module')]

Returns an iterator over immediate children modules.


compile(self,
        args,
       
*kwargs)

Compile this Module's forward using :func:torch.compile.


cpu(self: ~T) -> ~T

Moves all model parameters and buffers to the CPU.


cuda(self: ~T,
     device: Union[int, torch.device, None] = None) -> ~T

Moves all model parameters and buffers to the GPU.


double(self: ~T) -> ~T

Casts all floating point parameters and buffers to double datatype.


eval(self: ~T) -> ~T

Sets the module in evaluation mode.


extra_repr(self) -> str

Set the extra representation of the module


float(self: ~T) -> ~T

Casts all floating point parameters and buffers to float datatype.


_forward_unimplemented(self,
                       *input: Any) -> None

Defines the computation performed at every call.


get_buffer(self,
           target: str) -> 'Tensor'

Returns the buffer given by target if it exists, otherwise throws an error.


get_extra_state(self) -> Any

Returns any extra state to include in the module's state_dict. Implement this and a corresponding :func:set_extra_state for your module if you need to store extra state. This function is called when building the module's state_dict().


get_parameter(self,
              target: str) -> 'Parameter'

Returns the parameter given by target if it exists, otherwise throws an error.


get_submodule(self,
              target: str) -> 'Module'

Returns the submodule given by target if it exists, otherwise throws an error.


half(self: ~T) -> ~T

Casts all floating point parameters and buffers to half datatype.


ipu(self: ~T,
    device: Union[int, torch.device, None] = None) -> ~T

Moves all model parameters and buffers to the IPU.


load_state_dict(self,
                state_dict: Mapping[str, Any],
                strict: bool = True,
                assign: bool = False)

Copies parameters and buffers from :attr:state_dict into this module and its descendants. If :attr:strict is True, then the keys of :attr:state_dict must exactly match the keys returned by this module's :meth:~torch.nn.Module.state_dict function.


modules(self) -> Iterator[ForwardRef('Module')]

Returns an iterator over all modules in the network.


named_buffers(self,
              prefix: str = '',
              recurse: bool = True,
              remove_duplicate: bool = True) -> Iterator[Tuple[str, torch.Tensor]]

Returns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.


named_children(self) -> Iterator[Tuple[str, ForwardRef('Module')]]

Returns an iterator over immediate children modules, yielding both the name of the module as well as the module itself.


named_modules(self,
              memo: Optional[Set[ForwardRef('Module')]] = None,
              prefix: str = '',
              remove_duplicate: bool = True)

Returns an iterator over all modules in the network, yielding both the name of the module as well as the module itself.


named_parameters(self,
                 prefix: str = '',
                 recurse: bool = True,
                 remove_duplicate: bool = True) -> Iterator[Tuple[str, torch.nn.parameter.Parameter]]

Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.


parameters(self,
           recurse: bool = True) -> Iterator[torch.nn.parameter.Parameter]

Returns an iterator over module parameters.


register_backward_hook(self,
                       hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor],
                       Union[Tuple[torch.Tensor, ...], torch.Tensor]],
                       Union[None, Tuple[torch.Tensor, ...], torch.Tensor]]) -> torch.utils.hooks.RemovableHandle

Registers a backward hook on the module.


register_buffer(self,
                name: str,
                tensor: Optional[torch.Tensor],
                persistent: bool = True) -> None

Adds a buffer to the module.


register_forward_hook(self,
                      hook: Union[Callable[[~T, Tuple[Any, ...], Any],
                      Optional[Any]],
                      Callable[[~T, Tuple[Any, ...],
                      Dict[str, Any], Any],
                      Optional[Any]]],
                      *,
                      prepend: bool = False,
                      with_kwargs: bool = False,
                      always_call: bool = False) -> torch.utils.hooks.RemovableHandle

Registers a forward hook on the module.


register_forward_pre_hook(self,
                          hook: Union[Callable[[~T, Tuple[Any, ...]],
                          Optional[Any]],
                          Callable[[~T, Tuple[Any, ...],
                          Dict[str, Any]],
                          Optional[Tuple[Any, Dict[str, Any]]]]],
                          *,
                          prepend: bool = False,
                          with_kwargs: bool = False) -> torch.utils.hooks.RemovableHandle

Registers a forward pre-hook on the module.


register_full_backward_hook(self,
                            hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor],
                            Union[Tuple[torch.Tensor, ...], torch.Tensor]],
                            Union[None, Tuple[torch.Tensor, ...], torch.Tensor]],
                            prepend: bool = False) -> torch.utils.hooks.RemovableHandle

Registers a backward hook on the module.


register_full_backward_pre_hook(self,
                                hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor]],
                                Union[None, Tuple[torch.Tensor, ...], torch.Tensor]],
                                prepend: bool = False) -> torch.utils.hooks.RemovableHandle

Registers a backward pre-hook on the module.


register_load_state_dict_post_hook(self,
                                   hook)

Registers a post hook to be run after module's load_state_dict is called.


register_module(self,
                name: str,
                module: Optional[ForwardRef('Module')]) -> None

Alias for :func:add_module.


register_parameter(self,
                   name: str,
                   param: Optional[torch.nn.parameter.Parameter]) -> None

Adds a parameter to the module.


register_state_dict_pre_hook(self,
                             hook)

These hooks will be called with arguments: self, prefix, and keep_vars before calling state_dict on self. The registered hooks can be used to perform pre-processing before the state_dict call is made.


requires_grad_(self: ~T,
               requires_grad: bool = True) -> ~T

Change if autograd should record operations on parameters in this module.


set_extra_state(self,
                state: Any)

This function is called from :func:load_state_dict to handle any extra state found within the state_dict. Implement this function and a corresponding :func:get_extra_state for your module if you need to store extra state within its state_dict.


share_memory(self: ~T) -> ~T

See :meth:torch.Tensor.share_memory_


state_dict(self,
           *args,
           destination=None,
           prefix='',
           keep_vars=False)

Returns a dictionary containing references to the whole state of the module.


style_flow_generate(self,
                    content_image,
                    style_image)

A function that generates one image after training by StyleFlow method.


to(self,
   args,
  
*kwargs)

Moves and/or casts the parameters and buffers.


to_empty(self: ~T,
         *,
         device: Union[str, torch.device],
         recurse: bool = True) -> ~T

Moves the parameters and buffers to the specified device without copying storage.


train(self: ~T,
      mode: bool = True) -> ~T

Sets the module in training mode.


train_network(self,
              train_loader,
              content_weight,
              style_weight,
              type_loss=None)

Train the StyleFlow network and return the losses.


type(self: ~T,
     dst_type: Union[torch.dtype, str]) -> ~T

Casts all parameters and buffers to :attr:dst_type.


xpu(self: ~T,
    device: Union[int, torch.device, None] = None) -> ~T

Moves all model parameters and buffers to the XPU.


zero_grad(self,
          set_to_none: bool = True) -> None

Resets gradients of all model parameters. See similar function under :class:torch.optim.Optimizer for more context.


StyleFlow For Content-Fixed Image to Image Translation by Weichen Fan & al (2022).