2025-02-22 15:23:52 +08:00

192 lines
7.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import GaussianBlur
def warp(image, flow): #warp操作
"""
基于光流的图像变形函数
Args:
image: [B, C, H, W] 输入图像
flow: [B, 2, H, W] 光流场x/y方向位移
Returns:
warped: [B, C, H, W] 变形后的图像
"""
B, C, H, W = image.shape
# 生成网格坐标
grid_x, grid_y = torch.meshgrid(torch.arange(W), torch.arange(H))
grid = torch.stack((grid_x, grid_y), dim=0).float().to(image.device) # [2,H,W]
grid = grid.unsqueeze(0).repeat(B,1,1,1) # [B,2,H,W]
# 应用光流位移(归一化到[-1,1]
new_grid = grid + flow
new_grid[:,0,:,:] = 2.0 * new_grid[:,0,:,:] / (W-1) - 1.0 # x方向
new_grid[:,1,:,:] = 2.0 * new_grid[:,1,:,:] / (H-1) - 1.0 # y方向
new_grid = new_grid.permute(0,2,3,1) # [B,H,W,2]
# 双线性插值
return F.grid_sample(image, new_grid, align_corners=True)
# 时序归一化损失计算
def compute_ctn_loss(G, x, F_content): #公式10
"""
计算内容感知时序归一化损失
Args:
G: 生成器
x: 输入红外图像 [B,C,H,W]
F_content: 生成的光流场 [B,2,H,W]
"""
# 生成可见光图像
y_fake = G(x) # [B,3,H,W]
# 对生成结果应用光流变形
warped_fake = warp(y_fake, F_content) # [B,3,H,W]
# 对输入应用相同光流后生成图像
warped_x = warp(x, F_content) # [B,C,H,W]
y_fake_warped = G(warped_x) # [B,3,H,W]
# 计算L2损失
loss = F.mse_loss(warped_fake, y_fake_warped)
return loss
class ContentAwareOptimization(nn.Module):
def __init__(self, lambda_inc=2.0, eta_ratio=0.4):
super().__init__()
self.lambda_inc = lambda_inc # 权重增强系数
self.eta_ratio = eta_ratio # 选择内容区域的比例
def compute_cosine_similarity(self, gradients):
"""
计算每个patch梯度与平均梯度的余弦相似度
Args:
gradients: [B, N, D] 判别器输出的每个patch的梯度N=w*h
Returns:
cosine_sim: [B, N] 每个patch的余弦相似度
"""
mean_grad = torch.mean(gradients, dim=1, keepdim=True) # [B, 1, D]
# 计算余弦相似度
cosine_sim = F.cosine_similarity(gradients, mean_grad, dim=2) # [B, N]
return cosine_sim
def generate_weight_map(self, gradients_real, gradients_fake):
"""
生成内容感知权重图
Args:
gradients_real: [B, N, D] 真实图像判别器梯度
gradients_fake: [B, N, D] 生成图像判别器梯度
Returns:
weight_real: [B, N] 真实图像权重图
weight_fake: [B, N] 生成图像权重图
"""
# 计算真实图像块的余弦相似度
cosine_real = self.compute_cosine_similarity(gradients_real) # [B, N] 公式5
# 计算生成图像块的余弦相似度
cosine_fake = self.compute_cosine_similarity(gradients_fake) # [B, N]
# 选择内容丰富的区域余弦相似度最低的eta_ratio比例
k = int(self.eta_ratio * cosine_real.shape[1])
# 对真实图像生成权重图
_, real_indices = torch.topk(-cosine_real, k, dim=1) # 选择最不相似的区域
weight_real = torch.ones_like(cosine_real)
for b in range(cosine_real.shape[0]):
weight_real[b, real_indices[b]] = self.lambda_inc / (1e-6 + torch.abs(cosine_real[b, real_indices[b]])) #公式6
# 对生成图像生成权重图(同理)
_, fake_indices = torch.topk(-cosine_fake, k, dim=1)
weight_fake = torch.ones_like(cosine_fake)
for b in range(cosine_fake.shape[0]):
weight_fake[b, fake_indices[b]] = self.lambda_inc / (1e-6 + torch.abs(cosine_fake[b, fake_indices[b]]))
return weight_real, weight_fake
def forward(self, D_real, D_fake, real_scores, fake_scores):
"""
计算内容感知对抗损失
Args:
D_real: 判别器对真实图像的特征输出 [B, C, H, W]
D_fake: 判别器对生成图像的特征输出 [B, C, H, W]
real_scores: 真实图像的判别器预测 [B, N] (N=H*W)
fake_scores: 生成图像的判别器预测 [B, N]
Returns:
loss_co_adv: 内容感知对抗损失
"""
B, C, H, W = D_real.shape
N = H * W
# 注册钩子获取梯度
gradients_real = []
gradients_fake = []
def hook_real(grad):
gradients_real.append(grad.detach().view(B, N, -1))
def hook_fake(grad):
gradients_fake.append(grad.detach().view(B, N, -1))
D_real.register_hook(hook_real)
D_fake.register_hook(hook_fake)
# 计算原始对抗损失以触发梯度计算
loss_real = torch.mean(torch.log(real_scores + 1e-8))
loss_fake = torch.mean(torch.log(1 - fake_scores + 1e-8))
# 添加与 D_real、D_fake 相关的 dummy 项,确保梯度传递
loss_dummy = 1e-8 * (D_real.sum() + D_fake.sum())
total_loss = loss_real + loss_fake + loss_dummy
total_loss.backward(retain_graph=True)
# 获取梯度数据
gradients_real = gradients_real[0] # [B, N, D]
gradients_fake = gradients_fake[0] # [B, N, D]
# 生成权重图
self.weight_real, self.weight_fake = self.generate_weight_map(gradients_real, gradients_fake)
# 应用权重到对抗损失
loss_co_real = torch.mean(self.weight_real * torch.log(real_scores + 1e-8))
loss_co_fake = torch.mean(self.weight_fake * torch.log(1 - fake_scores + 1e-8))
# 计算并返回最终内容感知对抗损失
loss_co_adv = -(loss_co_real + loss_co_fake)
return loss_co_adv
class ContentAwareTemporalNorm(nn.Module):
def __init__(self, gamma_stride=0.1, kernel_size=21, sigma=5.0):
super().__init__()
self.gamma_stride = gamma_stride # 控制整体运动幅度
self.smoother = GaussianBlur(kernel_size, sigma=sigma) # 高斯平滑层
def forward(self, weight_map):
"""
生成内容感知光流
Args:
weight_map: [B, 1, H, W] 权重图(来自内容感知优化模块)
Returns:
F_content: [B, 2, H, W] 生成的光流场x/y方向位移
"""
B, _, H, W = weight_map.shape
# 1. 归一化权重图
# 保持区域相对强度,同时限制数值范围
weight_norm = F.normalize(weight_map, p=1, dim=(2,3)) # L1归一化 [B,1,H,W]
# 2. 生成高斯噪声(与光流场同尺寸)
z = torch.randn(B, 2, H, W, device=weight_map.device) # [B,2,H,W]
# 3. 合成基础光流
# 将权重图扩展为2通道x/y方向共享权重
weight_expanded = weight_norm.expand(-1, 2, -1, -1) # [B,2,H,W]
F_raw = self.gamma_stride * weight_expanded * z # [B,2,H,W] #公式9
# 4. 平滑处理(保持结构连续性)
# 对每个通道独立进行高斯模糊
F_smooth = self.smoother(F_raw) # [B,2,H,W]
# 5. 动态范围调整(可选)
# 限制光流幅值,避免极端位移
F_content = torch.tanh(F_smooth) # 缩放到[-1,1]范围
return F_content