From e5accb1d4c41da4e5812e0575ced67cc93e8f35c Mon Sep 17 00:00:00 2001 From: 123456 <3351416005@qq.com> Date: Sun, 23 Feb 2025 15:22:14 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0=20models/rome=5Funsb=5Fmodel?= =?UTF-8?q?.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- models/rome_unsb_model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/models/rome_unsb_model.py b/models/rome_unsb_model.py index 952a097..fb62536 100644 --- a/models/rome_unsb_model.py +++ b/models/rome_unsb_model.py @@ -17,7 +17,7 @@ def warp(image, flow): #warp操作 基于光流的图像变形函数 Args: image: [B, C, H, W] 输入图像 - flow: [B, 2, H, W] 光流场(x/y方向位移) + flow: [B, 2, H, W] 光流场(x/y方向位移) Returns: warped: [B, C, H, W] 变形后的图像 """ @@ -70,7 +70,7 @@ class ContentAwareOptimization(nn.Module): """ 计算每个patch梯度与平均梯度的余弦相似度 Args: - gradients: [B, N, D] 判别器输出的每个patch的梯度(N=w*h) + gradients: [B, N, D] 判别器输出的每个patch的梯度(N=w*h) Returns: cosine_sim: [B, N] 每个patch的余弦相似度 """ @@ -164,7 +164,7 @@ class ContentAwareTemporalNorm(nn.Module): Args: weight_map: [B, 1, H, W] 权重图(来自内容感知优化模块) Returns: - F_content: [B, 2, H, W] 生成的光流场(x/y方向位移) + F_content: [B, 2, H, W] 生成的光流场(x/y方向位移) """ B, _, H, W = weight_map.shape @@ -195,7 +195,7 @@ class CTNxModel(BaseModel): def modify_commandline_options(parser, is_train=True): """配置 CTNx 模型的特定选项""" - parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss:GAN(G(X))') + parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss: GAN(G(X))') parser.add_argument('--lambda_NCE', type=float, default=1.0, help='weight for NCE loss: NCE(G(X), X)') parser.add_argument('--lambda_SB', type=float, default=0.1, help='weight for SB loss') parser.add_argument('--lambda_ctn', type=float, default=1.0, help='weight for content-aware temporal norm')