From 8a081af0a32e201e869202787cb6d4a33914b565 Mon Sep 17 00:00:00 2001 From: bishe <123456789@163.com> Date: Sun, 23 Feb 2025 22:40:34 +0800 Subject: [PATCH] debug --- checkpoints/ROMA_UNSB_001/loss_log.txt | 5 ++ .../roma_unsb_model.cpython-39.pyc | Bin 19326 -> 19227 bytes models/roma_unsb_model.py | 64 +++++++++--------- 3 files changed, 36 insertions(+), 33 deletions(-) diff --git a/checkpoints/ROMA_UNSB_001/loss_log.txt b/checkpoints/ROMA_UNSB_001/loss_log.txt index 1af94c6..2bc6f9f 100644 --- a/checkpoints/ROMA_UNSB_001/loss_log.txt +++ b/checkpoints/ROMA_UNSB_001/loss_log.txt @@ -39,3 +39,8 @@ ================ Training Loss (Sun Feb 23 21:29:03 2025) ================ ================ Training Loss (Sun Feb 23 21:34:57 2025) ================ ================ Training Loss (Sun Feb 23 21:35:26 2025) ================ +================ Training Loss (Sun Feb 23 22:28:43 2025) ================ +================ Training Loss (Sun Feb 23 22:29:04 2025) ================ +================ Training Loss (Sun Feb 23 22:29:52 2025) ================ +================ Training Loss (Sun Feb 23 22:30:40 2025) ================ +================ Training Loss (Sun Feb 23 22:33:48 2025) ================ diff --git a/models/__pycache__/roma_unsb_model.cpython-39.pyc b/models/__pycache__/roma_unsb_model.cpython-39.pyc index 322c8409ea6b81e7c09a160bbe385515d67dab0e..e819f4454b60684604d42e7a80e4749c23d9717d 100644 GIT binary patch delta 824 zcmZWn&rcIU6yDd3q#e3T4NFTQX++ef2_>`$)cz8ow6r3a=%rP%S=d@BTFJ7sq^Tzhh4-xX&i6iKa&f{3+51F-W+D{FvQNf$oMehOHtgmA7u%%JL#piiq<3Q zvg*3Rg*>!A^AeuW@z~nP)ew9om9j4DcQ(tlZLHkuT3OM{az#_NZiZH@!&njMtN1bv zCmJ*z-=crw6LdeZM>moIZ%MeJY%;I9QLgD)A8?!|(niwP(?GiYjJ`;!G3e=r!wqgA z7`QsT-~|Z-S3bat7#{HOlEe1nzO|An0sOX}rym0lD?WSI0Z#g6?)?a{?{V4Ci%=rB zl-p~nBCqQ;Er?&WCeY^mBap0M<_Ew}DrSQoXQzcV_tJ&z!~exE*>{kz_Hu^+GQFSw M#1m)DDi%(|Kcf-hBLDyZ delta 942 zcmZXSOH31C5XZmSB9^7QEwq$ZQY)aO2S`DMvMmomKokWZF@eik*rl*wOO{nkvnvOT z7lV>~2^XTY2{%t>Z<-hn=*6p##CY(Lc=F)Mqd5De+7fs2&HVQ_voo`s-QI%VTOcL{ zfhTy?v^-C#nWthi4|`M&gyMUEVJcsX(4V>MJ@H{Id8ZJ;m z#|so!DWMk%)LEq1;bMW}swH$nsklNS{_ObWo9XaeMqf?pN_d&hckjR!b;o8dWyKYB zX|8=Pxv0v6Q7L>i0v|oIx|%dpIi)9;=d)gxL#-^RX+sKUg;jMSHE+lcm%>=jr&L2W zGK*?@B{FH>k4-^Y3jav@Qd*|pdM0Qz9;Kbq1A0Ba677BpXc8r j>|rncI(+wkcx&Vp^x28gF90(%H1^KracAsbW5?hR8E*vL diff --git a/models/roma_unsb_model.py b/models/roma_unsb_model.py index 702e3b8..ec2fe8d 100644 --- a/models/roma_unsb_model.py +++ b/models/roma_unsb_model.py @@ -395,8 +395,6 @@ class RomaUnsbModel(BaseModel): return result - - def forward(self): """Run forward pass; called by both functions and .""" @@ -462,34 +460,12 @@ class RomaUnsbModel(BaseModel): self.fake_B1 = self.netG(self.real_A1, self.time, z_in2) if self.opt.phase == 'train': - # 生成图像的梯度 - print(f'self.fake_B0: {self.fake_B0.shape}') - fake_gradient = torch.autograd.grad(self.fake_B0.sum(), self.fake_B0, create_graph=True)[0] - - # 梯度图 - print(f'fake_gradient: {fake_gradient.shape}') - self.weight_fake = self.cao.generate_weight_map(fake_gradient) - - # 生成图像的CTN光流图 - print(f'weight_fake: {self.weight_fake.shape}') - self.f_content = self.ctn(self.weight_fake) - - # 变换后的图片 - self.warped_real_A_noisy2 = warp(self.real_A_noisy, self.f_content) - self.warped_fake_B0 = warp(self.fake_B0,self.f_content) - - # 经过第二次生成器 - self.warped_fake_B0_2 = self.netG(self.warped_real_A_noisy2, self.time, z_in) - - if self.opt.isTrain: real_A0 = self.real_A0 real_A1 = self.real_A1 real_B0 = self.real_B0 real_B1 = self.real_B1 fake_B0 = self.fake_B0 fake_B1 = self.fake_B1 - warped_fake_B0_2=self.warped_fake_B0_2 - warped_fake_B0=self.warped_fake_B0 self.real_A0_resize = self.resize(real_A0) self.real_A1_resize = self.resize(real_A1) @@ -497,8 +473,6 @@ class RomaUnsbModel(BaseModel): real_B1 = self.resize(real_B1) self.fake_B0_resize = self.resize(fake_B0) self.fake_B1_resize = self.resize(fake_B1) - self.warped_fake_B0_2_resize = self.resize(warped_fake_B0_2) - self.warped_fake_B0_resize = self.resize(warped_fake_B0) self.mutil_real_A0_tokens = self.netPreViT(self.real_A0_resize, self.atten_layers, get_tokens=True) self.mutil_real_A1_tokens = self.netPreViT(self.real_A1_resize, self.atten_layers, get_tokens=True) @@ -506,8 +480,32 @@ class RomaUnsbModel(BaseModel): self.mutil_real_B1_tokens = self.netPreViT(real_B1, self.atten_layers, get_tokens=True) self.mutil_fake_B0_tokens = self.netPreViT(self.fake_B0_resize, self.atten_layers, get_tokens=True) self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True) + # [[1,576,768],[1,576,768],[1,576,768]] + # [3,576,768] + + # 生成图像的梯度 + fake_gradient = torch.autograd.grad(self.mutil_fake_B0_tokens.sum(), self.mutil_fake_B0_tokens, create_graph=True)[0] + + # 梯度图 + self.weight_fake = self.cao.generate_weight_map(fake_gradient) + + # 生成图像的CTN光流图 + self.f_content = self.ctn(self.weight_fake) + + # 变换后的图片 + self.warped_real_A_noisy2 = warp(self.real_A_noisy, self.f_content) + self.warped_fake_B0 = warp(self.fake_B0,self.f_content) + + # 经过第二次生成器 + self.warped_fake_B0_2 = self.netG(self.warped_real_A_noisy2, self.time, z_in) + + warped_fake_B0_2=self.warped_fake_B0_2 + warped_fake_B0=self.warped_fake_B0 + self.warped_fake_B0_2_resize = self.resize(warped_fake_B0_2) + self.warped_fake_B0_resize = self.resize(warped_fake_B0) self.mutil_warped_fake_B0_tokens = self.netPreViT(self.warped_fake_B0_resize, self.atten_layers, get_tokens=True) self.mutil_fake_B0_2_tokens = self.netPreViT(self.warped_fake_B0_2_resize, self.atten_layers, get_tokens=True) + def compute_D_loss(self): """计算判别器的 GAN 损失""" @@ -526,8 +524,8 @@ class RomaUnsbModel(BaseModel): def compute_E_loss(self): """计算判别器 E 的损失""" - XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B.detach()], dim=1) - XtXt_2 = torch.cat([self.real_A_noisy2, self.fake_B2.detach()], dim=1) + XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B0.detach()], dim=1) + XtXt_2 = torch.cat([self.real_A_noisy2, self.fake_B1.detach()], dim=1) temp = torch.logsumexp(self.netE(XtXt_1, self.time, XtXt_2).reshape(-1), dim=0).mean() self.loss_E = -self.netE(XtXt_1, self.time, XtXt_1).mean() + temp + temp**2 @@ -536,10 +534,10 @@ class RomaUnsbModel(BaseModel): def compute_G_loss(self): """计算生成器的 GAN 损失""" - bs = self.mutil_real_A0_tokens.size(0) + bs = self.real_A0.size(0) tau = self.opt.tau - fake = self.fake_B + fake = self.fake_B0 std = torch.rand(size=[1]).item() * self.opt.std if self.opt.lambda_GAN > 0.0: @@ -549,8 +547,8 @@ class RomaUnsbModel(BaseModel): self.loss_G_GAN = 0.0 self.loss_SB = 0 if self.opt.lambda_SB > 0.0: - XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B], dim=1) - XtXt_2 = torch.cat([self.real_A_noisy2, self.fake_B2], dim=1) + XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B0], dim=1) + XtXt_2 = torch.cat([self.real_A_noisy2, self.fake_B1], dim=1) bs = self.opt.batch_size @@ -560,7 +558,7 @@ class RomaUnsbModel(BaseModel): self.loss_SB += self.opt.tau * torch.mean((self.real_A_noisy - self.fake_B) ** 2) if self.opt.lambda_global > 0.0: - loss_global = self.calculate_similarity(self.mutil_real_A0_tokens, self.mutil_fake_B0_tokens) + self.calculate_similarity(self.mutil_real_A1_tokens, self.mutil_fake_B1_tokens) + loss_global = self.calculate_similarity(self.real_A0, self.fake_B0) + self.calculate_similarity(self.real_A1, self.fake_B1) loss_global *= 0.5 else: loss_global = 0.0