roma_unsb/models/roma_unsb_model.py
2025-03-09 21:41:52 +08:00

476 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import numpy as np
import math
import timm
import torch
import torchvision.models as models
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import GaussianBlur
from .base_model import BaseModel
from . import networks
from .patchnce import PatchNCELoss
import util.util as util
from torchvision.transforms import transforms as tfs
def warp(image, flow): #warp操作
"""
基于光流的图像变形函数
Args:
image: [B, C, H, W] 输入图像
flow: [B, 2, H, W] 光流场(x/y方向位移)
Returns:
warped: [B, C, H, W] 变形后的图像
"""
B, C, H, W = image.shape
# 生成网格坐标
grid_x, grid_y = torch.meshgrid(torch.arange(W), torch.arange(H))
grid = torch.stack((grid_x, grid_y), dim=0).float().to(image.device) # [2,H,W]
grid = grid.unsqueeze(0).repeat(B,1,1,1) # [B,2,H,W]
# 应用光流位移(归一化到[-1,1])
new_grid = grid + flow
new_grid[:,0,:,:] = 2.0 * new_grid[:,0,:,:] / (W-1) - 1.0 # x方向
new_grid[:,1,:,:] = 2.0 * new_grid[:,1,:,:] / (H-1) - 1.0 # y方向
new_grid = new_grid.permute(0,2,3,1) # [B,H,W,2]
# 双线性插值
return F.grid_sample(image, new_grid, align_corners=True)
# 时序归一化损失计算
def compute_ctn_loss(G, x, F_content): #公式10
"""
计算内容感知时序归一化损失
Args:
G: 生成器
x: 输入红外图像 [B,C,H,W]
F_content: 生成的光流场 [B,2,H,W]
"""
# 生成可见光图像
y_fake = G(x) # [B,3,H,W]
# 对生成结果应用光流变形
warped_fake = warp(y_fake, F_content) # [B,3,H,W]
# 对输入应用相同光流后生成图像
warped_x = warp(x, F_content) # [B,C,H,W]
y_fake_warped = G(warped_x) # [B,3,H,W]
# 计算L2损失
loss = F.mse_loss(warped_fake, y_fake_warped)
return loss
class ContentAwareOptimization(nn.Module):
def __init__(self, lambda_inc=2.0, eta_ratio=0.4):
super().__init__()
self.lambda_inc = lambda_inc
self.eta_ratio = eta_ratio
self.gradients_real = []
self.gradients_fake = []
def compute_cosine_similarity(self, gradients):
mean_grad = torch.mean(gradients, dim=1, keepdim=True)
return F.cosine_similarity(gradients, mean_grad, dim=2)
def generate_weight_map(self, gradients_real, gradients_fake):
# 计算余弦相似度
cosine_real = self.compute_cosine_similarity(gradients_real)
cosine_fake = self.compute_cosine_similarity(gradients_fake)
# 生成权重图(优化实现)
def _get_weights(cosine):
k = int(self.eta_ratio * cosine.shape[1])
_, indices = torch.topk(-cosine, k, dim=1)
weights = torch.ones_like(cosine)
weights.scatter_(1, indices, self.lambda_inc / (1e-6 + torch.abs(cosine.gather(1, indices))))
return weights
weight_real = _get_weights(cosine_real)
weight_fake = _get_weights(cosine_fake)
return weight_real, weight_fake
def forward(self, D_real, D_fake, real_scores, fake_scores):
# 清空梯度缓存
self.gradients_real.clear()
self.gradients_fake.clear()
self.criterionGAN=networks.GANLoss('lsgan').cuda()
# 注册钩子捕获梯度
hook_real = lambda grad: self.gradients_real.append(grad.detach())
hook_fake = lambda grad: self.gradients_fake.append(grad.detach())
D_real.register_hook(hook_real)
D_fake.register_hook(hook_fake)
# 触发梯度计算(保留计算图)
(real_scores.mean() + fake_scores.mean()).backward(retain_graph=True)
# 获取梯度并调整维度
grad_real = self.gradients_real[0].flatten(1) # [B, N, D] → [B, N*D]
grad_fake = self.gradients_fake[0].flatten(1)
# 生成权重图
weight_real, weight_fake = self.generate_weight_map(
grad_real.view(*D_real.shape),
grad_fake.view(*D_fake.shape)
)
# 正确应用权重到对数概率论文公式7
loss_co_real = torch.mean(weight_real * self.criterionGAN(real_scores , True))
loss_co_fake = torch.mean(weight_fake * self.criterionGAN(fake_scores , False))
# 总损失(注意符号:判别器需最大化该损失)
loss_co_adv = (loss_co_real + loss_co_fake)*0.5
return loss_co_adv, weight_real, weight_fake
class ContentAwareTemporalNorm(nn.Module):
def __init__(self, gamma_stride=0.1, kernel_size=21, sigma=5.0):
super().__init__()
self.gamma_stride = gamma_stride # 控制整体运动幅度
self.smoother = GaussianBlur(kernel_size, sigma=sigma) # 高斯平滑层
def upsample_weight_map(self, weight_patch, target_size=(256, 256)):
"""
将patch级别的权重图上采样到目标分辨率
Args:
weight_patch: [B, 1, 24, 24] 来自ViT的patch权重图
target_size: 目标分辨率 (H, W)
Returns:
weight_full: [B, 1, 256, 256] 上采样后的全分辨率权重图
"""
# 使用双线性插值上采样
B = weight_patch.shape[0]
weight_patch = weight_patch.view(B, 1, 24, 24)
weight_full = F.interpolate(
weight_patch,
size=target_size,
mode='bilinear',
align_corners=False
)
# 对每个16x16的patch内部保持权重一致可选
# 通过平均池化再扩展,消除插值引入的渐变
weight_full = F.avg_pool2d(weight_full, kernel_size=16, stride=16)
weight_full = F.interpolate(weight_full, scale_factor=16, mode='nearest')
return weight_full
def forward(self, weight_map):
"""
生成内容感知光流
Args:
weight_map: [B, 1, H, W] 权重图(来自内容感知优化模块)
Returns:
F_content: [B, 2, H, W] 生成的光流场(x/y方向位移)
"""
# 上采样权重图到全分辨率
weight_full = self.upsample_weight_map(weight_map) # [B,1,384,384]
# 1. 归一化权重图
# 保持区域相对强度,同时限制数值范围
weight_norm = F.normalize(weight_full, p=1, dim=(2,3)) # L1归一化 [B,1,H,W]
# 2. 生成高斯噪声
B, _, H, W = weight_norm.shape
z = torch.randn(B, 2, H, W, device=weight_norm.device) # [B,2,H,W]
# 3. 合成基础光流
# 将权重图扩展为2通道(x/y方向共享权重)
weight_expanded = weight_norm.expand(-1, 2, -1, -1) # [B,2,H,W]
F_raw = self.gamma_stride * weight_expanded * z # [B,2,H,W] #公式9
# 4. 平滑处理(保持结构连续性)
# 对每个通道独立进行高斯模糊
F_smooth = self.smoother(F_raw) # [B,2,H,W]
# 5. 动态范围调整(可选)
# 限制光流幅值,避免极端位移
F_content = torch.tanh(F_smooth) # 缩放到[-1,1]范围
return F_content
class RomaUnsbModel(BaseModel):
@staticmethod
def modify_commandline_options(parser, is_train=True):
"""配置 CTNx 模型的特定选项"""
parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss: GAN(G(X))')
parser.add_argument('--lambda_ctn', type=float, default=1.0, help='weight for content-aware temporal norm')
parser.add_argument('--lambda_D_ViT', type=float, default=1.0, help='weight for discriminator')
parser.add_argument('--lambda_global', type=float, default=1.0, help='weight for Global Structural Consistency')
parser.add_argument('--lambda_spatial', type=float, default=1.0, help='weight for Local Structural Consistency')
parser.add_argument('--lambda_inc', type=float, default=1.0, help='incremental weight for content-aware optimization')
parser.add_argument('--local_nums', type=int, default=64, help='number of local patches')
parser.add_argument('--side_length', type=int, default=7)
parser.add_argument('--nce_layers', type=str, default='0,4,8,12,16', help='compute NCE loss on which layers')
parser.add_argument('--eta_ratio', type=float, default=0.4, help='ratio of content-rich regions')
parser.add_argument('--gamma_stride', type=float, default=20, help='ratio of stride for computing the similarity matrix')
parser.add_argument('--atten_layers', type=str, default='5', help='compute Cross-Similarity on which layers')
parser.add_argument('--tau', type=float, default=0.01, help='Entropy parameter')
parser.add_argument('--num_timesteps', type=int, default=5, help='# of discrim filters in the first conv layer')
parser.add_argument('--n_mlp', type=int, default=3, help='only used if netD==n_layers')
opt, _ = parser.parse_known_args()
return parser
def __init__(self, opt):
"""初始化 CTNx 模型"""
BaseModel.__init__(self, opt)
# 指定需要打印的训练损失
self.loss_names = ['G_GAN', 'D_ViT', 'G', 'global', 'spatial','ctn']
self.visual_names = ['real_A0', 'fake_B0', 'real_B0','real_A1', 'fake_B1', 'real_B1']
self.atten_layers = [int(i) for i in self.opt.atten_layers.split(',')]
if self.opt.phase == 'test':
self.visual_names = ['real']
for NFE in range(self.opt.num_timesteps):
fake_name = 'fake_' + str(NFE+1)
self.visual_names.append(fake_name)
self.nce_layers = [int(i) for i in self.opt.nce_layers.split(',')]
if self.isTrain:
self.model_names = ['G', 'D_ViT']
else:
self.model_names = ['G']
# 创建网络
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt)
if self.isTrain:
self.resize = tfs.Resize(size=(384,384), antialias=True)
self.netD_ViT = networks.MLPDiscriminator().to(self.device)
# 加入预训练VIT
self.netPreViT = timm.create_model("vit_base_patch16_384", pretrained=True).to(self.device)
# 定义损失函数
self.criterionL1 = torch.nn.L1Loss().to(self.device)
self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
self.optimizer_D = torch.optim.Adam(self.netD_ViT.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
self.optimizers = [self.optimizer_G, self.optimizer_D]
self.cao = ContentAwareOptimization(opt.lambda_inc, opt.eta_ratio) #损失函数
self.ctn = ContentAwareTemporalNorm() #生成的伪光流
def data_dependent_initialize(self, data):
"""
The feature network netF is defined in terms of the shape of the intermediate, extracted
features of the encoder portion of netG. Because of this, the weights of netF are
initialized at the first feedforward pass with some input images.
Please also see PatchSampleF.create_mlp(), which is called at the first forward() call.
"""
pass
def optimize_parameters(self):
# forward
self.forward()
self.netG.train()
self.netD_ViT.train()
# update D
self.set_requires_grad(self.netD_ViT, True)
self.optimizer_D.zero_grad()
self.loss_D = self.compute_D_loss()
self.loss_D.backward()
self.optimizer_D.step()
# update G
self.set_requires_grad(self.netD_ViT, False)
self.optimizer_G.zero_grad()
self.loss_G = self.compute_G_loss()
self.loss_G.backward()
self.optimizer_G.step()
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
input (dict): include the data itself and its metadata information.
The option 'direction' can be used to swap domain A and domain B.
"""
AtoB = self.opt.direction == 'AtoB'
self.real_A0 = input['A0' if AtoB else 'B0'].to(self.device)
self.real_A1 = input['A1' if AtoB else 'B1'].to(self.device)
self.real_B0 = input['B0' if AtoB else 'A0'].to(self.device)
self.real_B1 = input['B1' if AtoB else 'A1'].to(self.device)
self.image_paths = input['A_paths' if AtoB else 'B_paths']
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
self.fake_B0 = self.netG(self.real_A0)
self.fake_B1 = self.netG(self.real_A1)
if self.opt.isTrain:
real_A0 = self.real_A0
real_A1 = self.real_A1
real_B0 = self.real_B0
real_B1 = self.real_B1
fake_B0 = self.fake_B0
fake_B1 = self.fake_B1
self.real_A0_resize = self.resize(real_A0)
self.real_A1_resize = self.resize(real_A1)
real_B0 = self.resize(real_B0)
real_B1 = self.resize(real_B1)
self.fake_B0_resize = self.resize(fake_B0)
self.fake_B1_resize = self.resize(fake_B1)
self.mutil_real_A0_tokens = self.netPreViT(self.real_A0_resize, self.atten_layers, get_tokens=True)
self.mutil_real_A1_tokens = self.netPreViT(self.real_A1_resize, self.atten_layers, get_tokens=True)
self.mutil_real_B0_tokens = self.netPreViT(real_B0, self.atten_layers, get_tokens=True)
self.mutil_real_B1_tokens = self.netPreViT(real_B1, self.atten_layers, get_tokens=True)
self.mutil_fake_B0_tokens = self.netPreViT(self.fake_B0_resize, self.atten_layers, get_tokens=True)
self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True)
def compute_D_loss(self):
"""Calculate GAN loss with Content-Aware Optimization"""
lambda_D_ViT = self.opt.lambda_D_ViT
loss_cao = 0.0
real_B0_tokens = self.mutil_real_B0_tokens[0]
pred_real0, real_features0 = self.netD_ViT(real_B0_tokens) # scores, features
real_B1_tokens = self.mutil_real_B1_tokens[0]
pred_real1, real_features1 = self.netD_ViT(real_B1_tokens) # scores, features
pre_fake0, fake_features0 = self.netD_ViT(self.mutil_fake_B0_tokens[0].detach())
pre_fake1, fake_features1 = self.netD_ViT(self.mutil_fake_B1_tokens[0].detach())
loss_cao0, self.weight_real0, self.weight_fake0 = self.cao(
D_real=real_features0,
D_fake=fake_features0,
real_scores=pred_real0,
fake_scores=pre_fake0
)
loss_cao1, self.weight_real1, self.weight_fake1 = self.cao(
D_real=real_features1,
D_fake=fake_features1,
real_scores=pred_real1,
fake_scores=pre_fake1
)
loss_cao += loss_cao0 + loss_cao1
# ===== 综合损失 =====
self.loss_D_ViT = loss_cao * 0.5 * lambda_D_ViT
# 记录损失值供可视化
# self.loss_D_real = loss_D_real.item()
# self.loss_D_fake = loss_D_fake.item()
# self.loss_cao = (loss_cao0 + loss_cao1).item() * 0.5
return self.loss_D_ViT
def compute_G_loss(self):
"""计算生成器的 GAN 损失"""
if self.opt.lambda_ctn > 0.0:
# 生成图像的CTN光流图
self.f_content0 = self.ctn(self.weight_fake0.detach())
self.f_content1 = self.ctn(self.weight_fake1.detach())
# 变换后的图片
self.warped_real_A0 = warp(self.real_A0, self.f_content0)
self.warped_real_A1 = warp(self.real_A1, self.f_content1)
self.warped_fake_B0 = warp(self.fake_B0,self.f_content0)
self.warped_fake_B1 = warp(self.fake_B1,self.f_content1)
# 经过第二次生成器
self.warped_fake_B0_2 = self.netG(self.warped_real_A0)
self.warped_fake_B1_2 = self.netG(self.warped_real_A1)
warped_fake_B0_2=self.warped_fake_B0_2
warped_fake_B1_2=self.warped_fake_B1_2
warped_fake_B0=self.warped_fake_B0
warped_fake_B1=self.warped_fake_B1
# 计算L2损失
self.loss_ctn0 = F.mse_loss(warped_fake_B0_2, warped_fake_B0)
self.loss_ctn1 = F.mse_loss(warped_fake_B1_2, warped_fake_B1)
self.loss_ctn = (self.loss_ctn0 + self.loss_ctn1)*0.5
if self.opt.lambda_GAN > 0.0:
pred_fake0,_ = self.netD_ViT(self.mutil_fake_B0_tokens[0])
pred_fake1,_ = self.netD_ViT(self.mutil_fake_B1_tokens[0])
self.loss_G_GAN0 = self.criterionGAN(pred_fake0, True).mean()
self.loss_G_GAN1 = self.criterionGAN(pred_fake1, True).mean()
self.loss_G_GAN = (self.loss_G_GAN0 + self.loss_G_GAN1)*0.5
else:
self.loss_G_GAN = 0.0
if self.opt.lambda_global or self.opt.lambda_spatial > 0.0:
self.loss_global, self.loss_spatial = self.calculate_attention_loss()
else:
self.loss_global, self.loss_spatial = 0.0, 0.0
self.loss_G = self.opt.lambda_GAN * self.loss_G_GAN + \
self.opt.lambda_ctn * self.loss_ctn + \
self.loss_global * self.opt.lambda_global+\
self.loss_spatial * self.opt.lambda_spatial
return self.loss_G
def calculate_attention_loss(self):
n_layers = len(self.atten_layers)
mutil_real_A0_tokens = self.mutil_real_A0_tokens
mutil_real_A1_tokens = self.mutil_real_A1_tokens
mutil_fake_B0_tokens = self.mutil_fake_B0_tokens
mutil_fake_B1_tokens = self.mutil_fake_B1_tokens
if self.opt.lambda_global > 0.0:
loss_global = self.calculate_similarity(mutil_real_A0_tokens, mutil_fake_B0_tokens) + self.calculate_similarity(mutil_real_A1_tokens, mutil_fake_B1_tokens)
loss_global *= 0.5
else:
loss_global = 0.0
if self.opt.lambda_spatial > 0.0:
loss_spatial = 0.0
local_nums = self.opt.local_nums
tokens_cnt = 576
local_id = np.random.permutation(tokens_cnt)
local_id = local_id[:int(min(local_nums, tokens_cnt))]
mutil_real_A0_local_tokens = self.netPreViT(self.real_A0_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
mutil_real_A1_local_tokens = self.netPreViT(self.real_A1_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
mutil_fake_B0_local_tokens = self.netPreViT(self.fake_B0_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
mutil_fake_B1_local_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
loss_spatial = self.calculate_similarity(mutil_real_A0_local_tokens, mutil_fake_B0_local_tokens) + self.calculate_similarity(mutil_real_A1_local_tokens, mutil_fake_B1_local_tokens)
loss_spatial *= 0.5
else:
loss_spatial = 0.0
return loss_global , loss_spatial
def calculate_similarity(self, mutil_src_tokens, mutil_tgt_tokens):
loss = 0.0
n_layers = len(self.atten_layers)
for src_tokens, tgt_tokens in zip(mutil_src_tokens, mutil_tgt_tokens):
src_tgt = src_tokens.bmm(tgt_tokens.permute(0,2,1))
tgt_src = tgt_tokens.bmm(src_tokens.permute(0,2,1))
cos_dis_global = F.cosine_similarity(src_tgt, tgt_src, dim=-1)
loss += self.criterionL1(torch.ones_like(cos_dis_global), cos_dis_global).mean()
loss = loss / n_layers
return loss