Compare commits

..

10 Commits

Author SHA1 Message Date
bishe
2a321918c0 add file: with_logist_dataset.py 2025-03-27 00:09:38 +08:00
bishe
c6cb68e700 尝试在每一步都给判别器看,但是速度太慢了 2025-03-07 18:43:06 +08:00
bishe
76fcec26e8 exp8 版本 2025-03-07 10:13:25 +08:00
bishe
2a0a56ac26 修改后的最新 2025-02-27 18:00:41 +08:00
bishe
7a6e856b4b running UNIV 2025-02-26 22:24:17 +08:00
bishe
e8e483fbf8 EDIT_DOWN 2025-02-26 22:07:11 +08:00
bishe
3c4d53377c EDIT_DOWN 2025-02-26 22:07:06 +08:00
bishe
6a2761be99 without cnt running 002 2025-02-24 23:35:03 +08:00
bishe
c2e6cfe0b1 running without cnt named 001 2025-02-24 23:10:23 +08:00
bishe
4af0d7463d withoutCNT 2025-02-24 23:00:25 +08:00
13 changed files with 348 additions and 840 deletions

5
.gitignore vendored Normal file
View File

@ -0,0 +1,5 @@
checkpoints/
*.log
*.pth
*.ckpt
__pycache__/

View File

@ -1,80 +0,0 @@
================ Training Loss (Sun Feb 23 15:46:44 2025) ================
================ Training Loss (Sun Feb 23 15:52:29 2025) ================
================ Training Loss (Sun Feb 23 16:00:07 2025) ================
================ Training Loss (Sun Feb 23 16:02:40 2025) ================
================ Training Loss (Sun Feb 23 16:05:19 2025) ================
================ Training Loss (Sun Feb 23 16:06:44 2025) ================
================ Training Loss (Sun Feb 23 16:09:38 2025) ================
================ Training Loss (Sun Feb 23 16:44:56 2025) ================
================ Training Loss (Sun Feb 23 16:49:46 2025) ================
================ Training Loss (Sun Feb 23 16:51:03 2025) ================
================ Training Loss (Sun Feb 23 16:51:23 2025) ================
================ Training Loss (Sun Feb 23 18:04:02 2025) ================
================ Training Loss (Sun Feb 23 18:04:39 2025) ================
================ Training Loss (Sun Feb 23 18:05:17 2025) ================
================ Training Loss (Sun Feb 23 18:06:40 2025) ================
================ Training Loss (Sun Feb 23 18:11:48 2025) ================
================ Training Loss (Sun Feb 23 18:13:31 2025) ================
================ Training Loss (Sun Feb 23 18:14:11 2025) ================
================ Training Loss (Sun Feb 23 18:14:29 2025) ================
================ Training Loss (Sun Feb 23 18:16:27 2025) ================
================ Training Loss (Sun Feb 23 18:16:44 2025) ================
================ Training Loss (Sun Feb 23 18:20:39 2025) ================
================ Training Loss (Sun Feb 23 18:21:44 2025) ================
================ Training Loss (Sun Feb 23 18:35:27 2025) ================
================ Training Loss (Sun Feb 23 18:39:21 2025) ================
================ Training Loss (Sun Feb 23 18:40:15 2025) ================
================ Training Loss (Sun Feb 23 18:41:15 2025) ================
================ Training Loss (Sun Feb 23 18:47:46 2025) ================
================ Training Loss (Sun Feb 23 18:48:36 2025) ================
================ Training Loss (Sun Feb 23 18:50:20 2025) ================
================ Training Loss (Sun Feb 23 18:51:50 2025) ================
================ Training Loss (Sun Feb 23 18:58:45 2025) ================
================ Training Loss (Sun Feb 23 18:59:52 2025) ================
================ Training Loss (Sun Feb 23 19:03:05 2025) ================
================ Training Loss (Sun Feb 23 19:03:57 2025) ================
================ Training Loss (Sun Feb 23 21:11:47 2025) ================
================ Training Loss (Sun Feb 23 21:17:10 2025) ================
================ Training Loss (Sun Feb 23 21:20:14 2025) ================
================ Training Loss (Sun Feb 23 21:29:03 2025) ================
================ Training Loss (Sun Feb 23 21:34:57 2025) ================
================ Training Loss (Sun Feb 23 21:35:26 2025) ================
================ Training Loss (Sun Feb 23 22:28:43 2025) ================
================ Training Loss (Sun Feb 23 22:29:04 2025) ================
================ Training Loss (Sun Feb 23 22:29:52 2025) ================
================ Training Loss (Sun Feb 23 22:30:40 2025) ================
================ Training Loss (Sun Feb 23 22:33:48 2025) ================
================ Training Loss (Sun Feb 23 22:39:16 2025) ================
================ Training Loss (Sun Feb 23 22:39:48 2025) ================
================ Training Loss (Sun Feb 23 22:41:34 2025) ================
================ Training Loss (Sun Feb 23 22:42:01 2025) ================
================ Training Loss (Sun Feb 23 22:44:17 2025) ================
================ Training Loss (Sun Feb 23 22:45:53 2025) ================
================ Training Loss (Sun Feb 23 22:46:48 2025) ================
================ Training Loss (Sun Feb 23 22:47:42 2025) ================
================ Training Loss (Sun Feb 23 22:49:44 2025) ================
================ Training Loss (Sun Feb 23 22:50:29 2025) ================
================ Training Loss (Sun Feb 23 22:51:47 2025) ================
================ Training Loss (Sun Feb 23 22:55:56 2025) ================
================ Training Loss (Sun Feb 23 22:56:19 2025) ================
================ Training Loss (Sun Feb 23 22:57:58 2025) ================
================ Training Loss (Sun Feb 23 22:59:09 2025) ================
================ Training Loss (Sun Feb 23 23:02:36 2025) ================
================ Training Loss (Sun Feb 23 23:03:56 2025) ================
================ Training Loss (Sun Feb 23 23:09:21 2025) ================
================ Training Loss (Sun Feb 23 23:10:05 2025) ================
================ Training Loss (Sun Feb 23 23:11:43 2025) ================
================ Training Loss (Sun Feb 23 23:12:41 2025) ================
================ Training Loss (Sun Feb 23 23:13:05 2025) ================
================ Training Loss (Sun Feb 23 23:13:59 2025) ================
================ Training Loss (Sun Feb 23 23:14:59 2025) ================
================ Training Loss (Mon Feb 24 21:53:50 2025) ================
================ Training Loss (Mon Feb 24 21:54:16 2025) ================
================ Training Loss (Mon Feb 24 21:54:50 2025) ================
================ Training Loss (Mon Feb 24 21:55:31 2025) ================
================ Training Loss (Mon Feb 24 21:56:10 2025) ================
================ Training Loss (Mon Feb 24 22:09:38 2025) ================
================ Training Loss (Mon Feb 24 22:10:16 2025) ================
================ Training Loss (Mon Feb 24 22:12:46 2025) ================
================ Training Loss (Mon Feb 24 22:13:04 2025) ================
================ Training Loss (Mon Feb 24 22:14:04 2025) ================

View File

@ -1,88 +0,0 @@
----------------- Options ---------------
adj_size_list: [2, 4, 6, 8, 12]
atten_layers: 1,3,5
batch_size: 1
beta1: 0.5
beta2: 0.999
checkpoints_dir: ./checkpoints
continue_train: False
crop_size: 256
dataroot: /home/openxs/kunyu/datasets/InfraredCity-Lite/Double/Moitor [default: placeholder]
dataset_mode: unaligned_double [default: unaligned]
direction: AtoB
display_env: ROMA [default: main]
display_freq: 50
display_id: None
display_ncols: 4
display_port: 8097
display_server: http://localhost
display_winsize: 256
easy_label: experiment_name
epoch: latest
epoch_count: 1
eta_ratio: 0.1
evaluation_freq: 5000
flip_equivariance: False
gan_mode: lsgan
gpu_ids: 0
init_gain: 0.02
init_type: xavier
input_nc: 3
isTrain: True [default: None]
lambda_D_ViT: 1.0
lambda_GAN: 8.0 [default: 1.0]
lambda_NCE: 8.0 [default: 1.0]
lambda_SB: 0.1
lambda_ctn: 1.0
lambda_global: 1.0
lambda_inc: 1.0
lmda_1: 0.1
load_size: 286
lr: 1e-05 [default: 0.0002]
lr_decay_iters: 50
lr_policy: linear
max_dataset_size: inf
model: roma_unsb [default: cut]
n_epochs: 100
n_epochs_decay: 100
n_layers_D: 3
n_mlp: 3
name: ROMA_UNSB_001 [default: experiment_name]
nce_T: 0.07
nce_idt: False [default: True]
nce_includes_all_negatives_from_minibatch: False
nce_layers: 0,4,8,12,16
ndf: 64
netD: basic_cond
netF: mlp_sample
netF_nc: 256
netG: resnet_9blocks_cond
ngf: 64
no_antialias: False
no_antialias_up: False
no_dropout: True
no_flip: True [default: False]
no_html: False
normD: instance
normG: instance
num_patches: 256
num_threads: 4
num_timesteps: 10 [default: 5]
output_nc: 3
phase: train
pool_size: 0
preprocess: resize_and_crop
pretrained_name: None
print_freq: 100
random_scale_max: 3.0
save_by_iter: False
save_epoch_freq: 5
save_latest_freq: 5000
serial_batches: False
stylegan2_G_num_downsampling: 1
suffix:
tau: 0.01
update_html_freq: 1000
use_idt: False
verbose: False
----------------- End -------------------

View File

