EDIT_DOWN
This commit is contained in:
parent
3c4d53377c
commit
e8e483fbf8
5
.gitignore
vendored
Normal file
5
.gitignore
vendored
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
checkpoints/
|
||||||
|
*.log
|
||||||
|
*.pth
|
||||||
|
*.ckpt
|
||||||
|
__pycache__/
|
||||||
Binary file not shown.
@ -66,6 +66,10 @@ class ContentAwareOptimization(nn.Module):
|
|||||||
self.lambda_inc = lambda_inc # 权重增强系数
|
self.lambda_inc = lambda_inc # 权重增强系数
|
||||||
self.eta_ratio = eta_ratio # 选择内容区域的比例
|
self.eta_ratio = eta_ratio # 选择内容区域的比例
|
||||||
|
|
||||||
|
# 改为类成员变量,确保钩子函数可访问
|
||||||
|
self.gradients_real = []
|
||||||
|
self.gradients_fake = []
|
||||||
|
|
||||||
def compute_cosine_similarity(self, gradients):
|
def compute_cosine_similarity(self, gradients):
|
||||||
"""
|
"""
|
||||||
计算每个patch梯度与平均梯度的余弦相似度
|
计算每个patch梯度与平均梯度的余弦相似度
|
||||||
@ -79,78 +83,65 @@ class ContentAwareOptimization(nn.Module):
|
|||||||
cosine_sim = F.cosine_similarity(gradients, mean_grad, dim=2) # [B, N]
|
cosine_sim = F.cosine_similarity(gradients, mean_grad, dim=2) # [B, N]
|
||||||
return cosine_sim
|
return cosine_sim
|
||||||
|
|
||||||
def generate_weight_map(self, gradients_fake):
|
def generate_weight_map(self, gradients_real, gradients_fake):
|
||||||
"""
|
"""
|
||||||
生成内容感知权重图
|
生成内容感知权重图
|
||||||
Args:
|
Args:
|
||||||
gradients_fake: [B, N, D] 生成图像判别器梯度 [2,3,256,256]
|
gradients_real: [B, N, D] 真实图像判别器梯度
|
||||||
|
gradients_fake: [B, N, D] 生成图像判别器梯度
|
||||||
Returns:
|
Returns:
|
||||||
weight_fake: [B, N] 生成图像权重图 [2,3,256]
|
weight_real: [B, N] 真实图像权重图
|
||||||
|
weight_fake: [B, N] 生成图像权重图
|
||||||
"""
|
"""
|
||||||
|
# 计算真实图像块的余弦相似度
|
||||||
|
cosine_real = self.compute_cosine_similarity(gradients_real) # [B, N] 公式5
|
||||||
# 计算生成图像块的余弦相似度
|
# 计算生成图像块的余弦相似度
|
||||||
cosine_fake = self.compute_cosine_similarity(gradients_fake) # [B, N]
|
cosine_fake = self.compute_cosine_similarity(gradients_fake) # [B, N]
|
||||||
|
|
||||||
# 选择内容丰富的区域(余弦相似度最低的eta_ratio比例)
|
# 选择内容丰富的区域(余弦相似度最低的eta_ratio比例)
|
||||||
k = int(self.eta_ratio * cosine_fake.shape[1])
|
k = int(self.eta_ratio * cosine_real.shape[1])
|
||||||
|
|
||||||
# 对生成图像生成权重图(同理)
|
# 对真实图像生成权重图
|
||||||
|
_, real_indices = torch.topk(-cosine_real, k, dim=1) # 选择最不相似的区域
|
||||||
|
weight_real = torch.ones_like(cosine_real)
|
||||||
|
for b in range(cosine_real.shape[0]):
|
||||||
|
weight_real[b, real_indices[b]] = self.lambda_inc / (1e-6 + torch.abs(cosine_real[b, real_indices[b]])) #公式6
|
||||||
|
|
||||||
|
# 对生成图像生成权重图(同理)
|
||||||
_, fake_indices = torch.topk(-cosine_fake, k, dim=1)
|
_, fake_indices = torch.topk(-cosine_fake, k, dim=1)
|
||||||
weight_fake = torch.ones_like(cosine_fake)
|
weight_fake = torch.ones_like(cosine_fake)
|
||||||
for b in range(cosine_fake.shape[0]):
|
for b in range(cosine_fake.shape[0]):
|
||||||
weight_fake[b, fake_indices[b]] = self.lambda_inc / (1e-6 + torch.abs(cosine_fake[b, fake_indices[b]]))
|
weight_fake[b, fake_indices[b]] = self.lambda_inc / (1e-6 + torch.abs(cosine_fake[b, fake_indices[b]]))
|
||||||
|
|
||||||
return weight_fake
|
return weight_real, weight_fake
|
||||||
|
|
||||||
def forward(self, D_real, D_fake, real_scores, fake_scores):
|
def forward(self, D_real, D_fake, real_scores, fake_scores):
|
||||||
"""
|
# 清空梯度缓存
|
||||||
计算内容感知对抗损失
|
self.gradients_real.clear()
|
||||||
Args:
|
self.gradients_fake.clear()
|
||||||
D_real: 判别器对真实图像的特征输出 [B, C, H, W]
|
|
||||||
D_fake: 判别器对生成图像的特征输出 [B, C, H, W]
|
|
||||||
real_scores: 真实图像的判别器预测 [B, N] (N=H*W)
|
|
||||||
fake_scores: 生成图像的判别器预测 [B, N]
|
|
||||||
Returns:
|
|
||||||
loss_co_adv: 内容感知对抗损失
|
|
||||||
"""
|
|
||||||
B, C, H, W = D_real.shape
|
|
||||||
N = H * W
|
|
||||||
|
|
||||||
# 注册钩子获取梯度
|
# 注册钩子
|
||||||
gradients_real = []
|
hook_real = lambda grad: self.gradients_real.append(grad.detach())
|
||||||
gradients_fake = []
|
hook_fake = lambda grad: self.gradients_fake.append(grad.detach())
|
||||||
|
|
||||||
def hook_real(grad):
|
|
||||||
gradients_real.append(grad.detach().view(B, N, -1))
|
|
||||||
|
|
||||||
def hook_fake(grad):
|
|
||||||
gradients_fake.append(grad.detach().view(B, N, -1))
|
|
||||||
|
|
||||||
D_real.register_hook(hook_real)
|
D_real.register_hook(hook_real)
|
||||||
D_fake.register_hook(hook_fake)
|
D_fake.register_hook(hook_fake)
|
||||||
|
|
||||||
# 计算原始对抗损失以触发梯度计算
|
# 触发梯度计算
|
||||||
loss_real = torch.mean(torch.log(real_scores + 1e-8))
|
(real_scores.mean() + fake_scores.mean()).backward(retain_graph=True)
|
||||||
loss_fake = torch.mean(torch.log(1 - fake_scores + 1e-8))
|
|
||||||
# 添加与 D_real、D_fake 相关的 dummy 项,确保梯度传递
|
|
||||||
loss_dummy = 1e-8 * (D_real.sum() + D_fake.sum())
|
|
||||||
total_loss = loss_real + loss_fake + loss_dummy
|
|
||||||
total_loss.backward(retain_graph=True)
|
|
||||||
|
|
||||||
# 获取梯度数据
|
# 获取梯度并调整维度
|
||||||
gradients_real = gradients_real[0] # [B, N, D]
|
grad_real = self.gradients_real[0] # [B, N, D]
|
||||||
gradients_fake = gradients_fake[0] # [B, N, D]
|
grad_fake = self.gradients_fake[0]
|
||||||
|
|
||||||
# 生成权重图
|
# 生成权重图
|
||||||
self.weight_real, self.weight_fake = self.generate_weight_map(gradients_real, gradients_fake)
|
weight_real, weight_fake = self.generate_weight_map(grad_real, grad_fake)
|
||||||
|
|
||||||
# 应用权重到对抗损失
|
# 计算加权损失
|
||||||
loss_co_real = torch.mean(self.weight_real * torch.log(real_scores + 1e-8))
|
loss_co_real = (weight_real * torch.log(real_scores + 1e-8)).mean()
|
||||||
loss_co_fake = torch.mean(self.weight_fake * torch.log(1 - fake_scores + 1e-8))
|
loss_co_fake = (weight_fake * torch.log(1 - fake_scores + 1e-8)).mean()
|
||||||
|
|
||||||
# 计算并返回最终内容感知对抗损失
|
return -(loss_co_real + loss_co_fake), weight_real, weight_fake
|
||||||
loss_co_adv = -(loss_co_real + loss_co_fake)
|
|
||||||
|
|
||||||
return loss_co_adv
|
|
||||||
|
|
||||||
class ContentAwareTemporalNorm(nn.Module):
|
class ContentAwareTemporalNorm(nn.Module):
|
||||||
def __init__(self, gamma_stride=0.1, kernel_size=21, sigma=5.0):
|
def __init__(self, gamma_stride=0.1, kernel_size=21, sigma=5.0):
|
||||||
@ -158,6 +149,33 @@ class ContentAwareTemporalNorm(nn.Module):
|
|||||||
self.gamma_stride = gamma_stride # 控制整体运动幅度
|
self.gamma_stride = gamma_stride # 控制整体运动幅度
|
||||||
self.smoother = GaussianBlur(kernel_size, sigma=sigma) # 高斯平滑层
|
self.smoother = GaussianBlur(kernel_size, sigma=sigma) # 高斯平滑层
|
||||||
|
|
||||||
|
def upsample_weight_map(self, weight_patch, target_size=(256, 256)):
|
||||||
|
"""
|
||||||
|
将patch级别的权重图上采样到目标分辨率
|
||||||
|
Args:
|
||||||
|
weight_patch: [B, 1, 24, 24] 来自ViT的patch权重图
|
||||||
|
target_size: 目标分辨率 (H, W)
|
||||||
|
Returns:
|
||||||
|
weight_full: [B, 1, 256, 256] 上采样后的全分辨率权重图
|
||||||
|
"""
|
||||||
|
# 使用双线性插值上采样
|
||||||
|
B = weight_patch.shape[0]
|
||||||
|
weight_patch = weight_patch.view(B, 1, 24, 24)
|
||||||
|
|
||||||
|
weight_full = F.interpolate(
|
||||||
|
weight_patch,
|
||||||
|
size=target_size,
|
||||||
|
mode='bilinear',
|
||||||
|
align_corners=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# 对每个16x16的patch内部保持权重一致(可选)
|
||||||
|
# 通过平均池化再扩展,消除插值引入的渐变
|
||||||
|
weight_full = F.avg_pool2d(weight_full, kernel_size=16, stride=16)
|
||||||
|
weight_full = F.interpolate(weight_full, scale_factor=16, mode='nearest')
|
||||||
|
|
||||||
|
return weight_full
|
||||||
|
|
||||||
def forward(self, weight_map):
|
def forward(self, weight_map):
|
||||||
"""
|
"""
|
||||||
生成内容感知光流
|
生成内容感知光流
|
||||||
@ -166,15 +184,16 @@ class ContentAwareTemporalNorm(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
F_content: [B, 2, H, W] 生成的光流场(x/y方向位移)
|
F_content: [B, 2, H, W] 生成的光流场(x/y方向位移)
|
||||||
"""
|
"""
|
||||||
#print(weight_map.shape)
|
# 上采样权重图到全分辨率
|
||||||
B, _, H, W = weight_map.shape
|
weight_full = self.upsample_weight_map(weight_map) # [B,1,384,384]
|
||||||
|
|
||||||
# 1. 归一化权重图
|
# 1. 归一化权重图
|
||||||
# 保持区域相对强度,同时限制数值范围
|
# 保持区域相对强度,同时限制数值范围
|
||||||
weight_norm = F.normalize(weight_map, p=1, dim=(2,3)) # L1归一化 [B,1,H,W]
|
weight_norm = F.normalize(weight_full, p=1, dim=(2,3)) # L1归一化 [B,1,H,W]
|
||||||
|
|
||||||
# 2. 生成高斯噪声(与光流场同尺寸)
|
# 2. 生成高斯噪声
|
||||||
z = torch.randn(B, 2, H, W, device=weight_map.device) # [B,2,H,W]
|
B, _, H, W = weight_norm.shape
|
||||||
|
z = torch.randn(B, 2, H, W, device=weight_norm.device) # [B,2,H,W]
|
||||||
|
|
||||||
# 3. 合成基础光流
|
# 3. 合成基础光流
|
||||||
# 将权重图扩展为2通道(x/y方向共享权重)
|
# 将权重图扩展为2通道(x/y方向共享权重)
|
||||||
@ -437,8 +456,8 @@ class RomaUnsbModel(BaseModel):
|
|||||||
|
|
||||||
# ============ 第三步:拼接输入并执行网络推理 =============
|
# ============ 第三步:拼接输入并执行网络推理 =============
|
||||||
bs = self.real_A0.size(0)
|
bs = self.real_A0.size(0)
|
||||||
z_in = torch.randn(size=[bs, 4 * self.opt.ngf]).to(self.real_A0.device)
|
self.z_in = torch.randn(size=[bs, 4 * self.opt.ngf]).to(self.real_A0.device)
|
||||||
z_in2 = torch.randn(size=[bs, 4 * self.opt.ngf]).to(self.real_A1.device)
|
self.z_in2 = torch.randn(size=[bs, 4 * self.opt.ngf]).to(self.real_A1.device)
|
||||||
# 将 real_A, real_B 拼接 (如 nce_idt=True),并同样处理 real_A_noisy 与 XtB
|
# 将 real_A, real_B 拼接 (如 nce_idt=True),并同样处理 real_A_noisy 与 XtB
|
||||||
self.real = self.real_A0
|
self.real = self.real_A0
|
||||||
self.realt = self.real_A_noisy
|
self.realt = self.real_A_noisy
|
||||||
@ -449,8 +468,8 @@ class RomaUnsbModel(BaseModel):
|
|||||||
self.real = torch.flip(self.real, [3])
|
self.real = torch.flip(self.real, [3])
|
||||||
self.realt = torch.flip(self.realt, [3])
|
self.realt = torch.flip(self.realt, [3])
|
||||||
|
|
||||||
self.fake_B0 = self.netG(self.real_A0, self.time, z_in)
|
self.fake_B0 = self.netG(self.real_A0, self.time, self.z_in)
|
||||||
self.fake_B1 = self.netG(self.real_A1, self.time, z_in2)
|
self.fake_B1 = self.netG(self.real_A1, self.time, self.z_in2)
|
||||||
|
|
||||||
if self.opt.phase == 'train':
|
if self.opt.phase == 'train':
|
||||||
real_A0 = self.real_A0
|
real_A0 = self.real_A0
|
||||||
@ -476,28 +495,6 @@ class RomaUnsbModel(BaseModel):
|
|||||||
# [[1,576,768],[1,576,768],[1,576,768]]
|
# [[1,576,768],[1,576,768],[1,576,768]]
|
||||||
# [3,576,768]
|
# [3,576,768]
|
||||||
|
|
||||||
## 生成图像的梯度
|
|
||||||
#fake_gradient = torch.autograd.grad(self.mutil_fake_B0_tokens.sum(), self.mutil_fake_B0_tokens, create_graph=True)[0]
|
|
||||||
#
|
|
||||||
## 梯度图
|
|
||||||
#self.weight_fake = self.cao.generate_weight_map(fake_gradient)
|
|
||||||
#
|
|
||||||
## 生成图像的CTN光流图
|
|
||||||
#self.f_content = self.ctn(self.weight_fake)
|
|
||||||
#
|
|
||||||
## 变换后的图片
|
|
||||||
#self.warped_real_A_noisy2 = warp(self.real_A_noisy, self.f_content)
|
|
||||||
#self.warped_fake_B0 = warp(self.fake_B0,self.f_content)
|
|
||||||
#
|
|
||||||
## 经过第二次生成器
|
|
||||||
#self.warped_fake_B0_2 = self.netG(self.warped_real_A_noisy2, self.time, z_in)
|
|
||||||
|
|
||||||
#warped_fake_B0_2=self.warped_fake_B0_2
|
|
||||||
#warped_fake_B0=self.warped_fake_B0
|
|
||||||
#self.warped_fake_B0_2_resize = self.resize(warped_fake_B0_2)
|
|
||||||
#self.warped_fake_B0_resize = self.resize(warped_fake_B0)
|
|
||||||
#self.mutil_warped_fake_B0_tokens = self.netPreViT(self.warped_fake_B0_resize, self.atten_layers, get_tokens=True)
|
|
||||||
#self.mutil_fake_B0_2_tokens = self.netPreViT(self.warped_fake_B0_2_resize, self.atten_layers, get_tokens=True)
|
|
||||||
|
|
||||||
|
|
||||||
def compute_D_loss(self): #判别器还是没有改
|
def compute_D_loss(self): #判别器还是没有改
|
||||||
@ -505,25 +502,19 @@ class RomaUnsbModel(BaseModel):
|
|||||||
|
|
||||||
lambda_D_ViT = self.opt.lambda_D_ViT
|
lambda_D_ViT = self.opt.lambda_D_ViT
|
||||||
fake_B0_tokens = self.mutil_fake_B0_tokens[0].detach()
|
fake_B0_tokens = self.mutil_fake_B0_tokens[0].detach()
|
||||||
fake_B1_tokens = self.mutil_fake_B1_tokens[0].detach()
|
|
||||||
|
|
||||||
real_B0_tokens = self.mutil_real_B0_tokens[0]
|
real_B0_tokens = self.mutil_real_B0_tokens[0]
|
||||||
real_B1_tokens = self.mutil_real_B1_tokens[0]
|
|
||||||
|
|
||||||
|
|
||||||
pre_fake0_ViT = self.netD_ViT(fake_B0_tokens)
|
pre_fake0_ViT = self.netD_ViT(fake_B0_tokens)
|
||||||
pre_fake1_ViT = self.netD_ViT(fake_B1_tokens)
|
self.loss_D_fake_ViT = self.criterionGAN(pre_fake0_ViT, False)
|
||||||
|
|
||||||
self.loss_D_fake_ViT = (self.criterionGAN(pre_fake0_ViT, False).mean() + self.criterionGAN(pre_fake1_ViT, False).mean()) * 0.5 * lambda_D_ViT
|
|
||||||
|
|
||||||
pred_real0_ViT = self.netD_ViT(real_B0_tokens)
|
pred_real0_ViT = self.netD_ViT(real_B0_tokens)
|
||||||
pred_real1_ViT = self.netD_ViT(real_B1_tokens)
|
self.loss_D_real_ViT = self.criterionGAN(pred_real0_ViT, True)
|
||||||
self.loss_D_real_ViT = (self.criterionGAN(pred_real0_ViT, True).mean() + self.criterionGAN(pred_real1_ViT, True).mean()) * 0.5 * lambda_D_ViT
|
|
||||||
|
|
||||||
self.loss_D_ViT = (self.loss_D_fake_ViT + self.loss_D_real_ViT) * 0.5
|
self.losscao, self.weight_real, self.weight_fake = self.cao(pred_real0_ViT, pre_fake0_ViT, self.loss_D_real_ViT, self.loss_D_fake_ViT)
|
||||||
|
|
||||||
|
return self.losscao* lambda_D_ViT
|
||||||
return self.loss_D_ViT
|
|
||||||
|
|
||||||
def compute_E_loss(self):
|
def compute_E_loss(self):
|
||||||
"""计算判别器 E 的损失"""
|
"""计算判别器 E 的损失"""
|
||||||
@ -537,12 +528,28 @@ class RomaUnsbModel(BaseModel):
|
|||||||
|
|
||||||
def compute_G_loss(self):
|
def compute_G_loss(self):
|
||||||
"""计算生成器的 GAN 损失"""
|
"""计算生成器的 GAN 损失"""
|
||||||
|
if self.opt.lambda_ctn > 0.0:
|
||||||
|
# 生成图像的CTN光流图
|
||||||
|
self.f_content = self.ctn(self.weight_fake)
|
||||||
|
|
||||||
|
# 变换后的图片
|
||||||
|
self.warped_real_A_noisy2 = warp(self.real_A_noisy, self.f_content)
|
||||||
|
self.warped_fake_B0 = warp(self.fake_B0,self.f_content)
|
||||||
|
|
||||||
|
# 经过第二次生成器
|
||||||
|
self.warped_fake_B0_2 = self.netG(self.warped_real_A_noisy2, self.time, self.z_in)
|
||||||
|
|
||||||
|
warped_fake_B0_2=self.warped_fake_B0_2
|
||||||
|
warped_fake_B0=self.warped_fake_B0
|
||||||
|
# 计算L2损失
|
||||||
|
self.ctn_loss = F.mse_loss(warped_fake_B0_2, warped_fake_B0)
|
||||||
|
|
||||||
if self.opt.lambda_GAN > 0.0:
|
if self.opt.lambda_GAN > 0.0:
|
||||||
pred_fake = self.netD_ViT(self.mutil_fake_B0_tokens[0])
|
pred_fake = self.netD_ViT(self.mutil_fake_B0_tokens[0])
|
||||||
self.loss_G_GAN = self.criterionGAN(pred_fake, True).mean() * self.opt.lambda_GAN
|
self.loss_G_GAN = self.criterionGAN(pred_fake, True).mean()
|
||||||
else:
|
else:
|
||||||
self.loss_G_GAN = 0.0
|
self.loss_G_GAN = 0.0
|
||||||
|
|
||||||
self.loss_SB = 0
|
self.loss_SB = 0
|
||||||
if self.opt.lambda_SB > 0.0:
|
if self.opt.lambda_SB > 0.0:
|
||||||
XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B0], dim=1)
|
XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B0], dim=1)
|
||||||
@ -551,9 +558,9 @@ class RomaUnsbModel(BaseModel):
|
|||||||
bs = self.opt.batch_size
|
bs = self.opt.batch_size
|
||||||
|
|
||||||
# eq.9
|
# eq.9
|
||||||
ET_XY = self.netE(XtXt_1, self.time, XtXt_1).mean() - torch.logsumexp(self.netE(XtXt_1, self.time, XtXt_2).reshape(-1), dim=0)
|
ET_XY = self.netE(XtXt_1, self.time, XtXt_1).mean() - self.netE(XtXt_1, self.time, XtXt_2).mean()
|
||||||
self.loss_SB = -(self.opt.num_timesteps - self.time[0]) / self.opt.num_timesteps * self.opt.tau * ET_XY
|
self.loss_SB = -(self.opt.num_timesteps - self.time[0]) / self.opt.num_timesteps * self.opt.tau * ET_XY
|
||||||
self.loss_SB += self.opt.tau * torch.mean((self.real_A_noisy - self.fake_B0) ** 2)
|
self.loss_SB += torch.mean((self.real_A_noisy - self.fake_B0) ** 2)
|
||||||
|
|
||||||
if self.opt.lambda_global > 0.0:
|
if self.opt.lambda_global > 0.0:
|
||||||
loss_global = self.calculate_similarity(self.real_A0, self.fake_B0) + self.calculate_similarity(self.real_A1, self.fake_B1)
|
loss_global = self.calculate_similarity(self.real_A0, self.fake_B0) + self.calculate_similarity(self.real_A1, self.fake_B1)
|
||||||
@ -561,12 +568,10 @@ class RomaUnsbModel(BaseModel):
|
|||||||
else:
|
else:
|
||||||
loss_global = 0.0
|
loss_global = 0.0
|
||||||
|
|
||||||
self.l2_loss = 0.0
|
self.loss_G = self.opt.lambda_GAN * self.loss_G_GAN + \
|
||||||
#if self.opt.lambda_ctn > 0.0:
|
self.opt.lambda_SB * self.loss_SB + \
|
||||||
# wapped_fake_B = warp(self.fake_B, self.f_content) # use updated self.f_content
|
self.opt.lambda_ctn * self.ctn_loss + \
|
||||||
# self.l2_loss = F.mse_loss(self.fake_B_2, wapped_fake_B) # complete the loss calculation
|
loss_global * self.opt.lambda_global
|
||||||
|
|
||||||
self.loss_G = self.loss_G_GAN + self.opt.lambda_SB * self.loss_SB + self.opt.lambda_ctn * self.l2_loss + loss_global * self.opt.lambda_global
|
|
||||||
return self.loss_G
|
return self.loss_G
|
||||||
|
|
||||||
def calculate_attention_loss(self):
|
def calculate_attention_loss(self):
|
||||||
|
|||||||
@ -7,7 +7,7 @@
|
|||||||
|
|
||||||
python train.py \
|
python train.py \
|
||||||
--dataroot /home/openxs/kunyu/datasets/InfraredCity-Lite/Double/Moitor \
|
--dataroot /home/openxs/kunyu/datasets/InfraredCity-Lite/Double/Moitor \
|
||||||
--name ROMA_UNSB_002 \
|
--name ROMA_UNSB_003 \
|
||||||
--dataset_mode unaligned_double \
|
--dataset_mode unaligned_double \
|
||||||
--no_flip \
|
--no_flip \
|
||||||
--display_env ROMA \
|
--display_env ROMA \
|
||||||
|
|||||||
1
train.py
1
train.py
@ -44,6 +44,7 @@ if __name__ == '__main__':
|
|||||||
model.setup(opt) # regular setup: load and print networks; create schedulers
|
model.setup(opt) # regular setup: load and print networks; create schedulers
|
||||||
model.parallelize()
|
model.parallelize()
|
||||||
model.set_input(data) # unpack data from dataset and apply preprocessing
|
model.set_input(data) # unpack data from dataset and apply preprocessing
|
||||||
|
#print('Call opt paras')
|
||||||
model.optimize_parameters() # calculate loss functions, get gradients, update network weights
|
model.optimize_parameters() # calculate loss functions, get gradients, update network weights
|
||||||
if len(opt.gpu_ids) > 0:
|
if len(opt.gpu_ids) > 0:
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user