Compare commits

...

5 Commits

Author SHA1 Message Date
bishe
e8e483fbf8 EDIT_DOWN 2025-02-26 22:07:11 +08:00
bishe
3c4d53377c EDIT_DOWN 2025-02-26 22:07:06 +08:00
bishe
6a2761be99 without cnt running 002 2025-02-24 23:35:03 +08:00
bishe
c2e6cfe0b1 running without cnt named 001 2025-02-24 23:10:23 +08:00
bishe
4af0d7463d withoutCNT 2025-02-24 23:00:25 +08:00
7 changed files with 122 additions and 286 deletions

5
.gitignore vendored Normal file
View File

@ -0,0 +1,5 @@
checkpoints/
*.log
*.pth
*.ckpt
__pycache__/

View File

@ -1,70 +0,0 @@
================ Training Loss (Sun Feb 23 15:46:44 2025) ================
================ Training Loss (Sun Feb 23 15:52:29 2025) ================
================ Training Loss (Sun Feb 23 16:00:07 2025) ================
================ Training Loss (Sun Feb 23 16:02:40 2025) ================
================ Training Loss (Sun Feb 23 16:05:19 2025) ================
================ Training Loss (Sun Feb 23 16:06:44 2025) ================
================ Training Loss (Sun Feb 23 16:09:38 2025) ================
================ Training Loss (Sun Feb 23 16:44:56 2025) ================
================ Training Loss (Sun Feb 23 16:49:46 2025) ================
================ Training Loss (Sun Feb 23 16:51:03 2025) ================
================ Training Loss (Sun Feb 23 16:51:23 2025) ================
================ Training Loss (Sun Feb 23 18:04:02 2025) ================
================ Training Loss (Sun Feb 23 18:04:39 2025) ================
================ Training Loss (Sun Feb 23 18:05:17 2025) ================
================ Training Loss (Sun Feb 23 18:06:40 2025) ================
================ Training Loss (Sun Feb 23 18:11:48 2025) ================
================ Training Loss (Sun Feb 23 18:13:31 2025) ================
================ Training Loss (Sun Feb 23 18:14:11 2025) ================
================ Training Loss (Sun Feb 23 18:14:29 2025) ================
================ Training Loss (Sun Feb 23 18:16:27 2025) ================
================ Training Loss (Sun Feb 23 18:16:44 2025) ================
================ Training Loss (Sun Feb 23 18:20:39 2025) ================
================ Training Loss (Sun Feb 23 18:21:44 2025) ================
================ Training Loss (Sun Feb 23 18:35:27 2025) ================
================ Training Loss (Sun Feb 23 18:39:21 2025) ================
================ Training Loss (Sun Feb 23 18:40:15 2025) ================
================ Training Loss (Sun Feb 23 18:41:15 2025) ================
================ Training Loss (Sun Feb 23 18:47:46 2025) ================
================ Training Loss (Sun Feb 23 18:48:36 2025) ================
================ Training Loss (Sun Feb 23 18:50:20 2025) ================
================ Training Loss (Sun Feb 23 18:51:50 2025) ================
================ Training Loss (Sun Feb 23 18:58:45 2025) ================
================ Training Loss (Sun Feb 23 18:59:52 2025) ================
================ Training Loss (Sun Feb 23 19:03:05 2025) ================
================ Training Loss (Sun Feb 23 19:03:57 2025) ================
================ Training Loss (Sun Feb 23 21:11:47 2025) ================
================ Training Loss (Sun Feb 23 21:17:10 2025) ================
================ Training Loss (Sun Feb 23 21:20:14 2025) ================
================ Training Loss (Sun Feb 23 21:29:03 2025) ================
================ Training Loss (Sun Feb 23 21:34:57 2025) ================
================ Training Loss (Sun Feb 23 21:35:26 2025) ================
================ Training Loss (Sun Feb 23 22:28:43 2025) ================
================ Training Loss (Sun Feb 23 22:29:04 2025) ================
================ Training Loss (Sun Feb 23 22:29:52 2025) ================
================ Training Loss (Sun Feb 23 22:30:40 2025) ================
================ Training Loss (Sun Feb 23 22:33:48 2025) ================
================ Training Loss (Sun Feb 23 22:39:16 2025) ================
================ Training Loss (Sun Feb 23 22:39:48 2025) ================
================ Training Loss (Sun Feb 23 22:41:34 2025) ================
================ Training Loss (Sun Feb 23 22:42:01 2025) ================
================ Training Loss (Sun Feb 23 22:44:17 2025) ================
================ Training Loss (Sun Feb 23 22:45:53 2025) ================
================ Training Loss (Sun Feb 23 22:46:48 2025) ================
================ Training Loss (Sun Feb 23 22:47:42 2025) ================
================ Training Loss (Sun Feb 23 22:49:44 2025) ================
================ Training Loss (Sun Feb 23 22:50:29 2025) ================
================ Training Loss (Sun Feb 23 22:51:47 2025) ================
================ Training Loss (Sun Feb 23 22:55:56 2025) ================
================ Training Loss (Sun Feb 23 22:56:19 2025) ================
================ Training Loss (Sun Feb 23 22:57:58 2025) ================
================ Training Loss (Sun Feb 23 22:59:09 2025) ================
================ Training Loss (Sun Feb 23 23:02:36 2025) ================
================ Training Loss (Sun Feb 23 23:03:56 2025) ================
================ Training Loss (Sun Feb 23 23:09:21 2025) ================
================ Training Loss (Sun Feb 23 23:10:05 2025) ================
================ Training Loss (Sun Feb 23 23:11:43 2025) ================
================ Training Loss (Sun Feb 23 23:12:41 2025) ================
================ Training Loss (Sun Feb 23 23:13:05 2025) ================
================ Training Loss (Sun Feb 23 23:13:59 2025) ================
================ Training Loss (Sun Feb 23 23:14:59 2025) ================

