""" The network architectures is based on the implementation of CycleGAN and CUT Original PyTorch repo of CycleGAN: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix Original PyTorch repo of CUT: https://github.com/taesungp/contrastive-unpaired-translation Original CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf Original CUT paper: https://arxiv.org/pdf/2007.15651.pdf We use the network architecture for our default modal image translation """ import torch import torch.nn as nn import torch.nn.functional as F import functools import numpy as np from torch.nn import init import math class PixelNorm(nn.Module): def __init__(self): super().__init__() def forward(self, input): return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000): assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32 half_dim = embedding_dim // 2 # magic number 10000 is from transformers emb = math.log(max_positions) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb) emb = timesteps.float()[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad emb = F.pad(emb, (0, 1), mode='constant') assert emb.shape == (timesteps.shape[0], embedding_dim) return emb ################################################################################## # Discriminator ################################################################################## class D_NLayersMulti(nn.Module): def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, num_D=1): super(D_NLayersMulti, self).__init__() # st() self.num_D = num_D if num_D == 1: layers = self.get_layers(input_nc, ndf, n_layers, norm_layer) self.model = nn.Sequential(*layers) else: layers = self.get_layers(input_nc, ndf, n_layers, norm_layer) self.add_module("model_0", nn.Sequential(*layers)) self.down = nn.AvgPool2d(3, stride=2, padding=[ 1, 1], count_include_pad=False) for i in range(1, num_D): ndf_i = int(round(ndf / (2**i))) layers = self.get_layers(input_nc, ndf_i, n_layers, norm_layer) self.add_module("model_%d" % i, nn.Sequential(*layers)) def get_layers(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d): kw = 4 padw = 1 sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] nf_mult = 1 nf_mult_prev = 1 for n in range(1, n_layers): nf_mult_prev = nf_mult nf_mult = min(2**n, 8) sequence += [ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True) ] nf_mult_prev = nf_mult nf_mult = min(2**n_layers, 8) sequence += [ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True) ] sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] return sequence def forward(self, input): if self.num_D == 1: return self.model(input) result = [] down = input for i in range(self.num_D): model = getattr(self, "model_%d" % i) result.append(model(down)) if i != self.num_D - 1: down = self.down(down) return result class ConvBlock_cond(nn.Module): def __init__(self, in_channel, out_channel,t_emb_dim, kernel_size=4,stride=1,padding=1,norm_layer=None,downsample=True,use_bias=None): super().__init__() self.downsample=downsample self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride, padding=padding, bias=use_bias) if norm_layer is not None: self.use_norm =True self.norm = norm_layer(out_channel) else: self.use_norm = False self.act = nn.LeakyReLU(0.2, True) self.down = Downsample(out_channel) self.dense= nn.Linear(t_emb_dim, out_channel) def forward(self, input,t_emb): out = self.conv1(input) out += self.dense(t_emb)[..., None, None] if self.use_norm: out = self.norm(out) out = self.act(out) if self.downsample: out = self.down(out) return out class NLayerDiscriminator_ncsn(nn.Module): """Defines a PatchGAN discriminator""" def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, no_antialias=False): """Construct a PatchGAN discriminator Parameters: input_nc (int) -- the number of channels in input images ndf (int) -- the number of filters in the last conv layer n_layers (int) -- the number of conv layers in the discriminator norm_layer -- normalization layer """ super(NLayerDiscriminator_ncsn, self).__init__() if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters use_bias = norm_layer.func == nn.InstanceNorm2d else: use_bias = norm_layer == nn.InstanceNorm2d self.model_main = nn.ModuleList() kw = 4 padw = 1 if no_antialias: sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] else: self.model_main.append(ConvBlock_cond(input_nc, ndf, 4*ndf,kernel_size=kw, stride=1, padding=padw,use_bias=use_bias)) nf_mult = 1 nf_mult_prev = 1 for n in range(1, n_layers): # gradually increase the number of filters nf_mult_prev = nf_mult nf_mult = min(2 ** n, 8) if no_antialias: sequence += [ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True)] else: self.model_main.append( ConvBlock_cond(ndf * nf_mult_prev, ndf * nf_mult, 4*ndf,kernel_size=kw, stride=1, padding=padw,use_bias=use_bias,norm_layer=norm_layer) ) nf_mult_prev = nf_mult nf_mult = min(2 ** n_layers, 8) self.model_main.append( ConvBlock_cond(ndf * nf_mult_prev, ndf * nf_mult,4*ndf, kernel_size=kw, stride=1, padding=padw,use_bias=use_bias,norm_layer=norm_layer,downsample=False) ) self.final_conv =nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) self.t_embed = TimestepEmbedding( embedding_dim=4*ndf, hidden_dim=4*ndf, output_dim=4*ndf, act=nn.LeakyReLU(0.2), ) def forward(self, input,t_emb,input2=None): """Standard forward.""" t_emb = self.t_embed(t_emb) if input2 is not None: out = torch.cat([input,input2],dim=1) else: out = input for layer in self.model_main: out = layer(out,t_emb) return self.final_conv(out) class PixelDiscriminator(nn.Module): """Defines a 1x1 PatchGAN discriminator (pixelGAN)""" def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d): """Construct a 1x1 PatchGAN discriminator Parameters: input_nc (int) -- the number of channels in input images ndf (int) -- the number of filters in the last conv layer norm_layer -- normalization layer """ super(PixelDiscriminator, self).__init__() if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters use_bias = norm_layer.func == nn.InstanceNorm2d else: use_bias = norm_layer == nn.InstanceNorm2d self.net = [ nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0), nn.LeakyReLU(0.2, True), nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias), norm_layer(ndf * 2), nn.LeakyReLU(0.2, True), nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)] self.net = nn.Sequential(*self.net) def forward(self, input): """Standard forward.""" return self.net(input) ################################################################################## # Generator ################################################################################## class TimestepEmbedding(nn.Module): def __init__(self, embedding_dim, hidden_dim, output_dim, act=nn.LeakyReLU(0.2)): super().__init__() self.embedding_dim = embedding_dim self.output_dim = output_dim self.hidden_dim = hidden_dim self.main = nn.Sequential( nn.Linear(embedding_dim, hidden_dim), nn.LeakyReLU(0.2), nn.Linear(hidden_dim, output_dim), nn.LeakyReLU(0.2), # EqualLinear(hidden_dim, output_dim,bias_init = 0, activation='fused_lrelu') ) def forward(self, temp): temb = get_timestep_embedding(temp, self.embedding_dim) temb = self.main(temb) return temb def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000): assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32 half_dim = embedding_dim // 2 # magic number 10000 is from transformers emb = math.log(max_positions) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb) emb = timesteps.float()[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad emb = F.pad(emb, (0, 1), mode='constant') assert emb.shape == (timesteps.shape[0], embedding_dim) return emb class AdaptiveLayer(nn.Module): def __init__(self, in_channel, style_dim): super().__init__() self.style_net = nn.Linear(style_dim, in_channel * 2) self.style_net.bias.data[:in_channel] = 1 self.style_net.bias.data[in_channel:] = 0 def forward(self, input, style): style = self.style_net(style).unsqueeze(2).unsqueeze(3) gamma, beta = style.chunk(2, 1) out = gamma * input + beta return out class ResnetGenerator_ncsn(nn.Module): """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations. We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style) """ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=9, padding_type='reflect', no_antialias=False, no_antialias_up=False, opt=None): """Construct a Resnet-based generator Parameters: input_nc (int) -- the number of channels in input images output_nc (int) -- the number of channels in output images ngf (int) -- the number of filters in the last conv layer norm_layer -- normalization layer use_dropout (bool) -- if use dropout layers n_blocks (int) -- the number of ResNet blocks padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero """ assert(n_blocks >= 0) super(ResnetGenerator_ncsn, self).__init__() self.opt = opt if type(norm_layer) == functools.partial: use_bias = norm_layer.func == nn.InstanceNorm2d else: use_bias = norm_layer == nn.InstanceNorm2d model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias), norm_layer(ngf), nn.ReLU(True)] self.ngf = ngf n_downsampling = 2 for i in range(n_downsampling): # add downsampling layers mult = 2 ** i if no_antialias: model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias), norm_layer(ngf * mult * 2), nn.ReLU(True)] else: model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=1, padding=1, bias=use_bias), norm_layer(ngf * mult * 2), nn.ReLU(True), Downsample(ngf * mult * 2) # nn.AvgPool2d(kernel_size=2, stride=2) ] self.model_res = nn.ModuleList() mult = 2 ** n_downsampling for i in range(n_blocks): # add ResNet blocks self.model_res += [ResnetBlock_cond(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias,temb_dim=4*ngf,z_dim=4*ngf)] model_upsample = [] for i in range(n_downsampling): # add upsampling layers mult = 2 ** (n_downsampling - i) if no_antialias_up: model_upsample += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias), norm_layer(int(ngf * mult / 2)), nn.ReLU(True)] else: model_upsample += [ Upsample(ngf * mult), # nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), nn.Conv2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=1, padding=1, bias=use_bias), norm_layer(int(ngf * mult / 2)), nn.ReLU(True)] model_upsample += [nn.ReflectionPad2d(3)] model_upsample += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] model_upsample += [nn.Tanh()] self.model = nn.Sequential(*model) self.model_upsample = nn.Sequential(*model_upsample) mapping_layers = [PixelNorm(), nn.Linear(self.ngf*4, self.ngf*4), nn.LeakyReLU(0.2)] for _ in range(opt.n_mlp): mapping_layers.append(nn.Linear(self.ngf*4, self.ngf*4)) mapping_layers.append(nn.LeakyReLU(0.2)) self.z_transform = nn.Sequential(*mapping_layers) modules_emb = [] modules_emb += [nn.Linear(self.ngf,self.ngf*4)] nn.init.zeros_(modules_emb[-1].bias) modules_emb += [nn.LeakyReLU(0.2)] modules_emb += [nn.Linear(self.ngf*4,self.ngf*4)] nn.init.zeros_(modules_emb[-1].bias) modules_emb += [nn.LeakyReLU(0.2)] self.time_embed = nn.Sequential(*modules_emb) def forward(self, x, time_cond,z,layers=[], encode_only=False): z_embed = self.z_transform(z) # print(z_embed.shape) temb = get_timestep_embedding(time_cond, self.ngf) time_embed = self.time_embed(temb) if len(layers) > 0: feat = x feats = [] for layer_id, layer in enumerate(self.model): feat = layer(feat) if layer_id in layers: feats.append(feat) for layer_id, layer in enumerate(self.model_res): feat = layer(feat,time_embed,z_embed) if layer_id+len(self.model) in layers: feats.append(feat) if layer_id+len(self.model) == layers[-1] and encode_only: return feats return feat, feats else: out = self.model(x) for layer in self.model_res: out = layer(out,time_embed,z_embed) out = self.model_upsample(out) return out ################################################################################## # Basic Blocks ################################################################################## class ResnetBlock(nn.Module): """Define a Resnet block""" def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): """Initialize the Resnet block A resnet block is a conv block with skip connections We construct a conv block with build_conv_block function, and implement skip connections in function. Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf """ super(ResnetBlock, self).__init__() self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): """Construct a convolutional block. Parameters: dim (int) -- the number of channels in the conv layer. padding_type (str) -- the name of padding layer: reflect | replicate | zero norm_layer -- normalization layer use_dropout (bool) -- if use dropout layers. use_bias (bool) -- if the conv layer uses bias or not Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU)) """ conv_block = [] p = 0 if padding_type == 'reflect': conv_block += [nn.ReflectionPad2d(1)] elif padding_type == 'replicate': conv_block += [nn.ReplicationPad2d(1)] elif padding_type == 'zero': p = 1 else: raise NotImplementedError('padding [%s] is not implemented' % padding_type) conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)] if use_dropout: conv_block += [nn.Dropout(0.5)] p = 0 if padding_type == 'reflect': conv_block += [nn.ReflectionPad2d(1)] elif padding_type == 'replicate': conv_block += [nn.ReplicationPad2d(1)] elif padding_type == 'zero': p = 1 else: raise NotImplementedError('padding [%s] is not implemented' % padding_type) conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)] return nn.Sequential(*conv_block) def forward(self, x): """Forward function (with skip connections)""" out = x + self.conv_block(x) # add skip connections return out class ResnetBlock_cond(nn.Module): """Define a Resnet block""" def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias,temb_dim,z_dim): """Initialize the Resnet block A resnet block is a conv block with skip connections We construct a conv block with build_conv_block function, and implement skip connections in function. Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf """ super(ResnetBlock_cond, self).__init__() self.conv_block,self.adaptive,self.conv_fin = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias,temb_dim,z_dim) def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias,temb_dim,z_dim): """Construct a convolutional block. Parameters: dim (int) -- the number of channels in the conv layer. padding_type (str) -- the name of padding layer: reflect | replicate | zero norm_layer -- normalization layer use_dropout (bool) -- if use dropout layers. use_bias (bool) -- if the conv layer uses bias or not Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU)) """ self.conv_block = nn.ModuleList() self.conv_fin = nn.ModuleList() p = 0 if padding_type == 'reflect': self.conv_block += [nn.ReflectionPad2d(1)] elif padding_type == 'replicate': self.conv_block += [nn.ReplicationPad2d(1)] elif padding_type == 'zero': p = 1 else: raise NotImplementedError('padding [%s] is not implemented' % padding_type) self.conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)] self.adaptive = AdaptiveLayer(dim,z_dim) self.conv_fin += [nn.ReLU(True)] if use_dropout: self.conv_fin += [nn.Dropout(0.5)] p = 0 if padding_type == 'reflect': self.conv_fin += [nn.ReflectionPad2d(1)] elif padding_type == 'replicate': self.conv_fin += [nn.ReplicationPad2d(1)] elif padding_type == 'zero': p = 1 else: raise NotImplementedError('padding [%s] is not implemented' % padding_type) self.conv_fin += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)] self.Dense_time = nn.Linear(temb_dim, dim) # self.Dense_time.weight.data = default_init()(self.Dense_time.weight.data.shape) nn.init.zeros_(self.Dense_time.bias) self.style = nn.Linear(z_dim, dim * 2) self.style.bias.data[:dim] = 1 self.style.bias.data[dim:] = 0 return self.conv_block,self.adaptive,self.conv_fin def forward(self, x,time_cond,z): time_input = self.Dense_time(time_cond) for n,layer in enumerate(self.conv_block): out = layer(x) if n==0: out += time_input[:, :, None, None] out = self.adaptive(out,z) for layer in self.conv_fin: out = layer(out) """Forward function (with skip connections)""" out = x + out # add skip connections return out ############################################################################### # Helper Functions ############################################################################### def get_filter(filt_size=3): if(filt_size == 1): a = np.array([1., ]) elif(filt_size == 2): a = np.array([1., 1.]) elif(filt_size == 3): a = np.array([1., 2., 1.]) elif(filt_size == 4): a = np.array([1., 3., 3., 1.]) elif(filt_size == 5): a = np.array([1., 4., 6., 4., 1.]) elif(filt_size == 6): a = np.array([1., 5., 10., 10., 5., 1.]) elif(filt_size == 7): a = np.array([1., 6., 15., 20., 15., 6., 1.]) filt = torch.Tensor(a[:, None] * a[None, :]) filt = filt / torch.sum(filt) return filt class Downsample(nn.Module): def __init__(self, channels, pad_type='reflect', filt_size=3, stride=2, pad_off=0): super(Downsample, self).__init__() self.filt_size = filt_size self.pad_off = pad_off self.pad_sizes = [int(1. * (filt_size - 1) / 2), int(np.ceil(1. * (filt_size - 1) / 2)), int(1. * (filt_size - 1) / 2), int(np.ceil(1. * (filt_size - 1) / 2))] self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes] self.stride = stride self.off = int((self.stride - 1) / 2.) self.channels = channels filt = get_filter(filt_size=self.filt_size) self.register_buffer('filt', filt[None, None, :, :].repeat((self.channels, 1, 1, 1))) self.pad = get_pad_layer(pad_type)(self.pad_sizes) def forward(self, inp): if(self.filt_size == 1): if(self.pad_off == 0): return inp[:, :, ::self.stride, ::self.stride] else: return self.pad(inp)[:, :, ::self.stride, ::self.stride] else: return F.conv2d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1]) class Upsample2(nn.Module): def __init__(self, scale_factor, mode='nearest'): super().__init__() self.factor = scale_factor self.mode = mode def forward(self, x): return torch.nn.functional.interpolate(x, scale_factor=self.factor, mode=self.mode) class Upsample(nn.Module): def __init__(self, channels, pad_type='repl', filt_size=4, stride=2): super(Upsample, self).__init__() self.filt_size = filt_size self.filt_odd = np.mod(filt_size, 2) == 1 self.pad_size = int((filt_size - 1) / 2) self.stride = stride self.off = int((self.stride - 1) / 2.) self.channels = channels filt = get_filter(filt_size=self.filt_size) * (stride**2) self.register_buffer('filt', filt[None, None, :, :].repeat((self.channels, 1, 1, 1))) self.pad = get_pad_layer(pad_type)([1, 1, 1, 1]) def forward(self, inp): ret_val = F.conv_transpose2d(self.pad(inp), self.filt, stride=self.stride, padding=1 + self.pad_size, groups=inp.shape[1])[:, :, 1:, 1:] if(self.filt_odd): return ret_val else: return ret_val[:, :, :-1, :-1] def get_pad_layer(pad_type): if(pad_type in ['refl', 'reflect']): PadLayer = nn.ReflectionPad2d elif(pad_type in ['repl', 'replicate']): PadLayer = nn.ReplicationPad2d elif(pad_type == 'zero'): PadLayer = nn.ZeroPad2d else: print('Pad type [%s] not recognized' % pad_type) return PadLayer class Identity(nn.Module): def forward(self, x): return x def get_norm_layer(norm_type='instance'): """Return a normalization layer Parameters: norm_type (str) -- the name of the normalization layer: batch | instance | none For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev). For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics. """ if norm_type == 'batch': norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) elif norm_type == 'instance': norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) elif norm_type == 'none': def norm_layer(x): return Identity() else: raise NotImplementedError('normalization layer [%s] is not found' % norm_type) return norm_layer def init_weights(net, init_type='normal', init_gain=0.02, debug=False): """Initialize network weights. Parameters: net (network) -- network to be initialized init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal init_gain (float) -- scaling factor for normal, xavier and orthogonal. We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might work better for some applications. Feel free to try yourself. """ def init_func(m): # define the initialization function classname = m.__class__.__name__ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): if debug: print(classname) if init_type == 'normal': init.normal_(m.weight.data, 0.0, init_gain) elif init_type == 'xavier': init.xavier_normal_(m.weight.data, gain=init_gain) elif init_type == 'kaiming': init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') elif init_type == 'orthogonal': init.orthogonal_(m.weight.data, gain=init_gain) else: raise NotImplementedError('initialization method [%s] is not implemented' % init_type) if hasattr(m, 'bias') and m.bias is not None: init.constant_(m.bias.data, 0.0) elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. init.normal_(m.weight.data, 1.0, init_gain) init.constant_(m.bias.data, 0.0) print('initialize network with %s' % init_type) net.apply(init_func) # apply the initialization function def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[], debug=False, initialize_weights=True): """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights Parameters: net (network) -- the network to be initialized init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal gain (float) -- scaling factor for normal, xavier and orthogonal. gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 Return an initialized network. """ if len(gpu_ids) > 0: assert(torch.cuda.is_available()) net.to(gpu_ids[0]) if initialize_weights: init_weights(net, init_type, init_gain=init_gain, debug=debug) return net