@ -13,7 +13,7 @@ import os.path
IMG_EXTENSIONS = [ IMG_EXTENSIONS = [
'.jpg', '.JPG', '.jpeg', '.JPEG', '.jpg', '.JPG', '.jpeg', '.JPEG',
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
'.tif', '.TIF', '.tiff', '.TIFF', '.tif', '.TIF', '.tiff', '.TIFF', '.pth',
] ]

View File

@ -0,0 +1,86 @@
import os.path
from data.base_dataset import BaseDataset, get_transform
from data.image_folder import make_dataset
from PIL import Image
import random
import util.util as util
from glob import glob
import torch
class UnalignedDataset(BaseDataset):
"""
This dataset class can load unaligned/unpaired datasets.
It requires two directories to host training images from domain A '/path/to/data/trainA'
and from domain B '/path/to/data/trainB' respectively.
You can train the model with the dataset flag '--dataroot /path/to/data'.
Similarly, you need to prepare two directories:
'/path/to/data/testA' and '/path/to/data/testB' during test time.
"""
def __init__(self, opt):
"""Initialize this dataset class.
Parameters:
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
"""
BaseDataset.__init__(self, opt)
self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A') # create a path '/path/to/data/trainA'
self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B') # create a path '/path/to/data/trainB'
self.dir_A_logi = '/home/openxs/kunyu/datasets/InfraredCity-Lite/Single/Monitor/trainA_dino'
if opt.phase == "test" and not os.path.exists(self.dir_A) \
and os.path.exists(os.path.join(opt.dataroot, "valA")):
self.dir_A = os.path.join(opt.dataroot, "valA")
self.dir_B = os.path.join(opt.dataroot, "valB")
self.A_paths = sorted(make_dataset(self.dir_A, opt.max_dataset_size)) # load images from '/path/to/data/trainA'
self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size)) # load images from '/path/to/data/trainB'
self.A_logi_paths = sorted(make_dataset(self.dir_A_logi, opt.max_dataset_size))
self.A_size = len(self.A_paths) # get the size of dataset A
self.B_size = len(self.B_paths) # get the size of dataset B
def __getitem__(self, index):
"""Return a data point and its metadata information.
Parameters:
index (int) -- a random integer for data indexing
Returns a dictionary that contains A, B, A_paths and B_paths
A (tensor) -- an image in the input domain
B (tensor) -- its corresponding image in the target domain
A_paths (str) -- image paths
B_paths (str) -- image paths
"""
A_path = self.A_paths[index % self.A_size] # make sure index is within then range
A_logi_path = self.A_logi_paths[index % self.A_size]
if self.opt.serial_batches: # make sure index is within then range
index_B = index % self.B_size
else: # randomize the index for domain B to avoid fixed pairs.
index_B = random.randint(0, self.B_size - 1)
B_path = self.B_paths[index_B]
A_img = Image.open(A_path).convert('RGB')
B_img = Image.open(B_path).convert('RGB')
# shape: [1, 150, 256, 256]
A_logi = torch.load(A_logi_path, map_location=f'cuda:{self.opt.gpu_id}')
# Apply image transformation
# For FastCUT mode, if in finetuning phase (learning rate is decaying),
# do not perform resize-crop data augmentation of CycleGAN.
# print('current_epoch', self.current_epoch)
is_finetuning = self.opt.isTrain and self.current_epoch > self.opt.n_epochs
modified_opt = util.copyconf(self.opt, load_size=self.opt.crop_size if is_finetuning else self.opt.load_size)
transform = get_transform(modified_opt)
A = transform(A_img)
B = transform(B_img)
return {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path, 'A_logi': A_logi, 'A_logi_paths': A_logi_path}
def __len__(self):
"""Return the total number of images in the dataset.
As we have two datasets with potentially different number of images,
we take a maximum of
"""
return max(self.A_size, self.B_size)

View File

