import torch import torch.nn as nn import torch.nn.functional as F from torchvision.transforms import GaussianBlur def warp(image, flow): #warp操作 """ 基于光流的图像变形函数 Args: image: [B, C, H, W] 输入图像 flow: [B, 2, H, W] 光流场(x/y方向位移) Returns: warped: [B, C, H, W] 变形后的图像 """ B, C, H, W = image.shape # 生成网格坐标 grid_x, grid_y = torch.meshgrid(torch.arange(W), torch.arange(H)) grid = torch.stack((grid_x, grid_y), dim=0).float().to(image.device) # [2,H,W] grid = grid.unsqueeze(0).repeat(B,1,1,1) # [B,2,H,W] # 应用光流位移(归一化到[-1,1]) new_grid = grid + flow new_grid[:,0,:,:] = 2.0 * new_grid[:,0,:,:] / (W-1) - 1.0 # x方向 new_grid[:,1,:,:] = 2.0 * new_grid[:,1,:,:] / (H-1) - 1.0 # y方向 new_grid = new_grid.permute(0,2,3,1) # [B,H,W,2] # 双线性插值 return F.grid_sample(image, new_grid, align_corners=True) # 时序归一化损失计算 def compute_ctn_loss(G, x, F_content): #公式10 """ 计算内容感知时序归一化损失 Args: G: 生成器 x: 输入红外图像 [B,C,H,W] F_content: 生成的光流场 [B,2,H,W] """ # 生成可见光图像 y_fake = G(x) # [B,3,H,W] # 对生成结果应用光流变形 warped_fake = warp(y_fake, F_content) # [B,3,H,W] # 对输入应用相同光流后生成图像 warped_x = warp(x, F_content) # [B,C,H,W] y_fake_warped = G(warped_x) # [B,3,H,W] # 计算L2损失 loss = F.mse_loss(warped_fake, y_fake_warped) return loss class ContentAwareOptimization(nn.Module): def __init__(self, lambda_inc=2.0, eta_ratio=0.4): super().__init__() self.lambda_inc = lambda_inc # 权重增强系数 self.eta_ratio = eta_ratio # 选择内容区域的比例 def compute_cosine_similarity(self, gradients): """ 计算每个patch梯度与平均梯度的余弦相似度 Args: gradients: [B, N, D] 判别器输出的每个patch的梯度(N=w*h) Returns: cosine_sim: [B, N] 每个patch的余弦相似度 """ mean_grad = torch.mean(gradients, dim=1, keepdim=True) # [B, 1, D] # 计算余弦相似度 cosine_sim = F.cosine_similarity(gradients, mean_grad, dim=2) # [B, N] return cosine_sim def generate_weight_map(self, gradients_real, gradients_fake): """ 生成内容感知权重图 Args: gradients_real: [B, N, D] 真实图像判别器梯度 gradients_fake: [B, N, D] 生成图像判别器梯度 Returns: weight_real: [B, N] 真实图像权重图 weight_fake: [B, N] 生成图像权重图 """ # 计算真实图像块的余弦相似度 cosine_real = self.compute_cosine_similarity(gradients_real) # [B, N] 公式5 # 计算生成图像块的余弦相似度 cosine_fake = self.compute_cosine_similarity(gradients_fake) # [B, N] # 选择内容丰富的区域(余弦相似度最低的eta_ratio比例) k = int(self.eta_ratio * cosine_real.shape[1]) # 对真实图像生成权重图 _, real_indices = torch.topk(-cosine_real, k, dim=1) # 选择最不相似的区域 weight_real = torch.ones_like(cosine_real) for b in range(cosine_real.shape[0]): weight_real[b, real_indices[b]] = self.lambda_inc / (1e-6 + torch.abs(cosine_real[b, real_indices[b]])) #公式6 # 对生成图像生成权重图(同理) _, fake_indices = torch.topk(-cosine_fake, k, dim=1) weight_fake = torch.ones_like(cosine_fake) for b in range(cosine_fake.shape[0]): weight_fake[b, fake_indices[b]] = self.lambda_inc / (1e-6 + torch.abs(cosine_fake[b, fake_indices[b]])) return weight_real, weight_fake def forward(self, D_real, D_fake, real_scores, fake_scores): """ 计算内容感知对抗损失 Args: D_real: 判别器对真实图像的特征输出 [B, C, H, W] D_fake: 判别器对生成图像的特征输出 [B, C, H, W] real_scores: 真实图像的判别器预测 [B, N] (N=H*W) fake_scores: 生成图像的判别器预测 [B, N] Returns: loss_co_adv: 内容感知对抗损失 """ B, C, H, W = D_real.shape N = H * W # 注册钩子获取梯度 gradients_real = [] gradients_fake = [] def hook_real(grad): gradients_real.append(grad.detach().view(B, N, -1)) def hook_fake(grad): gradients_fake.append(grad.detach().view(B, N, -1)) D_real.register_hook(hook_real) D_fake.register_hook(hook_fake) # 计算原始对抗损失以触发梯度计算 loss_real = torch.mean(torch.log(real_scores + 1e-8)) loss_fake = torch.mean(torch.log(1 - fake_scores + 1e-8)) # 添加与 D_real、D_fake 相关的 dummy 项,确保梯度传递 loss_dummy = 1e-8 * (D_real.sum() + D_fake.sum()) total_loss = loss_real + loss_fake + loss_dummy total_loss.backward(retain_graph=True) # 获取梯度数据 gradients_real = gradients_real[0] # [B, N, D] gradients_fake = gradients_fake[0] # [B, N, D] # 生成权重图 self.weight_real, self.weight_fake = self.generate_weight_map(gradients_real, gradients_fake) # 应用权重到对抗损失 loss_co_real = torch.mean(self.weight_real * torch.log(real_scores + 1e-8)) loss_co_fake = torch.mean(self.weight_fake * torch.log(1 - fake_scores + 1e-8)) # 计算并返回最终内容感知对抗损失 loss_co_adv = -(loss_co_real + loss_co_fake) return loss_co_adv class ContentAwareTemporalNorm(nn.Module): def __init__(self, gamma_stride=0.1, kernel_size=21, sigma=5.0): super().__init__() self.gamma_stride = gamma_stride # 控制整体运动幅度 self.smoother = GaussianBlur(kernel_size, sigma=sigma) # 高斯平滑层 def forward(self, weight_map): """ 生成内容感知光流 Args: weight_map: [B, 1, H, W] 权重图(来自内容感知优化模块) Returns: F_content: [B, 2, H, W] 生成的光流场(x/y方向位移) """ B, _, H, W = weight_map.shape # 1. 归一化权重图 # 保持区域相对强度,同时限制数值范围 weight_norm = F.normalize(weight_map, p=1, dim=(2,3)) # L1归一化 [B,1,H,W] # 2. 生成高斯噪声(与光流场同尺寸) z = torch.randn(B, 2, H, W, device=weight_map.device) # [B,2,H,W] # 3. 合成基础光流 # 将权重图扩展为2通道(x/y方向共享权重) weight_expanded = weight_norm.expand(-1, 2, -1, -1) # [B,2,H,W] F_raw = self.gamma_stride * weight_expanded * z # [B,2,H,W] #公式9 # 4. 平滑处理(保持结构连续性) # 对每个通道独立进行高斯模糊 F_smooth = self.smoother(F_raw) # [B,2,H,W] # 5. 动态范围调整(可选) # 限制光流幅值,避免极端位移 F_content = torch.tanh(F_smooth) # 缩放到[-1,1]范围 return F_content