first_change

This commit is contained in:
areszz 2025-02-22 15:23:52 +08:00
parent f88bc5b8f6
commit 8cd61d0503
3 changed files with 289 additions and 28 deletions

192
models/cnt.py Normal file
View File

@ -0,0 +1,192 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import GaussianBlur
def warp(image, flow): #warp操作
"""
基于光流的图像变形函数
Args:
image: [B, C, H, W] 输入图像
flow: [B, 2, H, W] 光流场x/y方向位移
Returns:
warped: [B, C, H, W] 变形后的图像
"""
B, C, H, W = image.shape
# 生成网格坐标
grid_x, grid_y = torch.meshgrid(torch.arange(W), torch.arange(H))
grid = torch.stack((grid_x, grid_y), dim=0).float().to(image.device) # [2,H,W]
grid = grid.unsqueeze(0).repeat(B,1,1,1) # [B,2,H,W]
# 应用光流位移(归一化到[-1,1]
new_grid = grid + flow
new_grid[:,0,:,:] = 2.0 * new_grid[:,0,:,:] / (W-1) - 1.0 # x方向
new_grid[:,1,:,:] = 2.0 * new_grid[:,1,:,:] / (H-1) - 1.0 # y方向
new_grid = new_grid.permute(0,2,3,1) # [B,H,W,2]
# 双线性插值
return F.grid_sample(image, new_grid, align_corners=True)
# 时序归一化损失计算
def compute_ctn_loss(G, x, F_content): #公式10
"""
计算内容感知时序归一化损失
Args:
G: 生成器
x: 输入红外图像 [B,C,H,W]
F_content: 生成的光流场 [B,2,H,W]
"""
# 生成可见光图像
y_fake = G(x) # [B,3,H,W]
# 对生成结果应用光流变形
warped_fake = warp(y_fake, F_content) # [B,3,H,W]
# 对输入应用相同光流后生成图像
warped_x = warp(x, F_content) # [B,C,H,W]
y_fake_warped = G(warped_x) # [B,3,H,W]
# 计算L2损失
loss = F.mse_loss(warped_fake, y_fake_warped)
return loss
class ContentAwareOptimization(nn.Module):
def __init__(self, lambda_inc=2.0, eta_ratio=0.4):
super().__init__()
self.lambda_inc = lambda_inc # 权重增强系数
self.eta_ratio = eta_ratio # 选择内容区域的比例
def compute_cosine_similarity(self, gradients):
"""
计算每个patch梯度与平均梯度的余弦相似度
Args:
gradients: [B, N, D] 判别器输出的每个patch的梯度N=w*h
Returns:
cosine_sim: [B, N] 每个patch的余弦相似度
"""
mean_grad = torch.mean(gradients, dim=1, keepdim=True) # [B, 1, D]
# 计算余弦相似度
cosine_sim = F.cosine_similarity(gradients, mean_grad, dim=2) # [B, N]
return cosine_sim
def generate_weight_map(self, gradients_real, gradients_fake):
"""
生成内容感知权重图
Args:
gradients_real: [B, N, D] 真实图像判别器梯度
gradients_fake: [B, N, D] 生成图像判别器梯度
Returns:
weight_real: [B, N] 真实图像权重图
weight_fake: [B, N] 生成图像权重图
"""
# 计算真实图像块的余弦相似度
cosine_real = self.compute_cosine_similarity(gradients_real) # [B, N] 公式5
# 计算生成图像块的余弦相似度
cosine_fake = self.compute_cosine_similarity(gradients_fake) # [B, N]
# 选择内容丰富的区域余弦相似度最低的eta_ratio比例
k = int(self.eta_ratio * cosine_real.shape[1])
# 对真实图像生成权重图
_, real_indices = torch.topk(-cosine_real, k, dim=1) # 选择最不相似的区域
weight_real = torch.ones_like(cosine_real)
for b in range(cosine_real.shape[0]):
weight_real[b, real_indices[b]] = self.lambda_inc / (1e-6 + torch.abs(cosine_real[b, real_indices[b]])) #公式6
# 对生成图像生成权重图(同理)
_, fake_indices = torch.topk(-cosine_fake, k, dim=1)
weight_fake = torch.ones_like(cosine_fake)
for b in range(cosine_fake.shape[0]):
weight_fake[b, fake_indices[b]] = self.lambda_inc / (1e-6 + torch.abs(cosine_fake[b, fake_indices[b]]))
return weight_real, weight_fake
def forward(self, D_real, D_fake, real_scores, fake_scores):
"""
计算内容感知对抗损失
Args:
D_real: 判别器对真实图像的特征输出 [B, C, H, W]
D_fake: 判别器对生成图像的特征输出 [B, C, H, W]
real_scores: 真实图像的判别器预测 [B, N] (N=H*W)
fake_scores: 生成图像的判别器预测 [B, N]
Returns:
loss_co_adv: 内容感知对抗损失
"""
B, C, H, W = D_real.shape
N = H * W
# 注册钩子获取梯度
gradients_real = []
gradients_fake = []
def hook_real(grad):
gradients_real.append(grad.detach().view(B, N, -1))
def hook_fake(grad):
gradients_fake.append(grad.detach().view(B, N, -1))
D_real.register_hook(hook_real)
D_fake.register_hook(hook_fake)
# 计算原始对抗损失以触发梯度计算
loss_real = torch.mean(torch.log(real_scores + 1e-8))
loss_fake = torch.mean(torch.log(1 - fake_scores + 1e-8))
# 添加与 D_real、D_fake 相关的 dummy 项,确保梯度传递
loss_dummy = 1e-8 * (D_real.sum() + D_fake.sum())
total_loss = loss_real + loss_fake + loss_dummy
total_loss.backward(retain_graph=True)
# 获取梯度数据
gradients_real = gradients_real[0] # [B, N, D]
gradients_fake = gradients_fake[0] # [B, N, D]
# 生成权重图
self.weight_real, self.weight_fake = self.generate_weight_map(gradients_real, gradients_fake)
# 应用权重到对抗损失
loss_co_real = torch.mean(self.weight_real * torch.log(real_scores + 1e-8))
loss_co_fake = torch.mean(self.weight_fake * torch.log(1 - fake_scores + 1e-8))
# 计算并返回最终内容感知对抗损失
loss_co_adv = -(loss_co_real + loss_co_fake)
return loss_co_adv
class ContentAwareTemporalNorm(nn.Module):
def __init__(self, gamma_stride=0.1, kernel_size=21, sigma=5.0):
super().__init__()
self.gamma_stride = gamma_stride # 控制整体运动幅度
self.smoother = GaussianBlur(kernel_size, sigma=sigma) # 高斯平滑层
def forward(self, weight_map):
"""
生成内容感知光流
Args:
weight_map: [B, 1, H, W] 权重图来自内容感知优化模块
Returns:
F_content: [B, 2, H, W] 生成的光流场x/y方向位移
"""
B, _, H, W = weight_map.shape
# 1. 归一化权重图
# 保持区域相对强度,同时限制数值范围
weight_norm = F.normalize(weight_map, p=1, dim=(2,3)) # L1归一化 [B,1,H,W]
# 2. 生成高斯噪声(与光流场同尺寸)
z = torch.randn(B, 2, H, W, device=weight_map.device) # [B,2,H,W]
# 3. 合成基础光流
# 将权重图扩展为2通道x/y方向共享权重
weight_expanded = weight_norm.expand(-1, 2, -1, -1) # [B,2,H,W]
F_raw = self.gamma_stride * weight_expanded * z # [B,2,H,W] #公式9
# 4. 平滑处理(保持结构连续性)
# 对每个通道独立进行高斯模糊
F_smooth = self.smoother(F_raw) # [B,2,H,W]
# 5. 动态范围调整(可选)
# 限制光流幅值,避免极端位移
F_content = torch.tanh(F_smooth) # 缩放到[-1,1]范围
return F_content

