diff --git a/models/roma_unsb_model.py b/models/roma_unsb_model.py index a0b7682..e54b36f 100644 --- a/models/roma_unsb_model.py +++ b/models/roma_unsb_model.py @@ -97,9 +97,13 @@ class ContentAwareOptimization(nn.Module): cosine_fake = self.compute_cosine_similarity(gradients_fake) # 生成权重图(与原代码相同) - k = int(self.eta_ratio * N) + k = int(self.eta_ratio * cosine_fake.shape[1]) + _, fake_indices = torch.topk(-cosine_fake, k, dim=1) weight_fake = torch.ones_like(cosine_fake) - + + for b in range(cosine_fake.shape[0]): + weight_fake[b, fake_indices[b]] = self.lambda_inc / (1e-6 + torch.abs(cosine_fake[b, fake_indices[b]])) + # 重建空间维度 -------------------------------------------------- # 将权重从[B, N]转换为[B, H, W] weight_fake = weight_fake.view(-1, H, W).unsqueeze(1) # [B,1,H,W]