diff --git a/models/__pycache__/networks.cpython-39.pyc b/models/__pycache__/networks.cpython-39.pyc index 91353a9..d690c50 100644 Binary files a/models/__pycache__/networks.cpython-39.pyc and b/models/__pycache__/networks.cpython-39.pyc differ diff --git a/models/__pycache__/roma_unsb_model.cpython-39.pyc b/models/__pycache__/roma_unsb_model.cpython-39.pyc index 728de33..7ad0b8a 100644 Binary files a/models/__pycache__/roma_unsb_model.cpython-39.pyc and b/models/__pycache__/roma_unsb_model.cpython-39.pyc differ diff --git a/models/roma_unsb_model.py b/models/roma_unsb_model.py index 9de01c6..151f483 100644 --- a/models/roma_unsb_model.py +++ b/models/roma_unsb_model.py @@ -198,7 +198,7 @@ class RomaUnsbModel(BaseModel): """配置 CTNx 模型的特定选项""" parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss: GAN(G(X))') - parser.add_argument('--lambda_SB', type=float, default=0.1, help='weight for SB loss') + parser.add_argument('--lambda_ctn', type=float, default=1.0, help='weight for content-aware temporal norm') parser.add_argument('--lambda_D_ViT', type=float, default=1.0, help='weight for discriminator') parser.add_argument('--lambda_global', type=float, default=1.0, help='weight for Global Structural Consistency') @@ -206,14 +206,8 @@ class RomaUnsbModel(BaseModel): parser.add_argument('--lambda_inc', type=float, default=1.0, help='incremental weight for content-aware optimization') parser.add_argument('--local_nums', type=int, default=64, help='number of local patches') parser.add_argument('--side_length', type=int, default=7) - parser.add_argument('--nce_idt', type=util.str2bool, nargs='?', const=True, default=False, help='use NCE loss for identity mapping: NCE(G(Y), Y))') - parser.add_argument('--nce_includes_all_negatives_from_minibatch', - type=util.str2bool, nargs='?', const=True, default=False, - help='(used for single image translation) If True, include the negatives from the other samples of the minibatch when computing the contrastive loss. Please see models/patchnce.py for more details.') parser.add_argument('--nce_layers', type=str, default='0,4,8,12,16', help='compute NCE loss on which layers') - - parser.add_argument('--netF', type=str, default='mlp_sample', choices=['sample', 'reshape', 'mlp_sample'], help='how to downsample the feature map') - + parser.add_argument('--eta_ratio', type=float, default=0.4, help='ratio of content-rich regions') parser.add_argument('--gamma_stride', type=float, default=20, help='ratio of stride for computing the similarity matrix') parser.add_argument('--atten_layers', type=str, default='5', help='compute Cross-Similarity on which layers') @@ -253,7 +247,7 @@ class RomaUnsbModel(BaseModel): # 创建网络 self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt) - + if self.isTrain: @@ -321,88 +315,28 @@ class RomaUnsbModel(BaseModel): def forward(self): """Run forward pass; called by both functions and .""" + self.fake_B0 = self.netG(self.real_A0) + self.fake_B1 = self.netG(self.real_A1) - # ============ 第一步:对 real_A / real_A2 进行多步随机生成过程 ============ - tau = self.opt.tau - T = self.opt.num_timesteps - incs = np.array([0] + [1/(i+1) for i in range(T-1)]) - times = np.cumsum(incs) - times = times / times[-1] - times = 0.5 * times[-1] + 0.5 * times #[0.5,1] - times = np.concatenate([np.zeros(1), times]) - times = torch.tensor(times).float().cuda() - self.times = times - bs = self.real_A0.size(0) - time_idx = (torch.randint(T, size=[1]).cuda() * torch.ones(size=[1]).cuda()).long() - self.time_idx = time_idx - self.fake_B0_list = [] - self.fake_B1_list = [] - - with torch.no_grad(): - self.netG.eval() - # ============ 第二步:对 real_A / real_A2 进行多步随机生成过程 ============ - for t in range(self.time_idx.int().item() + 1): - # 计算增量 delta 与 inter/scale,用于每个时间步的插值等 - if t > 0: - delta = times[t] - times[t - 1] - denom = times[-1] - times[t - 1] - inter = (delta / denom).reshape(-1, 1, 1, 1) - scale = (delta * (1 - delta / denom)).reshape(-1, 1, 1, 1) - - # 对 Xt、Xt2 进行随机噪声更新 - Xt = self.real_A0 if (t == 0) else (1 - inter) * Xt + inter * Xt_1.detach() + \ - (scale * tau).sqrt() * torch.randn_like(Xt).to(self.real_A0.device) - time_idx = (t * torch.ones(size=[self.real_A0.shape[0]]).to(self.real_A0.device)).long() - z = torch.randn(size=[self.real_A0.shape[0], 4 * self.opt.ngf]).to(self.real_A0.device) - time = times[time_idx] - Xt_1 = self.netG(Xt.detach(), time, z) - - Xt2 = self.real_A1 if (t == 0) else (1 - inter) * Xt2 + inter * Xt_12.detach() + \ - (scale * tau).sqrt() * torch.randn_like(Xt2).to(self.real_A1.device) - time_idx = (t * torch.ones(size=[self.real_A1.shape[0]]).to(self.real_A1.device)).long() - z = torch.randn(size=[self.real_A1.shape[0], 4 * self.opt.ngf]).to(self.real_A1.device) - Xt_12 = self.netG(Xt2.detach(), time, z) - self.fake_B0_list.append(Xt_1) - self.fake_B1_list.append(Xt_12) - - self.fake_B0_1 = self.fake_B0_list[0] - self.fake_B1_1 = self.fake_B0_list[0] - self.fake_B0 = self.fake_B0_list[-1] - self.fake_B1 = self.fake_B1_list[-1] - self.z_in = z - self.z_in2 = z - if self.opt.phase == 'train': + if self.opt.isTrain: real_A0 = self.real_A0 real_A1 = self.real_A1 real_B0 = self.real_B0 real_B1 = self.real_B1 fake_B0 = self.fake_B0 fake_B1 = self.fake_B1 - self.mutil_fake_B0_tokens_list = [] - self.mutil_fake_B1_tokens_list = [] - for fake_B0_t in self.fake_B0_list: - fake_B0_t_resize = self.resize(fake_B0_t) # 调整到 ViT 输入尺寸 - tokens = self.netPreViT(fake_B0_t_resize, self.atten_layers, get_tokens=True) - self.mutil_fake_B0_tokens_list.append(tokens) - for fake_B1_t in self.fake_B1_list: - fake_B1_t_resize = self.resize(fake_B1_t) - tokens = self.netPreViT(fake_B1_t_resize, self.atten_layers, get_tokens=True) - self.mutil_fake_B1_tokens_list.append(tokens) - self.real_A0_resize = self.resize(real_A0) self.real_A1_resize = self.resize(real_A1) real_B0 = self.resize(real_B0) real_B1 = self.resize(real_B1) self.fake_B0_resize = self.resize(fake_B0) self.fake_B1_resize = self.resize(fake_B1) - self.mutil_real_A0_tokens = self.netPreViT(self.real_A0_resize, self.atten_layers, get_tokens=True) self.mutil_real_A1_tokens = self.netPreViT(self.real_A1_resize, self.atten_layers, get_tokens=True) self.mutil_real_B0_tokens = self.netPreViT(real_B0, self.atten_layers, get_tokens=True) self.mutil_real_B1_tokens = self.netPreViT(real_B1, self.atten_layers, get_tokens=True) - # [[1,576,768],[1,576,768],[1,576,768]] - # [3,576,768] - + self.mutil_fake_B0_tokens = self.netPreViT(self.fake_B0_resize, self.atten_layers, get_tokens=True) + self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True) def compute_D_loss(self): """Calculate GAN loss with Content-Aware Optimization""" @@ -414,27 +348,25 @@ class RomaUnsbModel(BaseModel): real_B1_tokens = self.mutil_real_B1_tokens[0] pred_real1, real_features1 = self.netD_ViT(real_B1_tokens) # scores, features - for fake0_token, fake1_token in zip(self.mutil_fake_B0_tokens_list, self.mutil_fake_B1_tokens_list): - pre_fake0, fake_features0 = self.netD_ViT(fake0_token[0].detach()) - pre_fake1, fake_features1 = self.netD_ViT(fake1_token[0].detach()) - loss_cao0, self.weight_real0, self.weight_fake0 = self.cao( - D_real=real_features0, - D_fake=fake_features0, - real_scores=pred_real0, - fake_scores=pre_fake0 - ) - loss_cao1, self.weight_real1, self.weight_fake1 = self.cao( - D_real=real_features1, - D_fake=fake_features1, - real_scores=pred_real1, - fake_scores=pre_fake1 - ) - loss_cao += loss_cao0 + loss_cao1 + pre_fake0, fake_features0 = self.netD_ViT(self.mutil_fake_B0_tokens[0].detach()) + pre_fake1, fake_features1 = self.netD_ViT(self.mutil_fake_B1_tokens[0].detach()) + loss_cao0, self.weight_real0, self.weight_fake0 = self.cao( + D_real=real_features0, + D_fake=fake_features0, + real_scores=pred_real0, + fake_scores=pre_fake0 + ) + loss_cao1, self.weight_real1, self.weight_fake1 = self.cao( + D_real=real_features1, + D_fake=fake_features1, + real_scores=pred_real1, + fake_scores=pre_fake1 + ) + loss_cao += loss_cao0 + loss_cao1 # ===== 综合损失 ===== - total_steps = len(self.fake_B0_list) - self.loss_D_ViT = loss_cao * 0.5 * lambda_D_ViT/ total_steps + self.loss_D_ViT = loss_cao * 0.5 * lambda_D_ViT # 记录损失值供可视化 @@ -458,8 +390,8 @@ class RomaUnsbModel(BaseModel): self.warped_fake_B1 = warp(self.fake_B1,self.f_content1) # 经过第二次生成器 - self.warped_fake_B0_2 = self.netG(self.warped_real_A0, self.times[torch.zeros(size=[1]).cuda().long()], self.z_in) - self.warped_fake_B1_2 = self.netG(self.warped_real_A1, self.times[torch.zeros(size=[1]).cuda().long()], self.z_in2) + self.warped_fake_B0_2 = self.netG(self.warped_real_A0) + self.warped_fake_B1_2 = self.netG(self.warped_real_A1) warped_fake_B0_2=self.warped_fake_B0_2 warped_fake_B1_2=self.warped_fake_B1_2 @@ -472,8 +404,8 @@ class RomaUnsbModel(BaseModel): if self.opt.lambda_GAN > 0.0: - pred_fake0,_ = self.netD_ViT(self.mutil_fake_B0_tokens_list[-1][0]) - pred_fake1,_ = self.netD_ViT(self.mutil_fake_B1_tokens_list[-1][0]) + pred_fake0,_ = self.netD_ViT(self.mutil_fake_B0_tokens[0]) + pred_fake1,_ = self.netD_ViT(self.mutil_fake_B1_tokens[0]) self.loss_G_GAN0 = self.criterionGAN(pred_fake0, True).mean() self.loss_G_GAN1 = self.criterionGAN(pred_fake1, True).mean() self.loss_G_GAN = (self.loss_G_GAN0 + self.loss_G_GAN1)*0.5 diff --git a/options/__pycache__/base_options.cpython-39.pyc b/options/__pycache__/base_options.cpython-39.pyc index 7658be9..a03f478 100644 Binary files a/options/__pycache__/base_options.cpython-39.pyc and b/options/__pycache__/base_options.cpython-39.pyc differ diff --git a/options/__pycache__/train_options.cpython-39.pyc b/options/__pycache__/train_options.cpython-39.pyc index 76d2a2f..1b019a1 100644 Binary files a/options/__pycache__/train_options.cpython-39.pyc and b/options/__pycache__/train_options.cpython-39.pyc differ diff --git a/options/base_options.py b/options/base_options.py index b20e1b4..d63d895 100644 --- a/options/base_options.py +++ b/options/base_options.py @@ -36,7 +36,7 @@ class BaseOptions(): parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer') parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer') parser.add_argument('--netD', type=str, default='basic_cond', choices=['basic_cond', 'basic', 'n_layers', 'pixel', 'patch', 'tilestylegan2', 'stylegan2'], help='specify discriminator architecture. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator') - parser.add_argument('--netG', type=str, default='resnet_9blocks_cond', choices=['resnet_9blocks','resnet_9blocks_mask', 'resnet_6blocks', 'unet_256', 'unet_128', 'stylegan2', 'smallstylegan2', 'resnet_cat', 'resnet_9blocks_cond'], help='specify generator architecture') + parser.add_argument('--netG', type=str, default='resnet_9blocks', choices=['resnet_9blocks','resnet_9blocks_mask', 'resnet_6blocks', 'unet_256', 'unet_128', 'stylegan2', 'smallstylegan2', 'resnet_cat', 'resnet_9blocks_cond'], help='specify generator architecture') parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers') parser.add_argument('--normG', type=str, default='instance', choices=['instance', 'batch', 'none'], help='instance normalization or batch normalization for G') parser.add_argument('--normD', type=str, default='instance', choices=['instance', 'batch', 'none'], help='instance normalization or batch normalization for D') diff --git a/scripts/traincp.sh b/scripts/traincp.sh new file mode 100644 index 0000000..f26e8f3 --- /dev/null +++ b/scripts/traincp.sh @@ -0,0 +1,17 @@ +python train.py \ + --dataroot /home/openxs/kunyu/datasets/InfraredCity-Lite/Double/Moitor \ + --name cp_1 \ + --dataset_mode unaligned_double \ + --display_env CP \ + --model roma_unsb \ + --lambda_ctn 10 \ + --lambda_inc 1.0 \ + --lambda_global 6.0 \ + --lambda_spatial 6.0 \ + --gamma_stride 20 \ + --lr 0.000001 \ + --gpu_id 2 \ + --eta_ratio 0.4 \ + --n_epochs 100 \ + --n_epochs_decay 100 \ +# cp1 复现cptrans的效果 \ No newline at end of file