use kun's forward method

This commit is contained in:
bishe 2025-02-23 22:26:04 +08:00
parent 0639032b6c
commit 687559866d
3 changed files with 62 additions and 87 deletions

View File

@ -34,3 +34,8 @@
================ Training Loss (Sun Feb 23 19:03:05 2025) ================
================ Training Loss (Sun Feb 23 19:03:57 2025) ================
================ Training Loss (Sun Feb 23 21:11:47 2025) ================
================ Training Loss (Sun Feb 23 21:17:10 2025) ================
================ Training Loss (Sun Feb 23 21:20:14 2025) ================
================ Training Loss (Sun Feb 23 21:29:03 2025) ================
================ Training Loss (Sun Feb 23 21:34:57 2025) ================
================ Training Loss (Sun Feb 23 21:35:26 2025) ================

View File

@ -79,13 +79,13 @@ class ContentAwareOptimization(nn.Module):
cosine_sim = F.cosine_similarity(gradients, mean_grad, dim=2) # [B, N]
return cosine_sim
def generate_weight_map(self, gradients_fake):
def generate_weight_map(self, gradients_fake):
"""
生成内容感知权重图
Args:
gradients_fake: [B, N, D] 生成图像判别器梯度
gradients_fake: [B, N, D] 生成图像判别器梯度 [2,3,256,256]
Returns:
weight_fake: [B, N] 生成图像权重图
weight_fake: [B, N] 生成图像权重图 [2,3,256]
"""
# 计算生成图像块的余弦相似度
cosine_fake = self.compute_cosine_similarity(gradients_fake) # [B, N]
@ -398,28 +398,9 @@ class RomaUnsbModel(BaseModel):
def forward(self):
"""执行前向传递以生成输出图像"""
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
if self.opt.isTrain:
print(f'before resize: {self.real_A0.shape}')
real_A0 = self.resize(self.real_A0)
real_A1 = self.resize(self.real_A1)
real_B0 = self.resize(self.real_B0).requires_grad_(True)
real_B1 = self.resize(self.real_B1).requires_grad_(True)
# 使用VIT
print(f'before vit: {real_A0.shape}')
self.mutil_real_A0_tokens = self.netPreViT(real_A0, self.atten_layers, get_tokens=True)
self.mutil_real_A1_tokens = self.netPreViT(real_A1, self.atten_layers, get_tokens=True)
print(f'before cat: len = {len(self.mutil_real_A0_tokens)}\n{self.mutil_real_A0_tokens[0].shape}')
self.mutil_real_A0_tokens = torch.cat(self.mutil_real_A0_tokens, dim=0).unsqueeze(0).to(self.device)
self.mutil_real_A1_tokens = torch.cat(self.mutil_real_A1_tokens, dim=0).unsqueeze(0).to(self.device)
# 执行一次SB模块
# ============ 第一步:初始化时间步与时间索引 ============
# 计算 times并确定当前 time_idx(随机选取用来表示当前时间步)
# ============ 第一步:对 real_A / real_A2 进行多步随机生成过程 ============
tau = self.opt.tau
T = self.opt.num_timesteps
incs = np.array([0] + [1/(i+1) for i in range(T-1)])
@ -429,7 +410,7 @@ class RomaUnsbModel(BaseModel):
times = np.concatenate([np.zeros(1), times])
times = torch.tensor(times).float().cuda()
self.times = times
bs = self.mutil_real_A0_tokens.size(0)
bs = self.real_A0.size(0)
time_idx = (torch.randint(T, size=[1]).cuda() * torch.ones(size=[1]).cuda()).long()
self.time_idx = time_idx
@ -444,34 +425,30 @@ class RomaUnsbModel(BaseModel):
inter = (delta / denom).reshape(-1, 1, 1, 1)
scale = (delta * (1 - delta / denom)).reshape(-1, 1, 1, 1)
print(f'before noisy: {self.mutil_real_A0_tokens.shape}')
# 对 Xt、Xt2 进行随机噪声更新
Xt = self.mutil_real_A0_tokens if (t == 0) else (1 - inter) * Xt + inter * Xt_1.detach() + \
(scale * tau).sqrt() * torch.randn_like(Xt).to(self.mutil_real_A0_tokens.device)
time_idx = (t * torch.ones(size=[self.mutil_real_A0_tokens.shape[0]]).to(self.mutil_real_A0_tokens.device)).long()
z = torch.randn(size=[self.mutil_real_A0_tokens.shape[0], 4 * self.opt.ngf]).to(self.mutil_real_A0_tokens.device)
Xt = self.real_A0 if (t == 0) else (1 - inter) * Xt + inter * Xt_1.detach() + \
(scale * tau).sqrt() * torch.randn_like(Xt).to(self.real_A0.device)
time_idx = (t * torch.ones(size=[self.real_A0.shape[0]]).to(self.real_A0.device)).long()
z = torch.randn(size=[self.real_A0.shape[0], 4 * self.opt.ngf]).to(self.real_A0.device)
self.time = times[time_idx]
Xt_1 = self.netG(Xt, self.time, z)
Xt2 = self.mutil_real_A1_tokens if (t == 0) else (1 - inter) * Xt2 + inter * Xt_12.detach() + \
(scale * tau).sqrt() * torch.randn_like(Xt2).to(self.mutil_real_A1_tokens.device)
time_idx = (t * torch.ones(size=[self.mutil_real_A1_tokens.shape[0]]).to(self.mutil_real_A1_tokens.device)).long()
z = torch.randn(size=[self.mutil_real_A1_tokens.shape[0], 4 * self.opt.ngf]).to(self.mutil_real_A1_tokens.device)
Xt2 = self.real_A1 if (t == 0) else (1 - inter) * Xt2 + inter * Xt_12.detach() + \
(scale * tau).sqrt() * torch.randn_like(Xt2).to(self.real_A1.device)
time_idx = (t * torch.ones(size=[self.real_A1.shape[0]]).to(self.real_A1.device)).long()
z = torch.randn(size=[self.real_A1.shape[0], 4 * self.opt.ngf]).to(self.real_A1.device)
Xt_12 = self.netG(Xt2, self.time, z)
# 保存去噪后的中间结果 (real_A_noisy 等),供下一步做拼接
self.real_A_noisy = Xt.detach()
self.real_A_noisy2 = Xt2.detach()
# 保存noisy_map
print(f'after noisy map: {self.real_A_noisy.shape}')
self.noisy_map = self.real_A_noisy - self.mutil_real_A0_tokens
# ============ 第三步:拼接输入并执行网络推理 =============
bs = self.mutil_real_A0_tokens.size(0)
z_in = torch.randn(size=[2 * bs, 4 * self.opt.ngf]).to(self.mutil_real_A0_tokens.device)
z_in2 = torch.randn(size=[bs, 4 * self.opt.ngf]).to(self.mutil_real_A1_tokens.device)
bs = self.real_A0.size(0)
z_in = torch.randn(size=[2 * bs, 4 * self.opt.ngf]).to(self.real_A0.device)
z_in2 = torch.randn(size=[bs, 4 * self.opt.ngf]).to(self.real_A1.device)
# 将 real_A, real_B 拼接 (如 nce_idt=True),并同样处理 real_A_noisy 与 XtB
self.real = self.mutil_real_A0_tokens
self.real = self.real_A0
self.realt = self.real_A_noisy
if self.opt.flip_equivariance:
@ -480,65 +457,58 @@ class RomaUnsbModel(BaseModel):
self.real = torch.flip(self.real, [3])
self.realt = torch.flip(self.realt, [3])
# 使用 netG 生成最终的 fake, fake_B2 等结果
self.fake_B = self.netG(self.realt, self.time, z_in)
self.fake_B2 = self.netG(self.real, self.time, z_in2)
self.fake_B = self.resize(self.fake_B)
self.fake_B2 = self.resize(self.fake_B2)
self.fake_B0 = self.fake_B
self.fake_B1 = self.fake_B2
# 使用VIT
self.mutil_fake_B0_tokens = self.netPreViT(self.fake_B, self.atten_layers, get_tokens=True)
self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B2, self.atten_layers, get_tokens=True)
# ============ 第四步:推理模式下的多次采样 ============
if self.opt.phase == 'test':
tau = self.opt.tau
T = self.opt.num_timesteps
incs = np.array([0] + [1/(i+1) for i in range(T-1)])
times = np.cumsum(incs)
times = times / times[-1]
times = 0.5 * times[-1] + 0.5 * times
times = np.concatenate([np.zeros(1),times])
times = torch.tensor(times).float().cuda()
self.times = times
bs = self.real.size(0)
time_idx = (torch.randint(T, size=[1]).cuda() * torch.ones(size=[1]).cuda()).long()
self.time_idx = time_idx
visuals = []
with torch.no_grad():
self.netG.eval()
for t in range(self.opt.num_timesteps):
if t > 0:
delta = times[t] - times[t-1]
denom = times[-1] - times[t-1]
inter = (delta / denom).reshape(-1,1,1,1)
scale = (delta * (1 - delta / denom)).reshape(-1,1,1,1)
Xt = self.mutil_real_A0_tokens if (t == 0) else (1-inter) * Xt + inter * Xt_1.detach() + (scale * tau).sqrt() * torch.randn_like(Xt).to(self.mutil_real_A0_tokens.device)
time_idx = (t * torch.ones(size=[self.mutil_real_A0_tokens.shape[0]]).to(self.mutil_real_A0_tokens.device)).long()
time = times[time_idx]
z = torch.randn(size=[self.mutil_real_A0_tokens.shape[0], 4 * self.opt.ngf]).to(self.mutil_real_A0_tokens.device)
Xt_1 = self.netG(Xt, time_idx, z)
setattr(self, "fake_"+str(t+1), Xt_1)
self.fake_B0 = self.netG(self.real_A0, self.time, z_in)
self.fake_B1 = self.netG(self.real_A1, self.time, z_in2)
if self.opt.phase == 'train':
# 生成图像的梯度
print(f'self.fake_B0: {self.fake_B0.shape}')
fake_gradient = torch.autograd.grad(self.fake_B0.sum(), self.fake_B0, create_graph=True)[0]
# 梯度图
print(f'fake_gradient: {fake_gradient.shape}')
self.weight_fake = self.cao.generate_weight_map(fake_gradient)
# 生成图像的CTN光流图
print(f'weight_fake: {self.weight_fake.shape}')
self.f_content = self.ctn(self.weight_fake)
# 变换后的图片
self.warped_real_A_noisy2 = warp(self.real_A_noisy, self.f_content)
self.warped_fake_B0 = warp(self.fake_B0,self.f_content)
# 经过第二次生成器
self.warped_fake_B0_2 = self.netG(self.warped_real_A_noisy2, self.time, z_in)
if self.opt.isTrain:
real_A0 = self.real_A0
real_A1 = self.real_A1
real_B0 = self.real_B0
real_B1 = self.real_B1
fake_B0 = self.fake_B0
fake_B1 = self.fake_B1
warped_fake_B0_2=self.warped_fake_B0_2
warped_fake_B0=self.warped_fake_B0
self.real_A0_resize = self.resize(real_A0)
self.real_A1_resize = self.resize(real_A1)
real_B0 = self.resize(real_B0)
real_B1 = self.resize(real_B1)
self.fake_B0_resize = self.resize(fake_B0)
self.fake_B1_resize = self.resize(fake_B1)
self.warped_fake_B0_2_resize = self.resize(warped_fake_B0_2)
self.warped_fake_B0_resize = self.resize(warped_fake_B0)
self.mutil_real_A0_tokens = self.netPreViT(self.real_A0_resize, self.atten_layers, get_tokens=True)
self.mutil_real_A1_tokens = self.netPreViT(self.real_A1_resize, self.atten_layers, get_tokens=True)
self.mutil_real_B0_tokens = self.netPreViT(real_B0, self.atten_layers, get_tokens=True)
self.mutil_real_B1_tokens = self.netPreViT(real_B1, self.atten_layers, get_tokens=True)
self.mutil_fake_B0_tokens = self.netPreViT(self.fake_B0_resize, self.atten_layers, get_tokens=True)
self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True)
self.mutil_warped_fake_B0_tokens = self.netPreViT(self.warped_fake_B0_resize, self.atten_layers, get_tokens=True)
self.mutil_fake_B0_2_tokens = self.netPreViT(self.warped_fake_B0_2_resize, self.atten_layers, get_tokens=True)
def compute_D_loss(self):
"""计算判别器的 GAN 损失"""