This commit is contained in:
bishe 2025-02-24 21:13:36 +08:00
parent 55b9db967a
commit 7af2de920c

View File

@ -97,9 +97,13 @@ class ContentAwareOptimization(nn.Module):
cosine_fake = self.compute_cosine_similarity(gradients_fake)
# 生成权重图(与原代码相同)
k = int(self.eta_ratio * N)
k = int(self.eta_ratio * cosine_fake.shape[1])
_, fake_indices = torch.topk(-cosine_fake, k, dim=1)
weight_fake = torch.ones_like(cosine_fake)
for b in range(cosine_fake.shape[0]):
weight_fake[b, fake_indices[b]] = self.lambda_inc / (1e-6 + torch.abs(cosine_fake[b, fake_indices[b]]))
# 重建空间维度 --------------------------------------------------
# 将权重从[B, N]转换为[B, H, W]
weight_fake = weight_fake.view(-1, H, W).unsqueeze(1) # [B,1,H,W]