View File

@ -1,87 +0,0 @@
----------------- Options ---------------
atten_layers: 5
batch_size: 1
beta1: 0.5
beta2: 0.999
checkpoints_dir: ./checkpoints
continue_train: False
crop_size: 256
dataroot: /home/openxs/kunyu/datasets/InfraredCity-Lite/Double/Moitor [default: placeholder]
dataset_mode: unaligned_double [default: unaligned]
direction: AtoB
display_env: ROMA [default: main]
display_freq: 50
display_id: None
display_ncols: 4
display_port: 8097
display_server: http://localhost
display_winsize: 256
easy_label: experiment_name
epoch: latest
epoch_count: 1
eta_ratio: 0.1
evaluation_freq: 5000
flip_equivariance: False
gan_mode: lsgan
gpu_ids: 0
init_gain: 0.02
init_type: xavier
input_nc: 3
isTrain: True [default: None]
lambda_D_ViT: 1.0
lambda_GAN: 8.0 [default: 1.0]
lambda_NCE: 8.0 [default: 1.0]
lambda_SB: 0.1
lambda_ctn: 1.0
lambda_global: 1.0
lambda_inc: 1.0
lmda_1: 0.1
load_size: 286
lr: 1e-05 [default: 0.0002]
lr_decay_iters: 50
lr_policy: linear
max_dataset_size: inf
model: roma_unsb [default: cut]
n_epochs: 100
n_epochs_decay: 100
n_layers_D: 3
n_mlp: 3
name: ROMA_UNSB_001 [default: experiment_name]
nce_T: 0.07
nce_idt: False [default: True]
nce_includes_all_negatives_from_minibatch: False
nce_layers: 0,4,8,12,16
ndf: 64
netD: basic_cond
netF: mlp_sample
netF_nc: 256
netG: resnet_9blocks_cond
ngf: 64
no_antialias: False
no_antialias_up: False
no_dropout: True
no_flip: True [default: False]
no_html: False
normD: instance
normG: instance
num_patches: 256
num_threads: 4
num_timesteps: 10 [default: 5]
output_nc: 3
phase: train
pool_size: 0
preprocess: resize_and_crop
pretrained_name: None
print_freq: 100
random_scale_max: 3.0
save_by_iter: False
save_epoch_freq: 5
save_latest_freq: 5000
serial_batches: False
stylegan2_G_num_downsampling: 1
suffix:
tau: 0.01
update_html_freq: 1000
use_idt: False
verbose: False
----------------- End -------------------

View File

