更新 models/rome_unsb_model.py

This commit is contained in:
123456 2025-02-23 15:22:14 +08:00
parent a62a234d83
commit e5accb1d4c

View File

@ -17,7 +17,7 @@ def warp(image, flow): #warp操作
基于光流的图像变形函数
Args:
image: [B, C, H, W] 输入图像
flow: [B, 2, H, W] 光流场x/y方向位移
flow: [B, 2, H, W] 光流场(x/y方向位移)
Returns:
warped: [B, C, H, W] 变形后的图像
"""
@ -70,7 +70,7 @@ class ContentAwareOptimization(nn.Module):
"""
计算每个patch梯度与平均梯度的余弦相似度
Args:
gradients: [B, N, D] 判别器输出的每个patch的梯度N=w*h
gradients: [B, N, D] 判别器输出的每个patch的梯度(N=w*h)
Returns:
cosine_sim: [B, N] 每个patch的余弦相似度
"""
@ -164,7 +164,7 @@ class ContentAwareTemporalNorm(nn.Module):
Args:
weight_map: [B, 1, H, W] 权重图来自内容感知优化模块
Returns:
F_content: [B, 2, H, W] 生成的光流场x/y方向位移
F_content: [B, 2, H, W] 生成的光流场(x/y方向位移)
"""
B, _, H, W = weight_map.shape
@ -195,7 +195,7 @@ class CTNxModel(BaseModel):
def modify_commandline_options(parser, is_train=True):
"""配置 CTNx 模型的特定选项"""
parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN lossGAN(G(X))')
parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss: GAN(G(X))')
parser.add_argument('--lambda_NCE', type=float, default=1.0, help='weight for NCE loss: NCE(G(X), X)')
parser.add_argument('--lambda_SB', type=float, default=0.1, help='weight for SB loss')
parser.add_argument('--lambda_ctn', type=float, default=1.0, help='weight for content-aware temporal norm')