From e0dc08030cd878e36704940a716ee54abc74ddc4 Mon Sep 17 00:00:00 2001 From: bishe <123456789@163.com> Date: Sun, 9 Mar 2025 21:41:52 +0800 Subject: [PATCH] =?UTF-8?q?cptrans=E5=A4=8D=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- models/roma_unsb_model.py | 10 +++++----- scripts/train.sh | 2 +- scripts/traincp.sh | 9 +++++---- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/models/roma_unsb_model.py b/models/roma_unsb_model.py index 151f483..bdc98f4 100644 --- a/models/roma_unsb_model.py +++ b/models/roma_unsb_model.py @@ -227,7 +227,7 @@ class RomaUnsbModel(BaseModel): # 指定需要打印的训练损失 self.loss_names = ['G_GAN', 'D_ViT', 'G', 'global', 'spatial','ctn'] - self.visual_names = ['real_A0', 'fake_B0_1','fake_B0', 'real_B0','real_A1', 'fake_B1_1', 'fake_B1', 'real_B1'] + self.visual_names = ['real_A0', 'fake_B0', 'real_B0','real_A1', 'fake_B1', 'real_B1'] self.atten_layers = [int(i) for i in self.opt.atten_layers.split(',')] @@ -380,8 +380,8 @@ class RomaUnsbModel(BaseModel): """计算生成器的 GAN 损失""" if self.opt.lambda_ctn > 0.0: # 生成图像的CTN光流图 - self.f_content0 = self.ctn(self.weight_fake0) - self.f_content1 = self.ctn(self.weight_fake1) + self.f_content0 = self.ctn(self.weight_fake0.detach()) + self.f_content1 = self.ctn(self.weight_fake1.detach()) # 变换后的图片 self.warped_real_A0 = warp(self.real_A0, self.f_content0) @@ -429,8 +429,8 @@ class RomaUnsbModel(BaseModel): n_layers = len(self.atten_layers) mutil_real_A0_tokens = self.mutil_real_A0_tokens mutil_real_A1_tokens = self.mutil_real_A1_tokens - mutil_fake_B0_tokens = self.mutil_fake_B0_tokens_list[-1] - mutil_fake_B1_tokens = self.mutil_fake_B1_tokens_list[-1] + mutil_fake_B0_tokens = self.mutil_fake_B0_tokens + mutil_fake_B1_tokens = self.mutil_fake_B1_tokens if self.opt.lambda_global > 0.0: diff --git a/scripts/train.sh b/scripts/train.sh index bc7c6c3..6c84aa2 100755 --- a/scripts/train.sh +++ b/scripts/train.sh @@ -17,7 +17,7 @@ python train.py \ --lambda_global 6.0 \ --gamma_stride 20 \ --lr 0.000002 \ - --gpu_id 1 \ + --gpu_id 0 \ --nce_idt False \ --netF mlp_sample \ --eta_ratio 0.4 \ diff --git a/scripts/traincp.sh b/scripts/traincp.sh index f26e8f3..5b0625c 100644 --- a/scripts/traincp.sh +++ b/scripts/traincp.sh @@ -1,6 +1,6 @@ python train.py \ --dataroot /home/openxs/kunyu/datasets/InfraredCity-Lite/Double/Moitor \ - --name cp_1 \ + --name cp_2 \ --dataset_mode unaligned_double \ --display_env CP \ --model roma_unsb \ @@ -9,9 +9,10 @@ python train.py \ --lambda_global 6.0 \ --lambda_spatial 6.0 \ --gamma_stride 20 \ - --lr 0.000001 \ - --gpu_id 2 \ + --lr 0.000002 \ + --gpu_id 0 \ --eta_ratio 0.4 \ --n_epochs 100 \ --n_epochs_decay 100 \ -# cp1 复现cptrans的效果 \ No newline at end of file +# cp1 复现cptrans的效果 --lr 0.000001 +# cp2 修了一下cp1的代码,--lr 0.000002 \ No newline at end of file