diff --git a/models/networks.py b/models/networks.py index 3c29522..1cdf916 100644 --- a/models/networks.py +++ b/models/networks.py @@ -1401,23 +1401,32 @@ class UnetSkipConnectionBlock(nn.Module): class MLPDiscriminator(nn.Module): - def __init__(self, in_feat=768, hid_feat = 768, out_feat = 768, dropout = 0.): + def __init__(self, in_feat=768, hid_feat=512, out_feat=768, num_heads=1): super().__init__() - if not hid_feat: - hid_feat = in_feat - if not out_feat: - out_feat = in_feat - self.linear1 = nn.Linear(in_feat, hid_feat) - self.activation = nn.GELU() - self.linear2 = nn.Linear(hid_feat, out_feat) - self.dropout = nn.Dropout(dropout) - + # 自注意力层,加入Dropout + self.attention = nn.MultiheadAttention(embed_dim=in_feat, num_heads=num_heads, dropout=0.1) + # 加深加宽的MLP,加入Dropout + self.mlp = nn.Sequential( + nn.Linear(in_feat, hid_feat), # 768 -> 512 + nn.ReLU(), + nn.Dropout(0.3), + nn.Linear(hid_feat, hid_feat * 2), # 512 -> 1024 + nn.ReLU(), + nn.Dropout(0.3), + nn.Linear(hid_feat * 2, hid_feat), # 1024 -> 512 + nn.ReLU(), + nn.Dropout(0.3), + nn.Linear(hid_feat, out_feat), # 512 -> 768 + nn.ReLU(), + nn.Dropout(0.3), + nn.Linear(out_feat, 1) # 768 -> 1 + ) def forward(self, x): - features = self.linear1(x) # 中间特征,即 D_real 或 D_fake - x = self.activation(features) - x = self.dropout(x) - scores = self.linear2(x) # 最终分数,即 real_scores 或 fake_scores - return scores, features + attn_output, attn_weights = self.attention(x, x, x) # [B, N, D], [B, N, N] + attn_weights = attn_weights.mean(dim=1) # [B, N] + pred = self.mlp(attn_output.mean(dim=1)) # [B, 1] + return pred, attn_weights + class NLayerDiscriminator(nn.Module): """Defines a PatchGAN discriminator""" diff --git a/models/roma_unsb_model.py b/models/roma_unsb_model.py index bdc98f4..bd5d9fc 100644 --- a/models/roma_unsb_model.py +++ b/models/roma_unsb_model.py @@ -67,62 +67,39 @@ class ContentAwareOptimization(nn.Module): super().__init__() self.lambda_inc = lambda_inc self.eta_ratio = eta_ratio - self.gradients_real = [] - self.gradients_fake = [] - - - def compute_cosine_similarity(self, gradients): - mean_grad = torch.mean(gradients, dim=1, keepdim=True) - return F.cosine_similarity(gradients, mean_grad, dim=2) - - def generate_weight_map(self, gradients_real, gradients_fake): - # 计算余弦相似度 - cosine_real = self.compute_cosine_similarity(gradients_real) - cosine_fake = self.compute_cosine_similarity(gradients_fake) - - # 生成权重图(优化实现) - def _get_weights(cosine): - k = int(self.eta_ratio * cosine.shape[1]) - _, indices = torch.topk(-cosine, k, dim=1) - weights = torch.ones_like(cosine) - weights.scatter_(1, indices, self.lambda_inc / (1e-6 + torch.abs(cosine.gather(1, indices)))) - return weights - - weight_real = _get_weights(cosine_real) - weight_fake = _get_weights(cosine_fake) - return weight_real, weight_fake - - def forward(self, D_real, D_fake, real_scores, fake_scores): - # 清空梯度缓存 - self.gradients_real.clear() - self.gradients_fake.clear() self.criterionGAN=networks.GANLoss('lsgan').cuda() - # 注册钩子捕获梯度 - hook_real = lambda grad: self.gradients_real.append(grad.detach()) - hook_fake = lambda grad: self.gradients_fake.append(grad.detach()) - D_real.register_hook(hook_real) - D_fake.register_hook(hook_fake) - # 触发梯度计算(保留计算图) - (real_scores.mean() + fake_scores.mean()).backward(retain_graph=True) + def generate_weight_map(self, attn_real, attn_fake): + # attn_real, attn_fake: [B, N],自注意力权重 + # 归一化注意力权重 + weight_real = F.normalize(attn_real, p=1, dim=1) # [B, N] + weight_fake = F.normalize(attn_fake, p=1, dim=1) # [B, N] + + # 对真实图像权重处理 + k = int(self.eta_ratio * weight_real.shape[1]) + values_real, indices_real = torch.topk(weight_real, k, dim=1) + weight_real_enhanced = torch.ones_like(weight_real) + weight_real_enhanced.scatter_(1, indices_real, self.lambda_inc / (values_real + 1e-6)) + # 对生成图像权重处理 + values_fake, indices_fake = torch.topk(weight_fake, k, dim=1) + weight_fake_enhanced = torch.ones_like(weight_fake) + weight_fake_enhanced.scatter_(1, indices_fake, self.lambda_inc / (values_fake + 1e-6)) + + return weight_real_enhanced, weight_fake_enhanced - # 获取梯度并调整维度 - grad_real = self.gradients_real[0].flatten(1) # [B, N, D] → [B, N*D] - grad_fake = self.gradients_fake[0].flatten(1) + def forward(self,real_scores, fake_scores, attn_real, attn_fake): + # real_scores, fake_scores: 判别器预测得分 [B, 1] + # attn_real, attn_fake: 自注意力权重 [B, N] # 生成权重图 - weight_real, weight_fake = self.generate_weight_map( - grad_real.view(*D_real.shape), - grad_fake.view(*D_fake.shape) - ) + weight_real, weight_fake = self.generate_weight_map(attn_real, attn_fake) - # 正确应用权重到对数概率(论文公式7) - loss_co_real = torch.mean(weight_real * self.criterionGAN(real_scores , True)) - loss_co_fake = torch.mean(weight_fake * self.criterionGAN(fake_scores , False)) - - # 总损失(注意符号:判别器需最大化该损失) - loss_co_adv = (loss_co_real + loss_co_fake)*0.5 + # 应用权重到 GAN 损失 + loss_co_real = torch.mean(weight_real * self.criterionGAN(real_scores, True)) + loss_co_fake = torch.mean(weight_fake * self.criterionGAN(fake_scores, False)) + # 总损失 + loss_co_adv = (loss_co_real + loss_co_fake) * 0.5 return loss_co_adv, weight_real, weight_fake class ContentAwareTemporalNorm(nn.Module): @@ -132,18 +109,19 @@ class ContentAwareTemporalNorm(nn.Module): self.smoother = GaussianBlur(kernel_size, sigma=sigma) # 高斯平滑层 def upsample_weight_map(self, weight_patch, target_size=(256, 256)): - """ - 将patch级别的权重图上采样到目标分辨率 - Args: - weight_patch: [B, 1, 24, 24] 来自ViT的patch权重图 - target_size: 目标分辨率 (H, W) - Returns: - weight_full: [B, 1, 256, 256] 上采样后的全分辨率权重图 - """ - # 使用双线性插值上采样 - B = weight_patch.shape[0] - weight_patch = weight_patch.view(B, 1, 24, 24) + # 如果 weight_patch 是 [N, 1] 形状(例如 [576, 1]),添加批次维度 + if weight_patch.dim() == 2 and weight_patch.shape[1] == 1: + weight_patch = weight_patch.unsqueeze(0) # 变为 [1, 576, 1] + # 获取调整后的形状 + B, N, _ = weight_patch.shape # 例如 B=1, N=576 + if N != 576: + raise ValueError(f"预期 patch 数量 N=576 (24x24),但实际得到 N={N}") + + # 重塑为 [B, 1, 24, 24] + weight_patch = weight_patch.view(B, 1, 24, 24) # [1, 1, 24, 24] + + # 使用双线性插值上采样到目标大小 weight_full = F.interpolate( weight_patch, size=target_size, @@ -151,8 +129,7 @@ class ContentAwareTemporalNorm(nn.Module): align_corners=False ) - # 对每个16x16的patch内部保持权重一致(可选) - # 通过平均池化再扩展,消除插值引入的渐变 + # 可选:保持每个 16x16 patch 内部权重一致 weight_full = F.avg_pool2d(weight_full, kernel_size=16, stride=16) weight_full = F.interpolate(weight_full, scale_factor=16, mode='nearest') @@ -167,6 +144,7 @@ class ContentAwareTemporalNorm(nn.Module): F_content: [B, 2, H, W] 生成的光流场(x/y方向位移) """ # 上采样权重图到全分辨率 + weight_full = self.upsample_weight_map(weight_map) # [B,1,384,384] # 1. 归一化权重图 @@ -342,31 +320,25 @@ class RomaUnsbModel(BaseModel): """Calculate GAN loss with Content-Aware Optimization""" lambda_D_ViT = self.opt.lambda_D_ViT - loss_cao = 0.0 - real_B0_tokens = self.mutil_real_B0_tokens[0] - pred_real0, real_features0 = self.netD_ViT(real_B0_tokens) # scores, features - real_B1_tokens = self.mutil_real_B1_tokens[0] - pred_real1, real_features1 = self.netD_ViT(real_B1_tokens) # scores, features + pred_real0, attn_real0 = self.netD_ViT(self.mutil_real_B0_tokens[0]) # scores, features + pred_real1, attn_real1 = self.netD_ViT(self.mutil_real_B1_tokens[0]) # scores, features - pre_fake0, fake_features0 = self.netD_ViT(self.mutil_fake_B0_tokens[0].detach()) - pre_fake1, fake_features1 = self.netD_ViT(self.mutil_fake_B1_tokens[0].detach()) + pred_fake0, attn_fake0 = self.netD_ViT(self.mutil_fake_B0_tokens[0].detach()) + pred_fake1, attn_fake1 = self.netD_ViT(self.mutil_fake_B1_tokens[0].detach()) loss_cao0, self.weight_real0, self.weight_fake0 = self.cao( - D_real=real_features0, - D_fake=fake_features0, real_scores=pred_real0, - fake_scores=pre_fake0 + fake_scores=pred_fake0, + attn_real=attn_real0, + attn_fake=attn_fake0 ) loss_cao1, self.weight_real1, self.weight_fake1 = self.cao( - D_real=real_features1, - D_fake=fake_features1, real_scores=pred_real1, - fake_scores=pre_fake1 + fake_scores=pred_fake1, + attn_real=attn_real1, + attn_fake=attn_fake1 ) - loss_cao += loss_cao0 + loss_cao1 - - # ===== 综合损失 ===== - self.loss_D_ViT = loss_cao * 0.5 * lambda_D_ViT + self.loss_D_ViT = (loss_cao0 + loss_cao1) * 0.5 * lambda_D_ViT # 记录损失值供可视化 diff --git a/scripts/traincp.sh b/scripts/traincp.sh index 5b0625c..c35cd60 100644 --- a/scripts/traincp.sh +++ b/scripts/traincp.sh @@ -1,18 +1,20 @@ python train.py \ --dataroot /home/openxs/kunyu/datasets/InfraredCity-Lite/Double/Moitor \ - --name cp_2 \ + --name cp_3 \ --dataset_mode unaligned_double \ --display_env CP \ --model roma_unsb \ --lambda_ctn 10 \ - --lambda_inc 1.0 \ + --lambda_inc 8.0 \ + --eta_ratio 0.4 \ --lambda_global 6.0 \ --lambda_spatial 6.0 \ --gamma_stride 20 \ - --lr 0.000002 \ - --gpu_id 0 \ + --lr 0.00002 \ + --gpu_id 3 \ --eta_ratio 0.4 \ --n_epochs 100 \ --n_epochs_decay 100 \ # cp1 复现cptrans的效果 --lr 0.000001 -# cp2 修了一下cp1的代码,--lr 0.000002 \ No newline at end of file +# cp2 修了一下cp1的代码,--lr 0.000002 +## cp3 将梯度加强修改为attention加强,--lr 0.000005,--lambda_inc 8.0,--gpu_id 3(基于cp2的sh) \ No newline at end of file