最新的修改
This commit is contained in:
parent
67151c73f7
commit
55b9db967a
@ -60,7 +60,7 @@ def compute_ctn_loss(G, x, F_content): #公式10
|
||||
loss = F.mse_loss(warped_fake, y_fake_warped)
|
||||
return loss
|
||||
|
||||
class ContentAwareOptimization(nn.Module):
|
||||
class ContentAwareOptimization(nn.Module):
|
||||
def __init__(self, lambda_inc=2.0, eta_ratio=0.4):
|
||||
super().__init__()
|
||||
self.lambda_inc = lambda_inc # 权重增强系数
|
||||
@ -79,25 +79,30 @@ class ContentAwareOptimization(nn.Module):
|
||||
cosine_sim = F.cosine_similarity(gradients, mean_grad, dim=2) # [B, N]
|
||||
return cosine_sim
|
||||
|
||||
def generate_weight_map(self, gradients_fake):
|
||||
def generate_weight_map(self, gradients_fake, feature_shape):
|
||||
"""
|
||||
生成内容感知权重图
|
||||
生成内容感知权重图(修正空间维度)
|
||||
Args:
|
||||
gradients_fake: [B, N, D] 生成图像判别器梯度 [2,3,256,256]
|
||||
gradients_real: [B, N, D] 真实图像判别器梯度
|
||||
gradients_fake: [B, N, D] 生成图像判别器梯度
|
||||
feature_shape: tuple [H, W] 判别器输出的特征图尺寸
|
||||
Returns:
|
||||
weight_fake: [B, N] 生成图像权重图 [2,3,256]
|
||||
weight_real: [B, 1, H, W] 真实图像权重图
|
||||
weight_fake: [B, 1, H, W] 生成图像权重图
|
||||
"""
|
||||
# 计算生成图像块的余弦相似度
|
||||
cosine_fake = self.compute_cosine_similarity(gradients_fake) # [B, N]
|
||||
H, W = feature_shape
|
||||
N = H * W
|
||||
|
||||
# 选择内容丰富的区域(余弦相似度最低的eta_ratio比例)
|
||||
k = int(self.eta_ratio * cosine_fake.shape[1])
|
||||
|
||||
# 对生成图像生成权重图(同理)
|
||||
_, fake_indices = torch.topk(-cosine_fake, k, dim=1)
|
||||
# 计算余弦相似度(与原代码相同)
|
||||
cosine_fake = self.compute_cosine_similarity(gradients_fake)
|
||||
|
||||
# 生成权重图(与原代码相同)
|
||||
k = int(self.eta_ratio * N)
|
||||
weight_fake = torch.ones_like(cosine_fake)
|
||||
for b in range(cosine_fake.shape[0]):
|
||||
weight_fake[b, fake_indices[b]] = self.lambda_inc / (1e-6 + torch.abs(cosine_fake[b, fake_indices[b]]))
|
||||
|
||||
# 重建空间维度 --------------------------------------------------
|
||||
# 将权重从[B, N]转换为[B, H, W]
|
||||
weight_fake = weight_fake.view(-1, H, W).unsqueeze(1) # [B,1,H,W]
|
||||
|
||||
return weight_fake
|
||||
|
||||
@ -488,28 +493,28 @@ class RomaUnsbModel(BaseModel):
|
||||
# [[1,576,768],[1,576,768],[1,576,768]]
|
||||
# [3,576,768]
|
||||
|
||||
## 生成图像的梯度
|
||||
#fake_gradient = torch.autograd.grad(self.mutil_fake_B0_tokens.sum(), self.mutil_fake_B0_tokens, create_graph=True)[0]
|
||||
#
|
||||
## 梯度图
|
||||
#self.weight_fake = self.cao.generate_weight_map(fake_gradient)
|
||||
#
|
||||
## 生成图像的CTN光流图
|
||||
#self.f_content = self.ctn(self.weight_fake)
|
||||
#
|
||||
## 变换后的图片
|
||||
#self.warped_real_A_noisy2 = warp(self.real_A_noisy, self.f_content)
|
||||
#self.warped_fake_B0 = warp(self.fake_B0,self.f_content)
|
||||
#
|
||||
## 经过第二次生成器
|
||||
#self.warped_fake_B0_2 = self.netG(self.warped_real_A_noisy2, self.time, z_in)
|
||||
# 生成图像的梯度
|
||||
fake_gradient = torch.autograd.grad(self.mutil_fake_B0_tokens[0].sum(), self.mutil_fake_B0_tokens, create_graph=True)[0]
|
||||
|
||||
#warped_fake_B0_2=self.warped_fake_B0_2
|
||||
#warped_fake_B0=self.warped_fake_B0
|
||||
#self.warped_fake_B0_2_resize = self.resize(warped_fake_B0_2)
|
||||
#self.warped_fake_B0_resize = self.resize(warped_fake_B0)
|
||||
#self.mutil_warped_fake_B0_tokens = self.netPreViT(self.warped_fake_B0_resize, self.atten_layers, get_tokens=True)
|
||||
#self.mutil_fake_B0_2_tokens = self.netPreViT(self.warped_fake_B0_2_resize, self.atten_layers, get_tokens=True)
|
||||
# 梯度图
|
||||
self.weight_fake = self.cao.generate_weight_map(fake_gradient)
|
||||
|
||||
# 生成图像的CTN光流图
|
||||
self.f_content = self.ctn(self.weight_fake)
|
||||
|
||||
# 变换后的图片
|
||||
self.warped_real_A_noisy2 = warp(self.real_A_noisy, self.f_content)
|
||||
self.warped_fake_B0 = warp(self.fake_B0,self.f_content)
|
||||
|
||||
# 经过第二次生成器
|
||||
self.warped_fake_B0_2 = self.netG(self.warped_real_A_noisy2, self.time, z_in)
|
||||
|
||||
# warped_fake_B0_2=self.warped_fake_B0_2
|
||||
# warped_fake_B0=self.warped_fake_B0
|
||||
# self.warped_fake_B0_2_resize = self.resize(warped_fake_B0_2)
|
||||
# self.warped_fake_B0_resize = self.resize(warped_fake_B0)
|
||||
# self.mutil_warped_fake_B0_tokens = self.netPreViT(self.warped_fake_B0_resize, self.atten_layers, get_tokens=True)
|
||||
# self.mutil_fake_B0_2_tokens = self.netPreViT(self.warped_fake_B0_2_resize, self.atten_layers, get_tokens=True)
|
||||
|
||||
|
||||
def compute_D_loss(self): #判别器还是没有改
|
||||
@ -575,9 +580,9 @@ class RomaUnsbModel(BaseModel):
|
||||
loss_global = 0.0
|
||||
|
||||
self.l2_loss = 0.0
|
||||
#if self.opt.lambda_ctn > 0.0:
|
||||
# wapped_fake_B = warp(self.fake_B, self.f_content) # use updated self.f_content
|
||||
# self.l2_loss = F.mse_loss(self.fake_B_2, wapped_fake_B) # complete the loss calculation
|
||||
if self.opt.lambda_l2 > 0.0:
|
||||
wapped_fake_B = warp(self.fake_B0, self.f_content) # use updated self.f_content
|
||||
self.l2_loss = F.mse_loss(self.warped_fake_B0_2, wapped_fake_B) # complete the loss calculation
|
||||
|
||||
self.loss_G = self.loss_G_GAN + self.opt.lambda_SB * self.loss_SB + self.opt.lambda_ctn * self.l2_loss + loss_global * self.opt.lambda_global
|
||||
return self.loss_G
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user