Compare commits

...

3 Commits

Author SHA1 Message Date
bishe
77468f16f9 cptrans复现完成 2025-03-22 15:27:19 +08:00
bishe
c173e29ea6 final_version 2025-03-18 21:12:32 +08:00
bishe
537cb050a5 保存一个版本 2025-03-18 20:14:59 +08:00
6 changed files with 495 additions and 118 deletions

View File

@ -1411,13 +1411,12 @@ class MLPDiscriminator(nn.Module):
self.activation = nn.GELU()
self.linear2 = nn.Linear(hid_feat, out_feat)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
features = self.linear1(x) # 中间特征,即 D_real 或 D_fake
x = self.activation(features)
x = self.linear1(x)
x = self.activation(x)
x = self.dropout(x)
scores = self.linear2(x) # 最终分数,即 real_scores 或 fake_scores
return scores, features
x = self.linear2(x)
return self.dropout(x)
class NLayerDiscriminator(nn.Module):
"""Defines a PatchGAN discriminator"""

View File

@ -13,6 +13,7 @@ import util.util as util
from torchvision.transforms import transforms as tfs
def warp(image, flow): #warp操作
"""
基于光流的图像变形函数
@ -37,76 +38,74 @@ def warp(image, flow): #warp操作
# 双线性插值
return F.grid_sample(image, new_grid, align_corners=True)
# 时序归一化损失计算
def compute_ctn_loss(G, x, F_content): #公式10
"""
计算内容感知时序归一化损失
Args:
G: 生成器
x: 输入红外图像 [B,C,H,W]
F_content: 生成的光流场 [B,2,H,W]
"""
# 生成可见光图像
y_fake = G(x) # [B,3,H,W]
# 对生成结果应用光流变形
warped_fake = warp(y_fake, F_content) # [B,3,H,W]
# 对输入应用相同光流后生成图像
warped_x = warp(x, F_content) # [B,C,H,W]
y_fake_warped = G(warped_x) # [B,3,H,W]
# 计算L2损失
loss = F.mse_loss(warped_fake, y_fake_warped)
return loss
class ContentAwareOptimization(nn.Module):
def __init__(self, lambda_inc=2.0, eta_ratio=0.4):
super().__init__()
self.lambda_inc = lambda_inc
self.eta_ratio = eta_ratio
self.gradients = [] # 修改为单一梯度列表,通用性更强
self.criterionGAN = networks.GANLoss('lsgan').cuda()
self.lambda_inc = lambda_inc # 控制内容丰富区域的权重增量
self.eta_ratio = eta_ratio # 选择内容丰富区域的比例
self.criterionGAN = networks.GANLoss('lsgan').cuda() # 使用 LSGAN 损失
def compute_cosine_similarity(self, gradients):
mean_grad = torch.mean(gradients, dim=1, keepdim=True)
return F.cosine_similarity(gradients, mean_grad, dim=2)
def compute_cosine_similarity(self, grad_patch, grad_mean):
"""
计算每个 token 梯度与整体平均梯度的余弦相似度
Args:
grad_patch: [B, N, D]每个 token 的梯度来自 scores
grad_mean: [B, D]整体平均梯度
Returns:
cosine: [B, N]余弦相似度 δ_i
"""
# 对每个 token 计算余弦相似度
cosine = F.cosine_similarity(grad_patch, grad_mean.unsqueeze(1), dim=2) # [B, N]
return cosine
def generate_weight_map(self, gradients):
cosine = self.compute_cosine_similarity(gradients)
k = int(self.eta_ratio * cosine.shape[1])
_, indices = torch.topk(-cosine, k, dim=1)
def generate_weight_map(self, cosine):
"""
根据余弦相似度生成权重图
Args:
cosine: [B, N]余弦相似度 δ_i
Returns:
weights: [B, N]权重图 w_i
"""
B, N = cosine.shape
k = int(self.eta_ratio * N) # 选择 eta_ratio 比例的 token
_, indices = torch.topk(-cosine, k, dim=1) # 选择偏离最大的 k 个 token
weights = torch.ones_like(cosine)
weights.scatter_(1, indices, self.lambda_inc / (1e-6 + torch.abs(cosine.gather(1, indices))))
for b in range(B):
selected_cosine = cosine[b, indices[b]]
weights[b, indices[b]] = self.lambda_inc / (torch.exp(torch.abs(selected_cosine)) + 1e-6)
return weights
def forward(self, features, scores, target):
def forward(self, scores, target):
"""
前向传播计算加权后的 GAN 损失
Args:
features: 特征张量可以是判别器的 real/fake 特征或生成器的 fake 特征
scores: 判别器对特征的预测得分
target: 目标标签True 表示希望判为真False 表示希望判为假
scores: [B, N, D]判别器的预测得分
target: 目标标签True False
Returns:
loss: 加权后的 GAN 损失
weight: 生成的权重图
weighted_loss: 加权后的 GAN 损失
weight: 权重图 [B, N]
"""
self.gradients.clear()
# 注册梯度钩子
hook = lambda grad: self.gradients.append(grad.detach())
features.register_hook(hook)
# 计算原始 GAN 损失(假设 criterionGAN 返回 [B, N] 的损失分布)
loss = self.criterionGAN(scores, target)
# 触发梯度计算
scores.mean().backward(retain_graph=True)
# 捕获 scores 的梯度,形状为 [B, N, D]
grad_scores = torch.autograd.grad(loss, scores, retain_graph=True)[0]
# 获取梯度并调整维度
grad = self.gradients[0].flatten(1) # [B, N, D] → [B, N*D]
weight = self.generate_weight_map(grad.view(*features.shape))
# 计算整体平均梯度(在 N 维度上求均值)
grad_mean = torch.mean(grad_scores, dim=1) # [B, D]
# 计算余弦相似度 δ_i
cosine = self.compute_cosine_similarity(grad_scores, grad_mean) # [B, N]
# 生成权重图 w_i
weight = self.generate_weight_map(cosine) # [B, N]
# 计算加权后的 GAN 损失
weighted_loss = torch.mean(weight * self.criterionGAN(scores, target))
return weighted_loss, weight
# 计算加权 GAN 损失
loss = torch.mean(weight * self.criterionGAN(scores, target))
return loss, weight
class ContentAwareTemporalNorm(nn.Module):
def __init__(self, gamma_stride=0.1, kernel_size=21, sigma=5.0):
@ -115,66 +114,52 @@ class ContentAwareTemporalNorm(nn.Module):
self.smoother = GaussianBlur(kernel_size, sigma=sigma) # 高斯平滑层
def upsample_weight_map(self, weight_patch, target_size=(256, 256)):
"""
将patch级别的权重图上采样到目标分辨率
Args:
weight_patch: [B, 1, 24, 24] 来自ViT的patch权重图
target_size: 目标分辨率 (H, W)
Returns:
weight_full: [B, 1, 256, 256] 上采样后的全分辨率权重图
"""
# 使用双线性插值上采样
B = weight_patch.shape[0]
weight_patch = weight_patch.view(B, 1, 24, 24)
# weight_patch: [B, 1, H, W] 来自转换后的 weight_map
weight_full = F.interpolate(
weight_patch,
size=target_size,
mode='bilinear',
mode='bilinear', # 或 'nearest',根据需求选择
align_corners=False
)
# 对每个16x16的patch内部保持权重一致可选
# 通过平均池化再扩展,消除插值引入的渐变
weight_full = F.avg_pool2d(weight_full, kernel_size=16, stride=16)
weight_full = F.interpolate(weight_full, scale_factor=16, mode='nearest')
return weight_full
def forward(self, weight_map):
"""
生成内容感知光流
Args:
weight_map: [B, 1, H, W] 权重图(来自内容感知优化模块)
weight_map: [B, N] 权重图(来自 ContentAwareOptimization)其中 N=576
Returns:
F_content: [B, 2, H, W] 生成的光流场(x/y方向位移)
"""
B = weight_map.shape[0]
N = weight_map.shape[1]
# 假设 N 为完全平方数,计算边长(例如 576 -> 24x24
side = int(math.sqrt(N))
weight_map_2d = weight_map.view(B, 1, side, side) # 转换为 [B, 1, side, side]
# 上采样权重图到全分辨率
weight_full = self.upsample_weight_map(weight_map) # [B,1,384,384]
weight_full = self.upsample_weight_map(weight_map_2d) # [B, 1, 256, 256](例如)
# 1. 归一化权重图
# 保持区域相对强度,同时限制数值范围
weight_norm = F.normalize(weight_full, p=1, dim=(2,3)) # L1归一化 [B,1,H,W]
# 归一化权重图L1归一化
weight_norm = F.normalize(weight_full, p=1, dim=(2,3))
# 2. 生成高斯噪声
# 生成高斯噪声
B, _, H, W = weight_norm.shape
z = torch.randn(B, 2, H, W, device=weight_norm.device) # [B,2,H,W]
z = torch.randn(B, 2, H, W, device=weight_norm.device)
# 3. 合成基础光流
# 将权重图扩展为2通道(x/y方向共享权重)
weight_expanded = weight_norm.expand(-1, 2, -1, -1) # [B,2,H,W]
F_raw = self.gamma_stride * weight_expanded * z # [B,2,H,W] #公式9
# 合成基础光流
weight_expanded = weight_norm.expand(-1, 2, -1, -1)
F_raw = self.gamma_stride * weight_expanded * z
# 4. 平滑处理(保持结构连续性)
# 对每个通道独立进行高斯模糊
F_smooth = self.smoother(F_raw) # [B,2,H,W]
# 平滑处理
F_smooth = self.smoother(F_raw)
# 5. 动态范围调整(可选)
# 限制光流幅值,避免极端位移
F_content = torch.tanh(F_smooth) # 缩放到[-1,1]范围
# 动态范围调整
F_content = torch.tanh(F_smooth)
return F_content
class RomaUnsbModel(BaseModel):
@staticmethod
def modify_commandline_options(parser, is_train=True):
@ -327,21 +312,21 @@ class RomaUnsbModel(BaseModel):
# 处理 real_B0 和 fake_B0
real_B0_tokens = self.mutil_real_B0_tokens[0]
pred_real0, real_features0 = self.netD_ViT(real_B0_tokens)
pred_real0 = self.netD_ViT(real_B0_tokens)
fake_B0_tokens = self.mutil_fake_B0_tokens[0].detach()
pred_fake0, fake_features0 = self.netD_ViT(fake_B0_tokens)
pred_fake0 = self.netD_ViT(fake_B0_tokens)
loss_real0, self.weight_real0 = self.cao(real_features0, pred_real0, True)
loss_fake0, self.weight_fake0 = self.cao(fake_features0, pred_fake0, False)
loss_real0, self.weight_real0 = self.cao( pred_real0, True)
loss_fake0, self.weight_fake0 = self.cao( pred_fake0, False)
# 处理 real_B1 和 fake_B1
real_B1_tokens = self.mutil_real_B1_tokens[0]
pred_real1, real_features1 = self.netD_ViT(real_B1_tokens)
pred_real1 = self.netD_ViT(real_B1_tokens)
fake_B1_tokens = self.mutil_fake_B1_tokens[0].detach()
pred_fake1, fake_features1 = self.netD_ViT(fake_B1_tokens)
pred_fake1 = self.netD_ViT(fake_B1_tokens)
loss_real1, self.weight_real1 = self.cao(real_features1, pred_real1, True)
loss_fake1, self.weight_fake1 = self.cao(fake_features1, pred_fake1, False)
loss_real1, self.weight_real1 = self.cao( pred_real1, True)
loss_fake1, self.weight_fake1 = self.cao( pred_fake1, False)
# 综合损失
self.loss_D_ViT = (loss_real0 + loss_fake0 + loss_real1 + loss_fake1) * 0.25 * lambda_D_ViT
@ -380,8 +365,8 @@ class RomaUnsbModel(BaseModel):
# 计算 GAN 损失(引入 ContentAwareOptimization
if self.opt.lambda_GAN > 0.0:
pred_fake0,_ = self.netD_ViT(self.mutil_fake_B0_tokens[0])
pred_fake1,_ = self.netD_ViT(self.mutil_fake_B1_tokens[0])
pred_fake0 = self.netD_ViT(self.mutil_fake_B0_tokens[0])
pred_fake1 = self.netD_ViT(self.mutil_fake_B1_tokens[0])
self.loss_G_GAN0 = self.criterionGAN(pred_fake0, True).mean()
self.loss_G_GAN1 = self.criterionGAN(pred_fake1, True).mean()
self.loss_G_GAN = (self.loss_G_GAN0 + self.loss_G_GAN1)*0.5

View File

@ -0,0 +1,391 @@
import numpy as np
import math
import timm
import torch
import torchvision.models as models
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import GaussianBlur
from .base_model import BaseModel
from . import networks
from .patchnce import PatchNCELoss
import util.util as util
from torchvision.transforms import transforms as tfs
def warp(image, flow): #warp操作
"""
基于光流的图像变形函数
Args:
image: [B, C, H, W] 输入图像
flow: [B, 2, H, W] 光流场(x/y方向位移)
Returns:
warped: [B, C, H, W] 变形后的图像
"""
B, C, H, W = image.shape
# 生成网格坐标
grid_x, grid_y = torch.meshgrid(torch.arange(W), torch.arange(H))
grid = torch.stack((grid_x, grid_y), dim=0).float().to(image.device) # [2,H,W]
grid = grid.unsqueeze(0).repeat(B,1,1,1) # [B,2,H,W]
# 应用光流位移(归一化到[-1,1])
new_grid = grid + flow
new_grid[:,0,:,:] = 2.0 * new_grid[:,0,:,:] / (W-1) - 1.0 # x方向
new_grid[:,1,:,:] = 2.0 * new_grid[:,1,:,:] / (H-1) - 1.0 # y方向
new_grid = new_grid.permute(0,2,3,1) # [B,H,W,2]
# 双线性插值
return F.grid_sample(image, new_grid, align_corners=True)
class ContentAwareOptimization(nn.Module):
def __init__(self, lambda_inc=2.0, eta_ratio=0.4):
super().__init__()
self.lambda_inc = lambda_inc # 控制内容丰富区域的权重增量
self.eta_ratio = eta_ratio # 选择内容丰富区域的比例
self.criterionGAN = networks.GANLoss('lsgan').cuda() # 使用 LSGAN 损失
def compute_cosine_similarity(self, grad_patch, grad_mean):
"""
计算每个 token 梯度与整体平均梯度的余弦相似度
Args:
grad_patch: [B, N, D]每个 token 的梯度来自 scores
grad_mean: [B, D]整体平均梯度
Returns:
cosine: [B, N]余弦相似度 δ_i
"""
# 对每个 token 计算余弦相似度
cosine = F.cosine_similarity(grad_patch, grad_mean.unsqueeze(1), dim=2) # [B, N]
return cosine
def generate_weight_map(self, cosine):
"""
根据余弦相似度生成权重图
Args:
cosine: [B, N]余弦相似度 δ_i
Returns:
weights: [B, N]权重图 w_i
"""
B, N = cosine.shape
k = int(self.eta_ratio * N) # 选择 eta_ratio 比例的 token
_, indices = torch.topk(-cosine, k, dim=1) # 选择偏离最大的 k 个 token
weights = torch.ones_like(cosine)
for b in range(B):
selected_cosine = cosine[b, indices[b]]
weights[b, indices[b]] = self.lambda_inc / (torch.exp(torch.abs(selected_cosine)) + 1e-6)
return weights
def forward(self, scores, target):
"""
前向传播计算加权后的 GAN 损失
Args:
scores: [B, N, D]判别器的预测得分
target: 目标标签True False
Returns:
weighted_loss: 加权后的 GAN 损失
weight: 权重图 [B, N]
"""
# 计算原始 GAN 损失(假设 criterionGAN 返回 [B, N] 的损失分布)
loss = self.criterionGAN(scores, target)
# 捕获 scores 的梯度,形状为 [B, N, D]
grad_scores = torch.autograd.grad(loss, scores, retain_graph=True)[0]
# 计算整体平均梯度(在 N 维度上求均值)
grad_mean = torch.mean(grad_scores, dim=1) # [B, D]
# 计算余弦相似度 δ_i
cosine = self.compute_cosine_similarity(grad_scores, grad_mean) # [B, N]
# 生成权重图 w_i
weight = self.generate_weight_map(cosine) # [B, N]
# 计算加权后的 GAN 损失
weighted_loss = torch.mean(weight * self.criterionGAN(scores, target))
return weighted_loss, weight
class ContentAwareTemporalNorm(nn.Module):
def __init__(self, gamma_stride=0.1, kernel_size=21, sigma=5.0):
super().__init__()
self.gamma_stride = gamma_stride # 控制整体运动幅度
self.smoother = GaussianBlur(kernel_size, sigma=sigma) # 高斯平滑层
def upsample_weight_map(self, weight_patch, target_size=(256, 256)):
# weight_patch: [B, 1, H, W] 来自转换后的 weight_map
weight_full = F.interpolate(
weight_patch,
size=target_size,
mode='bilinear', # 或 'nearest',根据需求选择
align_corners=False
)
return weight_full
def forward(self, weight_map):
"""
生成内容感知光流
Args:
weight_map: [B, N] 权重图(来自 ContentAwareOptimization)其中 N=576
Returns:
F_content: [B, 2, H, W] 生成的光流场(x/y方向位移)
"""
B = weight_map.shape[0]
N = weight_map.shape[1]
# 假设 N 为完全平方数,计算边长(例如 576 -> 24x24
side = int(math.sqrt(N))
weight_map_2d = weight_map.view(B, 1, side, side) # 转换为 [B, 1, side, side]
# 上采样权重图到全分辨率
weight_full = self.upsample_weight_map(weight_map_2d) # [B, 1, 256, 256](例如)
# 归一化权重图L1归一化
weight_norm = F.normalize(weight_full, p=1, dim=(2,3))
# 生成高斯噪声
B, _, H, W = weight_norm.shape
z = torch.randn(B, 2, H, W, device=weight_norm.device)
# 合成基础光流
weight_expanded = weight_norm.expand(-1, 2, -1, -1)
F_raw = self.gamma_stride * weight_expanded * z
# 平滑处理
F_smooth = self.smoother(F_raw)
# 动态范围调整
F_content = torch.tanh(F_smooth)
return F_content
class RomaUnsbSingleModel(BaseModel):
@staticmethod
def modify_commandline_options(parser, is_train=True):
"""配置 CTNx 模型的特定选项"""
parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss: GAN(G(X))')
parser.add_argument('--lambda_ctn', type=float, default=1.0, help='weight for content-aware temporal norm')
parser.add_argument('--lambda_D_ViT', type=float, default=1.0, help='weight for discriminator')
parser.add_argument('--lambda_global', type=float, default=1.0, help='weight for Global Structural Consistency')
parser.add_argument('--lambda_spatial', type=float, default=1.0, help='weight for Local Structural Consistency')
parser.add_argument('--lambda_inc', type=float, default=1.0, help='incremental weight for content-aware optimization')
parser.add_argument('--local_nums', type=int, default=64, help='number of local patches')
parser.add_argument('--side_length', type=int, default=7)
parser.add_argument('--nce_layers', type=str, default='0,4,8,12,16', help='compute NCE loss on which layers')
parser.add_argument('--eta_ratio', type=float, default=0.4, help='ratio of content-rich regions')
parser.add_argument('--gamma_stride', type=float, default=20, help='ratio of stride for computing the similarity matrix')
parser.add_argument('--atten_layers', type=str, default='5', help='compute Cross-Similarity on which layers')
parser.add_argument('--tau', type=float, default=0.01, help='Entropy parameter')
parser.add_argument('--num_timesteps', type=int, default=5, help='# of discrim filters in the first conv layer')
parser.add_argument('--n_mlp', type=int, default=3, help='only used if netD==n_layers')
opt, _ = parser.parse_known_args()
return parser
def __init__(self, opt):
BaseModel.__init__(self, opt)
self.loss_names = ['G_GAN', 'D_ViT', 'G', 'global', 'spatial','ctn']
self.visual_names = ['real_A', 'fake_B', 'real_B']
self.atten_layers = [int(i) for i in self.opt.atten_layers.split(',')]
if self.isTrain:
self.model_names = ['G', 'D_ViT']
else: # during test time, only load G
self.model_names = ['G']
# define networks (both generator and discriminator)
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt)
if self.isTrain:
self.netD_ViT = networks.MLPDiscriminator().to(self.device)
# self.netPreViT = timm.create_model("vit_base_patch32_384",pretrained=True).to(self.device)
self.netPreViT = timm.create_model("vit_base_patch16_384",pretrained=True).to(self.device)
self.resize = tfs.Resize(size=(384,384))
# self.resize = tfs.Resize(size=(224, 224))
# define loss functions
self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
self.criterionL1 = torch.nn.L1Loss().to(self.device)
self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
self.optimizer_D_ViT = torch.optim.Adam(self.netD_ViT.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
self.optimizers.append(self.optimizer_G)
self.optimizers.append(self.optimizer_D_ViT)
self.cao = ContentAwareOptimization(opt.lambda_inc, opt.eta_ratio) #损失函数
self.ctn = ContentAwareTemporalNorm() #生成的伪光流
def data_dependent_initialize(self, data):
"""
The feature network netF is defined in terms of the shape of the intermediate, extracted
features of the encoder portion of netG. Because of this, the weights of netF are
initialized at the first feedforward pass with some input images.
Please also see PatchSampleF.create_mlp(), which is called at the first forward() call.
"""
pass
def optimize_parameters(self):
# forward
self.forward()
# update D
self.set_requires_grad(self.netD_ViT, True)
self.optimizer_D_ViT.zero_grad()
self.loss_D = self.compute_D_loss()
self.loss_D.backward()
self.optimizer_D_ViT.step()
# update G
self.set_requires_grad(self.netD_ViT, False)
self.optimizer_G.zero_grad()
self.loss_G = self.compute_G_loss()
self.loss_G.backward()
self.optimizer_G.step()
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
input (dict): include the data itself and its metadata information.
The option 'direction' can be used to swap domain A and domain B.
"""
AtoB = self.opt.direction == 'AtoB'
self.real_A = input['A' if AtoB else 'B'].to(self.device)
self.real_B = input['B' if AtoB else 'A'].to(self.device)
self.image_paths = input['A_paths' if AtoB else 'B_paths']
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
self.fake_B = self.netG(self.real_A)
if self.opt.isTrain:
real_A = self.real_A
real_B = self.real_B
fake_B = self.fake_B
self.real_A_resize = self.resize(real_A)
real_B = self.resize(real_B)
self.fake_B_resize = self.resize(fake_B)
self.mutil_real_A_tokens = self.netPreViT(self.real_A_resize, self.atten_layers, get_tokens=True)
self.mutil_real_B_tokens = self.netPreViT(real_B, self.atten_layers, get_tokens=True)
self.mutil_fake_B_tokens = self.netPreViT(self.fake_B_resize, self.atten_layers, get_tokens=True)
def compute_D_loss(self):
"""Calculate GAN loss for the discriminator"""
lambda_D_ViT = self.opt.lambda_D_ViT
fake_B_tokens = self.mutil_fake_B_tokens[0].detach()
real_B_tokens = self.mutil_real_B_tokens[0]
pre_fake_ViT = self.netD_ViT(fake_B_tokens)
pred_real_ViT = self.netD_ViT(real_B_tokens)
self.loss_D_real_ViT , self.weight_real = self.cao(pred_real_ViT, True)
self.loss_D_fake_ViT , self.weight_fake = self.cao(pre_fake_ViT, False)
self.loss_D_ViT = (self.loss_D_fake_ViT + self.loss_D_real_ViT) * 0.5* lambda_D_ViT
return self.loss_D_ViT
def compute_G_loss(self):
if self.opt.lambda_ctn > 0.0:
# 生成光流图(使用判别器的权重)
self.f_content = self.ctn(self.weight_fake.detach())
# 变换后的图片
self.warped_real_A = warp(self.real_A, self.f_content)
self.warped_fake_B = warp(self.fake_B, self.f_content)
# 第二次生成
self.warped_fake_B2 = self.netG(self.warped_real_A)
# 计算损失
self.loss_ctn = self.criterionL1(self.warped_fake_B, self.warped_fake_B2) * self.opt.lambda_ctn
else:
self.loss_ctn = 0.0
# if self.opt.lambda_GAN > 0.0:
# fake_B_tokens = self.mutil_fake_B_tokens[0]
# pred_fake_ViT = self.netD_ViT(fake_B_tokens)
# self.loss_G_GAN = self.criterionGAN(pred_fake_ViT, True) * self.opt.lambda_GAN
# else:
# self.loss_G_GAN = 0.0
if self.opt.lambda_GAN > 0.0:
fake_B_tokens = self.mutil_fake_B_tokens[0]
pred_fake_ViT = self.netD_ViT(fake_B_tokens)
self.loss_G_fake_ViT , self.weight_real = self.cao(pred_fake_ViT, True)
self.loss_G_GAN = self.loss_G_fake_ViT * self.opt.lambda_GAN
else:
self.loss_G_GAN = 0.0
if self.opt.lambda_global > 0.0 or self.opt.lambda_spatial > 0.0:
self.loss_global, self.loss_spatial = self.calculate_attention_loss()
else:
self.loss_global, self.loss_spatial = 0.0, 0.0
self.loss_G = self.loss_G_GAN + self.loss_global + self.loss_spatial + self.loss_ctn
return self.loss_G
def calculate_attention_loss(self):
n_layers = len(self.atten_layers)
mutil_real_A_tokens = self.mutil_real_A_tokens
mutil_fake_B_tokens = self.mutil_fake_B_tokens
if self.opt.lambda_global > 0.0:
loss_global = self.calculate_similarity(mutil_real_A_tokens, mutil_fake_B_tokens)
else:
loss_global = 0.0
if self.opt.lambda_spatial > 0.0:
loss_spatial = 0.0
local_nums = self.opt.local_nums
tokens_cnt = 576
local_id = np.random.permutation(tokens_cnt)
local_id = local_id[:int(min(local_nums, tokens_cnt))]
mutil_real_A_local_tokens = self.netPreViT(self.real_A_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
mutil_fake_B_local_tokens = self.netPreViT(self.fake_B_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
loss_spatial = self.calculate_similarity(mutil_real_A_local_tokens, mutil_fake_B_local_tokens)
else:
loss_spatial = 0.0
return loss_global * self.opt.lambda_global, loss_spatial * self.opt.lambda_spatial
def calculate_similarity(self, mutil_src_tokens, mutil_tgt_tokens):
loss = 0.0
n_layers = len(self.atten_layers)
for src_tokens, tgt_tokens in zip(mutil_src_tokens, mutil_tgt_tokens):
src_tgt = src_tokens.bmm(tgt_tokens.permute(0,2,1))
tgt_src = tgt_tokens.bmm(src_tokens.permute(0,2,1))
cos_dis_global = F.cosine_similarity(src_tgt, tgt_src, dim=-1)
loss += self.criterionL1(torch.ones_like(cos_dis_global), cos_dis_global).mean()
loss = loss / n_layers
return loss

View File

@ -1,2 +1,2 @@
# CUDA_VISIBLE_DEVICES=0 python test.py --dataroot /path/of/test_dataset --checkpoints_dir ./checkpoints --name train1 --model roma_single --num_test 10000 --epoch latest
python test.py --dataroot /home/openxs/kunyu/datasets/InfraredCity-Lite/Single/Monitor --checkpoints_dir ./checkpoints --name cp_4 --model roma_single --num_test 4132 --epoch 150 --gpu_id 2
python test.py --dataroot /home/openxs/kunyu/datasets/InfraredCity-Lite/Single/Monitor --checkpoints_dir ./checkpoints --name cp_2 --model roma_single --num_test 4132 --epoch 120 --gpu_id 1

View File

@ -1,16 +1,16 @@
python train.py \
--dataroot /home/openxs/kunyu/datasets/InfraredCity-Lite/Double/Moitor \
--name cp_5 \
--dataset_mode unaligned_double \
--display_env CP \
--model roma_unsb \
--lambda_ctn 0 \
--dataroot /home/openxs/kunyu/datasets/InfraredCity-Lite/Single/Monitor \
--name cp_2 \
--dataset_mode unaligned \
--display_env NEWCP \
--model roma_unsb_single \
--lambda_ctn 10 \
--lambda_inc 8.0 \
--lambda_global 6.0 \
--lambda_spatial 6.0 \
--gamma_stride 20 \
--lr 0.000005 \
--gpu_id 0 \
--lr 0.000002 \
--gpu_id 1 \
--eta_ratio 0.4 \
--n_epochs 100 \
--n_epochs_decay 100 \
@ -19,3 +19,5 @@ python train.py \
# cp3 加了--lambda_inc 8.0 --gpu_id 2
# cp4 在cp3的基础上把梯度增强给到了生成器中的ganloss --gpu_id 1
# cp5 在cp3的基础上--lambda_ctn 0 --gpu_id 0.--lr 0.000005
# # newcp1 重新调整了光流算法,并且弄成单帧的脚本了,这一次是最终的复现了。--gpu_id 0
# # newcp2 把梯度图对loss的影响同样加到了G_GAN中。

View File

@ -33,7 +33,7 @@ if __name__ == '__main__':
if total_iters % opt.print_freq == 0:
t_data = iter_start_time - iter_data_time
batch_size = data["A0"].size(0)
batch_size = data["A"].size(0)
total_iters += batch_size
epoch_iter += batch_size
if len(opt.gpu_ids) > 0: