diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..337f343 --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +checkpoints/ +*.log +*.pth +*.ckpt +__pycache__/ \ No newline at end of file diff --git a/models/__pycache__/roma_unsb_model.cpython-39.pyc b/models/__pycache__/roma_unsb_model.cpython-39.pyc index ae5f110..5f5b809 100644 Binary files a/models/__pycache__/roma_unsb_model.cpython-39.pyc and b/models/__pycache__/roma_unsb_model.cpython-39.pyc differ diff --git a/models/roma_unsb_model.py b/models/roma_unsb_model.py index 68ca586..6e4859d 100644 --- a/models/roma_unsb_model.py +++ b/models/roma_unsb_model.py @@ -65,6 +65,10 @@ class ContentAwareOptimization(nn.Module): super().__init__() self.lambda_inc = lambda_inc # 权重增强系数 self.eta_ratio = eta_ratio # 选择内容区域的比例 + + # 改为类成员变量,确保钩子函数可访问 + self.gradients_real = [] + self.gradients_fake = [] def compute_cosine_similarity(self, gradients): """ @@ -79,78 +83,65 @@ class ContentAwareOptimization(nn.Module): cosine_sim = F.cosine_similarity(gradients, mean_grad, dim=2) # [B, N] return cosine_sim - def generate_weight_map(self, gradients_fake): + def generate_weight_map(self, gradients_real, gradients_fake): """ 生成内容感知权重图 Args: - gradients_fake: [B, N, D] 生成图像判别器梯度 [2,3,256,256] + gradients_real: [B, N, D] 真实图像判别器梯度 + gradients_fake: [B, N, D] 生成图像判别器梯度 Returns: - weight_fake: [B, N] 生成图像权重图 [2,3,256] + weight_real: [B, N] 真实图像权重图 + weight_fake: [B, N] 生成图像权重图 """ + # 计算真实图像块的余弦相似度 + cosine_real = self.compute_cosine_similarity(gradients_real) # [B, N] 公式5 # 计算生成图像块的余弦相似度 cosine_fake = self.compute_cosine_similarity(gradients_fake) # [B, N] - # 选择内容丰富的区域(余弦相似度最低的eta_ratio比例) - k = int(self.eta_ratio * cosine_fake.shape[1]) - - # 对生成图像生成权重图(同理) + # 选择内容丰富的区域(余弦相似度最低的eta_ratio比例) + k = int(self.eta_ratio * cosine_real.shape[1]) + + # 对真实图像生成权重图 + _, real_indices = torch.topk(-cosine_real, k, dim=1) # 选择最不相似的区域 + weight_real = torch.ones_like(cosine_real) + for b in range(cosine_real.shape[0]): + weight_real[b, real_indices[b]] = self.lambda_inc / (1e-6 + torch.abs(cosine_real[b, real_indices[b]])) #公式6 + + # 对生成图像生成权重图(同理) _, fake_indices = torch.topk(-cosine_fake, k, dim=1) weight_fake = torch.ones_like(cosine_fake) for b in range(cosine_fake.shape[0]): weight_fake[b, fake_indices[b]] = self.lambda_inc / (1e-6 + torch.abs(cosine_fake[b, fake_indices[b]])) - return weight_fake + return weight_real, weight_fake def forward(self, D_real, D_fake, real_scores, fake_scores): - """ - 计算内容感知对抗损失 - Args: - D_real: 判别器对真实图像的特征输出 [B, C, H, W] - D_fake: 判别器对生成图像的特征输出 [B, C, H, W] - real_scores: 真实图像的判别器预测 [B, N] (N=H*W) - fake_scores: 生成图像的判别器预测 [B, N] - Returns: - loss_co_adv: 内容感知对抗损失 - """ - B, C, H, W = D_real.shape - N = H * W + # 清空梯度缓存 + self.gradients_real.clear() + self.gradients_fake.clear() - # 注册钩子获取梯度 - gradients_real = [] - gradients_fake = [] - - def hook_real(grad): - gradients_real.append(grad.detach().view(B, N, -1)) - - def hook_fake(grad): - gradients_fake.append(grad.detach().view(B, N, -1)) + # 注册钩子 + hook_real = lambda grad: self.gradients_real.append(grad.detach()) + hook_fake = lambda grad: self.gradients_fake.append(grad.detach()) D_real.register_hook(hook_real) D_fake.register_hook(hook_fake) - # 计算原始对抗损失以触发梯度计算 - loss_real = torch.mean(torch.log(real_scores + 1e-8)) - loss_fake = torch.mean(torch.log(1 - fake_scores + 1e-8)) - # 添加与 D_real、D_fake 相关的 dummy 项,确保梯度传递 - loss_dummy = 1e-8 * (D_real.sum() + D_fake.sum()) - total_loss = loss_real + loss_fake + loss_dummy - total_loss.backward(retain_graph=True) + # 触发梯度计算 + (real_scores.mean() + fake_scores.mean()).backward(retain_graph=True) - # 获取梯度数据 - gradients_real = gradients_real[0] # [B, N, D] - gradients_fake = gradients_fake[0] # [B, N, D] + # 获取梯度并调整维度 + grad_real = self.gradients_real[0] # [B, N, D] + grad_fake = self.gradients_fake[0] # 生成权重图 - self.weight_real, self.weight_fake = self.generate_weight_map(gradients_real, gradients_fake) + weight_real, weight_fake = self.generate_weight_map(grad_real, grad_fake) - # 应用权重到对抗损失 - loss_co_real = torch.mean(self.weight_real * torch.log(real_scores + 1e-8)) - loss_co_fake = torch.mean(self.weight_fake * torch.log(1 - fake_scores + 1e-8)) + # 计算加权损失 + loss_co_real = (weight_real * torch.log(real_scores + 1e-8)).mean() + loss_co_fake = (weight_fake * torch.log(1 - fake_scores + 1e-8)).mean() - # 计算并返回最终内容感知对抗损失 - loss_co_adv = -(loss_co_real + loss_co_fake) - - return loss_co_adv + return -(loss_co_real + loss_co_fake), weight_real, weight_fake class ContentAwareTemporalNorm(nn.Module): def __init__(self, gamma_stride=0.1, kernel_size=21, sigma=5.0): @@ -158,6 +149,33 @@ class ContentAwareTemporalNorm(nn.Module): self.gamma_stride = gamma_stride # 控制整体运动幅度 self.smoother = GaussianBlur(kernel_size, sigma=sigma) # 高斯平滑层 + def upsample_weight_map(self, weight_patch, target_size=(256, 256)): + """ + 将patch级别的权重图上采样到目标分辨率 + Args: + weight_patch: [B, 1, 24, 24] 来自ViT的patch权重图 + target_size: 目标分辨率 (H, W) + Returns: + weight_full: [B, 1, 256, 256] 上采样后的全分辨率权重图 + """ + # 使用双线性插值上采样 + B = weight_patch.shape[0] + weight_patch = weight_patch.view(B, 1, 24, 24) + + weight_full = F.interpolate( + weight_patch, + size=target_size, + mode='bilinear', + align_corners=False + ) + + # 对每个16x16的patch内部保持权重一致(可选) + # 通过平均池化再扩展,消除插值引入的渐变 + weight_full = F.avg_pool2d(weight_full, kernel_size=16, stride=16) + weight_full = F.interpolate(weight_full, scale_factor=16, mode='nearest') + + return weight_full + def forward(self, weight_map): """ 生成内容感知光流 @@ -166,15 +184,16 @@ class ContentAwareTemporalNorm(nn.Module): Returns: F_content: [B, 2, H, W] 生成的光流场(x/y方向位移) """ - #print(weight_map.shape) - B, _, H, W = weight_map.shape + # 上采样权重图到全分辨率 + weight_full = self.upsample_weight_map(weight_map) # [B,1,384,384] # 1. 归一化权重图 # 保持区域相对强度,同时限制数值范围 - weight_norm = F.normalize(weight_map, p=1, dim=(2,3)) # L1归一化 [B,1,H,W] + weight_norm = F.normalize(weight_full, p=1, dim=(2,3)) # L1归一化 [B,1,H,W] - # 2. 生成高斯噪声(与光流场同尺寸) - z = torch.randn(B, 2, H, W, device=weight_map.device) # [B,2,H,W] + # 2. 生成高斯噪声 + B, _, H, W = weight_norm.shape + z = torch.randn(B, 2, H, W, device=weight_norm.device) # [B,2,H,W] # 3. 合成基础光流 # 将权重图扩展为2通道(x/y方向共享权重) @@ -437,8 +456,8 @@ class RomaUnsbModel(BaseModel): # ============ 第三步:拼接输入并执行网络推理 ============= bs = self.real_A0.size(0) - z_in = torch.randn(size=[bs, 4 * self.opt.ngf]).to(self.real_A0.device) - z_in2 = torch.randn(size=[bs, 4 * self.opt.ngf]).to(self.real_A1.device) + self.z_in = torch.randn(size=[bs, 4 * self.opt.ngf]).to(self.real_A0.device) + self.z_in2 = torch.randn(size=[bs, 4 * self.opt.ngf]).to(self.real_A1.device) # 将 real_A, real_B 拼接 (如 nce_idt=True),并同样处理 real_A_noisy 与 XtB self.real = self.real_A0 self.realt = self.real_A_noisy @@ -449,8 +468,8 @@ class RomaUnsbModel(BaseModel): self.real = torch.flip(self.real, [3]) self.realt = torch.flip(self.realt, [3]) - self.fake_B0 = self.netG(self.real_A0, self.time, z_in) - self.fake_B1 = self.netG(self.real_A1, self.time, z_in2) + self.fake_B0 = self.netG(self.real_A0, self.time, self.z_in) + self.fake_B1 = self.netG(self.real_A1, self.time, self.z_in2) if self.opt.phase == 'train': real_A0 = self.real_A0 @@ -476,28 +495,6 @@ class RomaUnsbModel(BaseModel): # [[1,576,768],[1,576,768],[1,576,768]] # [3,576,768] - ## 生成图像的梯度 - #fake_gradient = torch.autograd.grad(self.mutil_fake_B0_tokens.sum(), self.mutil_fake_B0_tokens, create_graph=True)[0] - # - ## 梯度图 - #self.weight_fake = self.cao.generate_weight_map(fake_gradient) - # - ## 生成图像的CTN光流图 - #self.f_content = self.ctn(self.weight_fake) - # - ## 变换后的图片 - #self.warped_real_A_noisy2 = warp(self.real_A_noisy, self.f_content) - #self.warped_fake_B0 = warp(self.fake_B0,self.f_content) - # - ## 经过第二次生成器 - #self.warped_fake_B0_2 = self.netG(self.warped_real_A_noisy2, self.time, z_in) - - #warped_fake_B0_2=self.warped_fake_B0_2 - #warped_fake_B0=self.warped_fake_B0 - #self.warped_fake_B0_2_resize = self.resize(warped_fake_B0_2) - #self.warped_fake_B0_resize = self.resize(warped_fake_B0) - #self.mutil_warped_fake_B0_tokens = self.netPreViT(self.warped_fake_B0_resize, self.atten_layers, get_tokens=True) - #self.mutil_fake_B0_2_tokens = self.netPreViT(self.warped_fake_B0_2_resize, self.atten_layers, get_tokens=True) def compute_D_loss(self): #判别器还是没有改 @@ -505,25 +502,19 @@ class RomaUnsbModel(BaseModel): lambda_D_ViT = self.opt.lambda_D_ViT fake_B0_tokens = self.mutil_fake_B0_tokens[0].detach() - fake_B1_tokens = self.mutil_fake_B1_tokens[0].detach() + real_B0_tokens = self.mutil_real_B0_tokens[0] - real_B1_tokens = self.mutil_real_B1_tokens[0] - pre_fake0_ViT = self.netD_ViT(fake_B0_tokens) - pre_fake1_ViT = self.netD_ViT(fake_B1_tokens) + self.loss_D_fake_ViT = self.criterionGAN(pre_fake0_ViT, False) - self.loss_D_fake_ViT = (self.criterionGAN(pre_fake0_ViT, False).mean() + self.criterionGAN(pre_fake1_ViT, False).mean()) * 0.5 * lambda_D_ViT + pred_real0_ViT = self.netD_ViT(real_B0_tokens) + self.loss_D_real_ViT = self.criterionGAN(pred_real0_ViT, True) - pred_real0_ViT = self.netD_ViT(real_B0_tokens) - pred_real1_ViT = self.netD_ViT(real_B1_tokens) - self.loss_D_real_ViT = (self.criterionGAN(pred_real0_ViT, True).mean() + self.criterionGAN(pred_real1_ViT, True).mean()) * 0.5 * lambda_D_ViT - - self.loss_D_ViT = (self.loss_D_fake_ViT + self.loss_D_real_ViT) * 0.5 - - - return self.loss_D_ViT + self.losscao, self.weight_real, self.weight_fake = self.cao(pred_real0_ViT, pre_fake0_ViT, self.loss_D_real_ViT, self.loss_D_fake_ViT) + + return self.losscao* lambda_D_ViT def compute_E_loss(self): """计算判别器 E 的损失""" @@ -537,12 +528,28 @@ class RomaUnsbModel(BaseModel): def compute_G_loss(self): """计算生成器的 GAN 损失""" - + if self.opt.lambda_ctn > 0.0: + # 生成图像的CTN光流图 + self.f_content = self.ctn(self.weight_fake) + + # 变换后的图片 + self.warped_real_A_noisy2 = warp(self.real_A_noisy, self.f_content) + self.warped_fake_B0 = warp(self.fake_B0,self.f_content) + + # 经过第二次生成器 + self.warped_fake_B0_2 = self.netG(self.warped_real_A_noisy2, self.time, self.z_in) + + warped_fake_B0_2=self.warped_fake_B0_2 + warped_fake_B0=self.warped_fake_B0 + # 计算L2损失 + self.ctn_loss = F.mse_loss(warped_fake_B0_2, warped_fake_B0) + if self.opt.lambda_GAN > 0.0: pred_fake = self.netD_ViT(self.mutil_fake_B0_tokens[0]) - self.loss_G_GAN = self.criterionGAN(pred_fake, True).mean() * self.opt.lambda_GAN + self.loss_G_GAN = self.criterionGAN(pred_fake, True).mean() else: self.loss_G_GAN = 0.0 + self.loss_SB = 0 if self.opt.lambda_SB > 0.0: XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B0], dim=1) @@ -551,9 +558,9 @@ class RomaUnsbModel(BaseModel): bs = self.opt.batch_size # eq.9 - ET_XY = self.netE(XtXt_1, self.time, XtXt_1).mean() - torch.logsumexp(self.netE(XtXt_1, self.time, XtXt_2).reshape(-1), dim=0) + ET_XY = self.netE(XtXt_1, self.time, XtXt_1).mean() - self.netE(XtXt_1, self.time, XtXt_2).mean() self.loss_SB = -(self.opt.num_timesteps - self.time[0]) / self.opt.num_timesteps * self.opt.tau * ET_XY - self.loss_SB += self.opt.tau * torch.mean((self.real_A_noisy - self.fake_B0) ** 2) + self.loss_SB += torch.mean((self.real_A_noisy - self.fake_B0) ** 2) if self.opt.lambda_global > 0.0: loss_global = self.calculate_similarity(self.real_A0, self.fake_B0) + self.calculate_similarity(self.real_A1, self.fake_B1) @@ -561,12 +568,10 @@ class RomaUnsbModel(BaseModel): else: loss_global = 0.0 - self.l2_loss = 0.0 - #if self.opt.lambda_ctn > 0.0: - # wapped_fake_B = warp(self.fake_B, self.f_content) # use updated self.f_content - # self.l2_loss = F.mse_loss(self.fake_B_2, wapped_fake_B) # complete the loss calculation - - self.loss_G = self.loss_G_GAN + self.opt.lambda_SB * self.loss_SB + self.opt.lambda_ctn * self.l2_loss + loss_global * self.opt.lambda_global + self.loss_G = self.opt.lambda_GAN * self.loss_G_GAN + \ + self.opt.lambda_SB * self.loss_SB + \ + self.opt.lambda_ctn * self.ctn_loss + \ + loss_global * self.opt.lambda_global return self.loss_G def calculate_attention_loss(self): diff --git a/scripts/train.sh b/scripts/train.sh index 9f429ab..0c016bf 100755 --- a/scripts/train.sh +++ b/scripts/train.sh @@ -7,7 +7,7 @@ python train.py \ --dataroot /home/openxs/kunyu/datasets/InfraredCity-Lite/Double/Moitor \ - --name ROMA_UNSB_002 \ + --name ROMA_UNSB_003 \ --dataset_mode unaligned_double \ --no_flip \ --display_env ROMA \ diff --git a/train.py b/train.py index cfd728a..8cd245a 100644 --- a/train.py +++ b/train.py @@ -44,6 +44,7 @@ if __name__ == '__main__': model.setup(opt) # regular setup: load and print networks; create schedulers model.parallelize() model.set_input(data) # unpack data from dataset and apply preprocessing + #print('Call opt paras') model.optimize_parameters() # calculate loss functions, get gradients, update network weights if len(opt.gpu_ids) > 0: torch.cuda.synchronize()