更新 models/rome_unsb_model.py
This commit is contained in:
parent
a62a234d83
commit
e5accb1d4c
@ -17,7 +17,7 @@ def warp(image, flow): #warp操作
|
|||||||
基于光流的图像变形函数
|
基于光流的图像变形函数
|
||||||
Args:
|
Args:
|
||||||
image: [B, C, H, W] 输入图像
|
image: [B, C, H, W] 输入图像
|
||||||
flow: [B, 2, H, W] 光流场(x/y方向位移)
|
flow: [B, 2, H, W] 光流场(x/y方向位移)
|
||||||
Returns:
|
Returns:
|
||||||
warped: [B, C, H, W] 变形后的图像
|
warped: [B, C, H, W] 变形后的图像
|
||||||
"""
|
"""
|
||||||
@ -70,7 +70,7 @@ class ContentAwareOptimization(nn.Module):
|
|||||||
"""
|
"""
|
||||||
计算每个patch梯度与平均梯度的余弦相似度
|
计算每个patch梯度与平均梯度的余弦相似度
|
||||||
Args:
|
Args:
|
||||||
gradients: [B, N, D] 判别器输出的每个patch的梯度(N=w*h)
|
gradients: [B, N, D] 判别器输出的每个patch的梯度(N=w*h)
|
||||||
Returns:
|
Returns:
|
||||||
cosine_sim: [B, N] 每个patch的余弦相似度
|
cosine_sim: [B, N] 每个patch的余弦相似度
|
||||||
"""
|
"""
|
||||||
@ -164,7 +164,7 @@ class ContentAwareTemporalNorm(nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
weight_map: [B, 1, H, W] 权重图(来自内容感知优化模块)
|
weight_map: [B, 1, H, W] 权重图(来自内容感知优化模块)
|
||||||
Returns:
|
Returns:
|
||||||
F_content: [B, 2, H, W] 生成的光流场(x/y方向位移)
|
F_content: [B, 2, H, W] 生成的光流场(x/y方向位移)
|
||||||
"""
|
"""
|
||||||
B, _, H, W = weight_map.shape
|
B, _, H, W = weight_map.shape
|
||||||
|
|
||||||
@ -195,7 +195,7 @@ class CTNxModel(BaseModel):
|
|||||||
def modify_commandline_options(parser, is_train=True):
|
def modify_commandline_options(parser, is_train=True):
|
||||||
"""配置 CTNx 模型的特定选项"""
|
"""配置 CTNx 模型的特定选项"""
|
||||||
|
|
||||||
parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss:GAN(G(X))')
|
parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss: GAN(G(X))')
|
||||||
parser.add_argument('--lambda_NCE', type=float, default=1.0, help='weight for NCE loss: NCE(G(X), X)')
|
parser.add_argument('--lambda_NCE', type=float, default=1.0, help='weight for NCE loss: NCE(G(X), X)')
|
||||||
parser.add_argument('--lambda_SB', type=float, default=0.1, help='weight for SB loss')
|
parser.add_argument('--lambda_SB', type=float, default=0.1, help='weight for SB loss')
|
||||||
parser.add_argument('--lambda_ctn', type=float, default=1.0, help='weight for content-aware temporal norm')
|
parser.add_argument('--lambda_ctn', type=float, default=1.0, help='weight for content-aware temporal norm')
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user