保存一个版本

This commit is contained in:
bishe 2025-03-18 20:14:59 +08:00
parent f98c285950
commit 537cb050a5

View File

@ -13,6 +13,7 @@ import util.util as util
from torchvision.transforms import transforms as tfs
def warp(image, flow): #warp操作
"""
基于光流的图像变形函数
@ -37,76 +38,77 @@ def warp(image, flow): #warp操作
# 双线性插值
return F.grid_sample(image, new_grid, align_corners=True)
# 时序归一化损失计算
def compute_ctn_loss(G, x, F_content): #公式10
"""
计算内容感知时序归一化损失
Args:
G: 生成器
x: 输入红外图像 [B,C,H,W]
F_content: 生成的光流场 [B,2,H,W]
"""
# 生成可见光图像
y_fake = G(x) # [B,3,H,W]
# 对生成结果应用光流变形
warped_fake = warp(y_fake, F_content) # [B,3,H,W]
# 对输入应用相同光流后生成图像
warped_x = warp(x, F_content) # [B,C,H,W]
y_fake_warped = G(warped_x) # [B,3,H,W]
# 计算L2损失
loss = F.mse_loss(warped_fake, y_fake_warped)
return loss
class ContentAwareOptimization(nn.Module):
def __init__(self, lambda_inc=2.0, eta_ratio=0.4):
super().__init__()
self.lambda_inc = lambda_inc
self.eta_ratio = eta_ratio
self.gradients = [] # 修改为单一梯度列表,通用性更强
self.criterionGAN = networks.GANLoss('lsgan').cuda()
self.lambda_inc = lambda_inc # 控制内容丰富区域的权重增量
self.eta_ratio = eta_ratio # 选择内容丰富区域的比例
self.criterionGAN = networks.GANLoss('lsgan').cuda() # 使用 LSGAN 损失
def compute_cosine_similarity(self, gradients):
mean_grad = torch.mean(gradients, dim=1, keepdim=True)
return F.cosine_similarity(gradients, mean_grad, dim=2)
def generate_weight_map(self, gradients):
cosine = self.compute_cosine_similarity(gradients)
k = int(self.eta_ratio * cosine.shape[1])
_, indices = torch.topk(-cosine, k, dim=1)
weights = torch.ones_like(cosine)
weights.scatter_(1, indices, self.lambda_inc / (1e-6 + torch.abs(cosine.gather(1, indices))))
return weights
def forward(self, features, scores, target):
def compute_cosine_similarity(self, grad_patch, grad_mean):
"""
计算每个 patch 梯度与整体平均梯度的余弦相似度
Args:
features: 特征张量可以是判别器的 real/fake 特征或生成器的 fake 特征
scores: 判别器对特征的预测得分
target: 目标标签True 表示希望判为真False 表示希望判为假
grad_patch: [B, 1, H, W]每个 patch 的梯度基于 scores
grad_mean: [B, 1]整体平均梯度
Returns:
loss: 加权后的 GAN 损失
weight: 生成的权重图
cosine: [B, 1, H, W]余弦相似度 δ_i
"""
self.gradients.clear()
# 注册梯度钩子
hook = lambda grad: self.gradients.append(grad.detach())
features.register_hook(hook)
B, _, H, W = grad_patch.shape
grad_patch = grad_patch.view(B, 1, -1) # [B, 1, H*W]
grad_mean = grad_mean.unsqueeze(-1) # [B, 1, 1]
# 计算余弦相似度
cosine = F.cosine_similarity(grad_patch, grad_mean, dim=1) # [B, H*W]
return cosine.view(B, 1, H, W)
def generate_weight_map(self, cosine):
"""
根据余弦相似度生成权重图
Args:
cosine: [B, 1, H, W]余弦相似度 δ_i
Returns:
weights: [B, 1, H, W]权重图 w_i
"""
B, _, H, W = cosine.shape
cosine_flat = cosine.view(B, -1) # [B, H*W]
k = int(self.eta_ratio * cosine_flat.size(1)) # 选择 eta_ratio 比例的 patch
_, indices = torch.topk(-cosine_flat, k, dim=1) # 选择偏离最大的 k 个 patch
weights = torch.ones_like(cosine_flat)
for b in range(B):
selected_cosine = cosine_flat[b, indices[b]]
weights[b, indices[b]] = self.lambda_inc / (torch.exp(torch.abs(selected_cosine)) + 1e-6)
return weights.view(B, 1, H, W)
def forward(self, scores, target):
"""
前向传播计算加权后的 GAN 损失
Args:
scores: [B, 1, H, W]判别器的预测得分
target: 目标标签True False
Returns:
weighted_loss: 加权后的 GAN 损失
weight: 权重图 [B, 1, H, W]
"""
# 计算原始 GAN 损失
loss = self.criterionGAN(scores, target)
# 触发梯度计算
scores.mean().backward(retain_graph=True)
# 捕获特征的梯度
grad_scores = torch.autograd.grad(loss, scores, retain_graph=True)[0] # [B, C, H, W]
# 获取梯度并调整维度
grad = self.gradients[0].flatten(1) # [B, N, D] → [B, N*D]
weight = self.generate_weight_map(grad.view(*features.shape))
# 计算整体平均梯度
grad_mean = torch.mean(grad_scores, dim=(2, 3)) # [B, 1]
# 计算加权 GAN 损失
loss = torch.mean(weight * self.criterionGAN(scores, target))
return loss, weight
# 计算余弦相似度 δ_i公式 5
cosine = self.compute_cosine_similarity(grad_scores, grad_mean) # [B, 1, H, W]
# 生成权重图 w_i公式 6
weight = self.generate_weight_map(cosine)
# 应用权重到损失(公式 7 的部分实现)
weighted_loss = torch.mean(weight * self.criterionGAN(scores, target))
return weighted_loss, weight
class ContentAwareTemporalNorm(nn.Module):
def __init__(self, gamma_stride=0.1, kernel_size=21, sigma=5.0):
@ -115,31 +117,14 @@ class ContentAwareTemporalNorm(nn.Module):
self.smoother = GaussianBlur(kernel_size, sigma=sigma) # 高斯平滑层
def upsample_weight_map(self, weight_patch, target_size=(256, 256)):
"""
将patch级别的权重图上采样到目标分辨率
Args:
weight_patch: [B, 1, 24, 24] 来自ViT的patch权重图
target_size: 目标分辨率 (H, W)
Returns:
weight_full: [B, 1, 256, 256] 上采样后的全分辨率权重图
"""
# 使用双线性插值上采样
B = weight_patch.shape[0]
weight_patch = weight_patch.view(B, 1, 24, 24)
# weight_patch: [B, 1, 30, 30] 来自 PatchGAN
weight_full = F.interpolate(
weight_patch,
size=target_size,
mode='bilinear',
mode='bilinear', # 或 'nearest',根据需求选择
align_corners=False
)
# 对每个16x16的patch内部保持权重一致可选
# 通过平均池化再扩展,消除插值引入的渐变
weight_full = F.avg_pool2d(weight_full, kernel_size=16, stride=16)
weight_full = F.interpolate(weight_full, scale_factor=16, mode='nearest')
return weight_full
return weight_full # [B, 1, 256, 256]
def forward(self, weight_map):
"""
@ -175,6 +160,7 @@ class ContentAwareTemporalNorm(nn.Module):
return F_content
class RomaUnsbModel(BaseModel):
@staticmethod
def modify_commandline_options(parser, is_train=True):
@ -327,21 +313,22 @@ class RomaUnsbModel(BaseModel):
# 处理 real_B0 和 fake_B0
real_B0_tokens = self.mutil_real_B0_tokens[0]
pred_real0, real_features0 = self.netD_ViT(real_B0_tokens)
pred_real0 = self.netD_ViT(real_B0_tokens)
print(pred_real0.shape)
fake_B0_tokens = self.mutil_fake_B0_tokens[0].detach()
pred_fake0, fake_features0 = self.netD_ViT(fake_B0_tokens)
pred_fake0 = self.netD_ViT(fake_B0_tokens)
loss_real0, self.weight_real0 = self.cao(real_features0, pred_real0, True)
loss_fake0, self.weight_fake0 = self.cao(fake_features0, pred_fake0, False)
loss_real0, self.weight_real0 = self.cao( pred_real0, True)
loss_fake0, self.weight_fake0 = self.cao( pred_fake0, False)
# 处理 real_B1 和 fake_B1
real_B1_tokens = self.mutil_real_B1_tokens[0]
pred_real1, real_features1 = self.netD_ViT(real_B1_tokens)
pred_real1 = self.netD_ViT(real_B1_tokens)
fake_B1_tokens = self.mutil_fake_B1_tokens[0].detach()
pred_fake1, fake_features1 = self.netD_ViT(fake_B1_tokens)
pred_fake1 = self.netD_ViT(fake_B1_tokens)
loss_real1, self.weight_real1 = self.cao(real_features1, pred_real1, True)
loss_fake1, self.weight_fake1 = self.cao(fake_features1, pred_fake1, False)
loss_real1, self.weight_real1 = self.cao( pred_real1, True)
loss_fake1, self.weight_fake1 = self.cao( pred_fake1, False)
# 综合损失
self.loss_D_ViT = (loss_real0 + loss_fake0 + loss_real1 + loss_fake1) * 0.25 * lambda_D_ViT