在判别器中引入attention

This commit is contained in:
bishe 2025-03-09 23:30:05 +08:00
parent e0dc08030c
commit e9c0f5ffcb
3 changed files with 82 additions and 99 deletions

View File

@ -1401,23 +1401,32 @@ class UnetSkipConnectionBlock(nn.Module):
class MLPDiscriminator(nn.Module): class MLPDiscriminator(nn.Module):
def __init__(self, in_feat=768, hid_feat = 768, out_feat = 768, dropout = 0.): def __init__(self, in_feat=768, hid_feat=512, out_feat=768, num_heads=1):
super().__init__() super().__init__()
if not hid_feat: # 自注意力层加入Dropout
hid_feat = in_feat self.attention = nn.MultiheadAttention(embed_dim=in_feat, num_heads=num_heads, dropout=0.1)
if not out_feat: # 加深加宽的MLP加入Dropout
out_feat = in_feat self.mlp = nn.Sequential(
self.linear1 = nn.Linear(in_feat, hid_feat) nn.Linear(in_feat, hid_feat), # 768 -> 512
self.activation = nn.GELU() nn.ReLU(),
self.linear2 = nn.Linear(hid_feat, out_feat) nn.Dropout(0.3),
self.dropout = nn.Dropout(dropout) nn.Linear(hid_feat, hid_feat * 2), # 512 -> 1024
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(hid_feat * 2, hid_feat), # 1024 -> 512
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(hid_feat, out_feat), # 512 -> 768
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(out_feat, 1) # 768 -> 1
)
def forward(self, x): def forward(self, x):
features = self.linear1(x) # 中间特征,即 D_real 或 D_fake attn_output, attn_weights = self.attention(x, x, x) # [B, N, D], [B, N, N]
x = self.activation(features) attn_weights = attn_weights.mean(dim=1) # [B, N]
x = self.dropout(x) pred = self.mlp(attn_output.mean(dim=1)) # [B, 1]
scores = self.linear2(x) # 最终分数,即 real_scores 或 fake_scores return pred, attn_weights
return scores, features
class NLayerDiscriminator(nn.Module): class NLayerDiscriminator(nn.Module):
"""Defines a PatchGAN discriminator""" """Defines a PatchGAN discriminator"""

View File

