diff --git a/models/roma_unsb_model.py b/models/roma_unsb_model.py index d0aee8d..1d8a85d 100644 --- a/models/roma_unsb_model.py +++ b/models/roma_unsb_model.py @@ -118,8 +118,7 @@ class ContentAwareOptimization(nn.Module): def forward(self, D_real, D_fake, real_scores, fake_scores): # 清空梯度缓存 self.gradients_real.clear() - self.gradients_fake.clear() - + self.gradients_fake.clear() # 注册钩子 hook_real = lambda grad: self.gradients_real.append(grad.detach()) hook_fake = lambda grad: self.gradients_fake.append(grad.detach()) @@ -138,10 +137,10 @@ class ContentAwareOptimization(nn.Module): weight_real, weight_fake = self.generate_weight_map(grad_real, grad_fake) # 计算加权损失 - loss_co_real = (weight_real * torch.log(real_scores + 1e-8)).mean() - loss_co_fake = (weight_fake * torch.log(1 - fake_scores + 1e-8)).mean() + loss_co_real = (weight_real * real_scores).mean() + loss_co_fake = (weight_fake * fake_scores).mean() - return -(loss_co_real + loss_co_fake), weight_real, weight_fake + return (loss_co_real + loss_co_fake), weight_real, weight_fake class ContentAwareTemporalNorm(nn.Module): def __init__(self, gamma_stride=0.1, kernel_size=21, sigma=5.0): @@ -252,7 +251,7 @@ class RomaUnsbModel(BaseModel): BaseModel.__init__(self, opt) # 指定需要打印的训练损失 - self.loss_names = ['G_GAN', 'D_real_ViT', 'D_fake_ViT', 'G', 'SB', 'global', 'ctn'] + self.loss_names = ['G_GAN', 'D_real_ViT', 'D_fake_ViT', 'G', 'SB', 'global', 'ctn',] self.visual_names = ['real_A0', 'real_A_noisy', 'fake_B0', 'real_B0'] self.atten_layers = [int(i) for i in self.opt.atten_layers.split(',')] @@ -368,40 +367,6 @@ class RomaUnsbModel(BaseModel): self.image_paths = input['A_paths' if AtoB else 'B_paths'] - def tokens_concat(self, origin_tokens, adjacent_size): - adj_size = adjacent_size - B, token_num, C = origin_tokens.shape[0], origin_tokens.shape[1], origin_tokens.shape[2] - S = int(math.sqrt(token_num)) - if S * S != token_num: - print('Error! Not a square!') - token_map = origin_tokens.clone().reshape(B,S,S,C) - cut_patch_list = [] - for i in range(0, S, adj_size): - for j in range(0, S, adj_size): - i_left = i - i_right = i + adj_size + 1 if i + adj_size <= S else S + 1 - j_left = j - j_right = j + adj_size if j + adj_size <= S else S + 1 - - cut_patch = token_map[:, i_left:i_right, j_left: j_right, :] - cut_patch= cut_patch.reshape(B,-1,C) - cut_patch = torch.mean(cut_patch, dim=1, keepdim=True) - cut_patch_list.append(cut_patch) - - - result = torch.cat(cut_patch_list,dim=1) - return result - - def cat_results(self, origin_tokens, adj_size_list): - res_list = [origin_tokens] - for ad_s in adj_size_list: - cat_result = self.tokens_concat(origin_tokens, ad_s) - res_list.append(cat_result) - - result = torch.cat(res_list, dim=1) - - return result - def forward(self): """Run forward pass; called by both functions and .""" @@ -462,8 +427,8 @@ class RomaUnsbModel(BaseModel): self.real = torch.flip(self.real, [3]) self.realt = torch.flip(self.realt, [3]) - self.fake_B0 = self.netG(self.real_A0, self.time, self.z_in) - self.fake_B1 = self.netG(self.real_A1, self.time, self.z_in2) + self.fake_B0 = self.netG(self.real_A_noisy, self.time, self.z_in) + self.fake_B1 = self.netG(self.real_A_noisy2, self.time, self.z_in2) if self.opt.phase == 'train': real_A0 = self.real_A0 @@ -507,8 +472,8 @@ class RomaUnsbModel(BaseModel): self.loss_D_real_ViT = self.criterionGAN(pred_real0_ViT, True) self.losscao, self.weight_real, self.weight_fake = self.cao(pred_real0_ViT, pre_fake0_ViT, self.loss_D_real_ViT, self.loss_D_fake_ViT) - - return self.losscao* lambda_D_ViT + self.loss_D_ViT = self.losscao* lambda_D_ViT + return self.loss_D_ViT def compute_E_loss(self): """计算判别器 E 的损失""" diff --git a/scripts/train.sh b/scripts/train.sh index 6b117d9..fb83ff3 100755 --- a/scripts/train.sh +++ b/scripts/train.sh @@ -7,7 +7,7 @@ python train.py \ --dataroot /home/openxs/kunyu/datasets/InfraredCity-Lite/Double/Moitor \ - --name UNIV_2 \ + --name UNIV_1 \ --dataset_mode unaligned_double \ --no_flip \ --display_env UNIV \ @@ -17,13 +17,14 @@ python train.py \ --lambda_ctn 1.0 \ --lambda_inc 1.0 \ --lr 0.00001 \ - --gpu_id 1 \ + --gpu_id 0 \ + --lambda_D_ViT 1 \ --nce_idt False \ --netF mlp_sample \ --flip_equivariance True \ --eta_ratio 0.4 \ --tau 0.01 \ - --num_timesteps 5 \ + --num_timesteps 4 \ --input_nc 3 \ --n_epochs 400 \ --n_epochs_decay 200 \