withoutCNT
This commit is contained in:
parent
67151c73f7
commit
4af0d7463d
@ -68,3 +68,4 @@
|
|||||||
================ Training Loss (Sun Feb 23 23:13:05 2025) ================
|
================ Training Loss (Sun Feb 23 23:13:05 2025) ================
|
||||||
================ Training Loss (Sun Feb 23 23:13:59 2025) ================
|
================ Training Loss (Sun Feb 23 23:13:59 2025) ================
|
||||||
================ Training Loss (Sun Feb 23 23:14:59 2025) ================
|
================ Training Loss (Sun Feb 23 23:14:59 2025) ================
|
||||||
|
================ Training Loss (Mon Feb 24 22:59:41 2025) ================
|
||||||
|
|||||||
@ -19,9 +19,9 @@
|
|||||||
easy_label: experiment_name
|
easy_label: experiment_name
|
||||||
epoch: latest
|
epoch: latest
|
||||||
epoch_count: 1
|
epoch_count: 1
|
||||||
eta_ratio: 0.1
|
eta_ratio: 0.4
|
||||||
evaluation_freq: 5000
|
evaluation_freq: 5000
|
||||||
flip_equivariance: False
|
flip_equivariance: True [default: False]
|
||||||
gan_mode: lsgan
|
gan_mode: lsgan
|
||||||
gpu_ids: 0
|
gpu_ids: 0
|
||||||
init_gain: 0.02
|
init_gain: 0.02
|
||||||
@ -31,11 +31,10 @@
|
|||||||
lambda_D_ViT: 1.0
|
lambda_D_ViT: 1.0
|
||||||
lambda_GAN: 8.0 [default: 1.0]
|
lambda_GAN: 8.0 [default: 1.0]
|
||||||
lambda_NCE: 8.0 [default: 1.0]
|
lambda_NCE: 8.0 [default: 1.0]
|
||||||
lambda_SB: 0.1
|
lambda_SB: 1.0 [default: 0.1]
|
||||||
lambda_ctn: 1.0
|
lambda_ctn: 1.0
|
||||||
lambda_global: 1.0
|
lambda_global: 1.0
|
||||||
lambda_inc: 1.0
|
lambda_inc: 1.0
|
||||||
lmda_1: 0.1
|
|
||||||
load_size: 286
|
load_size: 286
|
||||||
lr: 1e-05 [default: 0.0002]
|
lr: 1e-05 [default: 0.0002]
|
||||||
lr_decay_iters: 50
|
lr_decay_iters: 50
|
||||||
@ -47,14 +46,12 @@
|
|||||||
n_layers_D: 3
|
n_layers_D: 3
|
||||||
n_mlp: 3
|
n_mlp: 3
|
||||||
name: ROMA_UNSB_001 [default: experiment_name]
|
name: ROMA_UNSB_001 [default: experiment_name]
|
||||||
nce_T: 0.07
|
|
||||||
nce_idt: False [default: True]
|
nce_idt: False [default: True]
|
||||||
nce_includes_all_negatives_from_minibatch: False
|
nce_includes_all_negatives_from_minibatch: False
|
||||||
nce_layers: 0,4,8,12,16
|
nce_layers: 0,4,8,12,16
|
||||||
ndf: 64
|
ndf: 64
|
||||||
netD: basic_cond
|
netD: basic_cond
|
||||||
netF: mlp_sample
|
netF: mlp_sample
|
||||||
netF_nc: 256
|
|
||||||
netG: resnet_9blocks_cond
|
netG: resnet_9blocks_cond
|
||||||
ngf: 64
|
ngf: 64
|
||||||
no_antialias: False
|
no_antialias: False
|
||||||
@ -64,9 +61,8 @@ nce_includes_all_negatives_from_minibatch: False
|
|||||||
no_html: False
|
no_html: False
|
||||||
normD: instance
|
normD: instance
|
||||||
normG: instance
|
normG: instance
|
||||||
num_patches: 256
|
|
||||||
num_threads: 4
|
num_threads: 4
|
||||||
num_timesteps: 10 [default: 5]
|
num_timesteps: 4 [default: 5]
|
||||||
output_nc: 3
|
output_nc: 3
|
||||||
phase: train
|
phase: train
|
||||||
pool_size: 0
|
pool_size: 0
|
||||||
|
|||||||
Binary file not shown.
@ -166,7 +166,7 @@ class ContentAwareTemporalNorm(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
F_content: [B, 2, H, W] 生成的光流场(x/y方向位移)
|
F_content: [B, 2, H, W] 生成的光流场(x/y方向位移)
|
||||||
"""
|
"""
|
||||||
print(weight_map.shape)
|
#print(weight_map.shape)
|
||||||
B, _, H, W = weight_map.shape
|
B, _, H, W = weight_map.shape
|
||||||
|
|
||||||
# 1. 归一化权重图
|
# 1. 归一化权重图
|
||||||
@ -204,23 +204,19 @@ class RomaUnsbModel(BaseModel):
|
|||||||
parser.add_argument('--lambda_global', type=float, default=1.0, help='weight for Global Structural Consistency')
|
parser.add_argument('--lambda_global', type=float, default=1.0, help='weight for Global Structural Consistency')
|
||||||
|
|
||||||
parser.add_argument('--nce_idt', type=util.str2bool, nargs='?', const=True, default=False, help='use NCE loss for identity mapping: NCE(G(Y), Y))')
|
parser.add_argument('--nce_idt', type=util.str2bool, nargs='?', const=True, default=False, help='use NCE loss for identity mapping: NCE(G(Y), Y))')
|
||||||
parser.add_argument('--nce_layers', type=str, default='0,4,8,12,16', help='compute NCE loss on which layers')
|
|
||||||
parser.add_argument('--nce_includes_all_negatives_from_minibatch',
|
parser.add_argument('--nce_includes_all_negatives_from_minibatch',
|
||||||
type=util.str2bool, nargs='?', const=True, default=False,
|
type=util.str2bool, nargs='?', const=True, default=False,
|
||||||
help='(used for single image translation) If True, include the negatives from the other samples of the minibatch when computing the contrastive loss. Please see models/patchnce.py for more details.')
|
help='(used for single image translation) If True, include the negatives from the other samples of the minibatch when computing the contrastive loss. Please see models/patchnce.py for more details.')
|
||||||
|
parser.add_argument('--nce_layers', type=str, default='0,4,8,12,16', help='compute NCE loss on which layers')
|
||||||
|
|
||||||
parser.add_argument('--netF', type=str, default='mlp_sample', choices=['sample', 'reshape', 'mlp_sample'], help='how to downsample the feature map')
|
parser.add_argument('--netF', type=str, default='mlp_sample', choices=['sample', 'reshape', 'mlp_sample'], help='how to downsample the feature map')
|
||||||
parser.add_argument('--netF_nc', type=int, default=256)
|
|
||||||
parser.add_argument('--nce_T', type=float, default=0.07, help='temperature for NCE loss')
|
|
||||||
|
|
||||||
parser.add_argument('--lmda_1', type=float, default=0.1)
|
|
||||||
parser.add_argument('--num_patches', type=int, default=256, help='number of patches per layer')
|
|
||||||
parser.add_argument('--flip_equivariance',
|
parser.add_argument('--flip_equivariance',
|
||||||
type=util.str2bool, nargs='?', const=True, default=False,
|
type=util.str2bool, nargs='?', const=True, default=False,
|
||||||
help="Enforce flip-equivariance as additional regularization. It's used by FastCUT, but not CUT")
|
help="Enforce flip-equivariance as additional regularization. It's used by FastCUT, but not CUT")
|
||||||
|
|
||||||
parser.add_argument('--lambda_inc', type=float, default=1.0, help='incremental weight for content-aware optimization')
|
parser.add_argument('--lambda_inc', type=float, default=1.0, help='incremental weight for content-aware optimization')
|
||||||
parser.add_argument('--eta_ratio', type=float, default=0.1, help='ratio of content-rich regions')
|
parser.add_argument('--eta_ratio', type=float, default=0.4, help='ratio of content-rich regions')
|
||||||
|
|
||||||
parser.add_argument('--atten_layers', type=str, default='5', help='compute Cross-Similarity on which layers')
|
parser.add_argument('--atten_layers', type=str, default='5', help='compute Cross-Similarity on which layers')
|
||||||
|
|
||||||
@ -261,12 +257,10 @@ class RomaUnsbModel(BaseModel):
|
|||||||
|
|
||||||
if self.isTrain:
|
if self.isTrain:
|
||||||
self.model_names = ['G', 'D_ViT', 'E']
|
self.model_names = ['G', 'D_ViT', 'E']
|
||||||
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.model_names = ['G']
|
self.model_names = ['G']
|
||||||
|
|
||||||
print(f'input_nc = {self.opt.input_nc}')
|
|
||||||
# 创建网络
|
# 创建网络
|
||||||
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt)
|
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt)
|
||||||
|
|
||||||
@ -284,9 +278,6 @@ class RomaUnsbModel(BaseModel):
|
|||||||
# 定义损失函数
|
# 定义损失函数
|
||||||
self.criterionL1 = torch.nn.L1Loss().to(self.device)
|
self.criterionL1 = torch.nn.L1Loss().to(self.device)
|
||||||
self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
|
self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
|
||||||
self.criterionNCE = []
|
|
||||||
for nce_layer in self.nce_layers:
|
|
||||||
self.criterionNCE.append(PatchNCELoss(opt).to(self.device))
|
|
||||||
self.criterionIdt = torch.nn.L1Loss().to(self.device)
|
self.criterionIdt = torch.nn.L1Loss().to(self.device)
|
||||||
self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
|
self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
|
||||||
self.optimizer_D = torch.optim.Adam(self.netD_ViT.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
|
self.optimizer_D = torch.optim.Adam(self.netD_ViT.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
|
||||||
@ -459,10 +450,8 @@ class RomaUnsbModel(BaseModel):
|
|||||||
self.real = torch.flip(self.real, [3])
|
self.real = torch.flip(self.real, [3])
|
||||||
self.realt = torch.flip(self.realt, [3])
|
self.realt = torch.flip(self.realt, [3])
|
||||||
|
|
||||||
print(f'fake_B0: {self.real_A0.shape}, fake_B1: {self.real_A1.shape}')
|
|
||||||
self.fake_B0 = self.netG(self.real_A0, self.time, z_in)
|
self.fake_B0 = self.netG(self.real_A0, self.time, z_in)
|
||||||
self.fake_B1 = self.netG(self.real_A1, self.time, z_in2)
|
self.fake_B1 = self.netG(self.real_A1, self.time, z_in2)
|
||||||
print(f'fake_B0: {self.fake_B0.shape}, fake_B1: {self.fake_B1.shape}')
|
|
||||||
|
|
||||||
if self.opt.phase == 'train':
|
if self.opt.phase == 'train':
|
||||||
real_A0 = self.real_A0
|
real_A0 = self.real_A0
|
||||||
@ -540,7 +529,6 @@ class RomaUnsbModel(BaseModel):
|
|||||||
def compute_E_loss(self):
|
def compute_E_loss(self):
|
||||||
"""计算判别器 E 的损失"""
|
"""计算判别器 E 的损失"""
|
||||||
|
|
||||||
print(f'resl_A_noisy: {self.real_A_noisy.shape} \n fake_B0: {self.fake_B0.shape}')
|
|
||||||
XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B0.detach()], dim=1)
|
XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B0.detach()], dim=1)
|
||||||
XtXt_2 = torch.cat([self.real_A_noisy2, self.fake_B1.detach()], dim=1)
|
XtXt_2 = torch.cat([self.real_A_noisy2, self.fake_B1.detach()], dim=1)
|
||||||
temp = torch.logsumexp(self.netE(XtXt_1, self.time, XtXt_2).reshape(-1), dim=0).mean()
|
temp = torch.logsumexp(self.netE(XtXt_1, self.time, XtXt_2).reshape(-1), dim=0).mean()
|
||||||
|
|||||||
@ -14,20 +14,15 @@ python train.py \
|
|||||||
--model roma_unsb \
|
--model roma_unsb \
|
||||||
--lambda_GAN 8.0 \
|
--lambda_GAN 8.0 \
|
||||||
--lambda_NCE 8.0 \
|
--lambda_NCE 8.0 \
|
||||||
--lambda_SB 0.1 \
|
--lambda_SB 1.0 \
|
||||||
--lambda_ctn 1.0 \
|
--lambda_ctn 1.0 \
|
||||||
--lambda_inc 1.0 \
|
--lambda_inc 1.0 \
|
||||||
--lr 0.00001 \
|
--lr 0.00001 \
|
||||||
--gpu_id 0 \
|
--gpu_id 0 \
|
||||||
--nce_idt False \
|
--nce_idt False \
|
||||||
--nce_layers 0,4,8,12,16 \
|
|
||||||
--netF mlp_sample \
|
--netF mlp_sample \
|
||||||
--netF_nc 256 \
|
--flip_equivariance True \
|
||||||
--nce_T 0.07 \
|
--eta_ratio 0.4 \
|
||||||
--lmda_1 0.1 \
|
|
||||||
--num_patches 256 \
|
|
||||||
--flip_equivariance False \
|
|
||||||
--eta_ratio 0.1 \
|
|
||||||
--tau 0.01 \
|
--tau 0.01 \
|
||||||
--num_timesteps 10 \
|
--num_timesteps 4 \
|
||||||
--input_nc 3
|
--input_nc 3
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user