@ -67,62 +67,39 @@ class ContentAwareOptimization(nn.Module):
super().__init__() super().__init__()
self.lambda_inc = lambda_inc self.lambda_inc = lambda_inc
self.eta_ratio = eta_ratio self.eta_ratio = eta_ratio
self.gradients_real = []
self.gradients_fake = []
def compute_cosine_similarity(self, gradients):
mean_grad = torch.mean(gradients, dim=1, keepdim=True)
return F.cosine_similarity(gradients, mean_grad, dim=2)
def generate_weight_map(self, gradients_real, gradients_fake):
# 计算余弦相似度
cosine_real = self.compute_cosine_similarity(gradients_real)
cosine_fake = self.compute_cosine_similarity(gradients_fake)
# 生成权重图(优化实现)
def _get_weights(cosine):
k = int(self.eta_ratio * cosine.shape[1])
_, indices = torch.topk(-cosine, k, dim=1)
weights = torch.ones_like(cosine)
weights.scatter_(1, indices, self.lambda_inc / (1e-6 + torch.abs(cosine.gather(1, indices))))
return weights
weight_real = _get_weights(cosine_real)
weight_fake = _get_weights(cosine_fake)
return weight_real, weight_fake
def forward(self, D_real, D_fake, real_scores, fake_scores):
# 清空梯度缓存
self.gradients_real.clear()
self.gradients_fake.clear()
self.criterionGAN=networks.GANLoss('lsgan').cuda() self.criterionGAN=networks.GANLoss('lsgan').cuda()
# 注册钩子捕获梯度
hook_real = lambda grad: self.gradients_real.append(grad.detach())
hook_fake = lambda grad: self.gradients_fake.append(grad.detach())
D_real.register_hook(hook_real)
D_fake.register_hook(hook_fake)
# 触发梯度计算(保留计算图) def generate_weight_map(self, attn_real, attn_fake):
(real_scores.mean() + fake_scores.mean()).backward(retain_graph=True) # attn_real, attn_fake: [B, N],自注意力权重
# 归一化注意力权重
weight_real = F.normalize(attn_real, p=1, dim=1) # [B, N]
weight_fake = F.normalize(attn_fake, p=1, dim=1) # [B, N]
# 获取梯度并调整维度 # 对真实图像权重处理
grad_real = self.gradients_real[0].flatten(1) # [B, N, D] → [B, N*D] k = int(self.eta_ratio * weight_real.shape[1])
grad_fake = self.gradients_fake[0].flatten(1) values_real, indices_real = torch.topk(weight_real, k, dim=1)
weight_real_enhanced = torch.ones_like(weight_real)
weight_real_enhanced.scatter_(1, indices_real, self.lambda_inc / (values_real + 1e-6))
# 对生成图像权重处理
values_fake, indices_fake = torch.topk(weight_fake, k, dim=1)
weight_fake_enhanced = torch.ones_like(weight_fake)
weight_fake_enhanced.scatter_(1, indices_fake, self.lambda_inc / (values_fake + 1e-6))
return weight_real_enhanced, weight_fake_enhanced
def forward(self,real_scores, fake_scores, attn_real, attn_fake):
# real_scores, fake_scores: 判别器预测得分 [B, 1]
# attn_real, attn_fake: 自注意力权重 [B, N]
# 生成权重图 # 生成权重图
weight_real, weight_fake = self.generate_weight_map( weight_real, weight_fake = self.generate_weight_map(attn_real, attn_fake)
grad_real.view(*D_real.shape),
grad_fake.view(*D_fake.shape)
)
# 正确应用权重到对数概率论文公式7 # 应用权重到 GAN 损失
loss_co_real = torch.mean(weight_real * self.criterionGAN(real_scores , True)) loss_co_real = torch.mean(weight_real * self.criterionGAN(real_scores, True))
loss_co_fake = torch.mean(weight_fake * self.criterionGAN(fake_scores , False)) loss_co_fake = torch.mean(weight_fake * self.criterionGAN(fake_scores, False))
# 总损失(注意符号:判别器需最大化该损失)
loss_co_adv = (loss_co_real + loss_co_fake)*0.5
# 总损失
loss_co_adv = (loss_co_real + loss_co_fake) * 0.5
return loss_co_adv, weight_real, weight_fake return loss_co_adv, weight_real, weight_fake
class ContentAwareTemporalNorm(nn.Module): class ContentAwareTemporalNorm(nn.Module):
@ -132,18 +109,19 @@ class ContentAwareTemporalNorm(nn.Module):
self.smoother = GaussianBlur(kernel_size, sigma=sigma) # 高斯平滑层 self.smoother = GaussianBlur(kernel_size, sigma=sigma) # 高斯平滑层
def upsample_weight_map(self, weight_patch, target_size=(256, 256)): def upsample_weight_map(self, weight_patch, target_size=(256, 256)):
""" # 如果 weight_patch 是 [N, 1] 形状(例如 [576, 1]),添加批次维度
将patch级别的权重图上采样到目标分辨率 if weight_patch.dim() == 2 and weight_patch.shape[1] == 1:
Args: weight_patch = weight_patch.unsqueeze(0) # 变为 [1, 576, 1]
weight_patch: [B, 1, 24, 24] 来自ViT的patch权重图
target_size: 目标分辨率 (H, W)
Returns:
weight_full: [B, 1, 256, 256] 上采样后的全分辨率权重图
"""
# 使用双线性插值上采样
B = weight_patch.shape[0]
weight_patch = weight_patch.view(B, 1, 24, 24)
# 获取调整后的形状
B, N, _ = weight_patch.shape # 例如 B=1, N=576
if N != 576:
raise ValueError(f"预期 patch 数量 N=576 (24x24),但实际得到 N={N}")
# 重塑为 [B, 1, 24, 24]
weight_patch = weight_patch.view(B, 1, 24, 24) # [1, 1, 24, 24]
# 使用双线性插值上采样到目标大小
weight_full = F.interpolate( weight_full = F.interpolate(
weight_patch, weight_patch,
size=target_size, size=target_size,
@ -151,8 +129,7 @@ class ContentAwareTemporalNorm(nn.Module):
align_corners=False align_corners=False
) )
# 对每个16x16的patch内部保持权重一致可选 # 可选:保持每个 16x16 patch 内部权重一致
# 通过平均池化再扩展,消除插值引入的渐变
weight_full = F.avg_pool2d(weight_full, kernel_size=16, stride=16) weight_full = F.avg_pool2d(weight_full, kernel_size=16, stride=16)
weight_full = F.interpolate(weight_full, scale_factor=16, mode='nearest') weight_full = F.interpolate(weight_full, scale_factor=16, mode='nearest')
@ -167,6 +144,7 @@ class ContentAwareTemporalNorm(nn.Module):
F_content: [B, 2, H, W] 生成的光流场(x/y方向位移) F_content: [B, 2, H, W] 生成的光流场(x/y方向位移)
""" """
# 上采样权重图到全分辨率 # 上采样权重图到全分辨率
weight_full = self.upsample_weight_map(weight_map) # [B,1,384,384] weight_full = self.upsample_weight_map(weight_map) # [B,1,384,384]
# 1. 归一化权重图 # 1. 归一化权重图
@ -342,31 +320,25 @@ class RomaUnsbModel(BaseModel):
"""Calculate GAN loss with Content-Aware Optimization""" """Calculate GAN loss with Content-Aware Optimization"""
lambda_D_ViT = self.opt.lambda_D_ViT lambda_D_ViT = self.opt.lambda_D_ViT
loss_cao = 0.0 pred_real0, attn_real0 = self.netD_ViT(self.mutil_real_B0_tokens[0]) # scores, features
real_B0_tokens = self.mutil_real_B0_tokens[0] pred_real1, attn_real1 = self.netD_ViT(self.mutil_real_B1_tokens[0]) # scores, features
pred_real0, real_features0 = self.netD_ViT(real_B0_tokens) # scores, features
real_B1_tokens = self.mutil_real_B1_tokens[0]
pred_real1, real_features1 = self.netD_ViT(real_B1_tokens) # scores, features
pre_fake0, fake_features0 = self.netD_ViT(self.mutil_fake_B0_tokens[0].detach()) pred_fake0, attn_fake0 = self.netD_ViT(self.mutil_fake_B0_tokens[0].detach())
pre_fake1, fake_features1 = self.netD_ViT(self.mutil_fake_B1_tokens[0].detach()) pred_fake1, attn_fake1 = self.netD_ViT(self.mutil_fake_B1_tokens[0].detach())
loss_cao0, self.weight_real0, self.weight_fake0 = self.cao( loss_cao0, self.weight_real0, self.weight_fake0 = self.cao(
D_real=real_features0,
D_fake=fake_features0,
real_scores=pred_real0, real_scores=pred_real0,
fake_scores=pre_fake0 fake_scores=pred_fake0,
attn_real=attn_real0,
attn_fake=attn_fake0
) )
loss_cao1, self.weight_real1, self.weight_fake1 = self.cao( loss_cao1, self.weight_real1, self.weight_fake1 = self.cao(
D_real=real_features1,
D_fake=fake_features1,
real_scores=pred_real1, real_scores=pred_real1,
fake_scores=pre_fake1 fake_scores=pred_fake1,
attn_real=attn_real1,
attn_fake=attn_fake1
) )
loss_cao += loss_cao0 + loss_cao1
self.loss_D_ViT = (loss_cao0 + loss_cao1) * 0.5 * lambda_D_ViT
# ===== 综合损失 =====
self.loss_D_ViT = loss_cao * 0.5 * lambda_D_ViT
# 记录损失值供可视化 # 记录损失值供可视化

