719 lines
30 KiB
Python
719 lines
30 KiB
Python
"""
|
|
The network architectures is based on the implementation of CycleGAN and CUT
|
|
Original PyTorch repo of CycleGAN: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
|
|
Original PyTorch repo of CUT: https://github.com/taesungp/contrastive-unpaired-translation
|
|
Original CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf
|
|
Original CUT paper: https://arxiv.org/pdf/2007.15651.pdf
|
|
We use the network architecture for our default modal image translation
|
|
"""
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import functools
|
|
import numpy as np
|
|
from torch.nn import init
|
|
import math
|
|
class PixelNorm(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, input):
|
|
return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
|
|
|
|
def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
|
|
assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32
|
|
half_dim = embedding_dim // 2
|
|
# magic number 10000 is from transformers
|
|
emb = math.log(max_positions) / (half_dim - 1)
|
|
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
|
|
emb = timesteps.float()[:, None] * emb[None, :]
|
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
|
if embedding_dim % 2 == 1: # zero pad
|
|
emb = F.pad(emb, (0, 1), mode='constant')
|
|
assert emb.shape == (timesteps.shape[0], embedding_dim)
|
|
return emb
|
|
|
|
##################################################################################
|
|
# Discriminator
|
|
##################################################################################
|
|
class D_NLayersMulti(nn.Module):
|
|
def __init__(self, input_nc, ndf=64, n_layers=3,
|
|
norm_layer=nn.BatchNorm2d, num_D=1):
|
|
super(D_NLayersMulti, self).__init__()
|
|
# st()
|
|
self.num_D = num_D
|
|
if num_D == 1:
|
|
layers = self.get_layers(input_nc, ndf, n_layers, norm_layer)
|
|
self.model = nn.Sequential(*layers)
|
|
else:
|
|
layers = self.get_layers(input_nc, ndf, n_layers, norm_layer)
|
|
self.add_module("model_0", nn.Sequential(*layers))
|
|
self.down = nn.AvgPool2d(3, stride=2, padding=[
|
|
1, 1], count_include_pad=False)
|
|
for i in range(1, num_D):
|
|
ndf_i = int(round(ndf / (2**i)))
|
|
layers = self.get_layers(input_nc, ndf_i, n_layers, norm_layer)
|
|
self.add_module("model_%d" % i, nn.Sequential(*layers))
|
|
|
|
def get_layers(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
|
|
kw = 4
|
|
padw = 1
|
|
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw,
|
|
stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
|
|
|
|
nf_mult = 1
|
|
nf_mult_prev = 1
|
|
for n in range(1, n_layers):
|
|
nf_mult_prev = nf_mult
|
|
nf_mult = min(2**n, 8)
|
|
sequence += [
|
|
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
|
|
kernel_size=kw, stride=2, padding=padw),
|
|
norm_layer(ndf * nf_mult),
|
|
nn.LeakyReLU(0.2, True)
|
|
]
|
|
|
|
nf_mult_prev = nf_mult
|
|
nf_mult = min(2**n_layers, 8)
|
|
sequence += [
|
|
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
|
|
kernel_size=kw, stride=1, padding=padw),
|
|
norm_layer(ndf * nf_mult),
|
|
nn.LeakyReLU(0.2, True)
|
|
]
|
|
|
|
sequence += [nn.Conv2d(ndf * nf_mult, 1,
|
|
kernel_size=kw, stride=1, padding=padw)]
|
|
|
|
return sequence
|
|
|
|
def forward(self, input):
|
|
if self.num_D == 1:
|
|
return self.model(input)
|
|
result = []
|
|
down = input
|
|
for i in range(self.num_D):
|
|
model = getattr(self, "model_%d" % i)
|
|
result.append(model(down))
|
|
if i != self.num_D - 1:
|
|
down = self.down(down)
|
|
return result
|
|
|
|
|
|
|
|
|
|
class ConvBlock_cond(nn.Module):
|
|
def __init__(self, in_channel, out_channel,t_emb_dim, kernel_size=4,stride=1,padding=1,norm_layer=None,downsample=True,use_bias=None):
|
|
super().__init__()
|
|
self.downsample=downsample
|
|
self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride, padding=padding, bias=use_bias)
|
|
|
|
if norm_layer is not None:
|
|
self.use_norm =True
|
|
self.norm = norm_layer(out_channel)
|
|
else:
|
|
self.use_norm = False
|
|
self.act = nn.LeakyReLU(0.2, True)
|
|
self.down = Downsample(out_channel)
|
|
|
|
self.dense= nn.Linear(t_emb_dim, out_channel)
|
|
def forward(self, input,t_emb):
|
|
out = self.conv1(input)
|
|
out += self.dense(t_emb)[..., None, None]
|
|
if self.use_norm:
|
|
out = self.norm(out)
|
|
out = self.act(out)
|
|
if self.downsample:
|
|
out = self.down(out)
|
|
|
|
return out
|
|
|
|
class NLayerDiscriminator_ncsn(nn.Module):
|
|
"""Defines a PatchGAN discriminator"""
|
|
|
|
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, no_antialias=False):
|
|
"""Construct a PatchGAN discriminator
|
|
|
|
Parameters:
|
|
input_nc (int) -- the number of channels in input images
|
|
ndf (int) -- the number of filters in the last conv layer
|
|
n_layers (int) -- the number of conv layers in the discriminator
|
|
norm_layer -- normalization layer
|
|
"""
|
|
super(NLayerDiscriminator_ncsn, self).__init__()
|
|
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
|
|
use_bias = norm_layer.func == nn.InstanceNorm2d
|
|
else:
|
|
use_bias = norm_layer == nn.InstanceNorm2d
|
|
self.model_main = nn.ModuleList()
|
|
kw = 4
|
|
padw = 1
|
|
if no_antialias:
|
|
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
|
|
else:
|
|
self.model_main.append(ConvBlock_cond(input_nc, ndf, 4*ndf,kernel_size=kw, stride=1, padding=padw,use_bias=use_bias))
|
|
|
|
nf_mult = 1
|
|
nf_mult_prev = 1
|
|
|
|
for n in range(1, n_layers): # gradually increase the number of filters
|
|
nf_mult_prev = nf_mult
|
|
nf_mult = min(2 ** n, 8)
|
|
if no_antialias:
|
|
sequence += [
|
|
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
|
|
norm_layer(ndf * nf_mult),
|
|
nn.LeakyReLU(0.2, True)]
|
|
else:
|
|
self.model_main.append(
|
|
ConvBlock_cond(ndf * nf_mult_prev, ndf * nf_mult, 4*ndf,kernel_size=kw, stride=1, padding=padw,use_bias=use_bias,norm_layer=norm_layer)
|
|
|
|
)
|
|
|
|
nf_mult_prev = nf_mult
|
|
nf_mult = min(2 ** n_layers, 8)
|
|
self.model_main.append(
|
|
ConvBlock_cond(ndf * nf_mult_prev, ndf * nf_mult,4*ndf, kernel_size=kw, stride=1, padding=padw,use_bias=use_bias,norm_layer=norm_layer,downsample=False)
|
|
|
|
)
|
|
self.final_conv =nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
|
|
self.t_embed = TimestepEmbedding(
|
|
embedding_dim=4*ndf,
|
|
hidden_dim=4*ndf,
|
|
output_dim=4*ndf,
|
|
act=nn.LeakyReLU(0.2),
|
|
)
|
|
|
|
def forward(self, input,t_emb,input2=None):
|
|
"""Standard forward."""
|
|
t_emb = self.t_embed(t_emb)
|
|
if input2 is not None:
|
|
out = torch.cat([input,input2],dim=1)
|
|
else:
|
|
|
|
out = input
|
|
for layer in self.model_main:
|
|
out = layer(out,t_emb)
|
|
|
|
return self.final_conv(out)
|
|
|
|
class PixelDiscriminator(nn.Module):
|
|
"""Defines a 1x1 PatchGAN discriminator (pixelGAN)"""
|
|
|
|
def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d):
|
|
"""Construct a 1x1 PatchGAN discriminator
|
|
|
|
Parameters:
|
|
input_nc (int) -- the number of channels in input images
|
|
ndf (int) -- the number of filters in the last conv layer
|
|
norm_layer -- normalization layer
|
|
"""
|
|
super(PixelDiscriminator, self).__init__()
|
|
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
|
|
use_bias = norm_layer.func == nn.InstanceNorm2d
|
|
else:
|
|
use_bias = norm_layer == nn.InstanceNorm2d
|
|
|
|
self.net = [
|
|
nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
|
|
nn.LeakyReLU(0.2, True),
|
|
nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
|
|
norm_layer(ndf * 2),
|
|
nn.LeakyReLU(0.2, True),
|
|
nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]
|
|
|
|
self.net = nn.Sequential(*self.net)
|
|
|
|
def forward(self, input):
|
|
"""Standard forward."""
|
|
return self.net(input)
|
|
|
|
|
|
##################################################################################
|
|
# Generator
|
|
##################################################################################
|
|
|
|
class TimestepEmbedding(nn.Module):
|
|
def __init__(self, embedding_dim, hidden_dim, output_dim, act=nn.LeakyReLU(0.2)):
|
|
super().__init__()
|
|
|
|
self.embedding_dim = embedding_dim
|
|
self.output_dim = output_dim
|
|
self.hidden_dim = hidden_dim
|
|
|
|
self.main = nn.Sequential(
|
|
nn.Linear(embedding_dim, hidden_dim),
|
|
nn.LeakyReLU(0.2),
|
|
nn.Linear(hidden_dim, output_dim),
|
|
nn.LeakyReLU(0.2),
|
|
# EqualLinear(hidden_dim, output_dim,bias_init = 0, activation='fused_lrelu')
|
|
)
|
|
|
|
def forward(self, temp):
|
|
temb = get_timestep_embedding(temp, self.embedding_dim)
|
|
temb = self.main(temb)
|
|
return temb
|
|
|
|
def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
|
|
assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32
|
|
half_dim = embedding_dim // 2
|
|
# magic number 10000 is from transformers
|
|
emb = math.log(max_positions) / (half_dim - 1)
|
|
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
|
|
emb = timesteps.float()[:, None] * emb[None, :]
|
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
|
if embedding_dim % 2 == 1: # zero pad
|
|
emb = F.pad(emb, (0, 1), mode='constant')
|
|
assert emb.shape == (timesteps.shape[0], embedding_dim)
|
|
return emb
|
|
|
|
class AdaptiveLayer(nn.Module):
|
|
def __init__(self, in_channel, style_dim):
|
|
super().__init__()
|
|
|
|
self.style_net = nn.Linear(style_dim, in_channel * 2)
|
|
|
|
self.style_net.bias.data[:in_channel] = 1
|
|
self.style_net.bias.data[in_channel:] = 0
|
|
|
|
def forward(self, input, style):
|
|
|
|
style = self.style_net(style).unsqueeze(2).unsqueeze(3)
|
|
gamma, beta = style.chunk(2, 1)
|
|
|
|
out = gamma * input + beta
|
|
|
|
return out
|
|
|
|
class ResnetGenerator_ncsn(nn.Module):
|
|
"""Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
|
|
|
|
We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
|
|
"""
|
|
|
|
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=9,
|
|
padding_type='reflect', no_antialias=False, no_antialias_up=False, opt=None):
|
|
"""Construct a Resnet-based generator
|
|
|
|
Parameters:
|
|
input_nc (int) -- the number of channels in input images
|
|
output_nc (int) -- the number of channels in output images
|
|
ngf (int) -- the number of filters in the last conv layer
|
|
norm_layer -- normalization layer
|
|
use_dropout (bool) -- if use dropout layers
|
|
n_blocks (int) -- the number of ResNet blocks
|
|
padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
|
|
"""
|
|
assert(n_blocks >= 0)
|
|
super(ResnetGenerator_ncsn, self).__init__()
|
|
self.opt = opt
|
|
if type(norm_layer) == functools.partial:
|
|
use_bias = norm_layer.func == nn.InstanceNorm2d
|
|
else:
|
|
use_bias = norm_layer == nn.InstanceNorm2d
|
|
|
|
model = [nn.ReflectionPad2d(3),
|
|
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
|
|
norm_layer(ngf),
|
|
nn.ReLU(True)]
|
|
self.ngf = ngf
|
|
n_downsampling = 2
|
|
for i in range(n_downsampling): # add downsampling layers
|
|
mult = 2 ** i
|
|
if no_antialias:
|
|
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
|
|
norm_layer(ngf * mult * 2),
|
|
nn.ReLU(True)]
|
|
else:
|
|
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=1, padding=1, bias=use_bias),
|
|
norm_layer(ngf * mult * 2),
|
|
nn.ReLU(True),
|
|
Downsample(ngf * mult * 2)
|
|
# nn.AvgPool2d(kernel_size=2, stride=2)
|
|
]
|
|
self.model_res = nn.ModuleList()
|
|
mult = 2 ** n_downsampling
|
|
for i in range(n_blocks): # add ResNet blocks
|
|
|
|
self.model_res += [ResnetBlock_cond(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias,temb_dim=4*ngf,z_dim=4*ngf)]
|
|
|
|
model_upsample = []
|
|
for i in range(n_downsampling): # add upsampling layers
|
|
mult = 2 ** (n_downsampling - i)
|
|
if no_antialias_up:
|
|
model_upsample += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias),
|
|
norm_layer(int(ngf * mult / 2)),
|
|
nn.ReLU(True)]
|
|
else:
|
|
model_upsample += [
|
|
Upsample(ngf * mult),
|
|
# nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
|
|
nn.Conv2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=1, padding=1, bias=use_bias),
|
|
norm_layer(int(ngf * mult / 2)),
|
|
nn.ReLU(True)]
|
|
model_upsample += [nn.ReflectionPad2d(3)]
|
|
model_upsample += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
|
|
model_upsample += [nn.Tanh()]
|
|
|
|
self.model = nn.Sequential(*model)
|
|
self.model_upsample = nn.Sequential(*model_upsample)
|
|
mapping_layers = [PixelNorm(),
|
|
nn.Linear(self.ngf*4, self.ngf*4),
|
|
nn.LeakyReLU(0.2)]
|
|
for _ in range(opt.n_mlp):
|
|
mapping_layers.append(nn.Linear(self.ngf*4, self.ngf*4))
|
|
mapping_layers.append(nn.LeakyReLU(0.2))
|
|
self.z_transform = nn.Sequential(*mapping_layers)
|
|
modules_emb = []
|
|
modules_emb += [nn.Linear(self.ngf,self.ngf*4)]
|
|
|
|
nn.init.zeros_(modules_emb[-1].bias)
|
|
modules_emb += [nn.LeakyReLU(0.2)]
|
|
modules_emb += [nn.Linear(self.ngf*4,self.ngf*4)]
|
|
|
|
nn.init.zeros_(modules_emb[-1].bias)
|
|
modules_emb += [nn.LeakyReLU(0.2)]
|
|
self.time_embed = nn.Sequential(*modules_emb)
|
|
|
|
def forward(self, x, time_cond,z,layers=[], encode_only=False):
|
|
z_embed = self.z_transform(z)
|
|
# print(z_embed.shape)
|
|
temb = get_timestep_embedding(time_cond, self.ngf)
|
|
time_embed = self.time_embed(temb)
|
|
if len(layers) > 0:
|
|
feat = x
|
|
feats = []
|
|
for layer_id, layer in enumerate(self.model):
|
|
feat = layer(feat)
|
|
if layer_id in layers:
|
|
feats.append(feat)
|
|
|
|
for layer_id, layer in enumerate(self.model_res):
|
|
feat = layer(feat,time_embed,z_embed)
|
|
if layer_id+len(self.model) in layers:
|
|
feats.append(feat)
|
|
if layer_id+len(self.model) == layers[-1] and encode_only:
|
|
return feats
|
|
return feat, feats
|
|
else:
|
|
|
|
out = self.model(x)
|
|
for layer in self.model_res:
|
|
out = layer(out,time_embed,z_embed)
|
|
out = self.model_upsample(out)
|
|
return out
|
|
##################################################################################
|
|
# Basic Blocks
|
|
##################################################################################
|
|
class ResnetBlock(nn.Module):
|
|
"""Define a Resnet block"""
|
|
|
|
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
|
|
"""Initialize the Resnet block
|
|
|
|
A resnet block is a conv block with skip connections
|
|
We construct a conv block with build_conv_block function,
|
|
and implement skip connections in <forward> function.
|
|
Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
|
|
"""
|
|
super(ResnetBlock, self).__init__()
|
|
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
|
|
|
|
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
|
|
"""Construct a convolutional block.
|
|
|
|
Parameters:
|
|
dim (int) -- the number of channels in the conv layer.
|
|
padding_type (str) -- the name of padding layer: reflect | replicate | zero
|
|
norm_layer -- normalization layer
|
|
use_dropout (bool) -- if use dropout layers.
|
|
use_bias (bool) -- if the conv layer uses bias or not
|
|
|
|
Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
|
|
"""
|
|
conv_block = []
|
|
p = 0
|
|
if padding_type == 'reflect':
|
|
conv_block += [nn.ReflectionPad2d(1)]
|
|
elif padding_type == 'replicate':
|
|
conv_block += [nn.ReplicationPad2d(1)]
|
|
elif padding_type == 'zero':
|
|
p = 1
|
|
else:
|
|
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
|
|
|
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
|
|
if use_dropout:
|
|
conv_block += [nn.Dropout(0.5)]
|
|
|
|
p = 0
|
|
if padding_type == 'reflect':
|
|
conv_block += [nn.ReflectionPad2d(1)]
|
|
elif padding_type == 'replicate':
|
|
conv_block += [nn.ReplicationPad2d(1)]
|
|
elif padding_type == 'zero':
|
|
p = 1
|
|
else:
|
|
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
|
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]
|
|
|
|
return nn.Sequential(*conv_block)
|
|
|
|
def forward(self, x):
|
|
"""Forward function (with skip connections)"""
|
|
out = x + self.conv_block(x) # add skip connections
|
|
return out
|
|
|
|
class ResnetBlock_cond(nn.Module):
|
|
"""Define a Resnet block"""
|
|
|
|
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias,temb_dim,z_dim):
|
|
"""Initialize the Resnet block
|
|
|
|
A resnet block is a conv block with skip connections
|
|
We construct a conv block with build_conv_block function,
|
|
and implement skip connections in <forward> function.
|
|
Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
|
|
"""
|
|
super(ResnetBlock_cond, self).__init__()
|
|
self.conv_block,self.adaptive,self.conv_fin = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias,temb_dim,z_dim)
|
|
|
|
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias,temb_dim,z_dim):
|
|
"""Construct a convolutional block.
|
|
|
|
Parameters:
|
|
dim (int) -- the number of channels in the conv layer.
|
|
padding_type (str) -- the name of padding layer: reflect | replicate | zero
|
|
norm_layer -- normalization layer
|
|
use_dropout (bool) -- if use dropout layers.
|
|
use_bias (bool) -- if the conv layer uses bias or not
|
|
|
|
Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
|
|
"""
|
|
|
|
self.conv_block = nn.ModuleList()
|
|
self.conv_fin = nn.ModuleList()
|
|
p = 0
|
|
if padding_type == 'reflect':
|
|
self.conv_block += [nn.ReflectionPad2d(1)]
|
|
elif padding_type == 'replicate':
|
|
self.conv_block += [nn.ReplicationPad2d(1)]
|
|
elif padding_type == 'zero':
|
|
p = 1
|
|
else:
|
|
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
|
|
|
self.conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]
|
|
self.adaptive = AdaptiveLayer(dim,z_dim)
|
|
self.conv_fin += [nn.ReLU(True)]
|
|
if use_dropout:
|
|
self.conv_fin += [nn.Dropout(0.5)]
|
|
|
|
p = 0
|
|
if padding_type == 'reflect':
|
|
self.conv_fin += [nn.ReflectionPad2d(1)]
|
|
elif padding_type == 'replicate':
|
|
self.conv_fin += [nn.ReplicationPad2d(1)]
|
|
elif padding_type == 'zero':
|
|
p = 1
|
|
else:
|
|
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
|
self.conv_fin += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]
|
|
|
|
self.Dense_time = nn.Linear(temb_dim, dim)
|
|
# self.Dense_time.weight.data = default_init()(self.Dense_time.weight.data.shape)
|
|
nn.init.zeros_(self.Dense_time.bias)
|
|
|
|
self.style = nn.Linear(z_dim, dim * 2)
|
|
|
|
self.style.bias.data[:dim] = 1
|
|
self.style.bias.data[dim:] = 0
|
|
|
|
return self.conv_block,self.adaptive,self.conv_fin
|
|
|
|
def forward(self, x,time_cond,z):
|
|
|
|
time_input = self.Dense_time(time_cond)
|
|
for n,layer in enumerate(self.conv_block):
|
|
out = layer(x)
|
|
if n==0:
|
|
out += time_input[:, :, None, None]
|
|
out = self.adaptive(out,z)
|
|
for layer in self.conv_fin:
|
|
out = layer(out)
|
|
"""Forward function (with skip connections)"""
|
|
out = x + out # add skip connections
|
|
return out
|
|
###############################################################################
|
|
# Helper Functions
|
|
###############################################################################
|
|
def get_filter(filt_size=3):
|
|
if(filt_size == 1):
|
|
a = np.array([1., ])
|
|
elif(filt_size == 2):
|
|
a = np.array([1., 1.])
|
|
elif(filt_size == 3):
|
|
a = np.array([1., 2., 1.])
|
|
elif(filt_size == 4):
|
|
a = np.array([1., 3., 3., 1.])
|
|
elif(filt_size == 5):
|
|
a = np.array([1., 4., 6., 4., 1.])
|
|
elif(filt_size == 6):
|
|
a = np.array([1., 5., 10., 10., 5., 1.])
|
|
elif(filt_size == 7):
|
|
a = np.array([1., 6., 15., 20., 15., 6., 1.])
|
|
|
|
filt = torch.Tensor(a[:, None] * a[None, :])
|
|
filt = filt / torch.sum(filt)
|
|
|
|
return filt
|
|
|
|
|
|
class Downsample(nn.Module):
|
|
def __init__(self, channels, pad_type='reflect', filt_size=3, stride=2, pad_off=0):
|
|
super(Downsample, self).__init__()
|
|
self.filt_size = filt_size
|
|
self.pad_off = pad_off
|
|
self.pad_sizes = [int(1. * (filt_size - 1) / 2), int(np.ceil(1. * (filt_size - 1) / 2)), int(1. * (filt_size - 1) / 2), int(np.ceil(1. * (filt_size - 1) / 2))]
|
|
self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes]
|
|
self.stride = stride
|
|
self.off = int((self.stride - 1) / 2.)
|
|
self.channels = channels
|
|
|
|
filt = get_filter(filt_size=self.filt_size)
|
|
self.register_buffer('filt', filt[None, None, :, :].repeat((self.channels, 1, 1, 1)))
|
|
|
|
self.pad = get_pad_layer(pad_type)(self.pad_sizes)
|
|
|
|
def forward(self, inp):
|
|
if(self.filt_size == 1):
|
|
if(self.pad_off == 0):
|
|
return inp[:, :, ::self.stride, ::self.stride]
|
|
else:
|
|
return self.pad(inp)[:, :, ::self.stride, ::self.stride]
|
|
else:
|
|
return F.conv2d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1])
|
|
|
|
|
|
class Upsample2(nn.Module):
|
|
def __init__(self, scale_factor, mode='nearest'):
|
|
super().__init__()
|
|
self.factor = scale_factor
|
|
self.mode = mode
|
|
|
|
def forward(self, x):
|
|
return torch.nn.functional.interpolate(x, scale_factor=self.factor, mode=self.mode)
|
|
|
|
|
|
class Upsample(nn.Module):
|
|
def __init__(self, channels, pad_type='repl', filt_size=4, stride=2):
|
|
super(Upsample, self).__init__()
|
|
self.filt_size = filt_size
|
|
self.filt_odd = np.mod(filt_size, 2) == 1
|
|
self.pad_size = int((filt_size - 1) / 2)
|
|
self.stride = stride
|
|
self.off = int((self.stride - 1) / 2.)
|
|
self.channels = channels
|
|
|
|
filt = get_filter(filt_size=self.filt_size) * (stride**2)
|
|
self.register_buffer('filt', filt[None, None, :, :].repeat((self.channels, 1, 1, 1)))
|
|
|
|
self.pad = get_pad_layer(pad_type)([1, 1, 1, 1])
|
|
|
|
def forward(self, inp):
|
|
ret_val = F.conv_transpose2d(self.pad(inp), self.filt, stride=self.stride, padding=1 + self.pad_size, groups=inp.shape[1])[:, :, 1:, 1:]
|
|
if(self.filt_odd):
|
|
return ret_val
|
|
else:
|
|
return ret_val[:, :, :-1, :-1]
|
|
|
|
|
|
def get_pad_layer(pad_type):
|
|
if(pad_type in ['refl', 'reflect']):
|
|
PadLayer = nn.ReflectionPad2d
|
|
elif(pad_type in ['repl', 'replicate']):
|
|
PadLayer = nn.ReplicationPad2d
|
|
elif(pad_type == 'zero'):
|
|
PadLayer = nn.ZeroPad2d
|
|
else:
|
|
print('Pad type [%s] not recognized' % pad_type)
|
|
return PadLayer
|
|
|
|
|
|
class Identity(nn.Module):
|
|
def forward(self, x):
|
|
return x
|
|
|
|
|
|
def get_norm_layer(norm_type='instance'):
|
|
"""Return a normalization layer
|
|
|
|
Parameters:
|
|
norm_type (str) -- the name of the normalization layer: batch | instance | none
|
|
|
|
For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
|
|
For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
|
|
"""
|
|
if norm_type == 'batch':
|
|
norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
|
|
elif norm_type == 'instance':
|
|
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
|
|
elif norm_type == 'none':
|
|
def norm_layer(x): return Identity()
|
|
else:
|
|
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
|
|
return norm_layer
|
|
|
|
|
|
def init_weights(net, init_type='normal', init_gain=0.02, debug=False):
|
|
"""Initialize network weights.
|
|
|
|
Parameters:
|
|
net (network) -- network to be initialized
|
|
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
|
|
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
|
|
|
|
We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
|
|
work better for some applications. Feel free to try yourself.
|
|
"""
|
|
def init_func(m): # define the initialization function
|
|
classname = m.__class__.__name__
|
|
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
|
|
if debug:
|
|
print(classname)
|
|
if init_type == 'normal':
|
|
init.normal_(m.weight.data, 0.0, init_gain)
|
|
elif init_type == 'xavier':
|
|
init.xavier_normal_(m.weight.data, gain=init_gain)
|
|
elif init_type == 'kaiming':
|
|
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
|
elif init_type == 'orthogonal':
|
|
init.orthogonal_(m.weight.data, gain=init_gain)
|
|
else:
|
|
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
|
|
if hasattr(m, 'bias') and m.bias is not None:
|
|
init.constant_(m.bias.data, 0.0)
|
|
elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
|
|
init.normal_(m.weight.data, 1.0, init_gain)
|
|
init.constant_(m.bias.data, 0.0)
|
|
|
|
print('initialize network with %s' % init_type)
|
|
net.apply(init_func) # apply the initialization function <init_func>
|
|
|
|
|
|
def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[], debug=False, initialize_weights=True):
|
|
"""Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
|
|
Parameters:
|
|
net (network) -- the network to be initialized
|
|
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
|
|
gain (float) -- scaling factor for normal, xavier and orthogonal.
|
|
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
|
|
|
|
Return an initialized network.
|
|
"""
|
|
if len(gpu_ids) > 0:
|
|
assert(torch.cuda.is_available())
|
|
net.to(gpu_ids[0])
|
|
if initialize_weights:
|
|
init_weights(net, init_type, init_gain=init_gain, debug=debug)
|
|
return net |