diff --git a/checkpoints/ROMA_UNSB_001/loss_log.txt b/checkpoints/ROMA_UNSB_001/loss_log.txt index 741a618..1af94c6 100644 --- a/checkpoints/ROMA_UNSB_001/loss_log.txt +++ b/checkpoints/ROMA_UNSB_001/loss_log.txt @@ -34,3 +34,8 @@ ================ Training Loss (Sun Feb 23 19:03:05 2025) ================ ================ Training Loss (Sun Feb 23 19:03:57 2025) ================ ================ Training Loss (Sun Feb 23 21:11:47 2025) ================ +================ Training Loss (Sun Feb 23 21:17:10 2025) ================ +================ Training Loss (Sun Feb 23 21:20:14 2025) ================ +================ Training Loss (Sun Feb 23 21:29:03 2025) ================ +================ Training Loss (Sun Feb 23 21:34:57 2025) ================ +================ Training Loss (Sun Feb 23 21:35:26 2025) ================ diff --git a/models/__pycache__/roma_unsb_model.cpython-39.pyc b/models/__pycache__/roma_unsb_model.cpython-39.pyc index 0c9cd4b..322c840 100644 Binary files a/models/__pycache__/roma_unsb_model.cpython-39.pyc and b/models/__pycache__/roma_unsb_model.cpython-39.pyc differ diff --git a/models/roma_unsb_model.py b/models/roma_unsb_model.py index 05cfaa7..702e3b8 100644 --- a/models/roma_unsb_model.py +++ b/models/roma_unsb_model.py @@ -79,13 +79,13 @@ class ContentAwareOptimization(nn.Module): cosine_sim = F.cosine_similarity(gradients, mean_grad, dim=2) # [B, N] return cosine_sim - def generate_weight_map(self, gradients_fake): + def generate_weight_map(self, gradients_fake): """ 生成内容感知权重图 Args: - gradients_fake: [B, N, D] 生成图像判别器梯度 + gradients_fake: [B, N, D] 生成图像判别器梯度 [2,3,256,256] Returns: - weight_fake: [B, N] 生成图像权重图 + weight_fake: [B, N] 生成图像权重图 [2,3,256] """ # 计算生成图像块的余弦相似度 cosine_fake = self.compute_cosine_similarity(gradients_fake) # [B, N] @@ -398,28 +398,9 @@ class RomaUnsbModel(BaseModel): def forward(self): - """执行前向传递以生成输出图像""" + """Run forward pass; called by both functions and .""" - if self.opt.isTrain: - print(f'before resize: {self.real_A0.shape}') - real_A0 = self.resize(self.real_A0) - real_A1 = self.resize(self.real_A1) - real_B0 = self.resize(self.real_B0).requires_grad_(True) - real_B1 = self.resize(self.real_B1).requires_grad_(True) - # 使用VIT - - print(f'before vit: {real_A0.shape}') - self.mutil_real_A0_tokens = self.netPreViT(real_A0, self.atten_layers, get_tokens=True) - self.mutil_real_A1_tokens = self.netPreViT(real_A1, self.atten_layers, get_tokens=True) - - print(f'before cat: len = {len(self.mutil_real_A0_tokens)}\n{self.mutil_real_A0_tokens[0].shape}') - self.mutil_real_A0_tokens = torch.cat(self.mutil_real_A0_tokens, dim=0).unsqueeze(0).to(self.device) - self.mutil_real_A1_tokens = torch.cat(self.mutil_real_A1_tokens, dim=0).unsqueeze(0).to(self.device) - - # 执行一次SB模块 - - # ============ 第一步:初始化时间步与时间索引 ============ - # 计算 times,并确定当前 time_idx(随机选取用来表示当前时间步) + # ============ 第一步:对 real_A / real_A2 进行多步随机生成过程 ============ tau = self.opt.tau T = self.opt.num_timesteps incs = np.array([0] + [1/(i+1) for i in range(T-1)]) @@ -429,7 +410,7 @@ class RomaUnsbModel(BaseModel): times = np.concatenate([np.zeros(1), times]) times = torch.tensor(times).float().cuda() self.times = times - bs = self.mutil_real_A0_tokens.size(0) + bs = self.real_A0.size(0) time_idx = (torch.randint(T, size=[1]).cuda() * torch.ones(size=[1]).cuda()).long() self.time_idx = time_idx @@ -444,34 +425,30 @@ class RomaUnsbModel(BaseModel): inter = (delta / denom).reshape(-1, 1, 1, 1) scale = (delta * (1 - delta / denom)).reshape(-1, 1, 1, 1) - print(f'before noisy: {self.mutil_real_A0_tokens.shape}') # 对 Xt、Xt2 进行随机噪声更新 - Xt = self.mutil_real_A0_tokens if (t == 0) else (1 - inter) * Xt + inter * Xt_1.detach() + \ - (scale * tau).sqrt() * torch.randn_like(Xt).to(self.mutil_real_A0_tokens.device) - time_idx = (t * torch.ones(size=[self.mutil_real_A0_tokens.shape[0]]).to(self.mutil_real_A0_tokens.device)).long() - z = torch.randn(size=[self.mutil_real_A0_tokens.shape[0], 4 * self.opt.ngf]).to(self.mutil_real_A0_tokens.device) + Xt = self.real_A0 if (t == 0) else (1 - inter) * Xt + inter * Xt_1.detach() + \ + (scale * tau).sqrt() * torch.randn_like(Xt).to(self.real_A0.device) + time_idx = (t * torch.ones(size=[self.real_A0.shape[0]]).to(self.real_A0.device)).long() + z = torch.randn(size=[self.real_A0.shape[0], 4 * self.opt.ngf]).to(self.real_A0.device) self.time = times[time_idx] Xt_1 = self.netG(Xt, self.time, z) - Xt2 = self.mutil_real_A1_tokens if (t == 0) else (1 - inter) * Xt2 + inter * Xt_12.detach() + \ - (scale * tau).sqrt() * torch.randn_like(Xt2).to(self.mutil_real_A1_tokens.device) - time_idx = (t * torch.ones(size=[self.mutil_real_A1_tokens.shape[0]]).to(self.mutil_real_A1_tokens.device)).long() - z = torch.randn(size=[self.mutil_real_A1_tokens.shape[0], 4 * self.opt.ngf]).to(self.mutil_real_A1_tokens.device) + Xt2 = self.real_A1 if (t == 0) else (1 - inter) * Xt2 + inter * Xt_12.detach() + \ + (scale * tau).sqrt() * torch.randn_like(Xt2).to(self.real_A1.device) + time_idx = (t * torch.ones(size=[self.real_A1.shape[0]]).to(self.real_A1.device)).long() + z = torch.randn(size=[self.real_A1.shape[0], 4 * self.opt.ngf]).to(self.real_A1.device) Xt_12 = self.netG(Xt2, self.time, z) # 保存去噪后的中间结果 (real_A_noisy 等),供下一步做拼接 self.real_A_noisy = Xt.detach() self.real_A_noisy2 = Xt2.detach() - # 保存noisy_map - print(f'after noisy map: {self.real_A_noisy.shape}') - self.noisy_map = self.real_A_noisy - self.mutil_real_A0_tokens # ============ 第三步:拼接输入并执行网络推理 ============= - bs = self.mutil_real_A0_tokens.size(0) - z_in = torch.randn(size=[2 * bs, 4 * self.opt.ngf]).to(self.mutil_real_A0_tokens.device) - z_in2 = torch.randn(size=[bs, 4 * self.opt.ngf]).to(self.mutil_real_A1_tokens.device) + bs = self.real_A0.size(0) + z_in = torch.randn(size=[2 * bs, 4 * self.opt.ngf]).to(self.real_A0.device) + z_in2 = torch.randn(size=[bs, 4 * self.opt.ngf]).to(self.real_A1.device) # 将 real_A, real_B 拼接 (如 nce_idt=True),并同样处理 real_A_noisy 与 XtB - self.real = self.mutil_real_A0_tokens + self.real = self.real_A0 self.realt = self.real_A_noisy if self.opt.flip_equivariance: @@ -479,65 +456,58 @@ class RomaUnsbModel(BaseModel): if self.flipped_for_equivariance: self.real = torch.flip(self.real, [3]) self.realt = torch.flip(self.realt, [3]) + + + self.fake_B0 = self.netG(self.real_A0, self.time, z_in) + self.fake_B1 = self.netG(self.real_A1, self.time, z_in2) - # 使用 netG 生成最终的 fake, fake_B2 等结果 - self.fake_B = self.netG(self.realt, self.time, z_in) - self.fake_B2 = self.netG(self.real, self.time, z_in2) - - self.fake_B = self.resize(self.fake_B) - self.fake_B2 = self.resize(self.fake_B2) - - self.fake_B0 = self.fake_B - self.fake_B1 = self.fake_B2 - - # 使用VIT - self.mutil_fake_B0_tokens = self.netPreViT(self.fake_B, self.atten_layers, get_tokens=True) - self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B2, self.atten_layers, get_tokens=True) - - # ============ 第四步:推理模式下的多次采样 ============ - if self.opt.phase == 'test': - tau = self.opt.tau - T = self.opt.num_timesteps - incs = np.array([0] + [1/(i+1) for i in range(T-1)]) - times = np.cumsum(incs) - times = times / times[-1] - times = 0.5 * times[-1] + 0.5 * times - times = np.concatenate([np.zeros(1),times]) - times = torch.tensor(times).float().cuda() - self.times = times - bs = self.real.size(0) - time_idx = (torch.randint(T, size=[1]).cuda() * torch.ones(size=[1]).cuda()).long() - self.time_idx = time_idx - visuals = [] - with torch.no_grad(): - self.netG.eval() - for t in range(self.opt.num_timesteps): - - if t > 0: - delta = times[t] - times[t-1] - denom = times[-1] - times[t-1] - inter = (delta / denom).reshape(-1,1,1,1) - scale = (delta * (1 - delta / denom)).reshape(-1,1,1,1) - Xt = self.mutil_real_A0_tokens if (t == 0) else (1-inter) * Xt + inter * Xt_1.detach() + (scale * tau).sqrt() * torch.randn_like(Xt).to(self.mutil_real_A0_tokens.device) - time_idx = (t * torch.ones(size=[self.mutil_real_A0_tokens.shape[0]]).to(self.mutil_real_A0_tokens.device)).long() - time = times[time_idx] - z = torch.randn(size=[self.mutil_real_A0_tokens.shape[0], 4 * self.opt.ngf]).to(self.mutil_real_A0_tokens.device) - Xt_1 = self.netG(Xt, time_idx, z) - - setattr(self, "fake_"+str(t+1), Xt_1) - if self.opt.phase == 'train': # 生成图像的梯度 + print(f'self.fake_B0: {self.fake_B0.shape}') fake_gradient = torch.autograd.grad(self.fake_B0.sum(), self.fake_B0, create_graph=True)[0] + # 梯度图 + print(f'fake_gradient: {fake_gradient.shape}') self.weight_fake = self.cao.generate_weight_map(fake_gradient) + # 生成图像的CTN光流图 + print(f'weight_fake: {self.weight_fake.shape}') self.f_content = self.ctn(self.weight_fake) + # 变换后的图片 self.warped_real_A_noisy2 = warp(self.real_A_noisy, self.f_content) self.warped_fake_B0 = warp(self.fake_B0,self.f_content) + # 经过第二次生成器 self.warped_fake_B0_2 = self.netG(self.warped_real_A_noisy2, self.time, z_in) + + if self.opt.isTrain: + real_A0 = self.real_A0 + real_A1 = self.real_A1 + real_B0 = self.real_B0 + real_B1 = self.real_B1 + fake_B0 = self.fake_B0 + fake_B1 = self.fake_B1 + warped_fake_B0_2=self.warped_fake_B0_2 + warped_fake_B0=self.warped_fake_B0 + + self.real_A0_resize = self.resize(real_A0) + self.real_A1_resize = self.resize(real_A1) + real_B0 = self.resize(real_B0) + real_B1 = self.resize(real_B1) + self.fake_B0_resize = self.resize(fake_B0) + self.fake_B1_resize = self.resize(fake_B1) + self.warped_fake_B0_2_resize = self.resize(warped_fake_B0_2) + self.warped_fake_B0_resize = self.resize(warped_fake_B0) + + self.mutil_real_A0_tokens = self.netPreViT(self.real_A0_resize, self.atten_layers, get_tokens=True) + self.mutil_real_A1_tokens = self.netPreViT(self.real_A1_resize, self.atten_layers, get_tokens=True) + self.mutil_real_B0_tokens = self.netPreViT(real_B0, self.atten_layers, get_tokens=True) + self.mutil_real_B1_tokens = self.netPreViT(real_B1, self.atten_layers, get_tokens=True) + self.mutil_fake_B0_tokens = self.netPreViT(self.fake_B0_resize, self.atten_layers, get_tokens=True) + self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True) + self.mutil_warped_fake_B0_tokens = self.netPreViT(self.warped_fake_B0_resize, self.atten_layers, get_tokens=True) + self.mutil_fake_B0_2_tokens = self.netPreViT(self.warped_fake_B0_2_resize, self.atten_layers, get_tokens=True) def compute_D_loss(self): """计算判别器的 GAN 损失"""