use kun's ctn method
This commit is contained in:
parent
6705075876
commit
0639032b6c
@ -33,3 +33,4 @@
|
|||||||
================ Training Loss (Sun Feb 23 18:59:52 2025) ================
|
================ Training Loss (Sun Feb 23 18:59:52 2025) ================
|
||||||
================ Training Loss (Sun Feb 23 19:03:05 2025) ================
|
================ Training Loss (Sun Feb 23 19:03:05 2025) ================
|
||||||
================ Training Loss (Sun Feb 23 19:03:57 2025) ================
|
================ Training Loss (Sun Feb 23 19:03:57 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 21:11:47 2025) ================
|
||||||
|
|||||||
Binary file not shown.
@ -527,26 +527,17 @@ class RomaUnsbModel(BaseModel):
|
|||||||
setattr(self, "fake_"+str(t+1), Xt_1)
|
setattr(self, "fake_"+str(t+1), Xt_1)
|
||||||
|
|
||||||
if self.opt.phase == 'train':
|
if self.opt.phase == 'train':
|
||||||
print(f'real_B0.shape = {real_B0.shape} fake_B0.shape = {self.fake_B0.shape}')
|
|
||||||
print(f"self.real_B0.requires_grad: {real_B0.requires_grad}")
|
|
||||||
# 真实图像的梯度
|
|
||||||
real_gradient = torch.autograd.grad(real_B0.sum(), real_B0, create_graph=True)[0]
|
|
||||||
# 生成图像的梯度
|
# 生成图像的梯度
|
||||||
fake_gradient = torch.autograd.grad(self.fake_B0.sum(), self.fake_B0, create_graph=True)[0]
|
fake_gradient = torch.autograd.grad(self.fake_B0.sum(), self.fake_B0, create_graph=True)[0]
|
||||||
# 梯度图
|
# 梯度图
|
||||||
self.weight_real, self.weight_fake = self.cao.generate_weight_map(fake_gradient)
|
self.weight_fake = self.cao.generate_weight_map(fake_gradient)
|
||||||
|
|
||||||
# 生成图像的CTN光流图
|
# 生成图像的CTN光流图
|
||||||
self.f_content = self.ctn(self.weight_fake)
|
self.f_content = self.ctn(self.weight_fake)
|
||||||
|
|
||||||
# 把前面生成后的图片再加上noisy_map
|
|
||||||
self.fake_B_2 = self.fake_B + self.noisy_map
|
|
||||||
|
|
||||||
# 变换后的图片
|
# 变换后的图片
|
||||||
wapped_fake_B = warp(self.fake_B, self.f_content)
|
self.warped_real_A_noisy2 = warp(self.real_A_noisy, self.f_content)
|
||||||
|
self.warped_fake_B0 = warp(self.fake_B0,self.f_content)
|
||||||
# 经过第二次生成器
|
# 经过第二次生成器
|
||||||
self.fake_B_2 = self.netG(wapped_fake_B, self.time, z_in)
|
self.warped_fake_B0_2 = self.netG(self.warped_real_A_noisy2, self.time, z_in)
|
||||||
|
|
||||||
def compute_D_loss(self):
|
def compute_D_loss(self):
|
||||||
"""计算判别器的 GAN 损失"""
|
"""计算判别器的 GAN 损失"""
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user