@ -66,6 +66,10 @@ class ContentAwareOptimization(nn.Module):
self.lambda_inc = lambda_inc # 权重增强系数
self.eta_ratio = eta_ratio # 选择内容区域的比例
# 改为类成员变量,确保钩子函数可访问
self.gradients_real = []
self.gradients_fake = []
def compute_cosine_similarity(self, gradients):
"""
计算每个patch梯度与平均梯度的余弦相似度
@ -79,78 +83,65 @@ class ContentAwareOptimization(nn.Module):
cosine_sim = F.cosine_similarity(gradients, mean_grad, dim=2) # [B, N]
return cosine_sim
def generate_weight_map(self, gradients_fake):
def generate_weight_map(self, gradients_real, gradients_fake):
"""
生成内容感知权重图
Args:
gradients_fake: [B, N, D] 生成图像判别器梯度 [2,3,256,256]
gradients_real: [B, N, D] 真实图像判别器梯度
gradients_fake: [B, N, D] 生成图像判别器梯度
Returns:
weight_fake: [B, N] 生成图像权重图 [2,3,256]
weight_real: [B, N] 真实图像权重图
weight_fake: [B, N] 生成图像权重图
"""
# 计算真实图像块的余弦相似度
cosine_real = self.compute_cosine_similarity(gradients_real) # [B, N] 公式5
# 计算生成图像块的余弦相似度
cosine_fake = self.compute_cosine_similarity(gradients_fake) # [B, N]
# 选择内容丰富的区域(余弦相似度最低的eta_ratio比例)
k = int(self.eta_ratio * cosine_fake.shape[1])
# 选择内容丰富的区域余弦相似度最低的eta_ratio比例
k = int(self.eta_ratio * cosine_real.shape[1])
# 对生成图像生成权重图(同理)
# 对真实图像生成权重图
_, real_indices = torch.topk(-cosine_real, k, dim=1) # 选择最不相似的区域
weight_real = torch.ones_like(cosine_real)
for b in range(cosine_real.shape[0]):
weight_real[b, real_indices[b]] = self.lambda_inc / (1e-6 + torch.abs(cosine_real[b, real_indices[b]])) #公式6
# 对生成图像生成权重图(同理)
_, fake_indices = torch.topk(-cosine_fake, k, dim=1)
weight_fake = torch.ones_like(cosine_fake)
for b in range(cosine_fake.shape[0]):
weight_fake[b, fake_indices[b]] = self.lambda_inc / (1e-6 + torch.abs(cosine_fake[b, fake_indices[b]]))
return weight_fake
return weight_real, weight_fake
def forward(self, D_real, D_fake, real_scores, fake_scores):
"""
计算内容感知对抗损失
Args:
D_real: 判别器对真实图像的特征输出 [B, C, H, W]
D_fake: 判别器对生成图像的特征输出 [B, C, H, W]
real_scores: 真实图像的判别器预测 [B, N] (N=H*W)
fake_scores: 生成图像的判别器预测 [B, N]
Returns:
loss_co_adv: 内容感知对抗损失
"""
B, C, H, W = D_real.shape
N = H * W
# 清空梯度缓存
self.gradients_real.clear()
self.gradients_fake.clear()
# 注册钩子获取梯度
gradients_real = []
gradients_fake = []
def hook_real(grad):
gradients_real.append(grad.detach().view(B, N, -1))
def hook_fake(grad):
gradients_fake.append(grad.detach().view(B, N, -1))
# 注册钩子
hook_real = lambda grad: self.gradients_real.append(grad.detach())
hook_fake = lambda grad: self.gradients_fake.append(grad.detach())
D_real.register_hook(hook_real)
D_fake.register_hook(hook_fake)
# 计算原始对抗损失以触发梯度计算
loss_real = torch.mean(torch.log(real_scores + 1e-8))
loss_fake = torch.mean(torch.log(1 - fake_scores + 1e-8))
# 添加与 D_real、D_fake 相关的 dummy 项,确保梯度传递
loss_dummy = 1e-8 * (D_real.sum() + D_fake.sum())
total_loss = loss_real + loss_fake + loss_dummy
total_loss.backward(retain_graph=True)
# 触发梯度计算
(real_scores.mean() + fake_scores.mean()).backward(retain_graph=True)
# 获取梯度数据
gradients_real = gradients_real[0] # [B, N, D]
gradients_fake = gradients_fake[0] # [B, N, D]
# 获取梯度并调整维度
grad_real = self.gradients_real[0] # [B, N, D]
grad_fake = self.gradients_fake[0]
# 生成权重图
self.weight_real, self.weight_fake = self.generate_weight_map(gradients_real, gradients_fake)
weight_real, weight_fake = self.generate_weight_map(grad_real, grad_fake)
# 应用权重到对抗损失
loss_co_real = torch.mean(self.weight_real * torch.log(real_scores + 1e-8))
loss_co_fake = torch.mean(self.weight_fake * torch.log(1 - fake_scores + 1e-8))
# 计算加权损失
loss_co_real = (weight_real * torch.log(real_scores + 1e-8)).mean()
loss_co_fake = (weight_fake * torch.log(1 - fake_scores + 1e-8)).mean()
# 计算并返回最终内容感知对抗损失
loss_co_adv = -(loss_co_real + loss_co_fake)
return loss_co_adv
return -(loss_co_real + loss_co_fake), weight_real, weight_fake
class ContentAwareTemporalNorm(nn.Module):
def __init__(self, gamma_stride=0.1, kernel_size=21, sigma=5.0):
@ -158,6 +149,33 @@ class ContentAwareTemporalNorm(nn.Module):
self.gamma_stride = gamma_stride # 控制整体运动幅度
self.smoother = GaussianBlur(kernel_size, sigma=sigma) # 高斯平滑层
def upsample_weight_map(self, weight_patch, target_size=(256, 256)):
"""
将patch级别的权重图上采样到目标分辨率
Args:
weight_patch: [B, 1, 24, 24] 来自ViT的patch权重图
target_size: 目标分辨率 (H, W)
Returns:
weight_full: [B, 1, 256, 256] 上采样后的全分辨率权重图
"""
# 使用双线性插值上采样
B = weight_patch.shape[0]
weight_patch = weight_patch.view(B, 1, 24, 24)
weight_full = F.interpolate(
weight_patch,
size=target_size,
mode='bilinear',
align_corners=False
)
# 对每个16x16的patch内部保持权重一致可选
# 通过平均池化再扩展,消除插值引入的渐变
weight_full = F.avg_pool2d(weight_full, kernel_size=16, stride=16)
weight_full = F.interpolate(weight_full, scale_factor=16, mode='nearest')
return weight_full
def forward(self, weight_map):
"""
生成内容感知光流
@ -166,15 +184,16 @@ class ContentAwareTemporalNorm(nn.Module):
Returns:
F_content: [B, 2, H, W] 生成的光流场(x/y方向位移)
"""
print(weight_map.shape)
B, _, H, W = weight_map.shape
# 上采样权重图到全分辨率
weight_full = self.upsample_weight_map(weight_map) # [B,1,384,384]
# 1. 归一化权重图
# 保持区域相对强度,同时限制数值范围
weight_norm = F.normalize(weight_map, p=1, dim=(2,3)) # L1归一化 [B,1,H,W]
weight_norm = F.normalize(weight_full, p=1, dim=(2,3)) # L1归一化 [B,1,H,W]
# 2. 生成高斯噪声(与光流场同尺寸)
z = torch.randn(B, 2, H, W, device=weight_map.device) # [B,2,H,W]
# 2. 生成高斯噪声
B, _, H, W = weight_norm.shape
z = torch.randn(B, 2, H, W, device=weight_norm.device) # [B,2,H,W]
# 3. 合成基础光流
# 将权重图扩展为2通道(x/y方向共享权重)
@ -204,23 +223,19 @@ class RomaUnsbModel(BaseModel):
parser.add_argument('--lambda_global', type=float, default=1.0, help='weight for Global Structural Consistency')
parser.add_argument('--nce_idt', type=util.str2bool, nargs='?', const=True, default=False, help='use NCE loss for identity mapping: NCE(G(Y), Y))')
parser.add_argument('--nce_layers', type=str, default='0,4,8,12,16', help='compute NCE loss on which layers')
parser.add_argument('--nce_includes_all_negatives_from_minibatch',
type=util.str2bool, nargs='?', const=True, default=False,
help='(used for single image translation) If True, include the negatives from the other samples of the minibatch when computing the contrastive loss. Please see models/patchnce.py for more details.')
parser.add_argument('--nce_layers', type=str, default='0,4,8,12,16', help='compute NCE loss on which layers')
parser.add_argument('--netF', type=str, default='mlp_sample', choices=['sample', 'reshape', 'mlp_sample'], help='how to downsample the feature map')
parser.add_argument('--netF_nc', type=int, default=256)
parser.add_argument('--nce_T', type=float, default=0.07, help='temperature for NCE loss')
parser.add_argument('--lmda_1', type=float, default=0.1)
parser.add_argument('--num_patches', type=int, default=256, help='number of patches per layer')
parser.add_argument('--flip_equivariance',
type=util.str2bool, nargs='?', const=True, default=False,
help="Enforce flip-equivariance as additional regularization. It's used by FastCUT, but not CUT")
parser.add_argument('--lambda_inc', type=float, default=1.0, help='incremental weight for content-aware optimization')
parser.add_argument('--eta_ratio', type=float, default=0.1, help='ratio of content-rich regions')
parser.add_argument('--eta_ratio', type=float, default=0.4, help='ratio of content-rich regions')
parser.add_argument('--atten_layers', type=str, default='5', help='compute Cross-Similarity on which layers')
@ -243,9 +258,8 @@ class RomaUnsbModel(BaseModel):
BaseModel.__init__(self, opt)
# 指定需要打印的训练损失
self.loss_names = ['G_GAN_1', 'D_real_1', 'D_fake_1', 'G_1', 'NCE_1', 'SB_1',
'G_2']
self.visual_names = ['real_A', 'real_A_noisy', 'fake_B', 'real_B']
self.loss_names = ['G_GAN', 'D_real_ViT', 'D_fake_ViT', 'G', 'SB']
self.visual_names = ['real_A0', 'real_A_noisy', 'fake_B0', 'real_B0']
self.atten_layers = [int(i) for i in self.opt.atten_layers.split(',')]
if self.opt.phase == 'test':
@ -262,11 +276,9 @@ class RomaUnsbModel(BaseModel):
if self.isTrain:
self.model_names = ['G', 'D_ViT', 'E']
else:
self.model_names = ['G']
print(f'input_nc = {self.opt.input_nc}')
# 创建网络
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt)
@ -284,9 +296,6 @@ class RomaUnsbModel(BaseModel):
# 定义损失函数
self.criterionL1 = torch.nn.L1Loss().to(self.device)
self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
self.criterionNCE = []
for nce_layer in self.nce_layers:
self.criterionNCE.append(PatchNCELoss(opt).to(self.device))
self.criterionIdt = torch.nn.L1Loss().to(self.device)
self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
self.optimizer_D = torch.optim.Adam(self.netD_ViT.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
@ -447,8 +456,8 @@ class RomaUnsbModel(BaseModel):
# ============ 第三步:拼接输入并执行网络推理 =============
bs = self.real_A0.size(0)
z_in = torch.randn(size=[bs, 4 * self.opt.ngf]).to(self.real_A0.device)
z_in2 = torch.randn(size=[bs, 4 * self.opt.ngf]).to(self.real_A1.device)
self.z_in = torch.randn(size=[bs, 4 * self.opt.ngf]).to(self.real_A0.device)
self.z_in2 = torch.randn(size=[bs, 4 * self.opt.ngf]).to(self.real_A1.device)
# 将 real_A, real_B 拼接 (如 nce_idt=True),并同样处理 real_A_noisy 与 XtB
self.real = self.real_A0
self.realt = self.real_A_noisy
@ -459,10 +468,8 @@ class RomaUnsbModel(BaseModel):
self.real = torch.flip(self.real, [3])
self.realt = torch.flip(self.realt, [3])
print(f'fake_B0: {self.real_A0.shape}, fake_B1: {self.real_A1.shape}')
self.fake_B0 = self.netG(self.real_A0, self.time, z_in)
self.fake_B1 = self.netG(self.real_A1, self.time, z_in2)
print(f'fake_B0: {self.fake_B0.shape}, fake_B1: {self.fake_B1.shape}')
self.fake_B0 = self.netG(self.real_A0, self.time, self.z_in)
self.fake_B1 = self.netG(self.real_A1, self.time, self.z_in2)
if self.opt.phase == 'train':
real_A0 = self.real_A0
@ -488,28 +495,6 @@ class RomaUnsbModel(BaseModel):
# [[1,576,768],[1,576,768],[1,576,768]]
# [3,576,768]
## 生成图像的梯度
#fake_gradient = torch.autograd.grad(self.mutil_fake_B0_tokens.sum(), self.mutil_fake_B0_tokens, create_graph=True)[0]
#
## 梯度图
#self.weight_fake = self.cao.generate_weight_map(fake_gradient)
#
## 生成图像的CTN光流图
#self.f_content = self.ctn(self.weight_fake)
#
## 变换后的图片
#self.warped_real_A_noisy2 = warp(self.real_A_noisy, self.f_content)
#self.warped_fake_B0 = warp(self.fake_B0,self.f_content)
#
## 经过第二次生成器
#self.warped_fake_B0_2 = self.netG(self.warped_real_A_noisy2, self.time, z_in)
#warped_fake_B0_2=self.warped_fake_B0_2
#warped_fake_B0=self.warped_fake_B0
#self.warped_fake_B0_2_resize = self.resize(warped_fake_B0_2)
#self.warped_fake_B0_resize = self.resize(warped_fake_B0)
#self.mutil_warped_fake_B0_tokens = self.netPreViT(self.warped_fake_B0_resize, self.atten_layers, get_tokens=True)
#self.mutil_fake_B0_2_tokens = self.netPreViT(self.warped_fake_B0_2_resize, self.atten_layers, get_tokens=True)
def compute_D_loss(self): #判别器还是没有改
@ -517,30 +502,23 @@ class RomaUnsbModel(BaseModel):
lambda_D_ViT = self.opt.lambda_D_ViT
fake_B0_tokens = self.mutil_fake_B0_tokens[0].detach()
fake_B1_tokens = self.mutil_fake_B1_tokens[0].detach()
real_B0_tokens = self.mutil_real_B0_tokens[0]
real_B1_tokens = self.mutil_real_B1_tokens[0]
pre_fake0_ViT = self.netD_ViT(fake_B0_tokens)
pre_fake1_ViT = self.netD_ViT(fake_B1_tokens)
self.loss_D_fake_ViT = (self.criterionGAN(pre_fake0_ViT, False).mean() + self.criterionGAN(pre_fake1_ViT, False).mean()) * 0.5 * lambda_D_ViT
self.loss_D_fake_ViT = self.criterionGAN(pre_fake0_ViT, False)
pred_real0_ViT = self.netD_ViT(real_B0_tokens)
pred_real1_ViT = self.netD_ViT(real_B1_tokens)
self.loss_D_real_ViT = (self.criterionGAN(pred_real0_ViT, True).mean() + self.criterionGAN(pred_real1_ViT, True).mean()) * 0.5 * lambda_D_ViT
self.loss_D_real_ViT = self.criterionGAN(pred_real0_ViT, True)
self.loss_D_ViT = (self.loss_D_fake_ViT + self.loss_D_real_ViT) * 0.5
self.losscao, self.weight_real, self.weight_fake = self.cao(pred_real0_ViT, pre_fake0_ViT, self.loss_D_real_ViT, self.loss_D_fake_ViT)
return self.loss_D_ViT
return self.losscao* lambda_D_ViT
def compute_E_loss(self):
"""计算判别器 E 的损失"""
print(f'resl_A_noisy: {self.real_A_noisy.shape} \n fake_B0: {self.fake_B0.shape}')
XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B0.detach()], dim=1)
XtXt_2 = torch.cat([self.real_A_noisy2, self.fake_B1.detach()], dim=1)
temp = torch.logsumexp(self.netE(XtXt_1, self.time, XtXt_2).reshape(-1), dim=0).mean()
@ -550,12 +528,28 @@ class RomaUnsbModel(BaseModel):
def compute_G_loss(self):
"""计算生成器的 GAN 损失"""
if self.opt.lambda_ctn > 0.0:
# 生成图像的CTN光流图
self.f_content = self.ctn(self.weight_fake)
# 变换后的图片
self.warped_real_A_noisy2 = warp(self.real_A_noisy, self.f_content)
self.warped_fake_B0 = warp(self.fake_B0,self.f_content)
# 经过第二次生成器
self.warped_fake_B0_2 = self.netG(self.warped_real_A_noisy2, self.time, self.z_in)
warped_fake_B0_2=self.warped_fake_B0_2
warped_fake_B0=self.warped_fake_B0
# 计算L2损失
self.ctn_loss = F.mse_loss(warped_fake_B0_2, warped_fake_B0)
if self.opt.lambda_GAN > 0.0:
pred_fake = self.netD_ViT(self.mutil_fake_B0_tokens[0])
self.loss_G_GAN = self.criterionGAN(pred_fake, True).mean() * self.opt.lambda_GAN
self.loss_G_GAN = self.criterionGAN(pred_fake, True).mean()
else:
self.loss_G_GAN = 0.0
self.loss_SB = 0
if self.opt.lambda_SB > 0.0:
XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B0], dim=1)
@ -564,9 +558,9 @@ class RomaUnsbModel(BaseModel):
bs = self.opt.batch_size
# eq.9
ET_XY = self.netE(XtXt_1, self.time, XtXt_1).mean() - torch.logsumexp(self.netE(XtXt_1, self.time, XtXt_2).reshape(-1), dim=0)
ET_XY = self.netE(XtXt_1, self.time, XtXt_1).mean() - self.netE(XtXt_1, self.time, XtXt_2).mean()
self.loss_SB = -(self.opt.num_timesteps - self.time[0]) / self.opt.num_timesteps * self.opt.tau * ET_XY
self.loss_SB += self.opt.tau * torch.mean((self.real_A_noisy - self.fake_B0) ** 2)
self.loss_SB += torch.mean((self.real_A_noisy - self.fake_B0) ** 2)
if self.opt.lambda_global > 0.0:
loss_global = self.calculate_similarity(self.real_A0, self.fake_B0) + self.calculate_similarity(self.real_A1, self.fake_B1)
@ -574,12 +568,10 @@ class RomaUnsbModel(BaseModel):
else:
loss_global = 0.0
self.l2_loss = 0.0
#if self.opt.lambda_ctn > 0.0:
# wapped_fake_B = warp(self.fake_B, self.f_content) # use updated self.f_content
# self.l2_loss = F.mse_loss(self.fake_B_2, wapped_fake_B) # complete the loss calculation
self.loss_G = self.loss_G_GAN + self.opt.lambda_SB * self.loss_SB + self.opt.lambda_ctn * self.l2_loss + loss_global * self.opt.lambda_global
self.loss_G = self.opt.lambda_GAN * self.loss_G_GAN + \
self.opt.lambda_SB * self.loss_SB + \
self.opt.lambda_ctn * self.ctn_loss + \
loss_global * self.opt.lambda_global
return self.loss_G
def calculate_attention_loss(self):

View File

@ -7,27 +7,22 @@
python train.py \
--dataroot /home/openxs/kunyu/datasets/InfraredCity-Lite/Double/Moitor \
--name ROMA_UNSB_001 \
--name ROMA_UNSB_003 \
--dataset_mode unaligned_double \
--no_flip \
--display_env ROMA \
--model roma_unsb \
--lambda_GAN 8.0 \
--lambda_GAN 1.0 \
--lambda_NCE 8.0 \
--lambda_SB 0.1 \
--lambda_SB 1.0 \
--lambda_ctn 1.0 \
--lambda_inc 1.0 \
--lr 0.00001 \
--gpu_id 0 \
--nce_idt False \
--nce_layers 0,4,8,12,16 \
--netF mlp_sample \
--netF_nc 256 \
--nce_T 0.07 \
--lmda_1 0.1 \
--num_patches 256 \
--flip_equivariance False \
--eta_ratio 0.1 \
--flip_equivariance True \
--eta_ratio 0.4 \
--tau 0.01 \
--num_timesteps 10 \
--num_timesteps 4 \
--input_nc 3

View File

@ -44,6 +44,7 @@ if __name__ == '__main__':
model.setup(opt) # regular setup: load and print networks; create schedulers
model.parallelize()
model.set_input(data) # unpack data from dataset and apply preprocessing
#print('Call opt paras')
model.optimize_parameters() # calculate loss functions, get gradients, update network weights
if len(opt.gpu_ids) > 0:
torch.cuda.synchronize()