Contrastive Coherence Preserving Loss for Versatile Style Transfer (CCPL)¶
View colab tutorial | View source | 📰 Paper
NETWORK ARCHITECTURE : CCPL¶
Inspirations for CCPL: Regions denoted by red boxes from the first frame (RA or R'A)
have the same location with corresponding patches in the second frame wrapped in a yellow box (RB or R'B)
. RC and R'C
(in the blue boxes) are cropped from the first frame but their style aligns with RB and R'B
. The difference between two patches is denoted by D
(for example, D(RA, RB)). Mutual information between D(RA, RC)
and D(R'A, R'C)
, (D(RA, RB) and D(R'A, R'B))
is encouraged to be maximized to preserve consistency from the content source.
Details of CCPL: Cf
and Gf
represent the encoded features of a specific layer of encoder E
. ⊖
denotes vector subtraction, and SCE
stands for softmax cross-entropy. The yellow dotted lines illustrate how the positive pair is produced.
Example¶
# Augmentare Imports
import augmentare
from augmentare.methods.style_transfer import *
# Create CCPL method
vgg_path = '/home/vuong.nguyen/vuong/augmentare/augmentare/methods/style_transfer/model/vgg_normalised_ccpl.pth'
model = CCPL(training_mode= "pho", vgg_path=vgg_path, device=device)
# Training the CCPL network
loss_train = model.train_network(content_images, style_images, num_s=8, num_l=3, max_iter=50000,
content_weight=1.0, style_weight=10.0, ccp_weight=5.0)
# Styled image by CCPL
gen_image = model.ccpl_generate(
content_image, style_image,
alpha=1.0, interpolation= False, preserve_color= True
)
Notebooks¶
CCPL
¶
CCPL class.
__init__(self,
training_mode,
vgg_path,
device)
¶
training_mode,
vgg_path,
device)
add_module(self,
name: str,
module: Optional[ForwardRef('Module')]) -> None
¶
name: str,
module: Optional[ForwardRef('Module')]) -> None
Adds a child module to the current module.
apply(self: ~T,
fn: Callable[[ForwardRef('Module')], None]) -> ~T
¶
fn: Callable[[ForwardRef('Module')], None]) -> ~T
Applies fn
recursively to every submodule (as returned by .children()
)
as well as self. Typical use includes initializing the parameters of a model
(see also :ref:nn-init-doc
).
bfloat16(self: ~T) -> ~T
¶
Casts all floating point parameters and buffers to bfloat16
datatype.
buffers(self,
recurse: bool = True) -> Iterator[torch.Tensor]
¶
recurse: bool = True) -> Iterator[torch.Tensor]
Returns an iterator over module buffers.
ccpl_generate(self,
content_images,
style_images,
alpha=1.0,
interpolation=False,
preserve_color=True)
¶
content_images,
style_images,
alpha=1.0,
interpolation=False,
preserve_color=True)
A function that generates one image after training by CCPL method.
children(self) -> Iterator[ForwardRef('Module')]
¶
Returns an iterator over immediate children modules.
compile(self,
args,
*kwargs)
¶
args,
*kwargs)
Compile this Module's forward using :func:torch.compile
.
cpu(self: ~T) -> ~T
¶
Moves all model parameters and buffers to the CPU.
cuda(self: ~T,
device: Union[int, torch.device, None] = None) -> ~T
¶
device: Union[int, torch.device, None] = None) -> ~T
Moves all model parameters and buffers to the GPU.
double(self: ~T) -> ~T
¶
Casts all floating point parameters and buffers to double
datatype.
eval(self: ~T) -> ~T
¶
Sets the module in evaluation mode.
extra_repr(self) -> str
¶
Set the extra representation of the module
float(self: ~T) -> ~T
¶
Casts all floating point parameters and buffers to float
datatype.
_forward_unimplemented(self,
*input: Any) -> None
¶
*input: Any) -> None
Defines the computation performed at every call.
get_buffer(self,
target: str) -> 'Tensor'
¶
target: str) -> 'Tensor'
Returns the buffer given by target
if it exists,
otherwise throws an error.
get_extra_state(self) -> Any
¶
Returns any extra state to include in the module's state_dict.
Implement this and a corresponding :func:set_extra_state
for your module
if you need to store extra state. This function is called when building the
module's state_dict()
.
get_parameter(self,
target: str) -> 'Parameter'
¶
target: str) -> 'Parameter'
Returns the parameter given by target
if it exists,
otherwise throws an error.
get_submodule(self,
target: str) -> 'Module'
¶
target: str) -> 'Module'
Returns the submodule given by target
if it exists,
otherwise throws an error.
half(self: ~T) -> ~T
¶
Casts all floating point parameters and buffers to half
datatype.
ipu(self: ~T,
device: Union[int, torch.device, None] = None) -> ~T
¶
device: Union[int, torch.device, None] = None) -> ~T
Moves all model parameters and buffers to the IPU.
load_state_dict(self,
state_dict: Mapping[str, Any],
strict: bool = True,
assign: bool = False)
¶
state_dict: Mapping[str, Any],
strict: bool = True,
assign: bool = False)
Copies parameters and buffers from :attr:state_dict
into
this module and its descendants. If :attr:strict
is True
, then
the keys of :attr:state_dict
must exactly match the keys returned
by this module's :meth:~torch.nn.Module.state_dict
function.
modules(self) -> Iterator[ForwardRef('Module')]
¶
Returns an iterator over all modules in the network.
named_buffers(self,
prefix: str = '',
recurse: bool = True,
remove_duplicate: bool = True) -> Iterator[Tuple[str, torch.Tensor]]
¶
prefix: str = '',
recurse: bool = True,
remove_duplicate: bool = True) -> Iterator[Tuple[str, torch.Tensor]]
Returns an iterator over module buffers, yielding both the
name of the buffer as well as the buffer itself.
named_children(self) -> Iterator[Tuple[str, ForwardRef('Module')]]
¶
Returns an iterator over immediate children modules, yielding both
the name of the module as well as the module itself.
named_modules(self,
memo: Optional[Set[ForwardRef('Module')]] = None,
prefix: str = '',
remove_duplicate: bool = True)
¶
memo: Optional[Set[ForwardRef('Module')]] = None,
prefix: str = '',
remove_duplicate: bool = True)
Returns an iterator over all modules in the network, yielding
both the name of the module as well as the module itself.
named_parameters(self,
prefix: str = '',
recurse: bool = True,
remove_duplicate: bool = True) -> Iterator[Tuple[str, torch.nn.parameter.Parameter]]
¶
prefix: str = '',
recurse: bool = True,
remove_duplicate: bool = True) -> Iterator[Tuple[str, torch.nn.parameter.Parameter]]
Returns an iterator over module parameters, yielding both the
name of the parameter as well as the parameter itself.
parameters(self,
recurse: bool = True) -> Iterator[torch.nn.parameter.Parameter]
¶
recurse: bool = True) -> Iterator[torch.nn.parameter.Parameter]
Returns an iterator over module parameters.
register_backward_hook(self,
hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor],
Union[Tuple[torch.Tensor, ...], torch.Tensor]],
Union[None, Tuple[torch.Tensor, ...], torch.Tensor]]) -> torch.utils.hooks.RemovableHandle
¶
hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor],
Union[Tuple[torch.Tensor, ...], torch.Tensor]],
Union[None, Tuple[torch.Tensor, ...], torch.Tensor]]) -> torch.utils.hooks.RemovableHandle
Registers a backward hook on the module.
register_buffer(self,
name: str,
tensor: Optional[torch.Tensor],
persistent: bool = True) -> None
¶
name: str,
tensor: Optional[torch.Tensor],
persistent: bool = True) -> None
Adds a buffer to the module.
register_forward_hook(self,
hook: Union[Callable[[~T, Tuple[Any, ...], Any],
Optional[Any]],
Callable[[~T, Tuple[Any, ...],
Dict[str, Any], Any],
Optional[Any]]],
*,
prepend: bool = False,
with_kwargs: bool = False,
always_call: bool = False) -> torch.utils.hooks.RemovableHandle
¶
hook: Union[Callable[[~T, Tuple[Any, ...], Any],
Optional[Any]],
Callable[[~T, Tuple[Any, ...],
Dict[str, Any], Any],
Optional[Any]]],
*,
prepend: bool = False,
with_kwargs: bool = False,
always_call: bool = False) -> torch.utils.hooks.RemovableHandle
Registers a forward hook on the module.
register_forward_pre_hook(self,
hook: Union[Callable[[~T, Tuple[Any, ...]],
Optional[Any]],
Callable[[~T, Tuple[Any, ...],
Dict[str, Any]],
Optional[Tuple[Any, Dict[str, Any]]]]],
*,
prepend: bool = False,
with_kwargs: bool = False) -> torch.utils.hooks.RemovableHandle
¶
hook: Union[Callable[[~T, Tuple[Any, ...]],
Optional[Any]],
Callable[[~T, Tuple[Any, ...],
Dict[str, Any]],
Optional[Tuple[Any, Dict[str, Any]]]]],
*,
prepend: bool = False,
with_kwargs: bool = False) -> torch.utils.hooks.RemovableHandle
Registers a forward pre-hook on the module.
register_full_backward_hook(self,
hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor],
Union[Tuple[torch.Tensor, ...], torch.Tensor]],
Union[None, Tuple[torch.Tensor, ...], torch.Tensor]],
prepend: bool = False) -> torch.utils.hooks.RemovableHandle
¶
hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor],
Union[Tuple[torch.Tensor, ...], torch.Tensor]],
Union[None, Tuple[torch.Tensor, ...], torch.Tensor]],
prepend: bool = False) -> torch.utils.hooks.RemovableHandle
Registers a backward hook on the module.
register_full_backward_pre_hook(self,
hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor]],
Union[None, Tuple[torch.Tensor, ...], torch.Tensor]],
prepend: bool = False) -> torch.utils.hooks.RemovableHandle
¶
hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor]],
Union[None, Tuple[torch.Tensor, ...], torch.Tensor]],
prepend: bool = False) -> torch.utils.hooks.RemovableHandle
Registers a backward pre-hook on the module.
register_load_state_dict_post_hook(self,
hook)
¶
hook)
Registers a post hook to be run after module's load_state_dict
is called.
register_module(self,
name: str,
module: Optional[ForwardRef('Module')]) -> None
¶
name: str,
module: Optional[ForwardRef('Module')]) -> None
Alias for :func:add_module
.
register_parameter(self,
name: str,
param: Optional[torch.nn.parameter.Parameter]) -> None
¶
name: str,
param: Optional[torch.nn.parameter.Parameter]) -> None
Adds a parameter to the module.
register_state_dict_pre_hook(self,
hook)
¶
hook)
These hooks will be called with arguments: self
, prefix
,
and keep_vars
before calling state_dict
on self
. The registered
hooks can be used to perform pre-processing before the state_dict
call is made.
requires_grad_(self: ~T,
requires_grad: bool = True) -> ~T
¶
requires_grad: bool = True) -> ~T
Change if autograd should record operations on parameters in this
module.
set_extra_state(self,
state: Any)
¶
state: Any)
This function is called from :func:load_state_dict
to handle any extra state
found within the state_dict
. Implement this function and a corresponding
:func:get_extra_state
for your module if you need to store extra state within its
state_dict
.
share_memory(self: ~T) -> ~T
¶
See :meth:torch.Tensor.share_memory_
state_dict(self,
*args,
destination=None,
prefix='',
keep_vars=False)
¶
*args,
destination=None,
prefix='',
keep_vars=False)
Returns a dictionary containing references to the whole state of the module.
style_transfer(vgg_in,
decoder_in,
sct_in,
content,
style,
device,
alpha=1.0,
interpolation_weights=None)
¶
decoder_in,
sct_in,
content,
style,
device,
alpha=1.0,
interpolation_weights=None)
Style transfer function for styling the image input.
to(self,
args,
*kwargs)
¶
args,
*kwargs)
Moves and/or casts the parameters and buffers.
to_empty(self: ~T,
*,
device: Union[str, torch.device],
recurse: bool = True) -> ~T
¶
*,
device: Union[str, torch.device],
recurse: bool = True) -> ~T
Moves the parameters and buffers to the specified device without copying storage.
train(self: ~T,
mode: bool = True) -> ~T
¶
mode: bool = True) -> ~T
Sets the module in training mode.
train_network(self,
content_set,
style_set,
num_s,
num_l,
max_iter,
content_weight,
style_weight,
ccp_weight)
¶
content_set,
style_set,
num_s,
num_l,
max_iter,
content_weight,
style_weight,
ccp_weight)
Train the CCPL network and return the losses.
type(self: ~T,
dst_type: Union[torch.dtype, str]) -> ~T
¶
dst_type: Union[torch.dtype, str]) -> ~T
Casts all parameters and buffers to :attr:dst_type
.
xpu(self: ~T,
device: Union[int, torch.device, None] = None) -> ~T
¶
device: Union[int, torch.device, None] = None) -> ~T
Moves all model parameters and buffers to the XPU.
zero_grad(self,
set_to_none: bool = True) -> None
¶
set_to_none: bool = True) -> None
Resets gradients of all model parameters. See similar function
under :class:torch.optim.Optimizer
for more context.
Contrastive Coherence Preserving Loss for Versatile Style Transfer by Zijie Wu & al (2022).