From 8cd61d0503c96feb478ba7f0beae2fdb10809668 Mon Sep 17 00:00:00 2001 From: areszz <1031614818@qq.com> Date: Sat, 22 Feb 2025 15:23:52 +0800 Subject: [PATCH] first_change --- models/cnt.py | 192 +++++++++++++++++++++++++++++++++++++++++++ models/roma_model.py | 124 +++++++++++++++++++++------- models/self_build.py | 1 + 3 files changed, 289 insertions(+), 28 deletions(-) create mode 100644 models/cnt.py diff --git a/models/cnt.py b/models/cnt.py new file mode 100644 index 0000000..c4bd72b --- /dev/null +++ b/models/cnt.py @@ -0,0 +1,192 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.transforms import GaussianBlur + +def warp(image, flow): #warp操作 + """ + 基于光流的图像变形函数 + Args: + image: [B, C, H, W] 输入图像 + flow: [B, 2, H, W] 光流场(x/y方向位移) + Returns: + warped: [B, C, H, W] 变形后的图像 + """ + B, C, H, W = image.shape + # 生成网格坐标 + grid_x, grid_y = torch.meshgrid(torch.arange(W), torch.arange(H)) + grid = torch.stack((grid_x, grid_y), dim=0).float().to(image.device) # [2,H,W] + grid = grid.unsqueeze(0).repeat(B,1,1,1) # [B,2,H,W] + + # 应用光流位移(归一化到[-1,1]) + new_grid = grid + flow + new_grid[:,0,:,:] = 2.0 * new_grid[:,0,:,:] / (W-1) - 1.0 # x方向 + new_grid[:,1,:,:] = 2.0 * new_grid[:,1,:,:] / (H-1) - 1.0 # y方向 + new_grid = new_grid.permute(0,2,3,1) # [B,H,W,2] + + # 双线性插值 + return F.grid_sample(image, new_grid, align_corners=True) + +# 时序归一化损失计算 +def compute_ctn_loss(G, x, F_content): #公式10 + """ + 计算内容感知时序归一化损失 + Args: + G: 生成器 + x: 输入红外图像 [B,C,H,W] + F_content: 生成的光流场 [B,2,H,W] + """ + + # 生成可见光图像 + y_fake = G(x) # [B,3,H,W] + + # 对生成结果应用光流变形 + warped_fake = warp(y_fake, F_content) # [B,3,H,W] + + # 对输入应用相同光流后生成图像 + warped_x = warp(x, F_content) # [B,C,H,W] + y_fake_warped = G(warped_x) # [B,3,H,W] + + # 计算L2损失 + loss = F.mse_loss(warped_fake, y_fake_warped) + return loss + +class ContentAwareOptimization(nn.Module): + def __init__(self, lambda_inc=2.0, eta_ratio=0.4): + super().__init__() + self.lambda_inc = lambda_inc # 权重增强系数 + self.eta_ratio = eta_ratio # 选择内容区域的比例 + + def compute_cosine_similarity(self, gradients): + """ + 计算每个patch梯度与平均梯度的余弦相似度 + Args: + gradients: [B, N, D] 判别器输出的每个patch的梯度(N=w*h) + Returns: + cosine_sim: [B, N] 每个patch的余弦相似度 + """ + mean_grad = torch.mean(gradients, dim=1, keepdim=True) # [B, 1, D] + # 计算余弦相似度 + cosine_sim = F.cosine_similarity(gradients, mean_grad, dim=2) # [B, N] + return cosine_sim + + def generate_weight_map(self, gradients_real, gradients_fake): + """ + 生成内容感知权重图 + Args: + gradients_real: [B, N, D] 真实图像判别器梯度 + gradients_fake: [B, N, D] 生成图像判别器梯度 + Returns: + weight_real: [B, N] 真实图像权重图 + weight_fake: [B, N] 生成图像权重图 + """ + # 计算真实图像块的余弦相似度 + cosine_real = self.compute_cosine_similarity(gradients_real) # [B, N] 公式5 + # 计算生成图像块的余弦相似度 + cosine_fake = self.compute_cosine_similarity(gradients_fake) # [B, N] + + # 选择内容丰富的区域(余弦相似度最低的eta_ratio比例) + k = int(self.eta_ratio * cosine_real.shape[1]) + + # 对真实图像生成权重图 + _, real_indices = torch.topk(-cosine_real, k, dim=1) # 选择最不相似的区域 + weight_real = torch.ones_like(cosine_real) + for b in range(cosine_real.shape[0]): + weight_real[b, real_indices[b]] = self.lambda_inc / (1e-6 + torch.abs(cosine_real[b, real_indices[b]])) #公式6 + + # 对生成图像生成权重图(同理) + _, fake_indices = torch.topk(-cosine_fake, k, dim=1) + weight_fake = torch.ones_like(cosine_fake) + for b in range(cosine_fake.shape[0]): + weight_fake[b, fake_indices[b]] = self.lambda_inc / (1e-6 + torch.abs(cosine_fake[b, fake_indices[b]])) + + return weight_real, weight_fake + + def forward(self, D_real, D_fake, real_scores, fake_scores): + """ + 计算内容感知对抗损失 + Args: + D_real: 判别器对真实图像的特征输出 [B, C, H, W] + D_fake: 判别器对生成图像的特征输出 [B, C, H, W] + real_scores: 真实图像的判别器预测 [B, N] (N=H*W) + fake_scores: 生成图像的判别器预测 [B, N] + Returns: + loss_co_adv: 内容感知对抗损失 + """ + B, C, H, W = D_real.shape + N = H * W + + # 注册钩子获取梯度 + gradients_real = [] + gradients_fake = [] + + def hook_real(grad): + gradients_real.append(grad.detach().view(B, N, -1)) + + def hook_fake(grad): + gradients_fake.append(grad.detach().view(B, N, -1)) + + D_real.register_hook(hook_real) + D_fake.register_hook(hook_fake) + + # 计算原始对抗损失以触发梯度计算 + loss_real = torch.mean(torch.log(real_scores + 1e-8)) + loss_fake = torch.mean(torch.log(1 - fake_scores + 1e-8)) + # 添加与 D_real、D_fake 相关的 dummy 项,确保梯度传递 + loss_dummy = 1e-8 * (D_real.sum() + D_fake.sum()) + total_loss = loss_real + loss_fake + loss_dummy + total_loss.backward(retain_graph=True) + + # 获取梯度数据 + gradients_real = gradients_real[0] # [B, N, D] + gradients_fake = gradients_fake[0] # [B, N, D] + + # 生成权重图 + self.weight_real, self.weight_fake = self.generate_weight_map(gradients_real, gradients_fake) + + # 应用权重到对抗损失 + loss_co_real = torch.mean(self.weight_real * torch.log(real_scores + 1e-8)) + loss_co_fake = torch.mean(self.weight_fake * torch.log(1 - fake_scores + 1e-8)) + + # 计算并返回最终内容感知对抗损失 + loss_co_adv = -(loss_co_real + loss_co_fake) + + return loss_co_adv + +class ContentAwareTemporalNorm(nn.Module): + def __init__(self, gamma_stride=0.1, kernel_size=21, sigma=5.0): + super().__init__() + self.gamma_stride = gamma_stride # 控制整体运动幅度 + self.smoother = GaussianBlur(kernel_size, sigma=sigma) # 高斯平滑层 + + def forward(self, weight_map): + """ + 生成内容感知光流 + Args: + weight_map: [B, 1, H, W] 权重图(来自内容感知优化模块) + Returns: + F_content: [B, 2, H, W] 生成的光流场(x/y方向位移) + """ + B, _, H, W = weight_map.shape + + # 1. 归一化权重图 + # 保持区域相对强度,同时限制数值范围 + weight_norm = F.normalize(weight_map, p=1, dim=(2,3)) # L1归一化 [B,1,H,W] + + # 2. 生成高斯噪声(与光流场同尺寸) + z = torch.randn(B, 2, H, W, device=weight_map.device) # [B,2,H,W] + + # 3. 合成基础光流 + # 将权重图扩展为2通道(x/y方向共享权重) + weight_expanded = weight_norm.expand(-1, 2, -1, -1) # [B,2,H,W] + F_raw = self.gamma_stride * weight_expanded * z # [B,2,H,W] #公式9 + + # 4. 平滑处理(保持结构连续性) + # 对每个通道独立进行高斯模糊 + F_smooth = self.smoother(F_raw) # [B,2,H,W] + + # 5. 动态范围调整(可选) + # 限制光流幅值,避免极端位移 + F_content = torch.tanh(F_smooth) # 缩放到[-1,1]范围 + + return F_content \ No newline at end of file diff --git a/models/roma_model.py b/models/roma_model.py index 48c307e..a3a7af7 100644 --- a/models/roma_model.py +++ b/models/roma_model.py @@ -3,6 +3,7 @@ import torch from .base_model import BaseModel from . import networks from .patchnce import PatchNCELoss +from .cnt import * import util.util as util import timm import time @@ -21,18 +22,21 @@ class ROMAModel(BaseModel): """ Configures options specific for CUT model """ parser.add_argument('--adj_size_list', type=list, default=[2, 4, 6, 8, 12], help='different scales of perception field') + parser.add_argument('--lambda_mlp', type=float, default=1.0, help='weight of lr for discriminator') - parser.add_argument('--lambda_motion', type=float, default=1.0, help='weight for Temporal Consistency') parser.add_argument('--lambda_D_ViT', type=float, default=1.0, help='weight for discriminator') parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss: GAN(G(X))') parser.add_argument('--lambda_global', type=float, default=1.0, help='weight for Global Structural Consistency') parser.add_argument('--lambda_spatial', type=float, default=1.0, help='weight for Local Structural Consistency') + parser.add_argument('--lambda_inc', type=float, default=2.0, help='weight for Content Aware Optimization') + parser.add_argument('--eta_ratio', type=float, default=0.4, help='ratio for selecting content region') + parser.add_argument('--atten_layers', type=str, default='1,3,5', help='compute Cross-Similarity on which layers') parser.add_argument('--local_nums', type=int, default=256) parser.add_argument('--which_D_layer', type=int, default=-1) parser.add_argument('--side_length', type=int, default=7) - parser.set_defaults(pool_size=0) + parser.set_defaults(pool_size=0) opt, _ = parser.parse_known_args() @@ -42,13 +46,13 @@ class ROMAModel(BaseModel): BaseModel.__init__(self, opt) - self.loss_names = ['G_GAN_ViT', 'D_real_ViT', 'D_fake_ViT', 'global', 'spatial', 'motion'] + self.loss_names = ['G_GAN_ViT', 'D_real_ViT', 'D_fake_ViT', 'global', 'spatial'] self.visual_names = ['real_A0', 'real_A1', 'fake_B0', 'fake_B1', 'real_B0', 'real_B1'] self.atten_layers = [int(i) for i in self.opt.atten_layers.split(',')] if self.isTrain: - self.model_names = ['G', 'D_ViT'] + self.model_names = ['G', 'D_ViT', 'G_2'] else: # during test time, only load G self.model_names = ['G'] @@ -62,7 +66,12 @@ class ROMAModel(BaseModel): self.netD_ViT = networks.MLPDiscriminator().to(self.device) self.netPreViT = timm.create_model("vit_base_patch16_384",pretrained=True).to(self.device) + # From UNSB + self.netE = networks.define_D(opt.output_nc*4, opt.ndf, opt.netD, opt.n_layers_D, opt.normD, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt) + # Deine another generator + self.netG_2 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt) + self.norm = F.softmax self.resize = tfs.Resize(size=(384,384)) @@ -76,8 +85,13 @@ class ROMAModel(BaseModel): self.criterionL1 = torch.nn.L1Loss().to(self.device) self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2)) self.optimizer_D_ViT = torch.optim.Adam(self.netD_ViT.parameters(), lr=opt.lr * opt.lambda_mlp, betas=(opt.beta1, opt.beta2)) + self.optimizer_E = torch.optim.Adam(self.netE.parameters(), lr=opt.lr * opt.lambda_mlp, betas=(opt.beta1, opt.beta2)) self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D_ViT) + self.optimizers.append(self.optimizer_E) + + self.cao = ContentAwareOptimization(opt.lambda_inc, opt.eta_ratio) #损失函数 + self.ctn = ContentAwareTemporalNorm() #生成的伪光流场 def data_dependent_initialize(self, data): """ @@ -99,6 +113,13 @@ class ROMAModel(BaseModel): self.loss_D = self.compute_D_loss() self.loss_D.backward() self.optimizer_D_ViT.step() + + # update E + self.set_requires_grad(self.netE, True) + self.optimizer_E.zero_grad() + self.loss_E = self.compute_E_loss() + self.loss_E.backward() + self.optimizer_E.step() # update G self.set_requires_grad(self.netD_ViT, False) @@ -133,7 +154,7 @@ class ROMAModel(BaseModel): times = np.concatenate([np.zeros(1), times]) times = torch.tensor(times).float().cuda() self.times = times - bs = self.mutil_real_A0_tokens.size(0) + bs = self.real_A0.size(0) time_idx = (torch.randint(T, size=[1]).cuda() * torch.ones(size=[1]).cuda()).long() self.time_idx = time_idx @@ -149,17 +170,17 @@ class ROMAModel(BaseModel): scale = (delta * (1 - delta / denom)).reshape(-1, 1, 1, 1) # 对 Xt、Xt2 进行随机噪声更新 - Xt = self.mutil_real_A0_tokens if (t == 0) else (1 - inter) * Xt + inter * Xt_1.detach() + \ - (scale * tau).sqrt() * torch.randn_like(Xt).to(self.mutil_real_A0_tokens.device) - time_idx = (t * torch.ones(size=[self.mutil_real_A0_tokens.shape[0]]).to(self.mutil_real_A0_tokens.device)).long() - z = torch.randn(size=[self.mutil_real_A0_tokens.shape[0], 4 * self.opt.ngf]).to(self.mutil_real_A0_tokens.device) + Xt = self.real_A0 if (t == 0) else (1 - inter) * Xt + inter * Xt_1.detach() + \ + (scale * tau).sqrt() * torch.randn_like(Xt).to(self.real_A0.device) + time_idx = (t * torch.ones(size=[self.real_A0.shape[0]]).to(self.real_A0.device)).long() + z = torch.randn(size=[self.real_A0.shape[0], 4 * self.opt.ngf]).to(self.real_A0.device) self.time = times[time_idx] Xt_1 = self.netG(Xt, self.time, z) - Xt2 = self.mutil_real_A1_tokens if (t == 0) else (1 - inter) * Xt2 + inter * Xt_12.detach() + \ - (scale * tau).sqrt() * torch.randn_like(Xt2).to(self.mutil_real_A1_tokens.device) - time_idx = (t * torch.ones(size=[self.mutil_real_A1_tokens.shape[0]]).to(self.mutil_real_A1_tokens.device)).long() - z = torch.randn(size=[self.mutil_real_A1_tokens.shape[0], 4 * self.opt.ngf]).to(self.mutil_real_A1_tokens.device) + Xt2 = self.real_A1 if (t == 0) else (1 - inter) * Xt2 + inter * Xt_12.detach() + \ + (scale * tau).sqrt() * torch.randn_like(Xt2).to(self.real_A1.device) + time_idx = (t * torch.ones(size=[self.real_A1.shape[0]]).to(self.real_A1.device)).long() + z = torch.randn(size=[self.real_A1.shape[0], 4 * self.opt.ngf]).to(self.real_A1.device) Xt_12 = self.netG(Xt2, self.time, z) # 保存去噪后的中间结果 (real_A_noisy 等),供下一步做拼接 @@ -169,11 +190,11 @@ class ROMAModel(BaseModel): self.noisy_map = self.real_A_noisy - self.real_A # ============ 第三步:拼接输入并执行网络推理 ============= - bs = self.mutil_real_A0_tokens.size(0) - z_in = torch.randn(size=[2 * bs, 4 * self.opt.ngf]).to(self.mutil_real_A0_tokens.device) - z_in2 = torch.randn(size=[bs, 4 * self.opt.ngf]).to(self.mutil_real_A1_tokens.device) + bs = self.real_A0.size(0) + z_in = torch.randn(size=[2 * bs, 4 * self.opt.ngf]).to(self.real_A0.device) + z_in2 = torch.randn(size=[bs, 4 * self.opt.ngf]).to(self.real_A1.device) # 将 real_A, real_B 拼接 (如 nce_idt=True),并同样处理 real_A_noisy 与 XtB - self.real = self.mutil_real_A0_tokens + self.real = self.real_A0 self.realt = self.real_A_noisy if self.opt.flip_equivariance: @@ -206,6 +227,28 @@ class ROMAModel(BaseModel): self.mutil_fake_B0_tokens = self.netPreViT(self.fake_B0_resize, self.atten_layers, get_tokens=True) self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True) + + if self.opt.phase == 'train': + # 真实图像的梯度 + real_gradient = torch.autograd.grad(self.real_B.sum(), self.real_B, create_graph=True)[0] + # 生成图像的梯度 + fake_gradient = torch.autograd.grad(self.fake_B.sum(), self.fake_B, create_graph=True)[0] + # 梯度图 + self.weight_real, self.weight_fake = self.cao.generate_weight_map(real_gradient, fake_gradient) + + # 生成图像的CTN光流图 + self.f_content = self.ctn(self.weight_fake) + + # 把前面生成后的图片再加上noisy_map + self.fake_B0_2 = self.fake_B0 + self.noisy_map + + # 变换后的图片 + wapped_fake_B0_2 = warp(self.fake_B0_2, self.f_content) + + # 经过第二次生成器 + self.fake_B0_2 = self.netG_2(wapped_fake_B0_2, self.time, z_in) + + def tokens_concat(self, origin_tokens, adjacent_size): adj_size = adjacent_size B, token_num, C = origin_tokens.shape[0], origin_tokens.shape[1], origin_tokens.shape[2] @@ -277,6 +320,18 @@ class ROMAModel(BaseModel): return self.loss_D_ViT + + def compute_E_loss(self): + """计算判别器 E 的损失""" + + XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B0.detach()], dim=1) + XtXt_2 = torch.cat([self.real_A_noisy2, self.fake_B1.detach()], dim=1) + temp = torch.logsumexp(self.netE(XtXt_1, self.time, XtXt_2).reshape(-1), dim=0).mean() + self.loss_E = -self.netE(XtXt_1, self.time, XtXt_1).mean() + temp + temp**2 + + return self.loss_E + + def compute_G_loss(self): if self.opt.lambda_GAN > 0.0: @@ -291,22 +346,35 @@ class ROMAModel(BaseModel): else: self.loss_G_GAN_ViT = 0.0 + + self.loss_SB = 0 + if self.opt.lambda_SB > 0.0: + XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B0], dim=1) + XtXt_2 = torch.cat([self.real_A_noisy2, self.fake_B1], dim=1) + + bs = self.opt.batch_size + + # eq.9 + ET_XY = self.netE(XtXt_1, self.time, XtXt_1).mean() - torch.logsumexp(self.netE(XtXt_1, self.time, XtXt_2).reshape(-1), dim=0) + self.loss_SB = -(self.opt.num_timesteps - self.time[0]) / self.opt.num_timesteps * self.opt.tau * ET_XY + self.loss_SB += self.opt.tau * torch.mean((self.real_A_noisy - self.fake_B0) ** 2) + self.loss_SB += self.opt.tau * torch.mean((self.real_A_noisy2 - self.fake_B1) ** 2) + + if self.opt.lambda_global > 0.0 or self.opt.lambda_spatial > 0.0: self.loss_global, self.loss_spatial = self.calculate_attention_loss() else: self.loss_global, self.loss_spatial = 0.0, 0.0 - - if self.opt.lambda_motion > 0.0: - self.loss_motion = 0.0 - for real_A0_tokens, real_A1_tokens, fake_B0_tokens, fake_B1_tokens in zip(self.mutil_real_A0_tokens, self.mutil_real_A1_tokens, self.mutil_fake_B0_tokens, self.mutil_fake_B1_tokens): - A0_B1 = real_A0_tokens.bmm(fake_B1_tokens.permute(0,2,1)) - B0_A1 = fake_B0_tokens.bmm(real_A1_tokens.permute(0,2,1)) - cos_dis_global = F.cosine_similarity(A0_B1, B0_A1, dim=-1) - self.loss_motion += self.criterionL1(torch.ones_like(cos_dis_global), cos_dis_global).mean() + + + if self.opt.lambda_ctn > 0.0: + wapped_fake_B1 = warp(self.fake_B1, self.f_content) # use updated self.f_content + self.l2_loss = F.mse_loss(self.fake_B0_2, wapped_fake_B1) * self.opt.lambda_ctn else: - self.loss_motion = 0.0 - - self.loss_G = self.loss_G_GAN_ViT + self.loss_global + self.loss_spatial + self.loss_motion + self.l2_loss = 0.0 + + + self.loss_G = self.loss_G_GAN_ViT + self.loss_global + self.loss_spatial + self.l2_loss # include l2_loss in total loss return self.loss_G def calculate_attention_loss(self): diff --git a/models/self_build.py b/models/self_build.py index 1cb0f37..31d5deb 100644 --- a/models/self_build.py +++ b/models/self_build.py @@ -332,6 +332,7 @@ class CTNxModel(BaseModel): self.loss_D.backward() self.optimizer_D.step() + # update E self.set_requires_grad(self.netE, True) self.optimizer_E.zero_grad() self.loss_E = self.compute_E_loss()