From e67b0f2511c0fa0e340f9cb7e58f6bc538844e85 Mon Sep 17 00:00:00 2001 From: bishe <123456789@163.com> Date: Mon, 24 Feb 2025 21:28:21 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9C=80=E6=96=B0=E7=9A=84=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- models/roma_unsb_model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/models/roma_unsb_model.py b/models/roma_unsb_model.py index e54b36f..3563ddf 100644 --- a/models/roma_unsb_model.py +++ b/models/roma_unsb_model.py @@ -123,7 +123,7 @@ class ContentAwareOptimization(nn.Module): """ B, C, H, W = D_real.shape N = H * W - + shape_hw = [h, w] # 注册钩子获取梯度 gradients_real = [] gradients_fake = [] @@ -150,7 +150,7 @@ class ContentAwareOptimization(nn.Module): gradients_fake = gradients_fake[0] # [B, N, D] # 生成权重图 - self.weight_real, self.weight_fake = self.generate_weight_map(gradients_real, gradients_fake) + self.weight_real, self.weight_fake = self.generate_weight_map(gradients_fake, shape_hw ) # 应用权重到对抗损失 loss_co_real = torch.mean(self.weight_real * torch.log(real_scores + 1e-8)) @@ -496,12 +496,12 @@ class RomaUnsbModel(BaseModel): self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True) # [[1,576,768],[1,576,768],[1,576,768]] # [3,576,768] - + shape_hw = list(self.real_A0_resize.shape[2:4]) # 生成图像的梯度 fake_gradient = torch.autograd.grad(self.mutil_fake_B0_tokens[0].sum(), self.mutil_fake_B0_tokens, create_graph=True)[0] # 梯度图 - self.weight_fake = self.cao.generate_weight_map(fake_gradient) + self.weight_fake = self.cao.generate_weight_map(fake_gradient,shape_hw) # 生成图像的CTN光流图 self.f_content = self.ctn(self.weight_fake)