修改后的最新
This commit is contained in:
parent
7a6e856b4b
commit
2a0a56ac26
@ -118,8 +118,7 @@ class ContentAwareOptimization(nn.Module):
|
|||||||
def forward(self, D_real, D_fake, real_scores, fake_scores):
|
def forward(self, D_real, D_fake, real_scores, fake_scores):
|
||||||
# 清空梯度缓存
|
# 清空梯度缓存
|
||||||
self.gradients_real.clear()
|
self.gradients_real.clear()
|
||||||
self.gradients_fake.clear()
|
self.gradients_fake.clear()
|
||||||
|
|
||||||
# 注册钩子
|
# 注册钩子
|
||||||
hook_real = lambda grad: self.gradients_real.append(grad.detach())
|
hook_real = lambda grad: self.gradients_real.append(grad.detach())
|
||||||
hook_fake = lambda grad: self.gradients_fake.append(grad.detach())
|
hook_fake = lambda grad: self.gradients_fake.append(grad.detach())
|
||||||
@ -138,10 +137,10 @@ class ContentAwareOptimization(nn.Module):
|
|||||||
weight_real, weight_fake = self.generate_weight_map(grad_real, grad_fake)
|
weight_real, weight_fake = self.generate_weight_map(grad_real, grad_fake)
|
||||||
|
|
||||||
# 计算加权损失
|
# 计算加权损失
|
||||||
loss_co_real = (weight_real * torch.log(real_scores + 1e-8)).mean()
|
loss_co_real = (weight_real * real_scores).mean()
|
||||||
loss_co_fake = (weight_fake * torch.log(1 - fake_scores + 1e-8)).mean()
|
loss_co_fake = (weight_fake * fake_scores).mean()
|
||||||
|
|
||||||
return -(loss_co_real + loss_co_fake), weight_real, weight_fake
|
return (loss_co_real + loss_co_fake), weight_real, weight_fake
|
||||||
|
|
||||||
class ContentAwareTemporalNorm(nn.Module):
|
class ContentAwareTemporalNorm(nn.Module):
|
||||||
def __init__(self, gamma_stride=0.1, kernel_size=21, sigma=5.0):
|
def __init__(self, gamma_stride=0.1, kernel_size=21, sigma=5.0):
|
||||||
@ -252,7 +251,7 @@ class RomaUnsbModel(BaseModel):
|
|||||||
BaseModel.__init__(self, opt)
|
BaseModel.__init__(self, opt)
|
||||||
|
|
||||||
# 指定需要打印的训练损失
|
# 指定需要打印的训练损失
|
||||||
self.loss_names = ['G_GAN', 'D_real_ViT', 'D_fake_ViT', 'G', 'SB', 'global', 'ctn']
|
self.loss_names = ['G_GAN', 'D_real_ViT', 'D_fake_ViT', 'G', 'SB', 'global', 'ctn',]
|
||||||
self.visual_names = ['real_A0', 'real_A_noisy', 'fake_B0', 'real_B0']
|
self.visual_names = ['real_A0', 'real_A_noisy', 'fake_B0', 'real_B0']
|
||||||
self.atten_layers = [int(i) for i in self.opt.atten_layers.split(',')]
|
self.atten_layers = [int(i) for i in self.opt.atten_layers.split(',')]
|
||||||
|
|
||||||
@ -368,40 +367,6 @@ class RomaUnsbModel(BaseModel):
|
|||||||
self.image_paths = input['A_paths' if AtoB else 'B_paths']
|
self.image_paths = input['A_paths' if AtoB else 'B_paths']
|
||||||
|
|
||||||
|
|
||||||
def tokens_concat(self, origin_tokens, adjacent_size):
|
|
||||||
adj_size = adjacent_size
|
|
||||||
B, token_num, C = origin_tokens.shape[0], origin_tokens.shape[1], origin_tokens.shape[2]
|
|
||||||
S = int(math.sqrt(token_num))
|
|
||||||
if S * S != token_num:
|
|
||||||
print('Error! Not a square!')
|
|
||||||
token_map = origin_tokens.clone().reshape(B,S,S,C)
|
|
||||||
cut_patch_list = []
|
|
||||||
for i in range(0, S, adj_size):
|
|
||||||
for j in range(0, S, adj_size):
|
|
||||||
i_left = i
|
|
||||||
i_right = i + adj_size + 1 if i + adj_size <= S else S + 1
|
|
||||||
j_left = j
|
|
||||||
j_right = j + adj_size if j + adj_size <= S else S + 1
|
|
||||||
|
|
||||||
cut_patch = token_map[:, i_left:i_right, j_left: j_right, :]
|
|
||||||
cut_patch= cut_patch.reshape(B,-1,C)
|
|
||||||
cut_patch = torch.mean(cut_patch, dim=1, keepdim=True)
|
|
||||||
cut_patch_list.append(cut_patch)
|
|
||||||
|
|
||||||
|
|
||||||
result = torch.cat(cut_patch_list,dim=1)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def cat_results(self, origin_tokens, adj_size_list):
|
|
||||||
res_list = [origin_tokens]
|
|
||||||
for ad_s in adj_size_list:
|
|
||||||
cat_result = self.tokens_concat(origin_tokens, ad_s)
|
|
||||||
res_list.append(cat_result)
|
|
||||||
|
|
||||||
result = torch.cat(res_list, dim=1)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def forward(self):
|
def forward(self):
|
||||||
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
|
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
|
||||||
|
|
||||||
@ -462,8 +427,8 @@ class RomaUnsbModel(BaseModel):
|
|||||||
self.real = torch.flip(self.real, [3])
|
self.real = torch.flip(self.real, [3])
|
||||||
self.realt = torch.flip(self.realt, [3])
|
self.realt = torch.flip(self.realt, [3])
|
||||||
|
|
||||||
self.fake_B0 = self.netG(self.real_A0, self.time, self.z_in)
|
self.fake_B0 = self.netG(self.real_A_noisy, self.time, self.z_in)
|
||||||
self.fake_B1 = self.netG(self.real_A1, self.time, self.z_in2)
|
self.fake_B1 = self.netG(self.real_A_noisy2, self.time, self.z_in2)
|
||||||
|
|
||||||
if self.opt.phase == 'train':
|
if self.opt.phase == 'train':
|
||||||
real_A0 = self.real_A0
|
real_A0 = self.real_A0
|
||||||
@ -507,8 +472,8 @@ class RomaUnsbModel(BaseModel):
|
|||||||
self.loss_D_real_ViT = self.criterionGAN(pred_real0_ViT, True)
|
self.loss_D_real_ViT = self.criterionGAN(pred_real0_ViT, True)
|
||||||
|
|
||||||
self.losscao, self.weight_real, self.weight_fake = self.cao(pred_real0_ViT, pre_fake0_ViT, self.loss_D_real_ViT, self.loss_D_fake_ViT)
|
self.losscao, self.weight_real, self.weight_fake = self.cao(pred_real0_ViT, pre_fake0_ViT, self.loss_D_real_ViT, self.loss_D_fake_ViT)
|
||||||
|
self.loss_D_ViT = self.losscao* lambda_D_ViT
|
||||||
return self.losscao* lambda_D_ViT
|
return self.loss_D_ViT
|
||||||
|
|
||||||
def compute_E_loss(self):
|
def compute_E_loss(self):
|
||||||
"""计算判别器 E 的损失"""
|
"""计算判别器 E 的损失"""
|
||||||
|
|||||||
@ -7,7 +7,7 @@
|
|||||||
|
|
||||||
python train.py \
|
python train.py \
|
||||||
--dataroot /home/openxs/kunyu/datasets/InfraredCity-Lite/Double/Moitor \
|
--dataroot /home/openxs/kunyu/datasets/InfraredCity-Lite/Double/Moitor \
|
||||||
--name UNIV_2 \
|
--name UNIV_1 \
|
||||||
--dataset_mode unaligned_double \
|
--dataset_mode unaligned_double \
|
||||||
--no_flip \
|
--no_flip \
|
||||||
--display_env UNIV \
|
--display_env UNIV \
|
||||||
@ -17,13 +17,14 @@ python train.py \
|
|||||||
--lambda_ctn 1.0 \
|
--lambda_ctn 1.0 \
|
||||||
--lambda_inc 1.0 \
|
--lambda_inc 1.0 \
|
||||||
--lr 0.00001 \
|
--lr 0.00001 \
|
||||||
--gpu_id 1 \
|
--gpu_id 0 \
|
||||||
|
--lambda_D_ViT 1 \
|
||||||
--nce_idt False \
|
--nce_idt False \
|
||||||
--netF mlp_sample \
|
--netF mlp_sample \
|
||||||
--flip_equivariance True \
|
--flip_equivariance True \
|
||||||
--eta_ratio 0.4 \
|
--eta_ratio 0.4 \
|
||||||
--tau 0.01 \
|
--tau 0.01 \
|
||||||
--num_timesteps 5 \
|
--num_timesteps 4 \
|
||||||
--input_nc 3 \
|
--input_nc 3 \
|
||||||
--n_epochs 400 \
|
--n_epochs 400 \
|
||||||
--n_epochs_decay 200 \
|
--n_epochs_decay 200 \
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user