View File

@ -1,18 +1,20 @@
python train.py \ python train.py \
--dataroot /home/openxs/kunyu/datasets/InfraredCity-Lite/Double/Moitor \ --dataroot /home/openxs/kunyu/datasets/InfraredCity-Lite/Double/Moitor \
--name cp_2 \ --name cp_3 \
--dataset_mode unaligned_double \ --dataset_mode unaligned_double \
--display_env CP \ --display_env CP \
--model roma_unsb \ --model roma_unsb \
--lambda_ctn 10 \ --lambda_ctn 10 \
--lambda_inc 1.0 \ --lambda_inc 8.0 \
--eta_ratio 0.4 \
--lambda_global 6.0 \ --lambda_global 6.0 \
--lambda_spatial 6.0 \ --lambda_spatial 6.0 \
--gamma_stride 20 \ --gamma_stride 20 \
--lr 0.000002 \ --lr 0.00002 \
--gpu_id 0 \ --gpu_id 3 \
--eta_ratio 0.4 \ --eta_ratio 0.4 \
--n_epochs 100 \ --n_epochs 100 \
--n_epochs_decay 100 \ --n_epochs_decay 100 \
# cp1 复现cptrans的效果 --lr 0.000001 # cp1 复现cptrans的效果 --lr 0.000001
# cp2 修了一下cp1的代码--lr 0.000002 # cp2 修了一下cp1的代码--lr 0.000002
## cp3 将梯度加强修改为attention加强--lr 0.000005,--lambda_inc 8.0,--gpu_id 3(基于cp2的sh)