debug
This commit is contained in:
parent
687559866d
commit
8a081af0a3
@ -39,3 +39,8 @@
|
||||
================ Training Loss (Sun Feb 23 21:29:03 2025) ================
|
||||
================ Training Loss (Sun Feb 23 21:34:57 2025) ================
|
||||
================ Training Loss (Sun Feb 23 21:35:26 2025) ================
|
||||
================ Training Loss (Sun Feb 23 22:28:43 2025) ================
|
||||
================ Training Loss (Sun Feb 23 22:29:04 2025) ================
|
||||
================ Training Loss (Sun Feb 23 22:29:52 2025) ================
|
||||
================ Training Loss (Sun Feb 23 22:30:40 2025) ================
|
||||
================ Training Loss (Sun Feb 23 22:33:48 2025) ================
|
||||
|
||||
Binary file not shown.
@ -395,8 +395,6 @@ class RomaUnsbModel(BaseModel):
|
||||
|
||||
return result
|
||||
|
||||
|
||||
|
||||
def forward(self):
|
||||
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
|
||||
|
||||
@ -462,16 +460,36 @@ class RomaUnsbModel(BaseModel):
|
||||
self.fake_B1 = self.netG(self.real_A1, self.time, z_in2)
|
||||
|
||||
if self.opt.phase == 'train':
|
||||
real_A0 = self.real_A0
|
||||
real_A1 = self.real_A1
|
||||
real_B0 = self.real_B0
|
||||
real_B1 = self.real_B1
|
||||
fake_B0 = self.fake_B0
|
||||
fake_B1 = self.fake_B1
|
||||
|
||||
self.real_A0_resize = self.resize(real_A0)
|
||||
self.real_A1_resize = self.resize(real_A1)
|
||||
real_B0 = self.resize(real_B0)
|
||||
real_B1 = self.resize(real_B1)
|
||||
self.fake_B0_resize = self.resize(fake_B0)
|
||||
self.fake_B1_resize = self.resize(fake_B1)
|
||||
|
||||
self.mutil_real_A0_tokens = self.netPreViT(self.real_A0_resize, self.atten_layers, get_tokens=True)
|
||||
self.mutil_real_A1_tokens = self.netPreViT(self.real_A1_resize, self.atten_layers, get_tokens=True)
|
||||
self.mutil_real_B0_tokens = self.netPreViT(real_B0, self.atten_layers, get_tokens=True)
|
||||
self.mutil_real_B1_tokens = self.netPreViT(real_B1, self.atten_layers, get_tokens=True)
|
||||
self.mutil_fake_B0_tokens = self.netPreViT(self.fake_B0_resize, self.atten_layers, get_tokens=True)
|
||||
self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True)
|
||||
# [[1,576,768],[1,576,768],[1,576,768]]
|
||||
# [3,576,768]
|
||||
|
||||
# 生成图像的梯度
|
||||
print(f'self.fake_B0: {self.fake_B0.shape}')
|
||||
fake_gradient = torch.autograd.grad(self.fake_B0.sum(), self.fake_B0, create_graph=True)[0]
|
||||
fake_gradient = torch.autograd.grad(self.mutil_fake_B0_tokens.sum(), self.mutil_fake_B0_tokens, create_graph=True)[0]
|
||||
|
||||
# 梯度图
|
||||
print(f'fake_gradient: {fake_gradient.shape}')
|
||||
self.weight_fake = self.cao.generate_weight_map(fake_gradient)
|
||||
|
||||
# 生成图像的CTN光流图
|
||||
print(f'weight_fake: {self.weight_fake.shape}')
|
||||
self.f_content = self.ctn(self.weight_fake)
|
||||
|
||||
# 变换后的图片
|
||||
@ -481,34 +499,14 @@ class RomaUnsbModel(BaseModel):
|
||||
# 经过第二次生成器
|
||||
self.warped_fake_B0_2 = self.netG(self.warped_real_A_noisy2, self.time, z_in)
|
||||
|
||||
if self.opt.isTrain:
|
||||
real_A0 = self.real_A0
|
||||
real_A1 = self.real_A1
|
||||
real_B0 = self.real_B0
|
||||
real_B1 = self.real_B1
|
||||
fake_B0 = self.fake_B0
|
||||
fake_B1 = self.fake_B1
|
||||
warped_fake_B0_2=self.warped_fake_B0_2
|
||||
warped_fake_B0=self.warped_fake_B0
|
||||
|
||||
self.real_A0_resize = self.resize(real_A0)
|
||||
self.real_A1_resize = self.resize(real_A1)
|
||||
real_B0 = self.resize(real_B0)
|
||||
real_B1 = self.resize(real_B1)
|
||||
self.fake_B0_resize = self.resize(fake_B0)
|
||||
self.fake_B1_resize = self.resize(fake_B1)
|
||||
self.warped_fake_B0_2_resize = self.resize(warped_fake_B0_2)
|
||||
self.warped_fake_B0_resize = self.resize(warped_fake_B0)
|
||||
|
||||
self.mutil_real_A0_tokens = self.netPreViT(self.real_A0_resize, self.atten_layers, get_tokens=True)
|
||||
self.mutil_real_A1_tokens = self.netPreViT(self.real_A1_resize, self.atten_layers, get_tokens=True)
|
||||
self.mutil_real_B0_tokens = self.netPreViT(real_B0, self.atten_layers, get_tokens=True)
|
||||
self.mutil_real_B1_tokens = self.netPreViT(real_B1, self.atten_layers, get_tokens=True)
|
||||
self.mutil_fake_B0_tokens = self.netPreViT(self.fake_B0_resize, self.atten_layers, get_tokens=True)
|
||||
self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True)
|
||||
self.mutil_warped_fake_B0_tokens = self.netPreViT(self.warped_fake_B0_resize, self.atten_layers, get_tokens=True)
|
||||
self.mutil_fake_B0_2_tokens = self.netPreViT(self.warped_fake_B0_2_resize, self.atten_layers, get_tokens=True)
|
||||
|
||||
|
||||
def compute_D_loss(self):
|
||||
"""计算判别器的 GAN 损失"""
|
||||
|
||||
@ -526,8 +524,8 @@ class RomaUnsbModel(BaseModel):
|
||||
def compute_E_loss(self):
|
||||
"""计算判别器 E 的损失"""
|
||||
|
||||
XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B.detach()], dim=1)
|
||||
XtXt_2 = torch.cat([self.real_A_noisy2, self.fake_B2.detach()], dim=1)
|
||||
XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B0.detach()], dim=1)
|
||||
XtXt_2 = torch.cat([self.real_A_noisy2, self.fake_B1.detach()], dim=1)
|
||||
temp = torch.logsumexp(self.netE(XtXt_1, self.time, XtXt_2).reshape(-1), dim=0).mean()
|
||||
self.loss_E = -self.netE(XtXt_1, self.time, XtXt_1).mean() + temp + temp**2
|
||||
|
||||
@ -536,10 +534,10 @@ class RomaUnsbModel(BaseModel):
|
||||
def compute_G_loss(self):
|
||||
"""计算生成器的 GAN 损失"""
|
||||
|
||||
bs = self.mutil_real_A0_tokens.size(0)
|
||||
bs = self.real_A0.size(0)
|
||||
tau = self.opt.tau
|
||||
|
||||
fake = self.fake_B
|
||||
fake = self.fake_B0
|
||||
std = torch.rand(size=[1]).item() * self.opt.std
|
||||
|
||||
if self.opt.lambda_GAN > 0.0:
|
||||
@ -549,8 +547,8 @@ class RomaUnsbModel(BaseModel):
|
||||
self.loss_G_GAN = 0.0
|
||||
self.loss_SB = 0
|
||||
if self.opt.lambda_SB > 0.0:
|
||||
XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B], dim=1)
|
||||
XtXt_2 = torch.cat([self.real_A_noisy2, self.fake_B2], dim=1)
|
||||
XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B0], dim=1)
|
||||
XtXt_2 = torch.cat([self.real_A_noisy2, self.fake_B1], dim=1)
|
||||
|
||||
bs = self.opt.batch_size
|
||||
|
||||
@ -560,7 +558,7 @@ class RomaUnsbModel(BaseModel):
|
||||
self.loss_SB += self.opt.tau * torch.mean((self.real_A_noisy - self.fake_B) ** 2)
|
||||
|
||||
if self.opt.lambda_global > 0.0:
|
||||
loss_global = self.calculate_similarity(self.mutil_real_A0_tokens, self.mutil_fake_B0_tokens) + self.calculate_similarity(self.mutil_real_A1_tokens, self.mutil_fake_B1_tokens)
|
||||
loss_global = self.calculate_similarity(self.real_A0, self.fake_B0) + self.calculate_similarity(self.real_A1, self.fake_B1)
|
||||
loss_global *= 0.5
|
||||
else:
|
||||
loss_global = 0.0
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user