diff --git a/checkpoints/ROMA_UNSB_001/loss_log.txt b/checkpoints/ROMA_UNSB_001/loss_log.txt index 1af94c6..2bc6f9f 100644 --- a/checkpoints/ROMA_UNSB_001/loss_log.txt +++ b/checkpoints/ROMA_UNSB_001/loss_log.txt @@ -39,3 +39,8 @@ ================ Training Loss (Sun Feb 23 21:29:03 2025) ================ ================ Training Loss (Sun Feb 23 21:34:57 2025) ================ ================ Training Loss (Sun Feb 23 21:35:26 2025) ================ +================ Training Loss (Sun Feb 23 22:28:43 2025) ================ +================ Training Loss (Sun Feb 23 22:29:04 2025) ================ +================ Training Loss (Sun Feb 23 22:29:52 2025) ================ +================ Training Loss (Sun Feb 23 22:30:40 2025) ================ +================ Training Loss (Sun Feb 23 22:33:48 2025) ================ diff --git a/models/__pycache__/roma_unsb_model.cpython-39.pyc b/models/__pycache__/roma_unsb_model.cpython-39.pyc index 322c840..e819f44 100644 Binary files a/models/__pycache__/roma_unsb_model.cpython-39.pyc and b/models/__pycache__/roma_unsb_model.cpython-39.pyc differ diff --git a/models/roma_unsb_model.py b/models/roma_unsb_model.py index 702e3b8..ec2fe8d 100644 --- a/models/roma_unsb_model.py +++ b/models/roma_unsb_model.py @@ -395,8 +395,6 @@ class RomaUnsbModel(BaseModel): return result - - def forward(self): """Run forward pass; called by both functions and .""" @@ -462,34 +460,12 @@ class RomaUnsbModel(BaseModel): self.fake_B1 = self.netG(self.real_A1, self.time, z_in2) if self.opt.phase == 'train': - # 生成图像的梯度 - print(f'self.fake_B0: {self.fake_B0.shape}') - fake_gradient = torch.autograd.grad(self.fake_B0.sum(), self.fake_B0, create_graph=True)[0] - - # 梯度图 - print(f'fake_gradient: {fake_gradient.shape}') - self.weight_fake = self.cao.generate_weight_map(fake_gradient) - - # 生成图像的CTN光流图 - print(f'weight_fake: {self.weight_fake.shape}') - self.f_content = self.ctn(self.weight_fake) - - # 变换后的图片 - self.warped_real_A_noisy2 = warp(self.real_A_noisy, self.f_content) - self.warped_fake_B0 = warp(self.fake_B0,self.f_content) - - # 经过第二次生成器 - self.warped_fake_B0_2 = self.netG(self.warped_real_A_noisy2, self.time, z_in) - - if self.opt.isTrain: real_A0 = self.real_A0 real_A1 = self.real_A1 real_B0 = self.real_B0 real_B1 = self.real_B1 fake_B0 = self.fake_B0 fake_B1 = self.fake_B1 - warped_fake_B0_2=self.warped_fake_B0_2 - warped_fake_B0=self.warped_fake_B0 self.real_A0_resize = self.resize(real_A0) self.real_A1_resize = self.resize(real_A1) @@ -497,8 +473,6 @@ class RomaUnsbModel(BaseModel): real_B1 = self.resize(real_B1) self.fake_B0_resize = self.resize(fake_B0) self.fake_B1_resize = self.resize(fake_B1) - self.warped_fake_B0_2_resize = self.resize(warped_fake_B0_2) - self.warped_fake_B0_resize = self.resize(warped_fake_B0) self.mutil_real_A0_tokens = self.netPreViT(self.real_A0_resize, self.atten_layers, get_tokens=True) self.mutil_real_A1_tokens = self.netPreViT(self.real_A1_resize, self.atten_layers, get_tokens=True) @@ -506,8 +480,32 @@ class RomaUnsbModel(BaseModel): self.mutil_real_B1_tokens = self.netPreViT(real_B1, self.atten_layers, get_tokens=True) self.mutil_fake_B0_tokens = self.netPreViT(self.fake_B0_resize, self.atten_layers, get_tokens=True) self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True) + # [[1,576,768],[1,576,768],[1,576,768]] + # [3,576,768] + + # 生成图像的梯度 + fake_gradient = torch.autograd.grad(self.mutil_fake_B0_tokens.sum(), self.mutil_fake_B0_tokens, create_graph=True)[0] + + # 梯度图 + self.weight_fake = self.cao.generate_weight_map(fake_gradient) + + # 生成图像的CTN光流图 + self.f_content = self.ctn(self.weight_fake) + + # 变换后的图片 + self.warped_real_A_noisy2 = warp(self.real_A_noisy, self.f_content) + self.warped_fake_B0 = warp(self.fake_B0,self.f_content) + + # 经过第二次生成器 + self.warped_fake_B0_2 = self.netG(self.warped_real_A_noisy2, self.time, z_in) + + warped_fake_B0_2=self.warped_fake_B0_2 + warped_fake_B0=self.warped_fake_B0 + self.warped_fake_B0_2_resize = self.resize(warped_fake_B0_2) + self.warped_fake_B0_resize = self.resize(warped_fake_B0) self.mutil_warped_fake_B0_tokens = self.netPreViT(self.warped_fake_B0_resize, self.atten_layers, get_tokens=True) self.mutil_fake_B0_2_tokens = self.netPreViT(self.warped_fake_B0_2_resize, self.atten_layers, get_tokens=True) + def compute_D_loss(self): """计算判别器的 GAN 损失""" @@ -526,8 +524,8 @@ class RomaUnsbModel(BaseModel): def compute_E_loss(self): """计算判别器 E 的损失""" - XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B.detach()], dim=1) - XtXt_2 = torch.cat([self.real_A_noisy2, self.fake_B2.detach()], dim=1) + XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B0.detach()], dim=1) + XtXt_2 = torch.cat([self.real_A_noisy2, self.fake_B1.detach()], dim=1) temp = torch.logsumexp(self.netE(XtXt_1, self.time, XtXt_2).reshape(-1), dim=0).mean() self.loss_E = -self.netE(XtXt_1, self.time, XtXt_1).mean() + temp + temp**2 @@ -536,10 +534,10 @@ class RomaUnsbModel(BaseModel): def compute_G_loss(self): """计算生成器的 GAN 损失""" - bs = self.mutil_real_A0_tokens.size(0) + bs = self.real_A0.size(0) tau = self.opt.tau - fake = self.fake_B + fake = self.fake_B0 std = torch.rand(size=[1]).item() * self.opt.std if self.opt.lambda_GAN > 0.0: @@ -549,8 +547,8 @@ class RomaUnsbModel(BaseModel): self.loss_G_GAN = 0.0 self.loss_SB = 0 if self.opt.lambda_SB > 0.0: - XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B], dim=1) - XtXt_2 = torch.cat([self.real_A_noisy2, self.fake_B2], dim=1) + XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B0], dim=1) + XtXt_2 = torch.cat([self.real_A_noisy2, self.fake_B1], dim=1) bs = self.opt.batch_size @@ -560,7 +558,7 @@ class RomaUnsbModel(BaseModel): self.loss_SB += self.opt.tau * torch.mean((self.real_A_noisy - self.fake_B) ** 2) if self.opt.lambda_global > 0.0: - loss_global = self.calculate_similarity(self.mutil_real_A0_tokens, self.mutil_fake_B0_tokens) + self.calculate_similarity(self.mutil_real_A1_tokens, self.mutil_fake_B1_tokens) + loss_global = self.calculate_similarity(self.real_A0, self.fake_B0) + self.calculate_similarity(self.real_A1, self.fake_B1) loss_global *= 0.5 else: loss_global = 0.0