@ -2,6 +2,7 @@ import numpy as np
import math import math
import timm import timm
import torch import torch
import torchvision.models as models
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torchvision.transforms import GaussianBlur from torchvision.transforms import GaussianBlur
@ -60,156 +61,69 @@ def compute_ctn_loss(G, x, F_content): #公式10
loss = F.mse_loss(warped_fake, y_fake_warped) loss = F.mse_loss(warped_fake, y_fake_warped)
return loss return loss
class ContentAwareOptimization(nn.Module):
class ContentAwareOptimization(nn.Module):
def __init__(self, lambda_inc=2.0, eta_ratio=0.4): def __init__(self, lambda_inc=2.0, eta_ratio=0.4):
super().__init__() super().__init__()
self.lambda_inc = lambda_inc # 权重增强系数 self.lambda_inc = lambda_inc
self.eta_ratio = eta_ratio # 选择内容区域的比例 self.eta_ratio = eta_ratio
self.gradients_real = []
def compute_cosine_similarity(self, gradients): self.gradients_fake = []
"""
计算每个patch梯度与平均梯度的余弦相似度
Args:
gradients: [B, N, D] 判别器输出的每个patch的梯度(N=w*h)
Returns:
cosine_sim: [B, N] 每个patch的余弦相似度
"""
mean_grad = torch.mean(gradients, dim=1, keepdim=True) # [B, 1, D]
# 计算余弦相似度
cosine_sim = F.cosine_similarity(gradients, mean_grad, dim=2) # [B, N]
return cosine_sim
def generate_weight_map(self, gradients_fake, feature_shape):
"""
生成内容感知权重图修正空间维度
Args:
gradients_real: [B, N, D] 真实图像判别器梯度
gradients_fake: [B, N, D] 生成图像判别器梯度
feature_shape: tuple [H, W] 判别器输出的特征图尺寸
Returns:
weight_real: [B, 1, H, W] 真实图像权重图
weight_fake: [B, 1, H, W] 生成图像权重图
"""
H, W = feature_shape
N = H * W
# 计算余弦相似度(与原代码相同)
def compute_cosine_similarity(self, gradients):
mean_grad = torch.mean(gradients, dim=1, keepdim=True)
return F.cosine_similarity(gradients, mean_grad, dim=2)
def generate_weight_map(self, gradients_real, gradients_fake):
# 计算余弦相似度
cosine_real = self.compute_cosine_similarity(gradients_real)
cosine_fake = self.compute_cosine_similarity(gradients_fake) cosine_fake = self.compute_cosine_similarity(gradients_fake)
# 生成权重图(与原代码相同) # 生成权重图(优化实现)
k = int(self.eta_ratio * cosine_fake.shape[1]) def _get_weights(cosine):
_, fake_indices = torch.topk(-cosine_fake, k, dim=1) k = int(self.eta_ratio * cosine.shape[1])
weight_fake = torch.ones_like(cosine_fake) _, indices = torch.topk(-cosine, k, dim=1)
weights = torch.ones_like(cosine)
for b in range(cosine_fake.shape[0]): weights.scatter_(1, indices, self.lambda_inc / (1e-6 + torch.abs(cosine.gather(1, indices))))
weight_fake[b, fake_indices[b]] = self.lambda_inc / (1e-6 + torch.abs(cosine_fake[b, fake_indices[b]])) return weights
# 重建空间维度 --------------------------------------------------
# 将权重从[B, N]转换为[B, H, W]
#print(f"Shape of weight_fake before view: {weight_fake.shape}")
#print(f"Shape of cosine_fake: {cosine_fake.shape}")
#print(f"H: {H}, W: {W}, N: {N}")
weight_fake = weight_fake.view(-1, H, W).unsqueeze(1) # [B,1,H,W]
return weight_fake weight_real = _get_weights(cosine_real)
weight_fake = _get_weights(cosine_fake)
def compute_cosine_similarity_image(self, gradients): return weight_real, weight_fake
"""
计算每个空间位置梯度与平均梯度的余弦相似度 (图像版本)
Args:
gradients: [B, C, H, W] 判别器输出的梯度
Returns:
cosine_sim: [B, H, W] 每个空间位置的余弦相似度
"""
# 将空间维度展平,以便计算所有空间位置的平均梯度
B, C, H, W = gradients.shape
gradients_reshaped = gradients.view(B, C, H * W) # [B, C, N] where N = H*W
gradients_transposed = gradients_reshaped.transpose(1, 2) # [B, N, C] 将C放到最后一维方便计算空间位置的平均梯度
mean_grad = torch.mean(gradients_transposed, dim=1, keepdim=True) # [B, 1, C] 在空间位置维度上求平均,得到平均梯度 [B, 1, C]
# mean_grad 现在是所有空间位置的平均梯度,形状为 [B, 1, C]
# 为了计算余弦相似度,我们需要将 mean_grad 扩展到与 gradients_transposed 相同的空间维度
mean_grad_expanded = mean_grad.expand(-1, H * W, -1) # [B, N, C]
# 计算余弦相似度dim=2 表示在特征维度 (C) 上计算
cosine_sim = F.cosine_similarity(gradients_transposed, mean_grad_expanded, dim=2) # [B, N]
# 将 cosine_sim 重新reshape回 [B, H, W]
cosine_sim = cosine_sim.view(B, H, W)
return cosine_sim
def generate_weight_map_image(self, gradients_fake, feature_shape):
"""
生成内容感知权重图修正空间维度 - 图像版本
Args:
gradients_fake: [B, C, H, W] 生成图像判别器梯度
feature_shape: tuple [H, W] 判别器输出的特征图尺寸
Returns:
weight_fake: [B, 1, H, W] 生成图像权重图
"""
H, W = feature_shape
# 计算余弦相似度(图像版本)
cosine_fake = self.compute_cosine_similarity_image(gradients_fake) # [B, H, W]
# 生成权重图与原代码相同但现在cosine_fake是[B, H, W]
k = int(self.eta_ratio * H * W) # k 仍然是基于总的空间位置数量计算
_, fake_indices = torch.topk(-cosine_fake.view(cosine_fake.shape[0], -1), k, dim=1) # 将 cosine_fake 展平为 [B, N] 以使用 topk
weight_fake = torch.ones_like(cosine_fake).view(cosine_fake.shape[0], -1) # 初始化权重图,并展平为 [B, N]
for b in range(cosine_fake.shape[0]):
weight_fake[b, fake_indices[b]] = self.lambda_inc / (1e-6 + torch.abs(cosine_fake.view(cosine_fake.shape[0], -1)[b, fake_indices[b]]))
weight_fake = weight_fake.view(-1, H, W).unsqueeze(1) # 重新 reshape 为 [B, H, W],并添加通道维度变为 [B, 1, H, W]
return weight_fake
def forward(self, D_real, D_fake, real_scores, fake_scores): def forward(self, D_real, D_fake, real_scores, fake_scores):
""" # 清空梯度缓存
计算内容感知对抗损失 self.gradients_real.clear()
Args: self.gradients_fake.clear()
D_real: 判别器对真实图像的特征输出 [B, C, H, W] self.criterionGAN=networks.GANLoss('lsgan').cuda()
D_fake: 判别器对生成图像的特征输出 [B, C, H, W] # 注册钩子捕获梯度
real_scores: 真实图像的判别器预测 [B, N] (N=H*W) hook_real = lambda grad: self.gradients_real.append(grad.detach())
fake_scores: 生成图像的判别器预测 [B, N] hook_fake = lambda grad: self.gradients_fake.append(grad.detach())
Returns:
loss_co_adv: 内容感知对抗损失
"""
B, C, H, W = D_real.shape
N = H * W
shape_hw = [H, W]
# 注册钩子获取梯度
gradients_real = []
gradients_fake = []
def hook_real(grad):
gradients_real.append(grad.detach().view(B, N, -1))
def hook_fake(grad):
gradients_fake.append(grad.detach().view(B, N, -1))
D_real.register_hook(hook_real) D_real.register_hook(hook_real)
D_fake.register_hook(hook_fake) D_fake.register_hook(hook_fake)
# 计算原始对抗损失以触发梯度计算 # 触发梯度计算(保留计算图)
loss_real = torch.mean(torch.log(real_scores + 1e-8)) (real_scores.mean() + fake_scores.mean()).backward(retain_graph=True)
loss_fake = torch.mean(torch.log(1 - fake_scores + 1e-8))
# 添加与 D_real、D_fake 相关的 dummy 项,确保梯度传递 # 获取梯度并调整维度
loss_dummy = 1e-8 * (D_real.sum() + D_fake.sum()) grad_real = self.gradients_real[0].flatten(1) # [B, N, D] → [B, N*D]
total_loss = loss_real + loss_fake + loss_dummy grad_fake = self.gradients_fake[0].flatten(1)
total_loss.backward(retain_graph=True)
# 获取梯度数据
gradients_real = gradients_real[1] # [B, N, D]
gradients_fake = gradients_fake[1] # [B, N, D]
# 生成权重图 # 生成权重图
self.weight_real, self.weight_fake = self.generate_weight_map(gradients_fake, shape_hw ) weight_real, weight_fake = self.generate_weight_map(
grad_real.view(*D_real.shape),
grad_fake.view(*D_fake.shape)
)
# 应用权重到对抗损失 # 正确应用权重到对数概率论文公式7
loss_co_real = torch.mean(self.weight_real * torch.log(real_scores + 1e-8)) loss_co_real = torch.mean(weight_real * self.criterionGAN(real_scores , True))
loss_co_fake = torch.mean(self.weight_fake * torch.log(1 - fake_scores + 1e-8)) loss_co_fake = torch.mean(weight_fake * self.criterionGAN(fake_scores , False))
# 计算并返回最终内容感知对抗损失 # 总损失(注意符号:判别器需最大化该损失)
loss_co_adv = -(loss_co_real + loss_co_fake) loss_co_adv = (loss_co_real + loss_co_fake)*0.5
return loss_co_adv return loss_co_adv, weight_real, weight_fake
class ContentAwareTemporalNorm(nn.Module): class ContentAwareTemporalNorm(nn.Module):
def __init__(self, gamma_stride=0.1, kernel_size=21, sigma=5.0): def __init__(self, gamma_stride=0.1, kernel_size=21, sigma=5.0):
@ -217,6 +131,33 @@ class ContentAwareTemporalNorm(nn.Module):
self.gamma_stride = gamma_stride # 控制整体运动幅度 self.gamma_stride = gamma_stride # 控制整体运动幅度
self.smoother = GaussianBlur(kernel_size, sigma=sigma) # 高斯平滑层 self.smoother = GaussianBlur(kernel_size, sigma=sigma) # 高斯平滑层
def upsample_weight_map(self, weight_patch, target_size=(256, 256)):
"""
将patch级别的权重图上采样到目标分辨率
Args:
weight_patch: [B, 1, 24, 24] 来自ViT的patch权重图
target_size: 目标分辨率 (H, W)
Returns:
weight_full: [B, 1, 256, 256] 上采样后的全分辨率权重图
"""
# 使用双线性插值上采样
B = weight_patch.shape[0]
weight_patch = weight_patch.view(B, 1, 24, 24)
weight_full = F.interpolate(
weight_patch,
size=target_size,
mode='bilinear',
align_corners=False
)
# 对每个16x16的patch内部保持权重一致可选
# 通过平均池化再扩展,消除插值引入的渐变
weight_full = F.avg_pool2d(weight_full, kernel_size=16, stride=16)
weight_full = F.interpolate(weight_full, scale_factor=16, mode='nearest')
return weight_full
def forward(self, weight_map): def forward(self, weight_map):
""" """
生成内容感知光流 生成内容感知光流
@ -225,15 +166,16 @@ class ContentAwareTemporalNorm(nn.Module):
Returns: Returns:
F_content: [B, 2, H, W] 生成的光流场(x/y方向位移) F_content: [B, 2, H, W] 生成的光流场(x/y方向位移)
""" """
print(weight_map.shape) # 上采样权重图到全分辨率
B, _, H, W = weight_map.shape weight_full = self.upsample_weight_map(weight_map) # [B,1,384,384]
# 1. 归一化权重图 # 1. 归一化权重图
# 保持区域相对强度,同时限制数值范围 # 保持区域相对强度,同时限制数值范围
weight_norm = F.normalize(weight_map, p=1, dim=(2,3)) # L1归一化 [B,1,H,W] weight_norm = F.normalize(weight_full, p=1, dim=(2,3)) # L1归一化 [B,1,H,W]
# 2. 生成高斯噪声(与光流场同尺寸) # 2. 生成高斯噪声
z = torch.randn(B, 2, H, W, device=weight_map.device) # [B,2,H,W] B, _, H, W = weight_norm.shape
z = torch.randn(B, 2, H, W, device=weight_norm.device) # [B,2,H,W]
# 3. 合成基础光流 # 3. 合成基础光流
# 将权重图扩展为2通道(x/y方向共享权重) # 将权重图扩展为2通道(x/y方向共享权重)
@ -248,7 +190,7 @@ class ContentAwareTemporalNorm(nn.Module):
# 限制光流幅值,避免极端位移 # 限制光流幅值,避免极端位移
F_content = torch.tanh(F_smooth) # 缩放到[-1,1]范围 F_content = torch.tanh(F_smooth) # 缩放到[-1,1]范围
return F_content return F_content
class RomaUnsbModel(BaseModel): class RomaUnsbModel(BaseModel):
@staticmethod @staticmethod
@ -256,44 +198,32 @@ class RomaUnsbModel(BaseModel):
"""配置 CTNx 模型的特定选项""" """配置 CTNx 模型的特定选项"""
parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss: GAN(G(X))') parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss: GAN(G(X))')
parser.add_argument('--lambda_NCE', type=float, default=1.0, help='weight for NCE loss: NCE(G(X), X)')
parser.add_argument('--lambda_SB', type=float, default=0.1, help='weight for SB loss') parser.add_argument('--lambda_SB', type=float, default=0.1, help='weight for SB loss')
parser.add_argument('--lambda_ctn', type=float, default=1.0, help='weight for content-aware temporal norm') parser.add_argument('--lambda_ctn', type=float, default=1.0, help='weight for content-aware temporal norm')
parser.add_argument('--lambda_D_ViT', type=float, default=1.0, help='weight for discriminator') parser.add_argument('--lambda_D_ViT', type=float, default=1.0, help='weight for discriminator')
parser.add_argument('--lambda_global', type=float, default=1.0, help='weight for Global Structural Consistency') parser.add_argument('--lambda_global', type=float, default=1.0, help='weight for Global Structural Consistency')
parser.add_argument('--lambda_spatial', type=float, default=1.0, help='weight for Local Structural Consistency')
parser.add_argument('--lambda_inc', type=float, default=1.0, help='incremental weight for content-aware optimization')
parser.add_argument('--local_nums', type=int, default=64, help='number of local patches')
parser.add_argument('--side_length', type=int, default=7)
parser.add_argument('--nce_idt', type=util.str2bool, nargs='?', const=True, default=False, help='use NCE loss for identity mapping: NCE(G(Y), Y))') parser.add_argument('--nce_idt', type=util.str2bool, nargs='?', const=True, default=False, help='use NCE loss for identity mapping: NCE(G(Y), Y))')
parser.add_argument('--nce_layers', type=str, default='0,4,8,12,16', help='compute NCE loss on which layers')
parser.add_argument('--nce_includes_all_negatives_from_minibatch', parser.add_argument('--nce_includes_all_negatives_from_minibatch',
type=util.str2bool, nargs='?', const=True, default=False, type=util.str2bool, nargs='?', const=True, default=False,
help='(used for single image translation) If True, include the negatives from the other samples of the minibatch when computing the contrastive loss. Please see models/patchnce.py for more details.') help='(used for single image translation) If True, include the negatives from the other samples of the minibatch when computing the contrastive loss. Please see models/patchnce.py for more details.')
parser.add_argument('--nce_layers', type=str, default='0,4,8,12,16', help='compute NCE loss on which layers')
parser.add_argument('--netF', type=str, default='mlp_sample', choices=['sample', 'reshape', 'mlp_sample'], help='how to downsample the feature map') parser.add_argument('--netF', type=str, default='mlp_sample', choices=['sample', 'reshape', 'mlp_sample'], help='how to downsample the feature map')
parser.add_argument('--netF_nc', type=int, default=256)
parser.add_argument('--nce_T', type=float, default=0.07, help='temperature for NCE loss') parser.add_argument('--eta_ratio', type=float, default=0.4, help='ratio of content-rich regions')
parser.add_argument('--gamma_stride', type=float, default=20, help='ratio of stride for computing the similarity matrix')
parser.add_argument('--lmda_1', type=float, default=0.1)
parser.add_argument('--num_patches', type=int, default=256, help='number of patches per layer')
parser.add_argument('--flip_equivariance',
type=util.str2bool, nargs='?', const=True, default=False,
help="Enforce flip-equivariance as additional regularization. It's used by FastCUT, but not CUT")
parser.add_argument('--lambda_inc', type=float, default=1.0, help='incremental weight for content-aware optimization')
parser.add_argument('--eta_ratio', type=float, default=0.1, help='ratio of content-rich regions')
parser.add_argument('--atten_layers', type=str, default='5', help='compute Cross-Similarity on which layers') parser.add_argument('--atten_layers', type=str, default='5', help='compute Cross-Similarity on which layers')
parser.add_argument('--tau', type=float, default=0.01, help='Entropy parameter') parser.add_argument('--tau', type=float, default=0.01, help='Entropy parameter')
parser.add_argument('--num_timesteps', type=int, default=5, help='# of discrim filters in the first conv layer') parser.add_argument('--num_timesteps', type=int, default=5, help='# of discrim filters in the first conv layer')
parser.add_argument('--adj_size_list', type=list, default=[2, 4, 6, 8, 12], help='different scales of perception field')
parser.add_argument('--n_mlp', type=int, default=3, help='only used if netD==n_layers') parser.add_argument('--n_mlp', type=int, default=3, help='only used if netD==n_layers')
parser.set_defaults(pool_size=0) # no image pooling
opt, _ = parser.parse_known_args() opt, _ = parser.parse_known_args()
# 直接设置为 sb 模式
parser.set_defaults(nce_idt=True, lambda_NCE=1.0)
return parser return parser
@ -302,11 +232,11 @@ class RomaUnsbModel(BaseModel):
BaseModel.__init__(self, opt) BaseModel.__init__(self, opt)
# 指定需要打印的训练损失 # 指定需要打印的训练损失
self.loss_names = ['G_GAN_1', 'D_real_1', 'D_fake_1', 'G_1', 'NCE_1', 'SB_1', self.loss_names = ['G_GAN', 'D_ViT', 'G', 'global', 'spatial','ctn']
'G_2'] self.visual_names = ['real_A0', 'fake_B0_1','fake_B0', 'real_B0','real_A1', 'fake_B1_1', 'fake_B1', 'real_B1']
self.visual_names = ['real_A', 'real_A_noisy', 'fake_B', 'real_B']
self.atten_layers = [int(i) for i in self.opt.atten_layers.split(',')] self.atten_layers = [int(i) for i in self.opt.atten_layers.split(',')]
if self.opt.phase == 'test': if self.opt.phase == 'test':
self.visual_names = ['real'] self.visual_names = ['real']
for NFE in range(self.opt.num_timesteps): for NFE in range(self.opt.num_timesteps):
@ -314,24 +244,18 @@ class RomaUnsbModel(BaseModel):
self.visual_names.append(fake_name) self.visual_names.append(fake_name)
self.nce_layers = [int(i) for i in self.opt.nce_layers.split(',')] self.nce_layers = [int(i) for i in self.opt.nce_layers.split(',')]
if opt.nce_idt and self.isTrain:
self.loss_names += ['NCE_Y']
self.visual_names += ['idt_B']
if self.isTrain: if self.isTrain:
self.model_names = ['G', 'D_ViT', 'E'] self.model_names = ['G', 'D_ViT']
else: else:
self.model_names = ['G'] self.model_names = ['G']
print(f'input_nc = {self.opt.input_nc}')
# 创建网络 # 创建网络
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt) self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt)
if self.isTrain: if self.isTrain:
self.netE = networks.define_D(opt.output_nc*4, opt.ndf, opt.netD, opt.n_layers_D, opt.normD, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt)
self.resize = tfs.Resize(size=(384,384), antialias=True) self.resize = tfs.Resize(size=(384,384), antialias=True)
@ -343,14 +267,9 @@ class RomaUnsbModel(BaseModel):
# 定义损失函数 # 定义损失函数
self.criterionL1 = torch.nn.L1Loss().to(self.device) self.criterionL1 = torch.nn.L1Loss().to(self.device)
self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
self.criterionNCE = []
for nce_layer in self.nce_layers:
self.criterionNCE.append(PatchNCELoss(opt).to(self.device))
self.criterionIdt = torch.nn.L1Loss().to(self.device)
self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2)) self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
self.optimizer_D = torch.optim.Adam(self.netD_ViT.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2)) self.optimizer_D = torch.optim.Adam(self.netD_ViT.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
self.optimizer_E = torch.optim.Adam(self.netE.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2)) self.optimizers = [self.optimizer_G, self.optimizer_D]
self.optimizers = [self.optimizer_G, self.optimizer_D, self.optimizer_E]
self.cao = ContentAwareOptimization(opt.lambda_inc, opt.eta_ratio) #损失函数 self.cao = ContentAwareOptimization(opt.lambda_inc, opt.eta_ratio) #损失函数
self.ctn = ContentAwareTemporalNorm() #生成的伪光流 self.ctn = ContentAwareTemporalNorm() #生成的伪光流
@ -362,19 +281,6 @@ class RomaUnsbModel(BaseModel):
initialized at the first feedforward pass with some input images. initialized at the first feedforward pass with some input images.
Please also see PatchSampleF.create_mlp(), which is called at the first forward() call. Please also see PatchSampleF.create_mlp(), which is called at the first forward() call.
""" """
#bs_per_gpu = data["A"].size(0) // max(len(self.opt.gpu_ids), 1)
#self.set_input(data)
#self.real_A = self.real_A[:bs_per_gpu]
#self.real_B = self.real_B[:bs_per_gpu]
#self.forward() # compute fake images: G(A)
#if self.opt.isTrain:
#
# self.compute_G_loss().backward()
# self.compute_D_loss().backward()
# self.compute_E_loss().backward()
# if self.opt.lambda_NCE > 0.0:
# self.optimizer_F = torch.optim.Adam(self.netF.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, self.opt.beta2))
# self.optimizers.append(self.optimizer_F)
pass pass
def optimize_parameters(self): def optimize_parameters(self):
@ -382,7 +288,6 @@ class RomaUnsbModel(BaseModel):
self.forward() self.forward()
self.netG.train() self.netG.train()
self.netE.train()
self.netD_ViT.train() self.netD_ViT.train()
# update D # update D
@ -392,19 +297,9 @@ class RomaUnsbModel(BaseModel):
self.loss_D.backward() self.loss_D.backward()
self.optimizer_D.step() self.optimizer_D.step()
# update E
self.set_requires_grad(self.netE, True)
self.optimizer_E.zero_grad()
self.loss_E = self.compute_E_loss()
self.loss_E.backward()
self.optimizer_E.step()
# update G # update G
self.set_requires_grad(self.netD_ViT, False) self.set_requires_grad(self.netD_ViT, False)
self.set_requires_grad(self.netE, False)
self.optimizer_G.zero_grad() self.optimizer_G.zero_grad()
self.loss_G = self.compute_G_loss() self.loss_G = self.compute_G_loss()
self.loss_G.backward() self.loss_G.backward()
self.optimizer_G.step() self.optimizer_G.step()
@ -423,38 +318,7 @@ class RomaUnsbModel(BaseModel):
self.real_B1 = input['B1' if AtoB else 'A1'].to(self.device) self.real_B1 = input['B1' if AtoB else 'A1'].to(self.device)
self.image_paths = input['A_paths' if AtoB else 'B_paths'] self.image_paths = input['A_paths' if AtoB else 'B_paths']
def tokens_concat(self, origin_tokens, adjacent_size):
adj_size = adjacent_size
B, token_num, C = origin_tokens.shape[0], origin_tokens.shape[1], origin_tokens.shape[2]
S = int(math.sqrt(token_num))
if S * S != token_num:
print('Error! Not a square!')
token_map = origin_tokens.clone().reshape(B,S,S,C)
cut_patch_list = []
for i in range(0, S, adj_size):
for j in range(0, S, adj_size):
i_left = i
i_right = i + adj_size + 1 if i + adj_size <= S else S + 1
j_left = j
j_right = j + adj_size if j + adj_size <= S else S + 1
cut_patch = token_map[:, i_left:i_right, j_left: j_right, :]
cut_patch= cut_patch.reshape(B,-1,C)
cut_patch = torch.mean(cut_patch, dim=1, keepdim=True)
cut_patch_list.append(cut_patch)
result = torch.cat(cut_patch_list,dim=1)
return result
def cat_results(self, origin_tokens, adj_size_list):
res_list = [origin_tokens]
for ad_s in adj_size_list:
cat_result = self.tokens_concat(origin_tokens, ad_s)
res_list.append(cat_result)
result = torch.cat(res_list, dim=1)
return result
def forward(self): def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>.""" """Run forward pass; called by both functions <optimize_parameters> and <test>."""
@ -471,7 +335,9 @@ class RomaUnsbModel(BaseModel):
bs = self.real_A0.size(0) bs = self.real_A0.size(0)
time_idx = (torch.randint(T, size=[1]).cuda() * torch.ones(size=[1]).cuda()).long() time_idx = (torch.randint(T, size=[1]).cuda() * torch.ones(size=[1]).cuda()).long()
self.time_idx = time_idx self.time_idx = time_idx
self.fake_B0_list = []
self.fake_B1_list = []
with torch.no_grad(): with torch.no_grad():
self.netG.eval() self.netG.eval()
# ============ 第二步:对 real_A / real_A2 进行多步随机生成过程 ============ # ============ 第二步:对 real_A / real_A2 进行多步随机生成过程 ============
@ -488,36 +354,23 @@ class RomaUnsbModel(BaseModel):
(scale * tau).sqrt() * torch.randn_like(Xt).to(self.real_A0.device) (scale * tau).sqrt() * torch.randn_like(Xt).to(self.real_A0.device)
time_idx = (t * torch.ones(size=[self.real_A0.shape[0]]).to(self.real_A0.device)).long() time_idx = (t * torch.ones(size=[self.real_A0.shape[0]]).to(self.real_A0.device)).long()
z = torch.randn(size=[self.real_A0.shape[0], 4 * self.opt.ngf]).to(self.real_A0.device) z = torch.randn(size=[self.real_A0.shape[0], 4 * self.opt.ngf]).to(self.real_A0.device)
self.time = times[time_idx] time = times[time_idx]
Xt_1 = self.netG(Xt, self.time, z) Xt_1 = self.netG(Xt.detach(), time, z)
Xt2 = self.real_A1 if (t == 0) else (1 - inter) * Xt2 + inter * Xt_12.detach() + \ Xt2 = self.real_A1 if (t == 0) else (1 - inter) * Xt2 + inter * Xt_12.detach() + \
(scale * tau).sqrt() * torch.randn_like(Xt2).to(self.real_A1.device) (scale * tau).sqrt() * torch.randn_like(Xt2).to(self.real_A1.device)
time_idx = (t * torch.ones(size=[self.real_A1.shape[0]]).to(self.real_A1.device)).long() time_idx = (t * torch.ones(size=[self.real_A1.shape[0]]).to(self.real_A1.device)).long()
z = torch.randn(size=[self.real_A1.shape[0], 4 * self.opt.ngf]).to(self.real_A1.device) z = torch.randn(size=[self.real_A1.shape[0], 4 * self.opt.ngf]).to(self.real_A1.device)
Xt_12 = self.netG(Xt2, self.time, z) Xt_12 = self.netG(Xt2.detach(), time, z)
self.fake_B0_list.append(Xt_1)
# 保存去噪后的中间结果 (real_A_noisy 等),供下一步做拼接 self.fake_B1_list.append(Xt_12)
self.real_A_noisy = Xt.detach()
self.real_A_noisy2 = Xt2.detach()
# ============ 第三步:拼接输入并执行网络推理 =============
bs = self.real_A0.size(0)
z_in = torch.randn(size=[bs, 4 * self.opt.ngf]).to(self.real_A0.device)
z_in2 = torch.randn(size=[bs, 4 * self.opt.ngf]).to(self.real_A1.device)
# 将 real_A, real_B 拼接 (如 nce_idt=True),并同样处理 real_A_noisy 与 XtB
self.real = self.real_A0
self.realt = self.real_A_noisy
if self.opt.flip_equivariance:
self.flipped_for_equivariance = self.opt.isTrain and (np.random.random() < 0.5)
if self.flipped_for_equivariance:
self.real = torch.flip(self.real, [3])
self.realt = torch.flip(self.realt, [3])
self.fake_B0 = self.netG(self.real_A0, self.time, z_in) self.fake_B0_1 = self.fake_B0_list[0]
self.fake_B1 = self.netG(self.real_A1, self.time, z_in2) self.fake_B1_1 = self.fake_B0_list[0]
self.fake_B0 = self.fake_B0_list[-1]
self.fake_B1 = self.fake_B1_list[-1]
self.z_in = z
self.z_in2 = z
if self.opt.phase == 'train': if self.opt.phase == 'train':
real_A0 = self.real_A0 real_A0 = self.real_A0
real_A1 = self.real_A1 real_A1 = self.real_A1
@ -525,6 +378,16 @@ class RomaUnsbModel(BaseModel):
real_B1 = self.real_B1 real_B1 = self.real_B1
fake_B0 = self.fake_B0 fake_B0 = self.fake_B0
fake_B1 = self.fake_B1 fake_B1 = self.fake_B1
self.mutil_fake_B0_tokens_list = []
self.mutil_fake_B1_tokens_list = []
for fake_B0_t in self.fake_B0_list:
fake_B0_t_resize = self.resize(fake_B0_t) # 调整到 ViT 输入尺寸
tokens = self.netPreViT(fake_B0_t_resize, self.atten_layers, get_tokens=True)
self.mutil_fake_B0_tokens_list.append(tokens)
for fake_B1_t in self.fake_B1_list:
fake_B1_t_resize = self.resize(fake_B1_t)
tokens = self.netPreViT(fake_B1_t_resize, self.atten_layers, get_tokens=True)
self.mutil_fake_B1_tokens_list.append(tokens)
self.real_A0_resize = self.resize(real_A0) self.real_A0_resize = self.resize(real_A0)
self.real_A1_resize = self.resize(real_A1) self.real_A1_resize = self.resize(real_A1)
@ -532,119 +395,110 @@ class RomaUnsbModel(BaseModel):
real_B1 = self.resize(real_B1) real_B1 = self.resize(real_B1)
self.fake_B0_resize = self.resize(fake_B0) self.fake_B0_resize = self.resize(fake_B0)
self.fake_B1_resize = self.resize(fake_B1) self.fake_B1_resize = self.resize(fake_B1)
self.mutil_real_A0_tokens = self.netPreViT(self.real_A0_resize, self.atten_layers, get_tokens=True) self.mutil_real_A0_tokens = self.netPreViT(self.real_A0_resize, self.atten_layers, get_tokens=True)
self.mutil_real_A1_tokens = self.netPreViT(self.real_A1_resize, self.atten_layers, get_tokens=True) self.mutil_real_A1_tokens = self.netPreViT(self.real_A1_resize, self.atten_layers, get_tokens=True)
self.mutil_real_B0_tokens = self.netPreViT(real_B0, self.atten_layers, get_tokens=True) self.mutil_real_B0_tokens = self.netPreViT(real_B0, self.atten_layers, get_tokens=True)
self.mutil_real_B1_tokens = self.netPreViT(real_B1, self.atten_layers, get_tokens=True) self.mutil_real_B1_tokens = self.netPreViT(real_B1, self.atten_layers, get_tokens=True)
self.mutil_fake_B0_tokens = self.netPreViT(self.fake_B0_resize, self.atten_layers, get_tokens=True)
self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True)
# [[1,576,768],[1,576,768],[1,576,768]] # [[1,576,768],[1,576,768],[1,576,768]]
# [3,576,768] # [3,576,768]
#self.mutil_real_A0_tokens = self.cat_results(self.mutil_real_A0_tokens[0], self.opt.adj_size_list)
#print(f'self.mutil_real_A0_tokens[0]:{self.mutil_real_A0_tokens[0].shape}')
shape_hw = list(self.real_A0_resize.shape[2:4])
# 生成图像的梯度
fake_gradient = torch.autograd.grad(self.mutil_fake_B0_tokens[0].sum(), self.mutil_fake_B0_tokens, create_graph=True)[0]
# 梯度图
self.weight_fake = self.cao.generate_weight_map_image(fake_gradient, shape_hw)
# 生成图像的CTN光流图
self.f_content = self.ctn(self.weight_fake)
# 变换后的图片
self.warped_real_A_noisy2 = warp(self.real_A_noisy, self.f_content)
self.warped_fake_B0 = warp(self.fake_B0,self.f_content)
# 经过第二次生成器
self.warped_fake_B0_2 = self.netG(self.warped_real_A_noisy2, self.time, z_in)
# warped_fake_B0_2=self.warped_fake_B0_2
# warped_fake_B0=self.warped_fake_B0
# self.warped_fake_B0_2_resize = self.resize(warped_fake_B0_2)
# self.warped_fake_B0_resize = self.resize(warped_fake_B0)
# self.mutil_warped_fake_B0_tokens = self.netPreViT(self.warped_fake_B0_resize, self.atten_layers, get_tokens=True)
# self.mutil_fake_B0_2_tokens = self.netPreViT(self.warped_fake_B0_2_resize, self.atten_layers, get_tokens=True)
def compute_D_loss(self): #判别器还是没有改
"""Calculate GAN loss for the discriminator"""
def compute_D_loss(self):
"""Calculate GAN loss with Content-Aware Optimization"""
lambda_D_ViT = self.opt.lambda_D_ViT lambda_D_ViT = self.opt.lambda_D_ViT
fake_B0_tokens = self.mutil_fake_B0_tokens[0].detach()
fake_B1_tokens = self.mutil_fake_B1_tokens[0].detach() loss_cao = 0.0
real_B0_tokens = self.mutil_real_B0_tokens[0] real_B0_tokens = self.mutil_real_B0_tokens[0]
pred_real0, real_features0 = self.netD_ViT(real_B0_tokens) # scores, features
real_B1_tokens = self.mutil_real_B1_tokens[0] real_B1_tokens = self.mutil_real_B1_tokens[0]
pred_real1, real_features1 = self.netD_ViT(real_B1_tokens) # scores, features
for fake0_token, fake1_token in zip(self.mutil_fake_B0_tokens_list, self.mutil_fake_B1_tokens_list):
pre_fake0, fake_features0 = self.netD_ViT(fake0_token[0].detach())
pre_fake1, fake_features1 = self.netD_ViT(fake1_token[0].detach())
loss_cao0, self.weight_real0, self.weight_fake0 = self.cao(
D_real=real_features0,
D_fake=fake_features0,
real_scores=pred_real0,
fake_scores=pre_fake0
)
loss_cao1, self.weight_real1, self.weight_fake1 = self.cao(
D_real=real_features1,
D_fake=fake_features1,
real_scores=pred_real1,
fake_scores=pre_fake1
)
loss_cao += loss_cao0 + loss_cao1
pre_fake0_ViT = self.netD_ViT(fake_B0_tokens) # ===== 综合损失 =====
pre_fake1_ViT = self.netD_ViT(fake_B1_tokens) total_steps = len(self.fake_B0_list)
self.loss_D_ViT = loss_cao * 0.5 * lambda_D_ViT/ total_steps
self.loss_D_fake_ViT = (self.criterionGAN(pre_fake0_ViT, False).mean() + self.criterionGAN(pre_fake1_ViT, False).mean()) * 0.5 * lambda_D_ViT
pred_real0_ViT = self.netD_ViT(real_B0_tokens) # 记录损失值供可视化
pred_real1_ViT = self.netD_ViT(real_B1_tokens) # self.loss_D_real = loss_D_real.item()
self.loss_D_real_ViT = (self.criterionGAN(pred_real0_ViT, True).mean() + self.criterionGAN(pred_real1_ViT, True).mean()) * 0.5 * lambda_D_ViT # self.loss_D_fake = loss_D_fake.item()
# self.loss_cao = (loss_cao0 + loss_cao1).item() * 0.5
self.loss_D_ViT = (self.loss_D_fake_ViT + self.loss_D_real_ViT) * 0.5
return self.loss_D_ViT return self.loss_D_ViT
def compute_E_loss(self):
"""计算判别器 E 的损失"""
print(f'resl_A_noisy: {self.real_A_noisy.shape} \n fake_B0: {self.fake_B0.shape}')
XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B0.detach()], dim=1)
XtXt_2 = torch.cat([self.real_A_noisy2, self.fake_B1.detach()], dim=1)
temp = torch.logsumexp(self.netE(XtXt_1, self.time, XtXt_2).reshape(-1), dim=0).mean()
self.loss_E = -self.netE(XtXt_1, self.time, XtXt_1).mean() + temp + temp**2
return self.loss_E
def compute_G_loss(self): def compute_G_loss(self):
"""计算生成器的 GAN 损失""" """计算生成器的 GAN 损失"""
if self.opt.lambda_ctn > 0.0:
# 生成图像的CTN光流图
self.f_content0 = self.ctn(self.weight_fake0)
self.f_content1 = self.ctn(self.weight_fake1)
# 变换后的图片
self.warped_real_A0 = warp(self.real_A0, self.f_content0)
self.warped_real_A1 = warp(self.real_A1, self.f_content1)
self.warped_fake_B0 = warp(self.fake_B0,self.f_content0)
self.warped_fake_B1 = warp(self.fake_B1,self.f_content1)
# 经过第二次生成器
self.warped_fake_B0_2 = self.netG(self.warped_real_A0, self.times[torch.zeros(size=[1]).cuda().long()], self.z_in)
self.warped_fake_B1_2 = self.netG(self.warped_real_A1, self.times[torch.zeros(size=[1]).cuda().long()], self.z_in2)
warped_fake_B0_2=self.warped_fake_B0_2
warped_fake_B1_2=self.warped_fake_B1_2
warped_fake_B0=self.warped_fake_B0
warped_fake_B1=self.warped_fake_B1
# 计算L2损失
self.loss_ctn0 = F.mse_loss(warped_fake_B0_2, warped_fake_B0)
self.loss_ctn1 = F.mse_loss(warped_fake_B1_2, warped_fake_B1)
self.loss_ctn = (self.loss_ctn0 + self.loss_ctn1)*0.5
if self.opt.lambda_GAN > 0.0: if self.opt.lambda_GAN > 0.0:
pred_fake = self.netD_ViT(self.mutil_fake_B0_tokens[0])
self.loss_G_GAN = self.criterionGAN(pred_fake, True).mean() * self.opt.lambda_GAN pred_fake0,_ = self.netD_ViT(self.mutil_fake_B0_tokens_list[-1][0])
pred_fake1,_ = self.netD_ViT(self.mutil_fake_B1_tokens_list[-1][0])
self.loss_G_GAN0 = self.criterionGAN(pred_fake0, True).mean()
self.loss_G_GAN1 = self.criterionGAN(pred_fake1, True).mean()
self.loss_G_GAN = (self.loss_G_GAN0 + self.loss_G_GAN1)*0.5
else: else:
self.loss_G_GAN = 0.0 self.loss_G_GAN = 0.0
self.loss_SB = 0
if self.opt.lambda_SB > 0.0:
XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B0], dim=1)
XtXt_2 = torch.cat([self.real_A_noisy2, self.fake_B1], dim=1)
bs = self.opt.batch_size
# eq.9 if self.opt.lambda_global or self.opt.lambda_spatial > 0.0:
ET_XY = self.netE(XtXt_1, self.time, XtXt_1).mean() - torch.logsumexp(self.netE(XtXt_1, self.time, XtXt_2).reshape(-1), dim=0) self.loss_global, self.loss_spatial = self.calculate_attention_loss()
self.loss_SB = -(self.opt.num_timesteps - self.time[0]) / self.opt.num_timesteps * self.opt.tau * ET_XY
self.loss_SB += self.opt.tau * torch.mean((self.real_A_noisy - self.fake_B0) ** 2)
if self.opt.lambda_global > 0.0:
loss_global = self.calculate_similarity(self.real_A0, self.fake_B0) + self.calculate_similarity(self.real_A1, self.fake_B1)
loss_global *= 0.5
else: else:
loss_global = 0.0 self.loss_global, self.loss_spatial = 0.0, 0.0
self.l2_loss = 0.0 self.loss_G = self.opt.lambda_GAN * self.loss_G_GAN + \
if self.opt.lambda_l2 > 0.0: self.opt.lambda_ctn * self.loss_ctn + \
wapped_fake_B = warp(self.fake_B0, self.f_content) # use updated self.f_content self.loss_global * self.opt.lambda_global+\
self.l2_loss = F.mse_loss(self.warped_fake_B0_2, wapped_fake_B) # complete the loss calculation self.loss_spatial * self.opt.lambda_spatial
self.loss_G = self.loss_G_GAN + self.opt.lambda_SB * self.loss_SB + self.opt.lambda_ctn * self.l2_loss + loss_global * self.opt.lambda_global
return self.loss_G return self.loss_G
def calculate_attention_loss(self): def calculate_attention_loss(self):
n_layers = len(self.atten_layers) n_layers = len(self.atten_layers)
mutil_real_A0_tokens = self.mutil_real_A0_tokens mutil_real_A0_tokens = self.mutil_real_A0_tokens
mutil_real_A1_tokens = self.mutil_real_A1_tokens mutil_real_A1_tokens = self.mutil_real_A1_tokens
mutil_fake_B0_tokens = self.mutil_fake_B0_tokens mutil_fake_B0_tokens = self.mutil_fake_B0_tokens_list[-1]
mutil_fake_B1_tokens = self.mutil_fake_B1_tokens mutil_fake_B1_tokens = self.mutil_fake_B1_tokens_list[-1]
if self.opt.lambda_global > 0.0: if self.opt.lambda_global > 0.0:
@ -661,20 +515,19 @@ class RomaUnsbModel(BaseModel):
local_id = np.random.permutation(tokens_cnt) local_id = np.random.permutation(tokens_cnt)
local_id = local_id[:int(min(local_nums, tokens_cnt))] local_id = local_id[:int(min(local_nums, tokens_cnt))]
mutil_real_A0_local_tokens = self.netPreViT(self.resize(self.real_A0), self.atten_layers, get_tokens=True, local_id=local_id, side_length=self.opt.side_length) mutil_real_A0_local_tokens = self.netPreViT(self.real_A0_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
mutil_real_A1_local_tokens = self.netPreViT(self.resize(self.real_A1), self.atten_layers, get_tokens=True, local_id=local_id, side_length=self.opt.side_length) mutil_real_A1_local_tokens = self.netPreViT(self.real_A1_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
mutil_fake_B0_local_tokens = self.netPreViT(self.resize(self.fake_B0), self.atten_layers, get_tokens=True, local_id=local_id, side_length=self.opt.side_length) mutil_fake_B0_local_tokens = self.netPreViT(self.fake_B0_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
mutil_fake_B1_local_tokens = self.netPreViT(self.resize(self.fake_B1), self.atten_layers, get_tokens=True, local_id=local_id, side_length=self.opt.side_length) mutil_fake_B1_local_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
loss_spatial = self.calculate_similarity(mutil_real_A0_local_tokens, mutil_fake_B0_local_tokens) + self.calculate_similarity(mutil_real_A1_local_tokens, mutil_fake_B1_local_tokens) loss_spatial = self.calculate_similarity(mutil_real_A0_local_tokens, mutil_fake_B0_local_tokens) + self.calculate_similarity(mutil_real_A1_local_tokens, mutil_fake_B1_local_tokens)
loss_spatial *= 0.5 loss_spatial *= 0.5
else: else:
loss_spatial = 0.0 loss_spatial = 0.0
return loss_global , loss_spatial
return loss_global * self.opt.lambda_global, loss_spatial * self.opt.lambda_spatial
def calculate_similarity(self, mutil_src_tokens, mutil_tgt_tokens): def calculate_similarity(self, mutil_src_tokens, mutil_tgt_tokens):
loss = 0.0 loss = 0.0
n_layers = len(self.atten_layers) n_layers = len(self.atten_layers)
@ -688,5 +541,3 @@ class RomaUnsbModel(BaseModel):
loss = loss / n_layers loss = loss / n_layers
return loss return loss

View File

@ -31,7 +31,7 @@ class TrainOptions(BaseOptions):
parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>, ...') parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>, ...')
parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
parser.add_argument('--pretrained_name', type=str, default=None, help='resume training from another checkpoint') parser.add_argument('--pretrained_name', type=str, default=None, help='resume training from another checkpoint')
# training parameters # training parameters
parser.add_argument('--n_epochs', type=int, default=100, help='number of epochs with the initial learning rate') parser.add_argument('--n_epochs', type=int, default=100, help='number of epochs with the initial learning rate')
parser.add_argument('--n_epochs_decay', type=int, default=100, help='number of epochs to linearly decay learning rate to zero') parser.add_argument('--n_epochs_decay', type=int, default=100, help='number of epochs to linearly decay learning rate to zero')

301
roma.py
View File

@ -1,301 +0,0 @@
import numpy as np
import torch
from .base_model import BaseModel
from . import networks
from .patchnce import PatchNCELoss
import util.util as util
import timm
import time
import torch.nn.functional as F
import sys
from functools import partial
import torch.nn as nn
import math
from torchvision.transforms import transforms as tfs
class ROMAModel(BaseModel):
@staticmethod
def modify_commandline_options(parser, is_train=True):
""" Configures options specific for CUT model
"""
parser.add_argument('--adj_size_list', type=list, default=[2, 4, 6, 8, 12], help='different scales of perception field')
parser.add_argument('--lambda_mlp', type=float, default=1.0, help='weight of lr for discriminator')
parser.add_argument('--lambda_motion', type=float, default=1.0, help='weight for Temporal Consistency')
parser.add_argument('--lambda_D_ViT', type=float, default=1.0, help='weight for discriminator')
parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss: GAN(G(X))')
parser.add_argument('--lambda_global', type=float, default=1.0, help='weight for Global Structural Consistency')
parser.add_argument('--lambda_spatial', type=float, default=1.0, help='weight for Local Structural Consistency')
parser.add_argument('--atten_layers', type=str, default='1,3,5', help='compute Cross-Similarity on which layers')
parser.add_argument('--local_nums', type=int, default=256)
parser.add_argument('--which_D_layer', type=int, default=-1)
parser.add_argument('--side_length', type=int, default=7)
parser.set_defaults(pool_size=0)
opt, _ = parser.parse_known_args()
return parser
def __init__(self, opt):
BaseModel.__init__(self, opt)
self.loss_names = ['G_GAN_ViT', 'D_real_ViT', 'D_fake_ViT', 'global', 'spatial', 'motion']
self.visual_names = ['real_A0', 'real_A1', 'fake_B0', 'fake_B1', 'real_B0', 'real_B1']
self.atten_layers = [int(i) for i in self.opt.atten_layers.split(',')]
if self.isTrain:
self.model_names = ['G', 'D_ViT']
else: # during test time, only load G
self.model_names = ['G']
# define networks (both generator and discriminator)
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt)
if self.isTrain:
self.netD_ViT = networks.MLPDiscriminator().to(self.device)
self.netPreViT = timm.create_model("vit_base_patch16_384",pretrained=True).to(self.device)
self.norm = F.softmax
self.resize = tfs.Resize(size=(384,384))
self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
self.criterionNCE = []
for atten_layer in self.atten_layers:
self.criterionNCE.append(PatchNCELoss(opt).to(self.device))
self.criterionL1 = torch.nn.L1Loss().to(self.device)
self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
self.optimizer_D_ViT = torch.optim.Adam(self.netD_ViT.parameters(), lr=opt.lr * opt.lambda_mlp, betas=(opt.beta1, opt.beta2))
self.optimizers.append(self.optimizer_G)
self.optimizers.append(self.optimizer_D_ViT)
def data_dependent_initialize(self, data):
"""
The feature network netF is defined in terms of the shape of the intermediate, extracted
features of the encoder portion of netG. Because of this, the weights of netF are
initialized at the first feedforward pass with some input images.
Please also see PatchSampleF.create_mlp(), which is called at the first forward() call.
"""
pass
def optimize_parameters(self):
# forward
self.forward()
# update D
self.set_requires_grad(self.netD_ViT, True)
self.optimizer_D_ViT.zero_grad()
self.loss_D = self.compute_D_loss()
self.loss_D.backward()
self.optimizer_D_ViT.step()
# update G
self.set_requires_grad(self.netD_ViT, False)
self.optimizer_G.zero_grad()
self.loss_G = self.compute_G_loss()
self.loss_G.backward()
self.optimizer_G.step()
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
input (dict): include the data itself and its metadata information.
The option 'direction' can be used to swap domain A and domain B.
"""
AtoB = self.opt.direction == 'AtoB'
self.real_A0 = input['A0' if AtoB else 'B0'].to(self.device)
self.real_A1 = input['A1' if AtoB else 'B1'].to(self.device)
self.real_B0 = input['B0' if AtoB else 'A0'].to(self.device)
self.real_B1 = input['B1' if AtoB else 'A1'].to(self.device)
self.image_paths = input['A_paths' if AtoB else 'B_paths']
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
self.fake_B0 = self.netG(self.real_A0)
self.fake_B1 = self.netG(self.real_A1)
if self.opt.isTrain:
real_A0 = self.real_A0
real_A1 = self.real_A1
real_B0 = self.real_B0
real_B1 = self.real_B1
fake_B0 = self.fake_B0
fake_B1 = self.fake_B1
self.real_A0_resize = self.resize(real_A0)
self.real_A1_resize = self.resize(real_A1)
real_B0 = self.resize(real_B0)
real_B1 = self.resize(real_B1)
self.fake_B0_resize = self.resize(fake_B0)
self.fake_B1_resize = self.resize(fake_B1)
self.mutil_real_A0_tokens = self.netPreViT(self.real_A0_resize, self.atten_layers, get_tokens=True)
self.mutil_real_A1_tokens = self.netPreViT(self.real_A1_resize, self.atten_layers, get_tokens=True)
self.mutil_real_B0_tokens = self.netPreViT(real_B0, self.atten_layers, get_tokens=True)
self.mutil_real_B1_tokens = self.netPreViT(real_B1, self.atten_layers, get_tokens=True)
self.mutil_fake_B0_tokens = self.netPreViT(self.fake_B0_resize, self.atten_layers, get_tokens=True)
self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True)
def tokens_concat(self, origin_tokens, adjacent_size):
adj_size = adjacent_size
B, token_num, C = origin_tokens.shape[0], origin_tokens.shape[1], origin_tokens.shape[2]
S = int(math.sqrt(token_num))
if S * S != token_num:
print('Error! Not a square!')
token_map = origin_tokens.clone().reshape(B,S,S,C)
cut_patch_list = []
for i in range(0, S, adj_size):
for j in range(0, S, adj_size):
i_left = i
i_right = i + adj_size + 1 if i + adj_size <= S else S + 1
j_left = j
j_right = j + adj_size if j + adj_size <= S else S + 1
cut_patch = token_map[:, i_left:i_right, j_left: j_right, :]
cut_patch= cut_patch.reshape(B,-1,C)
cut_patch = torch.mean(cut_patch, dim=1, keepdim=True)
cut_patch_list.append(cut_patch)
result = torch.cat(cut_patch_list,dim=1)
return result
def cat_results(self, origin_tokens, adj_size_list):
res_list = [origin_tokens]
for ad_s in adj_size_list:
cat_result = self.tokens_concat(origin_tokens, ad_s)
res_list.append(cat_result)
result = torch.cat(res_list, dim=1)
return result
def compute_D_loss(self):
"""Calculate GAN loss for the discriminator"""
lambda_D_ViT = self.opt.lambda_D_ViT
fake_B0_tokens = self.mutil_fake_B0_tokens[self.opt.which_D_layer].detach()
fake_B1_tokens = self.mutil_fake_B1_tokens[self.opt.which_D_layer].detach()
real_B0_tokens = self.mutil_real_B0_tokens[self.opt.which_D_layer]
real_B1_tokens = self.mutil_real_B1_tokens[self.opt.which_D_layer]
fake_B0_tokens = self.cat_results(fake_B0_tokens, self.opt.adj_size_list)
fake_B1_tokens = self.cat_results(fake_B1_tokens, self.opt.adj_size_list)
real_B0_tokens = self.cat_results(real_B0_tokens, self.opt.adj_size_list)
real_B1_tokens = self.cat_results(real_B1_tokens, self.opt.adj_size_list)
pre_fake0_ViT = self.netD_ViT(fake_B0_tokens)
pre_fake1_ViT = self.netD_ViT(fake_B1_tokens)
self.loss_D_fake_ViT = (self.criterionGAN(pre_fake0_ViT, False).mean() + self.criterionGAN(pre_fake1_ViT, False).mean()) * 0.5 * lambda_D_ViT
pred_real0_ViT = self.netD_ViT(real_B0_tokens)
pred_real1_ViT = self.netD_ViT(real_B1_tokens)
self.loss_D_real_ViT = (self.criterionGAN(pred_real0_ViT, True).mean() + self.criterionGAN(pred_real1_ViT, True).mean()) * 0.5 * lambda_D_ViT
self.loss_D_ViT = (self.loss_D_fake_ViT + self.loss_D_real_ViT) * 0.5
return self.loss_D_ViT
def compute_G_loss(self):
if self.opt.lambda_GAN > 0.0:
fake_B0_tokens = self.mutil_fake_B0_tokens[self.opt.which_D_layer]
fake_B1_tokens = self.mutil_fake_B1_tokens[self.opt.which_D_layer]
fake_B0_tokens = self.cat_results(fake_B0_tokens, self.opt.adj_size_list)
fake_B1_tokens = self.cat_results(fake_B1_tokens, self.opt.adj_size_list)
pred_fake0_ViT = self.netD_ViT(fake_B0_tokens)
pred_fake1_ViT = self.netD_ViT(fake_B1_tokens)
self.loss_G_GAN_ViT = (self.criterionGAN(pred_fake0_ViT, True) + self.criterionGAN(pred_fake1_ViT, True)) * 0.5 * self.opt.lambda_GAN
else:
self.loss_G_GAN_ViT = 0.0
if self.opt.lambda_global > 0.0 or self.opt.lambda_spatial > 0.0:
self.loss_global, self.loss_spatial = self.calculate_attention_loss()
else:
self.loss_global, self.loss_spatial = 0.0, 0.0
if self.opt.lambda_motion > 0.0:
self.loss_motion = 0.0
for real_A0_tokens, real_A1_tokens, fake_B0_tokens, fake_B1_tokens in zip(self.mutil_real_A0_tokens, self.mutil_real_A1_tokens, self.mutil_fake_B0_tokens, self.mutil_fake_B1_tokens):
A0_B1 = real_A0_tokens.bmm(fake_B1_tokens.permute(0,2,1))
B0_A1 = fake_B0_tokens.bmm(real_A1_tokens.permute(0,2,1))
cos_dis_global = F.cosine_similarity(A0_B1, B0_A1, dim=-1)
self.loss_motion += self.criterionL1(torch.ones_like(cos_dis_global), cos_dis_global).mean()
else:
self.loss_motion = 0.0
self.loss_G = self.loss_G_GAN_ViT + self.loss_global + self.loss_spatial + self.loss_motion
return self.loss_G
def calculate_attention_loss(self):
n_layers = len(self.atten_layers)
mutil_real_A0_tokens = self.mutil_real_A0_tokens
mutil_real_A1_tokens = self.mutil_real_A1_tokens
mutil_fake_B0_tokens = self.mutil_fake_B0_tokens
mutil_fake_B1_tokens = self.mutil_fake_B1_tokens
if self.opt.lambda_global > 0.0:
loss_global = self.calculate_similarity(mutil_real_A0_tokens, mutil_fake_B0_tokens) + self.calculate_similarity(mutil_real_A1_tokens, mutil_fake_B1_tokens)
loss_global *= 0.5
else:
loss_global = 0.0
if self.opt.lambda_spatial > 0.0:
loss_spatial = 0.0
local_nums = self.opt.local_nums
tokens_cnt = 576
local_id = np.random.permutation(tokens_cnt)
local_id = local_id[:int(min(local_nums, tokens_cnt))]
mutil_real_A0_local_tokens = self.netPreViT(self.real_A0_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
mutil_real_A1_local_tokens = self.netPreViT(self.real_A1_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
mutil_fake_B0_local_tokens = self.netPreViT(self.fake_B0_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
mutil_fake_B1_local_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
loss_spatial = self.calculate_similarity(mutil_real_A0_local_tokens, mutil_fake_B0_local_tokens) + self.calculate_similarity(mutil_real_A1_local_tokens, mutil_fake_B1_local_tokens)
loss_spatial *= 0.5
else:
loss_spatial = 0.0
return loss_global * self.opt.lambda_global, loss_spatial * self.opt.lambda_spatial
def calculate_similarity(self, mutil_src_tokens, mutil_tgt_tokens):
loss = 0.0
n_layers = len(self.atten_layers)
for src_tokens, tgt_tokens in zip(mutil_src_tokens, mutil_tgt_tokens):
src_tgt = src_tokens.bmm(tgt_tokens.permute(0,2,1))
tgt_src = tgt_tokens.bmm(src_tokens.permute(0,2,1))
cos_dis_global = F.cosine_similarity(src_tgt, tgt_src, dim=-1)
loss += self.criterionL1(torch.ones_like(cos_dis_global), cos_dis_global).mean()
loss = loss / n_layers
return loss

View File

@ -7,27 +7,29 @@
python train.py \ python train.py \
--dataroot /home/openxs/kunyu/datasets/InfraredCity-Lite/Double/Moitor \ --dataroot /home/openxs/kunyu/datasets/InfraredCity-Lite/Double/Moitor \
--name ROMA_UNSB_001 \ --name UNIV_5 \
--dataset_mode unaligned_double \ --dataset_mode unaligned_double \
--no_flip \ --display_env UNIV \
--display_env ROMA \
--model roma_unsb \ --model roma_unsb \
--lambda_GAN 8.0 \ --lambda_SB 1.0 \
--lambda_NCE 8.0 \ --lambda_ctn 10 \
--lambda_SB 0.1 \
--lambda_ctn 1.0 \
--lambda_inc 1.0 \ --lambda_inc 1.0 \
--lr 0.00001 \ --lambda_global 6.0 \
--gpu_id 0 \ --gamma_stride 20 \
--lr 0.000002 \
--gpu_id 1 \
--nce_idt False \ --nce_idt False \
--nce_layers 0,4,8,12,16 \
--netF mlp_sample \ --netF mlp_sample \
--netF_nc 256 \ --eta_ratio 0.4 \
--nce_T 0.07 \
--lmda_1 0.1 \
--num_patches 256 \
--flip_equivariance False \
--eta_ratio 0.1 \
--tau 0.01 \ --tau 0.01 \
--num_timesteps 10 \ --num_timesteps 5 \
--input_nc 3 --input_nc 3 \
--n_epochs 400 \
--n_epochs_decay 200 \
# exp1 num_timesteps=4 (已停)
# exp2 num_timesteps=5 (已停)
# exp3 --num_timesteps 5,--lambda_inc 8 --gamma_stride 20,--lambda_global 6.0,--lambda_ctn 10, --lr 0.000002 (已停)
# exp4 --num_timesteps 5,--lambda_inc 8 --gamma_stride 20,--lambda_global 6.0,--lambda_ctn 10, --lr 0.000002, ET_XY=self.netE(XtXt_1, self.time, XtXt_1).mean() - torch.logsumexp(self.netE(XtXt_1, self.time_idx, XtXt_2).reshape(-1), dim=0) ,并把GAN,CTN loss考虑到了A1和B1 (已停)
# exp5 基于 exp4 ,修改了 self.loss_global = self.calculate_similarity(self.mutil_real_A0_tokens, self.mutil_fake_B0_tokens) + self.calculate_similarity(mutil_real_A1_tokens, self.mutil_fake_B1_tokens) ,gpu_id 1 (已停)
# 上面几个实验效果都不好实验结果都已经删除了开的新的train_sbiv 对代码进行了调整,效果变得更好了。

32
scripts/train_sbiv.sh Executable file
View File

@ -0,0 +1,32 @@
#!/bin/sh
# Train for video mode
#CUDA_VISIBLE_DEVICES=0 python train.py --dataroot /path --name ROMA_name --dataset_mode unaligned_double --no_flip --local_nums 64 --display_env ROMA_env --model roma --side_length 7 --lambda_spatial 5.0 --lambda_global 5.0 --lambda_motion 1.0 --atten_layers 1,3,5 --lr 0.00001
# Train for image mode
#CUDA_VISIBLE_DEVICES=0 python train.py --dataroot /path --name ROMA_name --dataset_mode unaligned --local_nums 64 --display_env ROMA_env --model roma --side_length 7 --lambda_spatial 5.0 --lambda_global 5.0 --atten_layers 1,3,5 --lr 0.00001
python train.py \
--dataroot /home/openxs/kunyu/datasets/InfraredCity-Lite/Double/Moitor \
--name SBIV_1 \
--dataset_mode unaligned_double \
--display_env SBIV2 \
--model roma_unsb \
--lambda_ctn 10 \
--lambda_inc 1.0 \
--lambda_global 8.0 \
--lambda_spatial 8.0 \
--gamma_stride 20 \
--lr 0.000001 \
--gpu_id 0 \
--eta_ratio 0.3 \
--tau 0.01 \
--num_timesteps 3 \
--input_nc 3 \
--n_epochs 400 \
--n_epochs_decay 200 \
# exp6 num_timesteps=4 gpu_id 0基于 exp5 ,exp1 已停) (已停)
# exp7 num_timesteps=3 gpu_id 0 基于 exp6 (已停)
# # exp8 num_timesteps=4 gpu_id 1 ,修改了训练判别器的loss以及ctnloss基于exp6
# # exp9 num_timesteps=3 gpu_id 2 ,(基于 exp8
# # # exp10 num_timesteps=4 gpu_id 0 , --name SBIV_1 ,让判别器看到了每一个时间步的输出修改了训练判别器的loss以及ctnloss基于exp9

View File

@ -44,6 +44,7 @@ if __name__ == '__main__':
model.setup(opt) # regular setup: load and print networks; create schedulers model.setup(opt) # regular setup: load and print networks; create schedulers
model.parallelize() model.parallelize()
model.set_input(data) # unpack data from dataset and apply preprocessing model.set_input(data) # unpack data from dataset and apply preprocessing
#print('Call opt paras')
model.optimize_parameters() # calculate loss functions, get gradients, update network weights model.optimize_parameters() # calculate loss functions, get gradients, update network weights
if len(opt.gpu_ids) > 0: if len(opt.gpu_ids) > 0:
torch.cuda.synchronize() torch.cuda.synchronize()