View File

@ -3,6 +3,7 @@ import torch
from .base_model import BaseModel from .base_model import BaseModel
from . import networks from . import networks
from .patchnce import PatchNCELoss from .patchnce import PatchNCELoss
from .cnt import *
import util.util as util import util.util as util
import timm import timm
import time import time
@ -21,12 +22,15 @@ class ROMAModel(BaseModel):
""" Configures options specific for CUT model """ Configures options specific for CUT model
""" """
parser.add_argument('--adj_size_list', type=list, default=[2, 4, 6, 8, 12], help='different scales of perception field') parser.add_argument('--adj_size_list', type=list, default=[2, 4, 6, 8, 12], help='different scales of perception field')
parser.add_argument('--lambda_mlp', type=float, default=1.0, help='weight of lr for discriminator') parser.add_argument('--lambda_mlp', type=float, default=1.0, help='weight of lr for discriminator')
parser.add_argument('--lambda_motion', type=float, default=1.0, help='weight for Temporal Consistency')
parser.add_argument('--lambda_D_ViT', type=float, default=1.0, help='weight for discriminator') parser.add_argument('--lambda_D_ViT', type=float, default=1.0, help='weight for discriminator')
parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss: GAN(G(X))') parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss: GAN(G(X))')
parser.add_argument('--lambda_global', type=float, default=1.0, help='weight for Global Structural Consistency') parser.add_argument('--lambda_global', type=float, default=1.0, help='weight for Global Structural Consistency')
parser.add_argument('--lambda_spatial', type=float, default=1.0, help='weight for Local Structural Consistency') parser.add_argument('--lambda_spatial', type=float, default=1.0, help='weight for Local Structural Consistency')
parser.add_argument('--lambda_inc', type=float, default=2.0, help='weight for Content Aware Optimization')
parser.add_argument('--eta_ratio', type=float, default=0.4, help='ratio for selecting content region')
parser.add_argument('--atten_layers', type=str, default='1,3,5', help='compute Cross-Similarity on which layers') parser.add_argument('--atten_layers', type=str, default='1,3,5', help='compute Cross-Similarity on which layers')
parser.add_argument('--local_nums', type=int, default=256) parser.add_argument('--local_nums', type=int, default=256)
parser.add_argument('--which_D_layer', type=int, default=-1) parser.add_argument('--which_D_layer', type=int, default=-1)
@ -42,13 +46,13 @@ class ROMAModel(BaseModel):
BaseModel.__init__(self, opt) BaseModel.__init__(self, opt)
self.loss_names = ['G_GAN_ViT', 'D_real_ViT', 'D_fake_ViT', 'global', 'spatial', 'motion'] self.loss_names = ['G_GAN_ViT', 'D_real_ViT', 'D_fake_ViT', 'global', 'spatial']
self.visual_names = ['real_A0', 'real_A1', 'fake_B0', 'fake_B1', 'real_B0', 'real_B1'] self.visual_names = ['real_A0', 'real_A1', 'fake_B0', 'fake_B1', 'real_B0', 'real_B1']
self.atten_layers = [int(i) for i in self.opt.atten_layers.split(',')] self.atten_layers = [int(i) for i in self.opt.atten_layers.split(',')]
if self.isTrain: if self.isTrain:
self.model_names = ['G', 'D_ViT'] self.model_names = ['G', 'D_ViT', 'G_2']
else: # during test time, only load G else: # during test time, only load G
self.model_names = ['G'] self.model_names = ['G']
@ -62,6 +66,11 @@ class ROMAModel(BaseModel):
self.netD_ViT = networks.MLPDiscriminator().to(self.device) self.netD_ViT = networks.MLPDiscriminator().to(self.device)
self.netPreViT = timm.create_model("vit_base_patch16_384",pretrained=True).to(self.device) self.netPreViT = timm.create_model("vit_base_patch16_384",pretrained=True).to(self.device)
# From UNSB
self.netE = networks.define_D(opt.output_nc*4, opt.ndf, opt.netD, opt.n_layers_D, opt.normD, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt)
# Deine another generator
self.netG_2 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt)
self.norm = F.softmax self.norm = F.softmax
@ -76,8 +85,13 @@ class ROMAModel(BaseModel):
self.criterionL1 = torch.nn.L1Loss().to(self.device) self.criterionL1 = torch.nn.L1Loss().to(self.device)
self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2)) self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
self.optimizer_D_ViT = torch.optim.Adam(self.netD_ViT.parameters(), lr=opt.lr * opt.lambda_mlp, betas=(opt.beta1, opt.beta2)) self.optimizer_D_ViT = torch.optim.Adam(self.netD_ViT.parameters(), lr=opt.lr * opt.lambda_mlp, betas=(opt.beta1, opt.beta2))
self.optimizer_E = torch.optim.Adam(self.netE.parameters(), lr=opt.lr * opt.lambda_mlp, betas=(opt.beta1, opt.beta2))
self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_G)
self.optimizers.append(self.optimizer_D_ViT) self.optimizers.append(self.optimizer_D_ViT)
self.optimizers.append(self.optimizer_E)
self.cao = ContentAwareOptimization(opt.lambda_inc, opt.eta_ratio) #损失函数
self.ctn = ContentAwareTemporalNorm() #生成的伪光流场
def data_dependent_initialize(self, data): def data_dependent_initialize(self, data):
""" """
@ -100,6 +114,13 @@ class ROMAModel(BaseModel):
self.loss_D.backward() self.loss_D.backward()
self.optimizer_D_ViT.step() self.optimizer_D_ViT.step()
# update E
self.set_requires_grad(self.netE, True)
self.optimizer_E.zero_grad()
self.loss_E = self.compute_E_loss()
self.loss_E.backward()
self.optimizer_E.step()
# update G # update G
self.set_requires_grad(self.netD_ViT, False) self.set_requires_grad(self.netD_ViT, False)
self.optimizer_G.zero_grad() self.optimizer_G.zero_grad()
@ -133,7 +154,7 @@ class ROMAModel(BaseModel):
times = np.concatenate([np.zeros(1), times]) times = np.concatenate([np.zeros(1), times])
times = torch.tensor(times).float().cuda() times = torch.tensor(times).float().cuda()
self.times = times self.times = times
bs = self.mutil_real_A0_tokens.size(0) bs = self.real_A0.size(0)
time_idx = (torch.randint(T, size=[1]).cuda() * torch.ones(size=[1]).cuda()).long() time_idx = (torch.randint(T, size=[1]).cuda() * torch.ones(size=[1]).cuda()).long()
self.time_idx = time_idx self.time_idx = time_idx
@ -149,17 +170,17 @@ class ROMAModel(BaseModel):
scale = (delta * (1 - delta / denom)).reshape(-1, 1, 1, 1) scale = (delta * (1 - delta / denom)).reshape(-1, 1, 1, 1)
# 对 Xt、Xt2 进行随机噪声更新 # 对 Xt、Xt2 进行随机噪声更新
Xt = self.mutil_real_A0_tokens if (t == 0) else (1 - inter) * Xt + inter * Xt_1.detach() + \ Xt = self.real_A0 if (t == 0) else (1 - inter) * Xt + inter * Xt_1.detach() + \
(scale * tau).sqrt() * torch.randn_like(Xt).to(self.mutil_real_A0_tokens.device) (scale * tau).sqrt() * torch.randn_like(Xt).to(self.real_A0.device)
time_idx = (t * torch.ones(size=[self.mutil_real_A0_tokens.shape[0]]).to(self.mutil_real_A0_tokens.device)).long() time_idx = (t * torch.ones(size=[self.real_A0.shape[0]]).to(self.real_A0.device)).long()
z = torch.randn(size=[self.mutil_real_A0_tokens.shape[0], 4 * self.opt.ngf]).to(self.mutil_real_A0_tokens.device) z = torch.randn(size=[self.real_A0.shape[0], 4 * self.opt.ngf]).to(self.real_A0.device)
self.time = times[time_idx] self.time = times[time_idx]
Xt_1 = self.netG(Xt, self.time, z) Xt_1 = self.netG(Xt, self.time, z)
Xt2 = self.mutil_real_A1_tokens if (t == 0) else (1 - inter) * Xt2 + inter * Xt_12.detach() + \ Xt2 = self.real_A1 if (t == 0) else (1 - inter) * Xt2 + inter * Xt_12.detach() + \
(scale * tau).sqrt() * torch.randn_like(Xt2).to(self.mutil_real_A1_tokens.device) (scale * tau).sqrt() * torch.randn_like(Xt2).to(self.real_A1.device)
time_idx = (t * torch.ones(size=[self.mutil_real_A1_tokens.shape[0]]).to(self.mutil_real_A1_tokens.device)).long() time_idx = (t * torch.ones(size=[self.real_A1.shape[0]]).to(self.real_A1.device)).long()
z = torch.randn(size=[self.mutil_real_A1_tokens.shape[0], 4 * self.opt.ngf]).to(self.mutil_real_A1_tokens.device) z = torch.randn(size=[self.real_A1.shape[0], 4 * self.opt.ngf]).to(self.real_A1.device)
Xt_12 = self.netG(Xt2, self.time, z) Xt_12 = self.netG(Xt2, self.time, z)
# 保存去噪后的中间结果 (real_A_noisy 等),供下一步做拼接 # 保存去噪后的中间结果 (real_A_noisy 等),供下一步做拼接
@ -169,11 +190,11 @@ class ROMAModel(BaseModel):
self.noisy_map = self.real_A_noisy - self.real_A self.noisy_map = self.real_A_noisy - self.real_A
# ============ 第三步:拼接输入并执行网络推理 ============= # ============ 第三步:拼接输入并执行网络推理 =============
bs = self.mutil_real_A0_tokens.size(0) bs = self.real_A0.size(0)
z_in = torch.randn(size=[2 * bs, 4 * self.opt.ngf]).to(self.mutil_real_A0_tokens.device) z_in = torch.randn(size=[2 * bs, 4 * self.opt.ngf]).to(self.real_A0.device)
z_in2 = torch.randn(size=[bs, 4 * self.opt.ngf]).to(self.mutil_real_A1_tokens.device) z_in2 = torch.randn(size=[bs, 4 * self.opt.ngf]).to(self.real_A1.device)
# 将 real_A, real_B 拼接 (如 nce_idt=True),并同样处理 real_A_noisy 与 XtB # 将 real_A, real_B 拼接 (如 nce_idt=True),并同样处理 real_A_noisy 与 XtB
self.real = self.mutil_real_A0_tokens self.real = self.real_A0
self.realt = self.real_A_noisy self.realt = self.real_A_noisy
if self.opt.flip_equivariance: if self.opt.flip_equivariance:
@ -206,6 +227,28 @@ class ROMAModel(BaseModel):
self.mutil_fake_B0_tokens = self.netPreViT(self.fake_B0_resize, self.atten_layers, get_tokens=True) self.mutil_fake_B0_tokens = self.netPreViT(self.fake_B0_resize, self.atten_layers, get_tokens=True)
self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True) self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True)
if self.opt.phase == 'train':
# 真实图像的梯度
real_gradient = torch.autograd.grad(self.real_B.sum(), self.real_B, create_graph=True)[0]
# 生成图像的梯度
fake_gradient = torch.autograd.grad(self.fake_B.sum(), self.fake_B, create_graph=True)[0]
# 梯度图
self.weight_real, self.weight_fake = self.cao.generate_weight_map(real_gradient, fake_gradient)
# 生成图像的CTN光流图
self.f_content = self.ctn(self.weight_fake)
# 把前面生成后的图片再加上noisy_map
self.fake_B0_2 = self.fake_B0 + self.noisy_map
# 变换后的图片
wapped_fake_B0_2 = warp(self.fake_B0_2, self.f_content)
# 经过第二次生成器
self.fake_B0_2 = self.netG_2(wapped_fake_B0_2, self.time, z_in)
def tokens_concat(self, origin_tokens, adjacent_size): def tokens_concat(self, origin_tokens, adjacent_size):
adj_size = adjacent_size adj_size = adjacent_size
B, token_num, C = origin_tokens.shape[0], origin_tokens.shape[1], origin_tokens.shape[2] B, token_num, C = origin_tokens.shape[0], origin_tokens.shape[1], origin_tokens.shape[2]
@ -277,6 +320,18 @@ class ROMAModel(BaseModel):
return self.loss_D_ViT return self.loss_D_ViT
def compute_E_loss(self):
"""计算判别器 E 的损失"""
XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B0.detach()], dim=1)
XtXt_2 = torch.cat([self.real_A_noisy2, self.fake_B1.detach()], dim=1)
temp = torch.logsumexp(self.netE(XtXt_1, self.time, XtXt_2).reshape(-1), dim=0).mean()
self.loss_E = -self.netE(XtXt_1, self.time, XtXt_1).mean() + temp + temp**2
return self.loss_E
def compute_G_loss(self): def compute_G_loss(self):
if self.opt.lambda_GAN > 0.0: if self.opt.lambda_GAN > 0.0:
@ -291,22 +346,35 @@ class ROMAModel(BaseModel):
else: else:
self.loss_G_GAN_ViT = 0.0 self.loss_G_GAN_ViT = 0.0
self.loss_SB = 0
if self.opt.lambda_SB > 0.0:
XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B0], dim=1)
XtXt_2 = torch.cat([self.real_A_noisy2, self.fake_B1], dim=1)
bs = self.opt.batch_size
# eq.9
ET_XY = self.netE(XtXt_1, self.time, XtXt_1).mean() - torch.logsumexp(self.netE(XtXt_1, self.time, XtXt_2).reshape(-1), dim=0)
self.loss_SB = -(self.opt.num_timesteps - self.time[0]) / self.opt.num_timesteps * self.opt.tau * ET_XY
self.loss_SB += self.opt.tau * torch.mean((self.real_A_noisy - self.fake_B0) ** 2)
self.loss_SB += self.opt.tau * torch.mean((self.real_A_noisy2 - self.fake_B1) ** 2)
if self.opt.lambda_global > 0.0 or self.opt.lambda_spatial > 0.0: if self.opt.lambda_global > 0.0 or self.opt.lambda_spatial > 0.0:
self.loss_global, self.loss_spatial = self.calculate_attention_loss() self.loss_global, self.loss_spatial = self.calculate_attention_loss()
else: else:
self.loss_global, self.loss_spatial = 0.0, 0.0 self.loss_global, self.loss_spatial = 0.0, 0.0
if self.opt.lambda_motion > 0.0:
self.loss_motion = 0.0
for real_A0_tokens, real_A1_tokens, fake_B0_tokens, fake_B1_tokens in zip(self.mutil_real_A0_tokens, self.mutil_real_A1_tokens, self.mutil_fake_B0_tokens, self.mutil_fake_B1_tokens):
A0_B1 = real_A0_tokens.bmm(fake_B1_tokens.permute(0,2,1))
B0_A1 = fake_B0_tokens.bmm(real_A1_tokens.permute(0,2,1))
cos_dis_global = F.cosine_similarity(A0_B1, B0_A1, dim=-1)
self.loss_motion += self.criterionL1(torch.ones_like(cos_dis_global), cos_dis_global).mean()
else:
self.loss_motion = 0.0
self.loss_G = self.loss_G_GAN_ViT + self.loss_global + self.loss_spatial + self.loss_motion if self.opt.lambda_ctn > 0.0:
wapped_fake_B1 = warp(self.fake_B1, self.f_content) # use updated self.f_content
self.l2_loss = F.mse_loss(self.fake_B0_2, wapped_fake_B1) * self.opt.lambda_ctn
else:
self.l2_loss = 0.0
self.loss_G = self.loss_G_GAN_ViT + self.loss_global + self.loss_spatial + self.l2_loss # include l2_loss in total loss
return self.loss_G return self.loss_G
def calculate_attention_loss(self): def calculate_attention_loss(self):

View File

@ -332,6 +332,7 @@ class CTNxModel(BaseModel):
self.loss_D.backward() self.loss_D.backward()
self.optimizer_D.step() self.optimizer_D.step()
# update E
self.set_requires_grad(self.netE, True) self.set_requires_grad(self.netE, True)
self.optimizer_E.zero_grad() self.optimizer_E.zero_grad()
self.loss_E = self.compute_E_loss() self.loss_E = self.compute_E_loss()