Compare commits
3 Commits
f98c285950
...
77468f16f9
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
77468f16f9 | ||
|
|
c173e29ea6 | ||
|
|
537cb050a5 |
@ -1411,13 +1411,12 @@ class MLPDiscriminator(nn.Module):
|
||||
self.activation = nn.GELU()
|
||||
self.linear2 = nn.Linear(hid_feat, out_feat)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
features = self.linear1(x) # 中间特征,即 D_real 或 D_fake
|
||||
x = self.activation(features)
|
||||
x = self.linear1(x)
|
||||
x = self.activation(x)
|
||||
x = self.dropout(x)
|
||||
scores = self.linear2(x) # 最终分数,即 real_scores 或 fake_scores
|
||||
return scores, features
|
||||
x = self.linear2(x)
|
||||
return self.dropout(x)
|
||||
|
||||
class NLayerDiscriminator(nn.Module):
|
||||
"""Defines a PatchGAN discriminator"""
|
||||
|
||||
@ -13,6 +13,7 @@ import util.util as util
|
||||
|
||||
from torchvision.transforms import transforms as tfs
|
||||
|
||||
|
||||
def warp(image, flow): #warp操作
|
||||
"""
|
||||
基于光流的图像变形函数
|
||||
@ -37,76 +38,74 @@ def warp(image, flow): #warp操作
|
||||
# 双线性插值
|
||||
return F.grid_sample(image, new_grid, align_corners=True)
|
||||
|
||||
# 时序归一化损失计算
|
||||
def compute_ctn_loss(G, x, F_content): #公式10
|
||||
"""
|
||||
计算内容感知时序归一化损失
|
||||
Args:
|
||||
G: 生成器
|
||||
x: 输入红外图像 [B,C,H,W]
|
||||
F_content: 生成的光流场 [B,2,H,W]
|
||||
"""
|
||||
|
||||
# 生成可见光图像
|
||||
y_fake = G(x) # [B,3,H,W]
|
||||
|
||||
# 对生成结果应用光流变形
|
||||
warped_fake = warp(y_fake, F_content) # [B,3,H,W]
|
||||
|
||||
# 对输入应用相同光流后生成图像
|
||||
warped_x = warp(x, F_content) # [B,C,H,W]
|
||||
y_fake_warped = G(warped_x) # [B,3,H,W]
|
||||
|
||||
# 计算L2损失
|
||||
loss = F.mse_loss(warped_fake, y_fake_warped)
|
||||
return loss
|
||||
|
||||
|
||||
class ContentAwareOptimization(nn.Module):
|
||||
def __init__(self, lambda_inc=2.0, eta_ratio=0.4):
|
||||
super().__init__()
|
||||
self.lambda_inc = lambda_inc
|
||||
self.eta_ratio = eta_ratio
|
||||
self.gradients = [] # 修改为单一梯度列表,通用性更强
|
||||
self.criterionGAN = networks.GANLoss('lsgan').cuda()
|
||||
self.lambda_inc = lambda_inc # 控制内容丰富区域的权重增量
|
||||
self.eta_ratio = eta_ratio # 选择内容丰富区域的比例
|
||||
self.criterionGAN = networks.GANLoss('lsgan').cuda() # 使用 LSGAN 损失
|
||||
|
||||
def compute_cosine_similarity(self, gradients):
|
||||
mean_grad = torch.mean(gradients, dim=1, keepdim=True)
|
||||
return F.cosine_similarity(gradients, mean_grad, dim=2)
|
||||
def compute_cosine_similarity(self, grad_patch, grad_mean):
|
||||
"""
|
||||
计算每个 token 梯度与整体平均梯度的余弦相似度
|
||||
Args:
|
||||
grad_patch: [B, N, D],每个 token 的梯度(来自 scores)
|
||||
grad_mean: [B, D],整体平均梯度
|
||||
Returns:
|
||||
cosine: [B, N],余弦相似度 δ_i
|
||||
"""
|
||||
# 对每个 token 计算余弦相似度
|
||||
cosine = F.cosine_similarity(grad_patch, grad_mean.unsqueeze(1), dim=2) # [B, N]
|
||||
return cosine
|
||||
|
||||
def generate_weight_map(self, gradients):
|
||||
cosine = self.compute_cosine_similarity(gradients)
|
||||
k = int(self.eta_ratio * cosine.shape[1])
|
||||
_, indices = torch.topk(-cosine, k, dim=1)
|
||||
def generate_weight_map(self, cosine):
|
||||
"""
|
||||
根据余弦相似度生成权重图
|
||||
Args:
|
||||
cosine: [B, N],余弦相似度 δ_i
|
||||
Returns:
|
||||
weights: [B, N],权重图 w_i
|
||||
"""
|
||||
B, N = cosine.shape
|
||||
k = int(self.eta_ratio * N) # 选择 eta_ratio 比例的 token
|
||||
_, indices = torch.topk(-cosine, k, dim=1) # 选择偏离最大的 k 个 token
|
||||
weights = torch.ones_like(cosine)
|
||||
weights.scatter_(1, indices, self.lambda_inc / (1e-6 + torch.abs(cosine.gather(1, indices))))
|
||||
for b in range(B):
|
||||
selected_cosine = cosine[b, indices[b]]
|
||||
weights[b, indices[b]] = self.lambda_inc / (torch.exp(torch.abs(selected_cosine)) + 1e-6)
|
||||
return weights
|
||||
|
||||
def forward(self, features, scores, target):
|
||||
def forward(self, scores, target):
|
||||
"""
|
||||
前向传播,计算加权后的 GAN 损失
|
||||
Args:
|
||||
features: 特征张量(可以是判别器的 real/fake 特征,或生成器的 fake 特征)
|
||||
scores: 判别器对特征的预测得分
|
||||
target: 目标标签(True 表示希望判为真,False 表示希望判为假)
|
||||
scores: [B, N, D],判别器的预测得分
|
||||
target: 目标标签(True 或 False)
|
||||
Returns:
|
||||
loss: 加权后的 GAN 损失
|
||||
weight: 生成的权重图
|
||||
weighted_loss: 加权后的 GAN 损失
|
||||
weight: 权重图 [B, N]
|
||||
"""
|
||||
self.gradients.clear()
|
||||
# 注册梯度钩子
|
||||
hook = lambda grad: self.gradients.append(grad.detach())
|
||||
features.register_hook(hook)
|
||||
# 计算原始 GAN 损失(假设 criterionGAN 返回 [B, N] 的损失分布)
|
||||
loss = self.criterionGAN(scores, target)
|
||||
|
||||
# 触发梯度计算
|
||||
scores.mean().backward(retain_graph=True)
|
||||
# 捕获 scores 的梯度,形状为 [B, N, D]
|
||||
grad_scores = torch.autograd.grad(loss, scores, retain_graph=True)[0]
|
||||
|
||||
# 获取梯度并调整维度
|
||||
grad = self.gradients[0].flatten(1) # [B, N, D] → [B, N*D]
|
||||
weight = self.generate_weight_map(grad.view(*features.shape))
|
||||
# 计算整体平均梯度(在 N 维度上求均值)
|
||||
grad_mean = torch.mean(grad_scores, dim=1) # [B, D]
|
||||
|
||||
# 计算加权 GAN 损失
|
||||
loss = torch.mean(weight * self.criterionGAN(scores, target))
|
||||
return loss, weight
|
||||
# 计算余弦相似度 δ_i
|
||||
cosine = self.compute_cosine_similarity(grad_scores, grad_mean) # [B, N]
|
||||
|
||||
# 生成权重图 w_i
|
||||
weight = self.generate_weight_map(cosine) # [B, N]
|
||||
|
||||
# 计算加权后的 GAN 损失
|
||||
weighted_loss = torch.mean(weight * self.criterionGAN(scores, target))
|
||||
|
||||
return weighted_loss, weight
|
||||
|
||||
|
||||
class ContentAwareTemporalNorm(nn.Module):
|
||||
def __init__(self, gamma_stride=0.1, kernel_size=21, sigma=5.0):
|
||||
@ -115,65 +114,51 @@ class ContentAwareTemporalNorm(nn.Module):
|
||||
self.smoother = GaussianBlur(kernel_size, sigma=sigma) # 高斯平滑层
|
||||
|
||||
def upsample_weight_map(self, weight_patch, target_size=(256, 256)):
|
||||
"""
|
||||
将patch级别的权重图上采样到目标分辨率
|
||||
Args:
|
||||
weight_patch: [B, 1, 24, 24] 来自ViT的patch权重图
|
||||
target_size: 目标分辨率 (H, W)
|
||||
Returns:
|
||||
weight_full: [B, 1, 256, 256] 上采样后的全分辨率权重图
|
||||
"""
|
||||
# 使用双线性插值上采样
|
||||
B = weight_patch.shape[0]
|
||||
weight_patch = weight_patch.view(B, 1, 24, 24)
|
||||
|
||||
# weight_patch: [B, 1, H, W] 来自转换后的 weight_map
|
||||
weight_full = F.interpolate(
|
||||
weight_patch,
|
||||
size=target_size,
|
||||
mode='bilinear',
|
||||
mode='bilinear', # 或 'nearest',根据需求选择
|
||||
align_corners=False
|
||||
)
|
||||
|
||||
# 对每个16x16的patch内部保持权重一致(可选)
|
||||
# 通过平均池化再扩展,消除插值引入的渐变
|
||||
weight_full = F.avg_pool2d(weight_full, kernel_size=16, stride=16)
|
||||
weight_full = F.interpolate(weight_full, scale_factor=16, mode='nearest')
|
||||
|
||||
return weight_full
|
||||
|
||||
|
||||
def forward(self, weight_map):
|
||||
"""
|
||||
生成内容感知光流
|
||||
Args:
|
||||
weight_map: [B, 1, H, W] 权重图(来自内容感知优化模块)
|
||||
weight_map: [B, N] 权重图(来自 ContentAwareOptimization),其中 N=576
|
||||
Returns:
|
||||
F_content: [B, 2, H, W] 生成的光流场(x/y方向位移)
|
||||
"""
|
||||
B = weight_map.shape[0]
|
||||
N = weight_map.shape[1]
|
||||
# 假设 N 为完全平方数,计算边长(例如 576 -> 24x24)
|
||||
side = int(math.sqrt(N))
|
||||
weight_map_2d = weight_map.view(B, 1, side, side) # 转换为 [B, 1, side, side]
|
||||
|
||||
# 上采样权重图到全分辨率
|
||||
weight_full = self.upsample_weight_map(weight_map) # [B,1,384,384]
|
||||
weight_full = self.upsample_weight_map(weight_map_2d) # [B, 1, 256, 256](例如)
|
||||
|
||||
# 1. 归一化权重图
|
||||
# 保持区域相对强度,同时限制数值范围
|
||||
weight_norm = F.normalize(weight_full, p=1, dim=(2,3)) # L1归一化 [B,1,H,W]
|
||||
# 归一化权重图(L1归一化)
|
||||
weight_norm = F.normalize(weight_full, p=1, dim=(2,3))
|
||||
|
||||
# 2. 生成高斯噪声
|
||||
# 生成高斯噪声
|
||||
B, _, H, W = weight_norm.shape
|
||||
z = torch.randn(B, 2, H, W, device=weight_norm.device) # [B,2,H,W]
|
||||
z = torch.randn(B, 2, H, W, device=weight_norm.device)
|
||||
|
||||
# 3. 合成基础光流
|
||||
# 将权重图扩展为2通道(x/y方向共享权重)
|
||||
weight_expanded = weight_norm.expand(-1, 2, -1, -1) # [B,2,H,W]
|
||||
F_raw = self.gamma_stride * weight_expanded * z # [B,2,H,W] #公式9
|
||||
# 合成基础光流
|
||||
weight_expanded = weight_norm.expand(-1, 2, -1, -1)
|
||||
F_raw = self.gamma_stride * weight_expanded * z
|
||||
|
||||
# 4. 平滑处理(保持结构连续性)
|
||||
# 对每个通道独立进行高斯模糊
|
||||
F_smooth = self.smoother(F_raw) # [B,2,H,W]
|
||||
# 平滑处理
|
||||
F_smooth = self.smoother(F_raw)
|
||||
|
||||
# 5. 动态范围调整(可选)
|
||||
# 限制光流幅值,避免极端位移
|
||||
F_content = torch.tanh(F_smooth) # 缩放到[-1,1]范围
|
||||
# 动态范围调整
|
||||
F_content = torch.tanh(F_smooth)
|
||||
|
||||
return F_content
|
||||
return F_content
|
||||
|
||||
|
||||
class RomaUnsbModel(BaseModel):
|
||||
@staticmethod
|
||||
@ -327,21 +312,21 @@ class RomaUnsbModel(BaseModel):
|
||||
|
||||
# 处理 real_B0 和 fake_B0
|
||||
real_B0_tokens = self.mutil_real_B0_tokens[0]
|
||||
pred_real0, real_features0 = self.netD_ViT(real_B0_tokens)
|
||||
pred_real0 = self.netD_ViT(real_B0_tokens)
|
||||
fake_B0_tokens = self.mutil_fake_B0_tokens[0].detach()
|
||||
pred_fake0, fake_features0 = self.netD_ViT(fake_B0_tokens)
|
||||
pred_fake0 = self.netD_ViT(fake_B0_tokens)
|
||||
|
||||
loss_real0, self.weight_real0 = self.cao(real_features0, pred_real0, True)
|
||||
loss_fake0, self.weight_fake0 = self.cao(fake_features0, pred_fake0, False)
|
||||
loss_real0, self.weight_real0 = self.cao( pred_real0, True)
|
||||
loss_fake0, self.weight_fake0 = self.cao( pred_fake0, False)
|
||||
|
||||
# 处理 real_B1 和 fake_B1
|
||||
real_B1_tokens = self.mutil_real_B1_tokens[0]
|
||||
pred_real1, real_features1 = self.netD_ViT(real_B1_tokens)
|
||||
pred_real1 = self.netD_ViT(real_B1_tokens)
|
||||
fake_B1_tokens = self.mutil_fake_B1_tokens[0].detach()
|
||||
pred_fake1, fake_features1 = self.netD_ViT(fake_B1_tokens)
|
||||
pred_fake1 = self.netD_ViT(fake_B1_tokens)
|
||||
|
||||
loss_real1, self.weight_real1 = self.cao(real_features1, pred_real1, True)
|
||||
loss_fake1, self.weight_fake1 = self.cao(fake_features1, pred_fake1, False)
|
||||
loss_real1, self.weight_real1 = self.cao( pred_real1, True)
|
||||
loss_fake1, self.weight_fake1 = self.cao( pred_fake1, False)
|
||||
|
||||
# 综合损失
|
||||
self.loss_D_ViT = (loss_real0 + loss_fake0 + loss_real1 + loss_fake1) * 0.25 * lambda_D_ViT
|
||||
@ -380,8 +365,8 @@ class RomaUnsbModel(BaseModel):
|
||||
# 计算 GAN 损失(引入 ContentAwareOptimization)
|
||||
if self.opt.lambda_GAN > 0.0:
|
||||
|
||||
pred_fake0,_ = self.netD_ViT(self.mutil_fake_B0_tokens[0])
|
||||
pred_fake1,_ = self.netD_ViT(self.mutil_fake_B1_tokens[0])
|
||||
pred_fake0 = self.netD_ViT(self.mutil_fake_B0_tokens[0])
|
||||
pred_fake1 = self.netD_ViT(self.mutil_fake_B1_tokens[0])
|
||||
self.loss_G_GAN0 = self.criterionGAN(pred_fake0, True).mean()
|
||||
self.loss_G_GAN1 = self.criterionGAN(pred_fake1, True).mean()
|
||||
self.loss_G_GAN = (self.loss_G_GAN0 + self.loss_G_GAN1)*0.5
|
||||
|
||||
391
models/roma_unsb_single_model.py
Normal file
391
models/roma_unsb_single_model.py
Normal file
@ -0,0 +1,391 @@
|
||||
import numpy as np
|
||||
import math
|
||||
import timm
|
||||
import torch
|
||||
import torchvision.models as models
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torchvision.transforms import GaussianBlur
|
||||
from .base_model import BaseModel
|
||||
from . import networks
|
||||
from .patchnce import PatchNCELoss
|
||||
import util.util as util
|
||||
|
||||
from torchvision.transforms import transforms as tfs
|
||||
|
||||
def warp(image, flow): #warp操作
|
||||
"""
|
||||
基于光流的图像变形函数
|
||||
Args:
|
||||
image: [B, C, H, W] 输入图像
|
||||
flow: [B, 2, H, W] 光流场(x/y方向位移)
|
||||
Returns:
|
||||
warped: [B, C, H, W] 变形后的图像
|
||||
"""
|
||||
B, C, H, W = image.shape
|
||||
# 生成网格坐标
|
||||
grid_x, grid_y = torch.meshgrid(torch.arange(W), torch.arange(H))
|
||||
grid = torch.stack((grid_x, grid_y), dim=0).float().to(image.device) # [2,H,W]
|
||||
grid = grid.unsqueeze(0).repeat(B,1,1,1) # [B,2,H,W]
|
||||
|
||||
# 应用光流位移(归一化到[-1,1])
|
||||
new_grid = grid + flow
|
||||
new_grid[:,0,:,:] = 2.0 * new_grid[:,0,:,:] / (W-1) - 1.0 # x方向
|
||||
new_grid[:,1,:,:] = 2.0 * new_grid[:,1,:,:] / (H-1) - 1.0 # y方向
|
||||
new_grid = new_grid.permute(0,2,3,1) # [B,H,W,2]
|
||||
|
||||
# 双线性插值
|
||||
return F.grid_sample(image, new_grid, align_corners=True)
|
||||
|
||||
|
||||
class ContentAwareOptimization(nn.Module):
|
||||
def __init__(self, lambda_inc=2.0, eta_ratio=0.4):
|
||||
super().__init__()
|
||||
self.lambda_inc = lambda_inc # 控制内容丰富区域的权重增量
|
||||
self.eta_ratio = eta_ratio # 选择内容丰富区域的比例
|
||||
self.criterionGAN = networks.GANLoss('lsgan').cuda() # 使用 LSGAN 损失
|
||||
|
||||
def compute_cosine_similarity(self, grad_patch, grad_mean):
|
||||
"""
|
||||
计算每个 token 梯度与整体平均梯度的余弦相似度
|
||||
Args:
|
||||
grad_patch: [B, N, D],每个 token 的梯度(来自 scores)
|
||||
grad_mean: [B, D],整体平均梯度
|
||||
Returns:
|
||||
cosine: [B, N],余弦相似度 δ_i
|
||||
"""
|
||||
# 对每个 token 计算余弦相似度
|
||||
cosine = F.cosine_similarity(grad_patch, grad_mean.unsqueeze(1), dim=2) # [B, N]
|
||||
return cosine
|
||||
|
||||
def generate_weight_map(self, cosine):
|
||||
"""
|
||||
根据余弦相似度生成权重图
|
||||
Args:
|
||||
cosine: [B, N],余弦相似度 δ_i
|
||||
Returns:
|
||||
weights: [B, N],权重图 w_i
|
||||
"""
|
||||
B, N = cosine.shape
|
||||
k = int(self.eta_ratio * N) # 选择 eta_ratio 比例的 token
|
||||
_, indices = torch.topk(-cosine, k, dim=1) # 选择偏离最大的 k 个 token
|
||||
weights = torch.ones_like(cosine)
|
||||
for b in range(B):
|
||||
selected_cosine = cosine[b, indices[b]]
|
||||
weights[b, indices[b]] = self.lambda_inc / (torch.exp(torch.abs(selected_cosine)) + 1e-6)
|
||||
return weights
|
||||
|
||||
def forward(self, scores, target):
|
||||
"""
|
||||
前向传播,计算加权后的 GAN 损失
|
||||
Args:
|
||||
scores: [B, N, D],判别器的预测得分
|
||||
target: 目标标签(True 或 False)
|
||||
Returns:
|
||||
weighted_loss: 加权后的 GAN 损失
|
||||
weight: 权重图 [B, N]
|
||||
"""
|
||||
# 计算原始 GAN 损失(假设 criterionGAN 返回 [B, N] 的损失分布)
|
||||
loss = self.criterionGAN(scores, target)
|
||||
|
||||
# 捕获 scores 的梯度,形状为 [B, N, D]
|
||||
grad_scores = torch.autograd.grad(loss, scores, retain_graph=True)[0]
|
||||
|
||||
# 计算整体平均梯度(在 N 维度上求均值)
|
||||
grad_mean = torch.mean(grad_scores, dim=1) # [B, D]
|
||||
|
||||
# 计算余弦相似度 δ_i
|
||||
cosine = self.compute_cosine_similarity(grad_scores, grad_mean) # [B, N]
|
||||
|
||||
# 生成权重图 w_i
|
||||
weight = self.generate_weight_map(cosine) # [B, N]
|
||||
|
||||
# 计算加权后的 GAN 损失
|
||||
weighted_loss = torch.mean(weight * self.criterionGAN(scores, target))
|
||||
|
||||
return weighted_loss, weight
|
||||
|
||||
|
||||
class ContentAwareTemporalNorm(nn.Module):
|
||||
def __init__(self, gamma_stride=0.1, kernel_size=21, sigma=5.0):
|
||||
super().__init__()
|
||||
self.gamma_stride = gamma_stride # 控制整体运动幅度
|
||||
self.smoother = GaussianBlur(kernel_size, sigma=sigma) # 高斯平滑层
|
||||
|
||||
def upsample_weight_map(self, weight_patch, target_size=(256, 256)):
|
||||
# weight_patch: [B, 1, H, W] 来自转换后的 weight_map
|
||||
weight_full = F.interpolate(
|
||||
weight_patch,
|
||||
size=target_size,
|
||||
mode='bilinear', # 或 'nearest',根据需求选择
|
||||
align_corners=False
|
||||
)
|
||||
return weight_full
|
||||
|
||||
def forward(self, weight_map):
|
||||
"""
|
||||
生成内容感知光流
|
||||
Args:
|
||||
weight_map: [B, N] 权重图(来自 ContentAwareOptimization),其中 N=576
|
||||
Returns:
|
||||
F_content: [B, 2, H, W] 生成的光流场(x/y方向位移)
|
||||
"""
|
||||
B = weight_map.shape[0]
|
||||
N = weight_map.shape[1]
|
||||
# 假设 N 为完全平方数,计算边长(例如 576 -> 24x24)
|
||||
side = int(math.sqrt(N))
|
||||
weight_map_2d = weight_map.view(B, 1, side, side) # 转换为 [B, 1, side, side]
|
||||
|
||||
# 上采样权重图到全分辨率
|
||||
weight_full = self.upsample_weight_map(weight_map_2d) # [B, 1, 256, 256](例如)
|
||||
|
||||
# 归一化权重图(L1归一化)
|
||||
weight_norm = F.normalize(weight_full, p=1, dim=(2,3))
|
||||
|
||||
# 生成高斯噪声
|
||||
B, _, H, W = weight_norm.shape
|
||||
z = torch.randn(B, 2, H, W, device=weight_norm.device)
|
||||
|
||||
# 合成基础光流
|
||||
weight_expanded = weight_norm.expand(-1, 2, -1, -1)
|
||||
F_raw = self.gamma_stride * weight_expanded * z
|
||||
|
||||
# 平滑处理
|
||||
F_smooth = self.smoother(F_raw)
|
||||
|
||||
# 动态范围调整
|
||||
F_content = torch.tanh(F_smooth)
|
||||
|
||||
return F_content
|
||||
class RomaUnsbSingleModel(BaseModel):
|
||||
@staticmethod
|
||||
def modify_commandline_options(parser, is_train=True):
|
||||
"""配置 CTNx 模型的特定选项"""
|
||||
|
||||
parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss: GAN(G(X))')
|
||||
|
||||
parser.add_argument('--lambda_ctn', type=float, default=1.0, help='weight for content-aware temporal norm')
|
||||
parser.add_argument('--lambda_D_ViT', type=float, default=1.0, help='weight for discriminator')
|
||||
parser.add_argument('--lambda_global', type=float, default=1.0, help='weight for Global Structural Consistency')
|
||||
parser.add_argument('--lambda_spatial', type=float, default=1.0, help='weight for Local Structural Consistency')
|
||||
parser.add_argument('--lambda_inc', type=float, default=1.0, help='incremental weight for content-aware optimization')
|
||||
parser.add_argument('--local_nums', type=int, default=64, help='number of local patches')
|
||||
parser.add_argument('--side_length', type=int, default=7)
|
||||
parser.add_argument('--nce_layers', type=str, default='0,4,8,12,16', help='compute NCE loss on which layers')
|
||||
|
||||
parser.add_argument('--eta_ratio', type=float, default=0.4, help='ratio of content-rich regions')
|
||||
parser.add_argument('--gamma_stride', type=float, default=20, help='ratio of stride for computing the similarity matrix')
|
||||
parser.add_argument('--atten_layers', type=str, default='5', help='compute Cross-Similarity on which layers')
|
||||
|
||||
parser.add_argument('--tau', type=float, default=0.01, help='Entropy parameter')
|
||||
parser.add_argument('--num_timesteps', type=int, default=5, help='# of discrim filters in the first conv layer')
|
||||
|
||||
parser.add_argument('--n_mlp', type=int, default=3, help='only used if netD==n_layers')
|
||||
|
||||
opt, _ = parser.parse_known_args()
|
||||
|
||||
return parser
|
||||
|
||||
def __init__(self, opt):
|
||||
BaseModel.__init__(self, opt)
|
||||
|
||||
|
||||
self.loss_names = ['G_GAN', 'D_ViT', 'G', 'global', 'spatial','ctn']
|
||||
self.visual_names = ['real_A', 'fake_B', 'real_B']
|
||||
self.atten_layers = [int(i) for i in self.opt.atten_layers.split(',')]
|
||||
|
||||
|
||||
if self.isTrain:
|
||||
self.model_names = ['G', 'D_ViT']
|
||||
else: # during test time, only load G
|
||||
self.model_names = ['G']
|
||||
|
||||
|
||||
# define networks (both generator and discriminator)
|
||||
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt)
|
||||
|
||||
|
||||
if self.isTrain:
|
||||
|
||||
self.netD_ViT = networks.MLPDiscriminator().to(self.device)
|
||||
# self.netPreViT = timm.create_model("vit_base_patch32_384",pretrained=True).to(self.device)
|
||||
self.netPreViT = timm.create_model("vit_base_patch16_384",pretrained=True).to(self.device)
|
||||
|
||||
|
||||
self.resize = tfs.Resize(size=(384,384))
|
||||
# self.resize = tfs.Resize(size=(224, 224))
|
||||
|
||||
# define loss functions
|
||||
self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
|
||||
|
||||
self.criterionL1 = torch.nn.L1Loss().to(self.device)
|
||||
self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
|
||||
self.optimizer_D_ViT = torch.optim.Adam(self.netD_ViT.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
|
||||
self.optimizers.append(self.optimizer_G)
|
||||
self.optimizers.append(self.optimizer_D_ViT)
|
||||
|
||||
self.cao = ContentAwareOptimization(opt.lambda_inc, opt.eta_ratio) #损失函数
|
||||
self.ctn = ContentAwareTemporalNorm() #生成的伪光流
|
||||
def data_dependent_initialize(self, data):
|
||||
"""
|
||||
The feature network netF is defined in terms of the shape of the intermediate, extracted
|
||||
features of the encoder portion of netG. Because of this, the weights of netF are
|
||||
initialized at the first feedforward pass with some input images.
|
||||
Please also see PatchSampleF.create_mlp(), which is called at the first forward() call.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def optimize_parameters(self):
|
||||
# forward
|
||||
self.forward()
|
||||
|
||||
# update D
|
||||
self.set_requires_grad(self.netD_ViT, True)
|
||||
self.optimizer_D_ViT.zero_grad()
|
||||
self.loss_D = self.compute_D_loss()
|
||||
self.loss_D.backward()
|
||||
self.optimizer_D_ViT.step()
|
||||
|
||||
# update G
|
||||
self.set_requires_grad(self.netD_ViT, False)
|
||||
self.optimizer_G.zero_grad()
|
||||
self.loss_G = self.compute_G_loss()
|
||||
self.loss_G.backward()
|
||||
self.optimizer_G.step()
|
||||
|
||||
def set_input(self, input):
|
||||
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
|
||||
Parameters:
|
||||
input (dict): include the data itself and its metadata information.
|
||||
The option 'direction' can be used to swap domain A and domain B.
|
||||
"""
|
||||
AtoB = self.opt.direction == 'AtoB'
|
||||
self.real_A = input['A' if AtoB else 'B'].to(self.device)
|
||||
self.real_B = input['B' if AtoB else 'A'].to(self.device)
|
||||
self.image_paths = input['A_paths' if AtoB else 'B_paths']
|
||||
|
||||
def forward(self):
|
||||
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
|
||||
self.fake_B = self.netG(self.real_A)
|
||||
|
||||
if self.opt.isTrain:
|
||||
real_A = self.real_A
|
||||
real_B = self.real_B
|
||||
fake_B = self.fake_B
|
||||
self.real_A_resize = self.resize(real_A)
|
||||
real_B = self.resize(real_B)
|
||||
self.fake_B_resize = self.resize(fake_B)
|
||||
self.mutil_real_A_tokens = self.netPreViT(self.real_A_resize, self.atten_layers, get_tokens=True)
|
||||
self.mutil_real_B_tokens = self.netPreViT(real_B, self.atten_layers, get_tokens=True)
|
||||
self.mutil_fake_B_tokens = self.netPreViT(self.fake_B_resize, self.atten_layers, get_tokens=True)
|
||||
|
||||
|
||||
def compute_D_loss(self):
|
||||
"""Calculate GAN loss for the discriminator"""
|
||||
|
||||
|
||||
lambda_D_ViT = self.opt.lambda_D_ViT
|
||||
fake_B_tokens = self.mutil_fake_B_tokens[0].detach()
|
||||
real_B_tokens = self.mutil_real_B_tokens[0]
|
||||
pre_fake_ViT = self.netD_ViT(fake_B_tokens)
|
||||
pred_real_ViT = self.netD_ViT(real_B_tokens)
|
||||
|
||||
self.loss_D_real_ViT , self.weight_real = self.cao(pred_real_ViT, True)
|
||||
self.loss_D_fake_ViT , self.weight_fake = self.cao(pre_fake_ViT, False)
|
||||
|
||||
self.loss_D_ViT = (self.loss_D_fake_ViT + self.loss_D_real_ViT) * 0.5* lambda_D_ViT
|
||||
|
||||
|
||||
return self.loss_D_ViT
|
||||
|
||||
def compute_G_loss(self):
|
||||
if self.opt.lambda_ctn > 0.0:
|
||||
# 生成光流图(使用判别器的权重)
|
||||
self.f_content = self.ctn(self.weight_fake.detach())
|
||||
|
||||
# 变换后的图片
|
||||
self.warped_real_A = warp(self.real_A, self.f_content)
|
||||
self.warped_fake_B = warp(self.fake_B, self.f_content)
|
||||
# 第二次生成
|
||||
self.warped_fake_B2 = self.netG(self.warped_real_A)
|
||||
|
||||
# 计算损失
|
||||
self.loss_ctn = self.criterionL1(self.warped_fake_B, self.warped_fake_B2) * self.opt.lambda_ctn
|
||||
else:
|
||||
self.loss_ctn = 0.0
|
||||
|
||||
# if self.opt.lambda_GAN > 0.0:
|
||||
|
||||
# fake_B_tokens = self.mutil_fake_B_tokens[0]
|
||||
# pred_fake_ViT = self.netD_ViT(fake_B_tokens)
|
||||
# self.loss_G_GAN = self.criterionGAN(pred_fake_ViT, True) * self.opt.lambda_GAN
|
||||
# else:
|
||||
# self.loss_G_GAN = 0.0
|
||||
if self.opt.lambda_GAN > 0.0:
|
||||
|
||||
fake_B_tokens = self.mutil_fake_B_tokens[0]
|
||||
pred_fake_ViT = self.netD_ViT(fake_B_tokens)
|
||||
self.loss_G_fake_ViT , self.weight_real = self.cao(pred_fake_ViT, True)
|
||||
self.loss_G_GAN = self.loss_G_fake_ViT * self.opt.lambda_GAN
|
||||
else:
|
||||
self.loss_G_GAN = 0.0
|
||||
if self.opt.lambda_global > 0.0 or self.opt.lambda_spatial > 0.0:
|
||||
self.loss_global, self.loss_spatial = self.calculate_attention_loss()
|
||||
else:
|
||||
self.loss_global, self.loss_spatial = 0.0, 0.0
|
||||
|
||||
|
||||
|
||||
self.loss_G = self.loss_G_GAN + self.loss_global + self.loss_spatial + self.loss_ctn
|
||||
return self.loss_G
|
||||
|
||||
def calculate_attention_loss(self):
|
||||
n_layers = len(self.atten_layers)
|
||||
mutil_real_A_tokens = self.mutil_real_A_tokens
|
||||
mutil_fake_B_tokens = self.mutil_fake_B_tokens
|
||||
|
||||
|
||||
|
||||
if self.opt.lambda_global > 0.0:
|
||||
loss_global = self.calculate_similarity(mutil_real_A_tokens, mutil_fake_B_tokens)
|
||||
|
||||
|
||||
else:
|
||||
loss_global = 0.0
|
||||
|
||||
if self.opt.lambda_spatial > 0.0:
|
||||
loss_spatial = 0.0
|
||||
local_nums = self.opt.local_nums
|
||||
tokens_cnt = 576
|
||||
local_id = np.random.permutation(tokens_cnt)
|
||||
local_id = local_id[:int(min(local_nums, tokens_cnt))]
|
||||
|
||||
mutil_real_A_local_tokens = self.netPreViT(self.real_A_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
|
||||
|
||||
mutil_fake_B_local_tokens = self.netPreViT(self.fake_B_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
|
||||
|
||||
loss_spatial = self.calculate_similarity(mutil_real_A_local_tokens, mutil_fake_B_local_tokens)
|
||||
|
||||
|
||||
else:
|
||||
loss_spatial = 0.0
|
||||
|
||||
|
||||
|
||||
return loss_global * self.opt.lambda_global, loss_spatial * self.opt.lambda_spatial
|
||||
|
||||
def calculate_similarity(self, mutil_src_tokens, mutil_tgt_tokens):
|
||||
loss = 0.0
|
||||
n_layers = len(self.atten_layers)
|
||||
|
||||
for src_tokens, tgt_tokens in zip(mutil_src_tokens, mutil_tgt_tokens):
|
||||
|
||||
src_tgt = src_tokens.bmm(tgt_tokens.permute(0,2,1))
|
||||
tgt_src = tgt_tokens.bmm(src_tokens.permute(0,2,1))
|
||||
cos_dis_global = F.cosine_similarity(src_tgt, tgt_src, dim=-1)
|
||||
loss += self.criterionL1(torch.ones_like(cos_dis_global), cos_dis_global).mean()
|
||||
|
||||
loss = loss / n_layers
|
||||
return loss
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
# CUDA_VISIBLE_DEVICES=0 python test.py --dataroot /path/of/test_dataset --checkpoints_dir ./checkpoints --name train1 --model roma_single --num_test 10000 --epoch latest
|
||||
python test.py --dataroot /home/openxs/kunyu/datasets/InfraredCity-Lite/Single/Monitor --checkpoints_dir ./checkpoints --name cp_4 --model roma_single --num_test 4132 --epoch 150 --gpu_id 2
|
||||
python test.py --dataroot /home/openxs/kunyu/datasets/InfraredCity-Lite/Single/Monitor --checkpoints_dir ./checkpoints --name cp_2 --model roma_single --num_test 4132 --epoch 120 --gpu_id 1
|
||||
@ -1,16 +1,16 @@
|
||||
python train.py \
|
||||
--dataroot /home/openxs/kunyu/datasets/InfraredCity-Lite/Double/Moitor \
|
||||
--name cp_5 \
|
||||
--dataset_mode unaligned_double \
|
||||
--display_env CP \
|
||||
--model roma_unsb \
|
||||
--lambda_ctn 0 \
|
||||
--dataroot /home/openxs/kunyu/datasets/InfraredCity-Lite/Single/Monitor \
|
||||
--name cp_2 \
|
||||
--dataset_mode unaligned \
|
||||
--display_env NEWCP \
|
||||
--model roma_unsb_single \
|
||||
--lambda_ctn 10 \
|
||||
--lambda_inc 8.0 \
|
||||
--lambda_global 6.0 \
|
||||
--lambda_spatial 6.0 \
|
||||
--gamma_stride 20 \
|
||||
--lr 0.000005 \
|
||||
--gpu_id 0 \
|
||||
--lr 0.000002 \
|
||||
--gpu_id 1 \
|
||||
--eta_ratio 0.4 \
|
||||
--n_epochs 100 \
|
||||
--n_epochs_decay 100 \
|
||||
@ -18,4 +18,6 @@ python train.py \
|
||||
# cp2 修了一下cp1的代码,--lr 0.000002
|
||||
# cp3 加了--lambda_inc 8.0 --gpu_id 2
|
||||
# cp4 在cp3的基础上把梯度增强给到了生成器中的ganloss --gpu_id 1
|
||||
# cp5 在cp3的基础上,--lambda_ctn 0 ,--gpu_id 0.--lr 0.000005
|
||||
# cp5 在cp3的基础上,--lambda_ctn 0 ,--gpu_id 0.--lr 0.000005
|
||||
# # newcp1 重新调整了光流算法,并且弄成单帧的脚本了,这一次是最终的复现了。--gpu_id 0
|
||||
# # newcp2 把梯度图对loss的影响同样加到了G_GAN中。
|
||||
Loading…
x
Reference in New Issue
Block a user