Compare commits
4 Commits
c6cb68e700
...
e9c0f5ffcb
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e9c0f5ffcb | ||
|
|
e0dc08030c | ||
|
|
997fdd3770 | ||
|
|
14ba81514f |
Binary file not shown.
Binary file not shown.
@ -1401,23 +1401,31 @@ class UnetSkipConnectionBlock(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class MLPDiscriminator(nn.Module):
|
class MLPDiscriminator(nn.Module):
|
||||||
def __init__(self, in_feat=768, hid_feat = 768, out_feat = 768, dropout = 0.):
|
def __init__(self, in_feat=768, hid_feat=512, out_feat=768, num_heads=1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if not hid_feat:
|
# 自注意力层,加入Dropout
|
||||||
hid_feat = in_feat
|
self.attention = nn.MultiheadAttention(embed_dim=in_feat, num_heads=num_heads, dropout=0.1)
|
||||||
if not out_feat:
|
# 加深加宽的MLP,加入Dropout
|
||||||
out_feat = in_feat
|
self.mlp = nn.Sequential(
|
||||||
self.linear1 = nn.Linear(in_feat, hid_feat)
|
nn.Linear(in_feat, hid_feat), # 768 -> 512
|
||||||
self.activation = nn.GELU()
|
nn.ReLU(),
|
||||||
self.linear2 = nn.Linear(hid_feat, out_feat)
|
nn.Dropout(0.3),
|
||||||
self.dropout = nn.Dropout(dropout)
|
nn.Linear(hid_feat, hid_feat * 2), # 512 -> 1024
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout(0.3),
|
||||||
|
nn.Linear(hid_feat * 2, hid_feat), # 1024 -> 512
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout(0.3),
|
||||||
|
nn.Linear(hid_feat, out_feat), # 512 -> 768
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout(0.3),
|
||||||
|
nn.Linear(out_feat, 1) # 768 -> 1
|
||||||
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.linear1(x)
|
attn_output, attn_weights = self.attention(x, x, x) # [B, N, D], [B, N, N]
|
||||||
x = self.activation(x)
|
attn_weights = attn_weights.mean(dim=1) # [B, N]
|
||||||
x = self.dropout(x)
|
pred = self.mlp(attn_output.mean(dim=1)) # [B, 1]
|
||||||
x = self.linear2(x)
|
return pred, attn_weights
|
||||||
return self.dropout(x)
|
|
||||||
|
|
||||||
|
|
||||||
class NLayerDiscriminator(nn.Module):
|
class NLayerDiscriminator(nn.Module):
|
||||||
|
|||||||
@ -67,62 +67,39 @@ class ContentAwareOptimization(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.lambda_inc = lambda_inc
|
self.lambda_inc = lambda_inc
|
||||||
self.eta_ratio = eta_ratio
|
self.eta_ratio = eta_ratio
|
||||||
self.gradients_real = []
|
|
||||||
self.gradients_fake = []
|
|
||||||
|
|
||||||
|
|
||||||
def compute_cosine_similarity(self, gradients):
|
|
||||||
mean_grad = torch.mean(gradients, dim=1, keepdim=True)
|
|
||||||
return F.cosine_similarity(gradients, mean_grad, dim=2)
|
|
||||||
|
|
||||||
def generate_weight_map(self, gradients_real, gradients_fake):
|
|
||||||
# 计算余弦相似度
|
|
||||||
cosine_real = self.compute_cosine_similarity(gradients_real)
|
|
||||||
cosine_fake = self.compute_cosine_similarity(gradients_fake)
|
|
||||||
|
|
||||||
# 生成权重图(优化实现)
|
|
||||||
def _get_weights(cosine):
|
|
||||||
k = int(self.eta_ratio * cosine.shape[1])
|
|
||||||
_, indices = torch.topk(-cosine, k, dim=1)
|
|
||||||
weights = torch.ones_like(cosine)
|
|
||||||
weights.scatter_(1, indices, self.lambda_inc / (1e-6 + torch.abs(cosine.gather(1, indices))))
|
|
||||||
return weights
|
|
||||||
|
|
||||||
weight_real = _get_weights(cosine_real)
|
|
||||||
weight_fake = _get_weights(cosine_fake)
|
|
||||||
return weight_real, weight_fake
|
|
||||||
|
|
||||||
def forward(self, D_real, D_fake, real_scores, fake_scores):
|
|
||||||
# 清空梯度缓存
|
|
||||||
self.gradients_real.clear()
|
|
||||||
self.gradients_fake.clear()
|
|
||||||
self.criterionGAN=networks.GANLoss('lsgan').cuda()
|
self.criterionGAN=networks.GANLoss('lsgan').cuda()
|
||||||
# 注册钩子捕获梯度
|
|
||||||
hook_real = lambda grad: self.gradients_real.append(grad.detach())
|
|
||||||
hook_fake = lambda grad: self.gradients_fake.append(grad.detach())
|
|
||||||
D_real.register_hook(hook_real)
|
|
||||||
D_fake.register_hook(hook_fake)
|
|
||||||
|
|
||||||
# 触发梯度计算(保留计算图)
|
def generate_weight_map(self, attn_real, attn_fake):
|
||||||
(real_scores.mean() + fake_scores.mean()).backward(retain_graph=True)
|
# attn_real, attn_fake: [B, N],自注意力权重
|
||||||
|
# 归一化注意力权重
|
||||||
|
weight_real = F.normalize(attn_real, p=1, dim=1) # [B, N]
|
||||||
|
weight_fake = F.normalize(attn_fake, p=1, dim=1) # [B, N]
|
||||||
|
|
||||||
|
# 对真实图像权重处理
|
||||||
|
k = int(self.eta_ratio * weight_real.shape[1])
|
||||||
|
values_real, indices_real = torch.topk(weight_real, k, dim=1)
|
||||||
|
weight_real_enhanced = torch.ones_like(weight_real)
|
||||||
|
weight_real_enhanced.scatter_(1, indices_real, self.lambda_inc / (values_real + 1e-6))
|
||||||
|
# 对生成图像权重处理
|
||||||
|
values_fake, indices_fake = torch.topk(weight_fake, k, dim=1)
|
||||||
|
weight_fake_enhanced = torch.ones_like(weight_fake)
|
||||||
|
weight_fake_enhanced.scatter_(1, indices_fake, self.lambda_inc / (values_fake + 1e-6))
|
||||||
|
|
||||||
|
return weight_real_enhanced, weight_fake_enhanced
|
||||||
|
|
||||||
# 获取梯度并调整维度
|
def forward(self,real_scores, fake_scores, attn_real, attn_fake):
|
||||||
grad_real = self.gradients_real[0].flatten(1) # [B, N, D] → [B, N*D]
|
# real_scores, fake_scores: 判别器预测得分 [B, 1]
|
||||||
grad_fake = self.gradients_fake[0].flatten(1)
|
# attn_real, attn_fake: 自注意力权重 [B, N]
|
||||||
|
|
||||||
# 生成权重图
|
# 生成权重图
|
||||||
weight_real, weight_fake = self.generate_weight_map(
|
weight_real, weight_fake = self.generate_weight_map(attn_real, attn_fake)
|
||||||
grad_real.view(*D_real.shape),
|
|
||||||
grad_fake.view(*D_fake.shape)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 正确应用权重到对数概率(论文公式7)
|
# 应用权重到 GAN 损失
|
||||||
loss_co_real = torch.mean(weight_real * self.criterionGAN(real_scores , True))
|
loss_co_real = torch.mean(weight_real * self.criterionGAN(real_scores, True))
|
||||||
loss_co_fake = torch.mean(weight_fake * self.criterionGAN(fake_scores , False))
|
loss_co_fake = torch.mean(weight_fake * self.criterionGAN(fake_scores, False))
|
||||||
|
|
||||||
# 总损失(注意符号:判别器需最大化该损失)
|
|
||||||
loss_co_adv = (loss_co_real + loss_co_fake)*0.5
|
|
||||||
|
|
||||||
|
# 总损失
|
||||||
|
loss_co_adv = (loss_co_real + loss_co_fake) * 0.5
|
||||||
return loss_co_adv, weight_real, weight_fake
|
return loss_co_adv, weight_real, weight_fake
|
||||||
|
|
||||||
class ContentAwareTemporalNorm(nn.Module):
|
class ContentAwareTemporalNorm(nn.Module):
|
||||||
@ -132,18 +109,19 @@ class ContentAwareTemporalNorm(nn.Module):
|
|||||||
self.smoother = GaussianBlur(kernel_size, sigma=sigma) # 高斯平滑层
|
self.smoother = GaussianBlur(kernel_size, sigma=sigma) # 高斯平滑层
|
||||||
|
|
||||||
def upsample_weight_map(self, weight_patch, target_size=(256, 256)):
|
def upsample_weight_map(self, weight_patch, target_size=(256, 256)):
|
||||||
"""
|
# 如果 weight_patch 是 [N, 1] 形状(例如 [576, 1]),添加批次维度
|
||||||
将patch级别的权重图上采样到目标分辨率
|
if weight_patch.dim() == 2 and weight_patch.shape[1] == 1:
|
||||||
Args:
|
weight_patch = weight_patch.unsqueeze(0) # 变为 [1, 576, 1]
|
||||||
weight_patch: [B, 1, 24, 24] 来自ViT的patch权重图
|
|
||||||
target_size: 目标分辨率 (H, W)
|
|
||||||
Returns:
|
|
||||||
weight_full: [B, 1, 256, 256] 上采样后的全分辨率权重图
|
|
||||||
"""
|
|
||||||
# 使用双线性插值上采样
|
|
||||||
B = weight_patch.shape[0]
|
|
||||||
weight_patch = weight_patch.view(B, 1, 24, 24)
|
|
||||||
|
|
||||||
|
# 获取调整后的形状
|
||||||
|
B, N, _ = weight_patch.shape # 例如 B=1, N=576
|
||||||
|
if N != 576:
|
||||||
|
raise ValueError(f"预期 patch 数量 N=576 (24x24),但实际得到 N={N}")
|
||||||
|
|
||||||
|
# 重塑为 [B, 1, 24, 24]
|
||||||
|
weight_patch = weight_patch.view(B, 1, 24, 24) # [1, 1, 24, 24]
|
||||||
|
|
||||||
|
# 使用双线性插值上采样到目标大小
|
||||||
weight_full = F.interpolate(
|
weight_full = F.interpolate(
|
||||||
weight_patch,
|
weight_patch,
|
||||||
size=target_size,
|
size=target_size,
|
||||||
@ -151,8 +129,7 @@ class ContentAwareTemporalNorm(nn.Module):
|
|||||||
align_corners=False
|
align_corners=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# 对每个16x16的patch内部保持权重一致(可选)
|
# 可选:保持每个 16x16 patch 内部权重一致
|
||||||
# 通过平均池化再扩展,消除插值引入的渐变
|
|
||||||
weight_full = F.avg_pool2d(weight_full, kernel_size=16, stride=16)
|
weight_full = F.avg_pool2d(weight_full, kernel_size=16, stride=16)
|
||||||
weight_full = F.interpolate(weight_full, scale_factor=16, mode='nearest')
|
weight_full = F.interpolate(weight_full, scale_factor=16, mode='nearest')
|
||||||
|
|
||||||
@ -167,6 +144,7 @@ class ContentAwareTemporalNorm(nn.Module):
|
|||||||
F_content: [B, 2, H, W] 生成的光流场(x/y方向位移)
|
F_content: [B, 2, H, W] 生成的光流场(x/y方向位移)
|
||||||
"""
|
"""
|
||||||
# 上采样权重图到全分辨率
|
# 上采样权重图到全分辨率
|
||||||
|
|
||||||
weight_full = self.upsample_weight_map(weight_map) # [B,1,384,384]
|
weight_full = self.upsample_weight_map(weight_map) # [B,1,384,384]
|
||||||
|
|
||||||
# 1. 归一化权重图
|
# 1. 归一化权重图
|
||||||
@ -198,7 +176,7 @@ class RomaUnsbModel(BaseModel):
|
|||||||
"""配置 CTNx 模型的特定选项"""
|
"""配置 CTNx 模型的特定选项"""
|
||||||
|
|
||||||
parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss: GAN(G(X))')
|
parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss: GAN(G(X))')
|
||||||
parser.add_argument('--lambda_SB', type=float, default=0.1, help='weight for SB loss')
|
|
||||||
parser.add_argument('--lambda_ctn', type=float, default=1.0, help='weight for content-aware temporal norm')
|
parser.add_argument('--lambda_ctn', type=float, default=1.0, help='weight for content-aware temporal norm')
|
||||||
parser.add_argument('--lambda_D_ViT', type=float, default=1.0, help='weight for discriminator')
|
parser.add_argument('--lambda_D_ViT', type=float, default=1.0, help='weight for discriminator')
|
||||||
parser.add_argument('--lambda_global', type=float, default=1.0, help='weight for Global Structural Consistency')
|
parser.add_argument('--lambda_global', type=float, default=1.0, help='weight for Global Structural Consistency')
|
||||||
@ -206,14 +184,8 @@ class RomaUnsbModel(BaseModel):
|
|||||||
parser.add_argument('--lambda_inc', type=float, default=1.0, help='incremental weight for content-aware optimization')
|
parser.add_argument('--lambda_inc', type=float, default=1.0, help='incremental weight for content-aware optimization')
|
||||||
parser.add_argument('--local_nums', type=int, default=64, help='number of local patches')
|
parser.add_argument('--local_nums', type=int, default=64, help='number of local patches')
|
||||||
parser.add_argument('--side_length', type=int, default=7)
|
parser.add_argument('--side_length', type=int, default=7)
|
||||||
parser.add_argument('--nce_idt', type=util.str2bool, nargs='?', const=True, default=False, help='use NCE loss for identity mapping: NCE(G(Y), Y))')
|
|
||||||
parser.add_argument('--nce_includes_all_negatives_from_minibatch',
|
|
||||||
type=util.str2bool, nargs='?', const=True, default=False,
|
|
||||||
help='(used for single image translation) If True, include the negatives from the other samples of the minibatch when computing the contrastive loss. Please see models/patchnce.py for more details.')
|
|
||||||
parser.add_argument('--nce_layers', type=str, default='0,4,8,12,16', help='compute NCE loss on which layers')
|
parser.add_argument('--nce_layers', type=str, default='0,4,8,12,16', help='compute NCE loss on which layers')
|
||||||
|
|
||||||
parser.add_argument('--netF', type=str, default='mlp_sample', choices=['sample', 'reshape', 'mlp_sample'], help='how to downsample the feature map')
|
|
||||||
|
|
||||||
parser.add_argument('--eta_ratio', type=float, default=0.4, help='ratio of content-rich regions')
|
parser.add_argument('--eta_ratio', type=float, default=0.4, help='ratio of content-rich regions')
|
||||||
parser.add_argument('--gamma_stride', type=float, default=20, help='ratio of stride for computing the similarity matrix')
|
parser.add_argument('--gamma_stride', type=float, default=20, help='ratio of stride for computing the similarity matrix')
|
||||||
parser.add_argument('--atten_layers', type=str, default='5', help='compute Cross-Similarity on which layers')
|
parser.add_argument('--atten_layers', type=str, default='5', help='compute Cross-Similarity on which layers')
|
||||||
@ -233,7 +205,7 @@ class RomaUnsbModel(BaseModel):
|
|||||||
|
|
||||||
# 指定需要打印的训练损失
|
# 指定需要打印的训练损失
|
||||||
self.loss_names = ['G_GAN', 'D_ViT', 'G', 'global', 'spatial','ctn']
|
self.loss_names = ['G_GAN', 'D_ViT', 'G', 'global', 'spatial','ctn']
|
||||||
self.visual_names = ['real_A0', 'fake_B0_1','fake_B0', 'real_B0','real_A1', 'fake_B1_1', 'fake_B1', 'real_B1']
|
self.visual_names = ['real_A0', 'fake_B0', 'real_B0','real_A1', 'fake_B1', 'real_B1']
|
||||||
self.atten_layers = [int(i) for i in self.opt.atten_layers.split(',')]
|
self.atten_layers = [int(i) for i in self.opt.atten_layers.split(',')]
|
||||||
|
|
||||||
|
|
||||||
@ -253,7 +225,7 @@ class RomaUnsbModel(BaseModel):
|
|||||||
|
|
||||||
# 创建网络
|
# 创建网络
|
||||||
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt)
|
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt)
|
||||||
|
|
||||||
|
|
||||||
if self.isTrain:
|
if self.isTrain:
|
||||||
|
|
||||||
@ -321,120 +293,52 @@ class RomaUnsbModel(BaseModel):
|
|||||||
|
|
||||||
def forward(self):
|
def forward(self):
|
||||||
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
|
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
|
||||||
|
self.fake_B0 = self.netG(self.real_A0)
|
||||||
|
self.fake_B1 = self.netG(self.real_A1)
|
||||||
|
|
||||||
# ============ 第一步:对 real_A / real_A2 进行多步随机生成过程 ============
|
if self.opt.isTrain:
|
||||||
tau = self.opt.tau
|
|
||||||
T = self.opt.num_timesteps
|
|
||||||
incs = np.array([0] + [1/(i+1) for i in range(T-1)])
|
|
||||||
times = np.cumsum(incs)
|
|
||||||
times = times / times[-1]
|
|
||||||
times = 0.5 * times[-1] + 0.5 * times #[0.5,1]
|
|
||||||
times = np.concatenate([np.zeros(1), times])
|
|
||||||
times = torch.tensor(times).float().cuda()
|
|
||||||
self.times = times
|
|
||||||
bs = self.real_A0.size(0)
|
|
||||||
time_idx = (torch.randint(T, size=[1]).cuda() * torch.ones(size=[1]).cuda()).long()
|
|
||||||
self.time_idx = time_idx
|
|
||||||
self.fake_B0_list = []
|
|
||||||
self.fake_B1_list = []
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
self.netG.eval()
|
|
||||||
# ============ 第二步:对 real_A / real_A2 进行多步随机生成过程 ============
|
|
||||||
for t in range(self.time_idx.int().item() + 1):
|
|
||||||
# 计算增量 delta 与 inter/scale,用于每个时间步的插值等
|
|
||||||
if t > 0:
|
|
||||||
delta = times[t] - times[t - 1]
|
|
||||||
denom = times[-1] - times[t - 1]
|
|
||||||
inter = (delta / denom).reshape(-1, 1, 1, 1)
|
|
||||||
scale = (delta * (1 - delta / denom)).reshape(-1, 1, 1, 1)
|
|
||||||
|
|
||||||
# 对 Xt、Xt2 进行随机噪声更新
|
|
||||||
Xt = self.real_A0 if (t == 0) else (1 - inter) * Xt + inter * Xt_1.detach() + \
|
|
||||||
(scale * tau).sqrt() * torch.randn_like(Xt).to(self.real_A0.device)
|
|
||||||
time_idx = (t * torch.ones(size=[self.real_A0.shape[0]]).to(self.real_A0.device)).long()
|
|
||||||
z = torch.randn(size=[self.real_A0.shape[0], 4 * self.opt.ngf]).to(self.real_A0.device)
|
|
||||||
time = times[time_idx]
|
|
||||||
Xt_1 = self.netG(Xt.detach(), time, z)
|
|
||||||
|
|
||||||
Xt2 = self.real_A1 if (t == 0) else (1 - inter) * Xt2 + inter * Xt_12.detach() + \
|
|
||||||
(scale * tau).sqrt() * torch.randn_like(Xt2).to(self.real_A1.device)
|
|
||||||
time_idx = (t * torch.ones(size=[self.real_A1.shape[0]]).to(self.real_A1.device)).long()
|
|
||||||
z = torch.randn(size=[self.real_A1.shape[0], 4 * self.opt.ngf]).to(self.real_A1.device)
|
|
||||||
Xt_12 = self.netG(Xt2.detach(), time, z)
|
|
||||||
self.fake_B0_list.append(Xt_1)
|
|
||||||
self.fake_B1_list.append(Xt_12)
|
|
||||||
|
|
||||||
self.fake_B0_1 = self.fake_B0_list[0]
|
|
||||||
self.fake_B1_1 = self.fake_B0_list[0]
|
|
||||||
self.fake_B0 = self.fake_B0_list[-1]
|
|
||||||
self.fake_B1 = self.fake_B1_list[-1]
|
|
||||||
self.z_in = z
|
|
||||||
self.z_in2 = z
|
|
||||||
if self.opt.phase == 'train':
|
|
||||||
real_A0 = self.real_A0
|
real_A0 = self.real_A0
|
||||||
real_A1 = self.real_A1
|
real_A1 = self.real_A1
|
||||||
real_B0 = self.real_B0
|
real_B0 = self.real_B0
|
||||||
real_B1 = self.real_B1
|
real_B1 = self.real_B1
|
||||||
fake_B0 = self.fake_B0
|
fake_B0 = self.fake_B0
|
||||||
fake_B1 = self.fake_B1
|
fake_B1 = self.fake_B1
|
||||||
self.mutil_fake_B0_tokens_list = []
|
|
||||||
self.mutil_fake_B1_tokens_list = []
|
|
||||||
for fake_B0_t in self.fake_B0_list:
|
|
||||||
fake_B0_t_resize = self.resize(fake_B0_t) # 调整到 ViT 输入尺寸
|
|
||||||
tokens = self.netPreViT(fake_B0_t_resize, self.atten_layers, get_tokens=True)
|
|
||||||
self.mutil_fake_B0_tokens_list.append(tokens)
|
|
||||||
for fake_B1_t in self.fake_B1_list:
|
|
||||||
fake_B1_t_resize = self.resize(fake_B1_t)
|
|
||||||
tokens = self.netPreViT(fake_B1_t_resize, self.atten_layers, get_tokens=True)
|
|
||||||
self.mutil_fake_B1_tokens_list.append(tokens)
|
|
||||||
|
|
||||||
self.real_A0_resize = self.resize(real_A0)
|
self.real_A0_resize = self.resize(real_A0)
|
||||||
self.real_A1_resize = self.resize(real_A1)
|
self.real_A1_resize = self.resize(real_A1)
|
||||||
real_B0 = self.resize(real_B0)
|
real_B0 = self.resize(real_B0)
|
||||||
real_B1 = self.resize(real_B1)
|
real_B1 = self.resize(real_B1)
|
||||||
self.fake_B0_resize = self.resize(fake_B0)
|
self.fake_B0_resize = self.resize(fake_B0)
|
||||||
self.fake_B1_resize = self.resize(fake_B1)
|
self.fake_B1_resize = self.resize(fake_B1)
|
||||||
|
|
||||||
self.mutil_real_A0_tokens = self.netPreViT(self.real_A0_resize, self.atten_layers, get_tokens=True)
|
self.mutil_real_A0_tokens = self.netPreViT(self.real_A0_resize, self.atten_layers, get_tokens=True)
|
||||||
self.mutil_real_A1_tokens = self.netPreViT(self.real_A1_resize, self.atten_layers, get_tokens=True)
|
self.mutil_real_A1_tokens = self.netPreViT(self.real_A1_resize, self.atten_layers, get_tokens=True)
|
||||||
self.mutil_real_B0_tokens = self.netPreViT(real_B0, self.atten_layers, get_tokens=True)
|
self.mutil_real_B0_tokens = self.netPreViT(real_B0, self.atten_layers, get_tokens=True)
|
||||||
self.mutil_real_B1_tokens = self.netPreViT(real_B1, self.atten_layers, get_tokens=True)
|
self.mutil_real_B1_tokens = self.netPreViT(real_B1, self.atten_layers, get_tokens=True)
|
||||||
# [[1,576,768],[1,576,768],[1,576,768]]
|
self.mutil_fake_B0_tokens = self.netPreViT(self.fake_B0_resize, self.atten_layers, get_tokens=True)
|
||||||
# [3,576,768]
|
self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True)
|
||||||
|
|
||||||
|
|
||||||
def compute_D_loss(self):
|
def compute_D_loss(self):
|
||||||
"""Calculate GAN loss with Content-Aware Optimization"""
|
"""Calculate GAN loss with Content-Aware Optimization"""
|
||||||
lambda_D_ViT = self.opt.lambda_D_ViT
|
lambda_D_ViT = self.opt.lambda_D_ViT
|
||||||
|
|
||||||
loss_cao = 0.0
|
pred_real0, attn_real0 = self.netD_ViT(self.mutil_real_B0_tokens[0]) # scores, features
|
||||||
real_B0_tokens = self.mutil_real_B0_tokens[0]
|
pred_real1, attn_real1 = self.netD_ViT(self.mutil_real_B1_tokens[0]) # scores, features
|
||||||
pred_real0, real_features0 = self.netD_ViT(real_B0_tokens) # scores, features
|
|
||||||
real_B1_tokens = self.mutil_real_B1_tokens[0]
|
|
||||||
pred_real1, real_features1 = self.netD_ViT(real_B1_tokens) # scores, features
|
|
||||||
|
|
||||||
for fake0_token, fake1_token in zip(self.mutil_fake_B0_tokens_list, self.mutil_fake_B1_tokens_list):
|
pred_fake0, attn_fake0 = self.netD_ViT(self.mutil_fake_B0_tokens[0].detach())
|
||||||
pre_fake0, fake_features0 = self.netD_ViT(fake0_token[0].detach())
|
pred_fake1, attn_fake1 = self.netD_ViT(self.mutil_fake_B1_tokens[0].detach())
|
||||||
pre_fake1, fake_features1 = self.netD_ViT(fake1_token[0].detach())
|
loss_cao0, self.weight_real0, self.weight_fake0 = self.cao(
|
||||||
loss_cao0, self.weight_real0, self.weight_fake0 = self.cao(
|
real_scores=pred_real0,
|
||||||
D_real=real_features0,
|
fake_scores=pred_fake0,
|
||||||
D_fake=fake_features0,
|
attn_real=attn_real0,
|
||||||
real_scores=pred_real0,
|
attn_fake=attn_fake0
|
||||||
fake_scores=pre_fake0
|
)
|
||||||
)
|
loss_cao1, self.weight_real1, self.weight_fake1 = self.cao(
|
||||||
loss_cao1, self.weight_real1, self.weight_fake1 = self.cao(
|
real_scores=pred_real1,
|
||||||
D_real=real_features1,
|
fake_scores=pred_fake1,
|
||||||
D_fake=fake_features1,
|
attn_real=attn_real1,
|
||||||
real_scores=pred_real1,
|
attn_fake=attn_fake1
|
||||||
fake_scores=pre_fake1
|
)
|
||||||
)
|
|
||||||
loss_cao += loss_cao0 + loss_cao1
|
|
||||||
|
|
||||||
|
self.loss_D_ViT = (loss_cao0 + loss_cao1) * 0.5 * lambda_D_ViT
|
||||||
# ===== 综合损失 =====
|
|
||||||
total_steps = len(self.fake_B0_list)
|
|
||||||
self.loss_D_ViT = loss_cao * 0.5 * lambda_D_ViT/ total_steps
|
|
||||||
|
|
||||||
|
|
||||||
# 记录损失值供可视化
|
# 记录损失值供可视化
|
||||||
@ -448,8 +352,8 @@ class RomaUnsbModel(BaseModel):
|
|||||||
"""计算生成器的 GAN 损失"""
|
"""计算生成器的 GAN 损失"""
|
||||||
if self.opt.lambda_ctn > 0.0:
|
if self.opt.lambda_ctn > 0.0:
|
||||||
# 生成图像的CTN光流图
|
# 生成图像的CTN光流图
|
||||||
self.f_content0 = self.ctn(self.weight_fake0)
|
self.f_content0 = self.ctn(self.weight_fake0.detach())
|
||||||
self.f_content1 = self.ctn(self.weight_fake1)
|
self.f_content1 = self.ctn(self.weight_fake1.detach())
|
||||||
|
|
||||||
# 变换后的图片
|
# 变换后的图片
|
||||||
self.warped_real_A0 = warp(self.real_A0, self.f_content0)
|
self.warped_real_A0 = warp(self.real_A0, self.f_content0)
|
||||||
@ -458,8 +362,8 @@ class RomaUnsbModel(BaseModel):
|
|||||||
self.warped_fake_B1 = warp(self.fake_B1,self.f_content1)
|
self.warped_fake_B1 = warp(self.fake_B1,self.f_content1)
|
||||||
|
|
||||||
# 经过第二次生成器
|
# 经过第二次生成器
|
||||||
self.warped_fake_B0_2 = self.netG(self.warped_real_A0, self.times[torch.zeros(size=[1]).cuda().long()], self.z_in)
|
self.warped_fake_B0_2 = self.netG(self.warped_real_A0)
|
||||||
self.warped_fake_B1_2 = self.netG(self.warped_real_A1, self.times[torch.zeros(size=[1]).cuda().long()], self.z_in2)
|
self.warped_fake_B1_2 = self.netG(self.warped_real_A1)
|
||||||
|
|
||||||
warped_fake_B0_2=self.warped_fake_B0_2
|
warped_fake_B0_2=self.warped_fake_B0_2
|
||||||
warped_fake_B1_2=self.warped_fake_B1_2
|
warped_fake_B1_2=self.warped_fake_B1_2
|
||||||
@ -472,8 +376,8 @@ class RomaUnsbModel(BaseModel):
|
|||||||
|
|
||||||
if self.opt.lambda_GAN > 0.0:
|
if self.opt.lambda_GAN > 0.0:
|
||||||
|
|
||||||
pred_fake0,_ = self.netD_ViT(self.mutil_fake_B0_tokens_list[-1][0])
|
pred_fake0,_ = self.netD_ViT(self.mutil_fake_B0_tokens[0])
|
||||||
pred_fake1,_ = self.netD_ViT(self.mutil_fake_B1_tokens_list[-1][0])
|
pred_fake1,_ = self.netD_ViT(self.mutil_fake_B1_tokens[0])
|
||||||
self.loss_G_GAN0 = self.criterionGAN(pred_fake0, True).mean()
|
self.loss_G_GAN0 = self.criterionGAN(pred_fake0, True).mean()
|
||||||
self.loss_G_GAN1 = self.criterionGAN(pred_fake1, True).mean()
|
self.loss_G_GAN1 = self.criterionGAN(pred_fake1, True).mean()
|
||||||
self.loss_G_GAN = (self.loss_G_GAN0 + self.loss_G_GAN1)*0.5
|
self.loss_G_GAN = (self.loss_G_GAN0 + self.loss_G_GAN1)*0.5
|
||||||
@ -497,8 +401,8 @@ class RomaUnsbModel(BaseModel):
|
|||||||
n_layers = len(self.atten_layers)
|
n_layers = len(self.atten_layers)
|
||||||
mutil_real_A0_tokens = self.mutil_real_A0_tokens
|
mutil_real_A0_tokens = self.mutil_real_A0_tokens
|
||||||
mutil_real_A1_tokens = self.mutil_real_A1_tokens
|
mutil_real_A1_tokens = self.mutil_real_A1_tokens
|
||||||
mutil_fake_B0_tokens = self.mutil_fake_B0_tokens_list[-1]
|
mutil_fake_B0_tokens = self.mutil_fake_B0_tokens
|
||||||
mutil_fake_B1_tokens = self.mutil_fake_B1_tokens_list[-1]
|
mutil_fake_B1_tokens = self.mutil_fake_B1_tokens
|
||||||
|
|
||||||
|
|
||||||
if self.opt.lambda_global > 0.0:
|
if self.opt.lambda_global > 0.0:
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
@ -36,7 +36,7 @@ class BaseOptions():
|
|||||||
parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')
|
parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')
|
||||||
parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')
|
parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')
|
||||||
parser.add_argument('--netD', type=str, default='basic_cond', choices=['basic_cond', 'basic', 'n_layers', 'pixel', 'patch', 'tilestylegan2', 'stylegan2'], help='specify discriminator architecture. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')
|
parser.add_argument('--netD', type=str, default='basic_cond', choices=['basic_cond', 'basic', 'n_layers', 'pixel', 'patch', 'tilestylegan2', 'stylegan2'], help='specify discriminator architecture. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')
|
||||||
parser.add_argument('--netG', type=str, default='resnet_9blocks_cond', choices=['resnet_9blocks','resnet_9blocks_mask', 'resnet_6blocks', 'unet_256', 'unet_128', 'stylegan2', 'smallstylegan2', 'resnet_cat', 'resnet_9blocks_cond'], help='specify generator architecture')
|
parser.add_argument('--netG', type=str, default='resnet_9blocks', choices=['resnet_9blocks','resnet_9blocks_mask', 'resnet_6blocks', 'unet_256', 'unet_128', 'stylegan2', 'smallstylegan2', 'resnet_cat', 'resnet_9blocks_cond'], help='specify generator architecture')
|
||||||
parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')
|
parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')
|
||||||
parser.add_argument('--normG', type=str, default='instance', choices=['instance', 'batch', 'none'], help='instance normalization or batch normalization for G')
|
parser.add_argument('--normG', type=str, default='instance', choices=['instance', 'batch', 'none'], help='instance normalization or batch normalization for G')
|
||||||
parser.add_argument('--normD', type=str, default='instance', choices=['instance', 'batch', 'none'], help='instance normalization or batch normalization for D')
|
parser.add_argument('--normD', type=str, default='instance', choices=['instance', 'batch', 'none'], help='instance normalization or batch normalization for D')
|
||||||
|
|||||||
@ -17,7 +17,7 @@ python train.py \
|
|||||||
--lambda_global 6.0 \
|
--lambda_global 6.0 \
|
||||||
--gamma_stride 20 \
|
--gamma_stride 20 \
|
||||||
--lr 0.000002 \
|
--lr 0.000002 \
|
||||||
--gpu_id 1 \
|
--gpu_id 0 \
|
||||||
--nce_idt False \
|
--nce_idt False \
|
||||||
--netF mlp_sample \
|
--netF mlp_sample \
|
||||||
--eta_ratio 0.4 \
|
--eta_ratio 0.4 \
|
||||||
|
|||||||
20
scripts/traincp.sh
Normal file
20
scripts/traincp.sh
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
python train.py \
|
||||||
|
--dataroot /home/openxs/kunyu/datasets/InfraredCity-Lite/Double/Moitor \
|
||||||
|
--name cp_3 \
|
||||||
|
--dataset_mode unaligned_double \
|
||||||
|
--display_env CP \
|
||||||
|
--model roma_unsb \
|
||||||
|
--lambda_ctn 10 \
|
||||||
|
--lambda_inc 8.0 \
|
||||||
|
--eta_ratio 0.4 \
|
||||||
|
--lambda_global 6.0 \
|
||||||
|
--lambda_spatial 6.0 \
|
||||||
|
--gamma_stride 20 \
|
||||||
|
--lr 0.00002 \
|
||||||
|
--gpu_id 3 \
|
||||||
|
--eta_ratio 0.4 \
|
||||||
|
--n_epochs 100 \
|
||||||
|
--n_epochs_decay 100 \
|
||||||
|
# cp1 复现cptrans的效果 --lr 0.000001
|
||||||
|
# cp2 修了一下cp1的代码,--lr 0.000002
|
||||||
|
## cp3 将梯度加强修改为attention加强,--lr 0.000005,--lambda_inc 8.0,--gpu_id 3(基于cp2的sh)
|
||||||
Loading…
x
Reference in New Issue
Block a user