first commit

This commit is contained in:
areszz 2025-02-22 14:21:54 +08:00
commit 5cb1f58852
411 changed files with 47043 additions and 0 deletions

142
README.md Normal file
View File

@ -0,0 +1,142 @@
# ROMA
This repository is the official Pytorch implementation for ACM MM'22 paper
"ROMA: Cross-Domain Region Similarity Matching for Unpaired Nighttime Infrared to Daytime Visible Video Translation".[[Arxiv]](https://arxiv.org/abs/2204.12367)
**Examples of Object Detection:**
![detection1](./images/detection1.gif)
![](./images/detection2.gif)
**Examples of Video Fusion**
![fusion](./images/fusion.gif)
More experimental results can be obtained by contacting us.
# Introduction
## Method
![Method](images/method_final.jpg)
- The domain gaps between unpaired nighttime infrared and daytime visible videos are even huger than paired ones that captured at the same time, establishing an effective translation mapping will greatly contribute to various fields.
- Our proposed cross-similarity, which are calculated across domains, could make the generative process focus on learning the content of structural correspondence between real and synthesized frames, getting rid of the negative effects of different styles.
## Training
The following is the required structure of dataset. For the video mode, the input of a single data is the result of concatenating **two adjacent frames**; for the image mode, the input of a single data is **a single image**.
```
Video/Image mode:
trainA: \Path\of\trainA
trainB: \Path\of\trainB
```
Concrete examples of the training and testing are shown in the script files `./scripts/train.sh` and `./scripts/test.sh`, respectively.
## InfraredCity and InfraredCity-Lite Dataset
<table class="tg">
<thead>
<tr>
<th class="tg-uzvj" colspan="2">InfraredCity</th>
<th class="tg-uzvj" colspan="4">Total Frame</th>
</tr>
</thead>
<tbody>
<tr>
<td class="tg-9wq8" colspan="2">Nighttime Infrared</td>
<td class="tg-9wq8" colspan="4">201,856</td>
</tr>
<tr>
<td class="tg-9wq8" colspan="2">Nighttime Visible</td>
<td class="tg-9wq8" colspan="4">178,698</td>
</tr>
<tr>
<td class="tg-9wq8" colspan="2">Daytime Visible</td>
<td class="tg-9wq8" colspan="4">199,430</td>
</tr>
<tr>
<td class="tg-9wq8" colspan="6"></td>
</tr>
<tr>
<td class="tg-uzvj" colspan="2">InfraredCity-Lite</td>
<td class="tg-uzvj">Infrared<br>Train</td>
<td class="tg-uzvj">Infrared<br>Test</td>
<td class="tg-uzvj">Visible<br>Train</td>
<td class="tg-uzvj">Total</td>
</tr>
<tr>
<td class="tg-9wq8" rowspan="2">City</td>
<td class="tg-9wq8">clearday</td>
<td class="tg-9wq8">5,538</td>
<td class="tg-9wq8">1,000</td>
<td class="tg-9wq8" rowspan="2">5360</td>
<td class="tg-9wq8" rowspan="2">15,180</td>
</tr>
<tr>
<td class="tg-9wq8">overcast</td>
<td class="tg-9wq8">2,282</td>
<td class="tg-9wq8">1,000</td>
</tr>
<tr>
<td class="tg-9wq8" rowspan="2">Highway</td>
<td class="tg-9wq8">clearday</td>
<td class="tg-9wq8">4,412</td>
<td class="tg-9wq8">1,000</td>
<td class="tg-9wq8" rowspan="2">6,463</td>
<td class="tg-9wq8" rowspan="2">15,853</td>
</tr>
<tr>
<td class="tg-9wq8">overcast</td>
<td class="tg-9wq8">2,978</td>
<td class="tg-9wq8">1,000</td>
</tr>
<tr>
<td class="tg-9wq8" colspan="2">Monitor</td>
<td class="tg-9wq8">5,612</td>
<td class="tg-9wq8">500</td>
<td class="tg-9wq8">4,194</td>
<td class="tg-9wq8">10,306</td>
</tr>
</tbody>
</table>
The datasets and their more details are available in [InfiRay](http://openai.raytrontek.com/apply/Infrared_city.html/).
### Citation
If you find our work useful in your research or publication, please cite our work:
```
@inproceedings{ROMA2022,
title = {ROMA: Cross-Domain Region Similarity Matching for Unpaired Nighttime Infrared to Daytime Visible Video Translation},
author = {Zhenjie Yu and Kai Chen and Shuang Li and Bingfeng Han and Chi Harold Liu and Shuigen Wang},
booktitle = {ACM MM},
pages = {5294--5302},
year = {2022}
}
```
#### Acknowledgements
This code borrows heavily from the PyTorch implementation of [Cycle-GAN and Pix2Pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) and [CUT](https://github.com/taesungp/contrastive-unpaired-translation).
A huge thanks to them!
```
@inproceedings{CycleGAN2017,
title = {Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networkss},
author = {Zhu, Jun-Yan and Park, Taesung and Isola, Phillip and Efros, Alexei A},
booktitle = {ICCV},
year = {2017}
}
@inproceedings{CUT2020,
author = {Taesung Park and Alexei A. Efros and Richard Zhang and Jun{-}Yan Zhu},
title = {Contrastive Learning for Unpaired Image-to-Image Translation},
booktitle = {ECCV},
pages = {319--345},
year = {2020},
}
```

98
data/__init__.py Normal file
View File

@ -0,0 +1,98 @@
"""This package includes all the modules related to data loading and preprocessing
To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
You need to implement four functions:
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
-- <__len__>: return the size of dataset.
-- <__getitem__>: get a data point from data loader.
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
See our template dataset class 'template_dataset.py' for more details.
"""
import importlib
import torch.utils.data
from data.base_dataset import BaseDataset
def find_dataset_using_name(dataset_name):
"""Import the module "data/[dataset_name]_dataset.py".
In the file, the class called DatasetNameDataset() will
be instantiated. It has to be a subclass of BaseDataset,
and it is case-insensitive.
"""
dataset_filename = "data." + dataset_name + "_dataset"
datasetlib = importlib.import_module(dataset_filename)
dataset = None
target_dataset_name = dataset_name.replace('_', '') + 'dataset'
for name, cls in datasetlib.__dict__.items():
if name.lower() == target_dataset_name.lower() \
and issubclass(cls, BaseDataset):
dataset = cls
if dataset is None:
raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
return dataset
def get_option_setter(dataset_name):
"""Return the static method <modify_commandline_options> of the dataset class."""
dataset_class = find_dataset_using_name(dataset_name)
return dataset_class.modify_commandline_options
def create_dataset(opt):
"""Create a dataset given the option.
This function wraps the class CustomDatasetDataLoader.
This is the main interface between this package and 'train.py'/'test.py'
Example:
>>> from data import create_dataset
>>> dataset = create_dataset(opt)
"""
data_loader = CustomDatasetDataLoader(opt)
dataset = data_loader.load_data()
return dataset
class CustomDatasetDataLoader():
"""Wrapper class of Dataset class that performs multi-threaded data loading"""
def __init__(self, opt):
"""Initialize this class
Step 1: create a dataset instance given the name [dataset_mode]
Step 2: create a multi-threaded data loader.
"""
self.opt = opt
dataset_class = find_dataset_using_name(opt.dataset_mode)
self.dataset = dataset_class(opt)
print("dataset [%s] was created" % type(self.dataset).__name__)
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=opt.batch_size,
shuffle=not opt.serial_batches,
num_workers=int(opt.num_threads),
drop_last=True if opt.isTrain else False,
)
def set_epoch(self, epoch):
self.dataset.current_epoch = epoch
def load_data(self):
return self
def __len__(self):
"""Return the number of data in the dataset"""
return min(len(self.dataset), self.opt.max_dataset_size)
def __iter__(self):
"""Return a batch of data"""
for i, data in enumerate(self.dataloader):
if i * self.opt.batch_size >= self.opt.max_dataset_size:
break
yield data

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

230
data/base_dataset.py Normal file
View File

@ -0,0 +1,230 @@
"""This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
"""
import random
import numpy as np
import torch.utils.data as data
from PIL import Image
import torchvision.transforms as transforms
from abc import ABC, abstractmethod
class BaseDataset(data.Dataset, ABC):
"""This class is an abstract base class (ABC) for datasets.
To create a subclass, you need to implement the following four functions:
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
-- <__len__>: return the size of dataset.
-- <__getitem__>: get a data point.
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
"""
def __init__(self, opt):
"""Initialize the class; save the options in the class
Parameters:
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
"""
self.opt = opt
self.root = opt.dataroot
self.current_epoch = 0
@staticmethod
def modify_commandline_options(parser, is_train):
"""Add new dataset-specific options, and rewrite default values for existing options.
Parameters:
parser -- original option parser
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
Returns:
the modified parser.
"""
return parser
@abstractmethod
def __len__(self):
"""Return the total number of images in the dataset."""
return 0
@abstractmethod
def __getitem__(self, index):
"""Return a data point and its metadata information.
Parameters:
index - - a random integer for data indexing
Returns:
a dictionary of data with their names. It ususally contains the data itself and its metadata information.
"""
pass
def get_params(opt, size):
w, h = size
new_h = h
new_w = w
if opt.preprocess == 'resize_and_crop':
new_h = new_w = opt.load_size
elif opt.preprocess == 'scale_width_and_crop':
new_w = opt.load_size
new_h = opt.load_size * h // w
x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
y = random.randint(0, np.maximum(0, new_h - opt.crop_size))
flip = random.random() > 0.5
return {'crop_pos': (x, y), 'flip': flip}
def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True):
transform_list = []
if grayscale:
transform_list.append(transforms.Grayscale(1))
if 'fixsize' in opt.preprocess:
transform_list.append(transforms.Resize(params["size"], method))
if 'resize' in opt.preprocess:
osize = [opt.load_size, opt.load_size]
if "gta2cityscapes" in opt.dataroot:
osize[0] = opt.load_size // 2
transform_list.append(transforms.Resize(osize, method))
elif 'scale_width' in opt.preprocess:
transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, opt.crop_size, method)))
elif 'scale_shortside' in opt.preprocess:
transform_list.append(transforms.Lambda(lambda img: __scale_shortside(img, opt.load_size, opt.crop_size, method)))
if 'zoom' in opt.preprocess:
if params is None:
transform_list.append(transforms.Lambda(lambda img: __random_zoom(img, opt.load_size, opt.crop_size, method)))
else:
transform_list.append(transforms.Lambda(lambda img: __random_zoom(img, opt.load_size, opt.crop_size, method, factor=params["scale_factor"])))
if 'crop' in opt.preprocess:
if params is None or 'crop_pos' not in params:
transform_list.append(transforms.RandomCrop(opt.crop_size))
else:
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))
if 'patch' in opt.preprocess:
transform_list.append(transforms.Lambda(lambda img: __patch(img, params['patch_index'], opt.crop_size)))
if 'trim' in opt.preprocess:
transform_list.append(transforms.Lambda(lambda img: __trim(img, opt.crop_size)))
# if opt.preprocess == 'none':
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))
if not opt.no_flip:
if params is None or 'flip' not in params:
transform_list.append(transforms.RandomHorizontalFlip())
elif 'flip' in params:
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
if convert:
transform_list += [transforms.ToTensor()]
if grayscale:
transform_list += [transforms.Normalize((0.5,), (0.5,))]
else:
transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
return transforms.Compose(transform_list)
def __make_power_2(img, base, method=Image.BICUBIC):
ow, oh = img.size
h = int(round(oh / base) * base)
w = int(round(ow / base) * base)
if h == oh and w == ow:
return img
return img.resize((w, h), method)
def __random_zoom(img, target_width, crop_width, method=Image.BICUBIC, factor=None):
if factor is None:
zoom_level = np.random.uniform(0.8, 1.0, size=[2])
else:
zoom_level = (factor[0], factor[1])
iw, ih = img.size
zoomw = max(crop_width, iw * zoom_level[0])
zoomh = max(crop_width, ih * zoom_level[1])
img = img.resize((int(round(zoomw)), int(round(zoomh))), method)
return img
def __scale_shortside(img, target_width, crop_width, method=Image.BICUBIC):
ow, oh = img.size
shortside = min(ow, oh)
if shortside >= target_width:
return img
else:
scale = target_width / shortside
return img.resize((round(ow * scale), round(oh * scale)), method)
def __trim(img, trim_width):
ow, oh = img.size
if ow > trim_width:
xstart = np.random.randint(ow - trim_width)
xend = xstart + trim_width
else:
xstart = 0
xend = ow
if oh > trim_width:
ystart = np.random.randint(oh - trim_width)
yend = ystart + trim_width
else:
ystart = 0
yend = oh
return img.crop((xstart, ystart, xend, yend))
def __scale_width(img, target_width, crop_width, method=Image.BICUBIC):
ow, oh = img.size
if ow == target_width and oh >= crop_width:
return img
w = target_width
h = int(max(target_width * oh / ow, crop_width))
return img.resize((w, h), method)
def __crop(img, pos, size):
ow, oh = img.size
x1, y1 = pos
tw = th = size
if (ow > tw or oh > th):
return img.crop((x1, y1, x1 + tw, y1 + th))
return img
def __patch(img, index, size):
ow, oh = img.size
nw, nh = ow // size, oh // size
roomx = ow - nw * size
roomy = oh - nh * size
startx = np.random.randint(int(roomx) + 1)
starty = np.random.randint(int(roomy) + 1)
index = index % (nw * nh)
ix = index // nh
iy = index % nh
gridx = startx + ix * size
gridy = starty + iy * size
return img.crop((gridx, gridy, gridx + size, gridy + size))
def __flip(img, flip):
if flip:
return img.transpose(Image.FLIP_LEFT_RIGHT)
return img
def __print_size_warning(ow, oh, w, h):
"""Print warning information about image size(only print once)"""
if not hasattr(__print_size_warning, 'has_printed'):
print("The image size needs to be a multiple of 4. "
"The loaded image size was (%d, %d), so it was adjusted to "
"(%d, %d). This adjustment will be done to all images "
"whose sizes are not multiples of 4" % (ow, oh, w, h))
__print_size_warning.has_printed = True

66
data/image_folder.py Normal file
View File

@ -0,0 +1,66 @@
"""A modified image folder class
We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
so that this class can load images from both current directory and its subdirectories.
"""
import torch.utils.data as data
from PIL import Image
import os
import os.path
IMG_EXTENSIONS = [
'.jpg', '.JPG', '.jpeg', '.JPEG',
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
'.tif', '.TIF', '.tiff', '.TIFF',
]
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def make_dataset(dir, max_dataset_size=float("inf")):
images = []
assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir
for root, _, fnames in sorted(os.walk(dir, followlinks=True)):
for fname in fnames:
if is_image_file(fname):
path = os.path.join(root, fname)
images.append(path)
return images[:min(max_dataset_size, len(images))]
def default_loader(path):
return Image.open(path).convert('RGB')
class ImageFolder(data.Dataset):
def __init__(self, root, transform=None, return_paths=False,
loader=default_loader):
imgs = make_dataset(root)
if len(imgs) == 0:
raise(RuntimeError("Found 0 images in: " + root + "\n"
"Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
self.root = root
self.imgs = imgs
self.transform = transform
self.return_paths = return_paths
self.loader = loader
def __getitem__(self, index):
path = self.imgs[index]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.return_paths:
return img, path
else:
return img
def __len__(self):
return len(self.imgs)

40
data/single_dataset.py Normal file
View File

@ -0,0 +1,40 @@
from data.base_dataset import BaseDataset, get_transform
from data.image_folder import make_dataset
from PIL import Image
class SingleDataset(BaseDataset):
"""This dataset class can load a set of images specified by the path --dataroot /path/to/data.
It can be used for generating CycleGAN results only for one side with the model option '-model test'.
"""
def __init__(self, opt):
"""Initialize this dataset class.
Parameters:
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
"""
BaseDataset.__init__(self, opt)
self.A_paths = sorted(make_dataset(opt.dataroot, opt.max_dataset_size))
input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc
self.transform = get_transform(opt, grayscale=(input_nc == 1))
def __getitem__(self, index):
"""Return a data point and its metadata information.
Parameters:
index - - a random integer for data indexing
Returns a dictionary that contains A and A_paths
A(tensor) - - an image in one domain
A_paths(str) - - the path of the image
"""
A_path = self.A_paths[index]
A_img = Image.open(A_path).convert('RGB')
A = self.transform(A_img)
return {'A': A, 'A_paths': A_path}
def __len__(self):
"""Return the total number of images in the dataset."""
return len(self.A_paths)

108
data/singleimage_dataset.py Normal file
View File

@ -0,0 +1,108 @@
import numpy as np
import os.path
from data.base_dataset import BaseDataset, get_transform
from data.image_folder import make_dataset
from PIL import Image
import random
import util.util as util
class SingleImageDataset(BaseDataset):
"""
This dataset class can load unaligned/unpaired datasets.
It requires two directories to host training images from domain A '/path/to/data/trainA'
and from domain B '/path/to/data/trainB' respectively.
You can train the model with the dataset flag '--dataroot /path/to/data'.
Similarly, you need to prepare two directories:
'/path/to/data/testA' and '/path/to/data/testB' during test time.
"""
def __init__(self, opt):
"""Initialize this dataset class.
Parameters:
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
"""
BaseDataset.__init__(self, opt)
self.dir_A = os.path.join(opt.dataroot, 'trainA') # create a path '/path/to/data/trainA'
self.dir_B = os.path.join(opt.dataroot, 'trainB') # create a path '/path/to/data/trainB'
if os.path.exists(self.dir_A) and os.path.exists(self.dir_B):
self.A_paths = sorted(make_dataset(self.dir_A, opt.max_dataset_size)) # load images from '/path/to/data/trainA'
self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size)) # load images from '/path/to/data/trainB'
self.A_size = len(self.A_paths) # get the size of dataset A
self.B_size = len(self.B_paths) # get the size of dataset B
assert len(self.A_paths) == 1 and len(self.B_paths) == 1,\
"SingleImageDataset class should be used with one image in each domain"
A_img = Image.open(self.A_paths[0]).convert('RGB')
B_img = Image.open(self.B_paths[0]).convert('RGB')
print("Image sizes %s and %s" % (str(A_img.size), str(B_img.size)))
self.A_img = A_img
self.B_img = B_img
# In single-image translation, we augment the data loader by applying
# random scaling. Still, we design the data loader such that the
# amount of scaling is the same within a minibatch. To do this,
# we precompute the random scaling values, and repeat them by |batch_size|.
A_zoom = 1 / self.opt.random_scale_max
zoom_levels_A = np.random.uniform(A_zoom, 1.0, size=(len(self) // opt.batch_size + 1, 1, 2))
self.zoom_levels_A = np.reshape(np.tile(zoom_levels_A, (1, opt.batch_size, 1)), [-1, 2])
B_zoom = 1 / self.opt.random_scale_max
zoom_levels_B = np.random.uniform(B_zoom, 1.0, size=(len(self) // opt.batch_size + 1, 1, 2))
self.zoom_levels_B = np.reshape(np.tile(zoom_levels_B, (1, opt.batch_size, 1)), [-1, 2])
# While the crop locations are randomized, the negative samples should
# not come from the same location. To do this, we precompute the
# crop locations with no repetition.
self.patch_indices_A = list(range(len(self)))
random.shuffle(self.patch_indices_A)
self.patch_indices_B = list(range(len(self)))
random.shuffle(self.patch_indices_B)
def __getitem__(self, index):
"""Return a data point and its metadata information.
Parameters:
index (int) -- a random integer for data indexing
Returns a dictionary that contains A, B, A_paths and B_paths
A (tensor) -- an image in the input domain
B (tensor) -- its corresponding image in the target domain
A_paths (str) -- image paths
B_paths (str) -- image paths
"""
A_path = self.A_paths[0]
B_path = self.B_paths[0]
A_img = self.A_img
B_img = self.B_img
# apply image transformation
if self.opt.phase == "train":
param = {'scale_factor': self.zoom_levels_A[index],
'patch_index': self.patch_indices_A[index],
'flip': random.random() > 0.5}
transform_A = get_transform(self.opt, params=param, method=Image.BILINEAR)
A = transform_A(A_img)
param = {'scale_factor': self.zoom_levels_B[index],
'patch_index': self.patch_indices_B[index],
'flip': random.random() > 0.5}
transform_B = get_transform(self.opt, params=param, method=Image.BILINEAR)
B = transform_B(B_img)
else:
transform = get_transform(self.opt, method=Image.BILINEAR)
A = transform(A_img)
B = transform(B_img)
return {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path}
def __len__(self):
""" Let's pretend the single image contains 100,000 crops for convenience.
"""
return 100000

75
data/template_dataset.py Normal file
View File

@ -0,0 +1,75 @@
"""Dataset class template
This module provides a template for users to implement custom datasets.
You can specify '--dataset_mode template' to use this dataset.
The class name should be consistent with both the filename and its dataset_mode option.
The filename should be <dataset_mode>_dataset.py
The class name should be <Dataset_mode>Dataset.py
You need to implement the following functions:
-- <modify_commandline_options>: Add dataset-specific options and rewrite default values for existing options.
-- <__init__>: Initialize this dataset class.
-- <__getitem__>: Return a data point and its metadata information.
-- <__len__>: Return the number of images.
"""
from data.base_dataset import BaseDataset, get_transform
# from data.image_folder import make_dataset
# from PIL import Image
class TemplateDataset(BaseDataset):
"""A template dataset class for you to implement custom datasets."""
@staticmethod
def modify_commandline_options(parser, is_train):
"""Add new dataset-specific options, and rewrite default values for existing options.
Parameters:
parser -- original option parser
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
Returns:
the modified parser.
"""
parser.add_argument('--new_dataset_option', type=float, default=1.0, help='new dataset option')
parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0) # specify dataset-specific default values
return parser
def __init__(self, opt):
"""Initialize this dataset class.
Parameters:
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
A few things can be done here.
- save the options (have been done in BaseDataset)
- get image paths and meta information of the dataset.
- define the image transformation.
"""
# save the option and dataset root
BaseDataset.__init__(self, opt)
# get the image paths of your dataset;
self.image_paths = [] # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root
# define the default transform function. You can use <base_dataset.get_transform>; You can also define your custom transform function
self.transform = get_transform(opt)
def __getitem__(self, index):
"""Return a data point and its metadata information.
Parameters:
index -- a random integer for data indexing
Returns:
a dictionary of data with their names. It usually contains the data itself and its metadata information.
Step 1: get a random image path: e.g., path = self.image_paths[index]
Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB').
Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image)
Step 4: return a data point as a dictionary.
"""
path = 'temp' # needs to be a string
data_A = None # needs to be a tensor
data_B = None # needs to be a tensor
return {'data_A': data_A, 'data_B': data_B, 'path': path}
def __len__(self):
"""Return the total number of images."""
return len(self.image_paths)

79
data/unaligned_dataset.py Normal file
View File

@ -0,0 +1,79 @@
import os.path
from data.base_dataset import BaseDataset, get_transform
from data.image_folder import make_dataset
from PIL import Image
import random
import util.util as util
class UnalignedDataset(BaseDataset):
"""
This dataset class can load unaligned/unpaired datasets.
It requires two directories to host training images from domain A '/path/to/data/trainA'
and from domain B '/path/to/data/trainB' respectively.
You can train the model with the dataset flag '--dataroot /path/to/data'.
Similarly, you need to prepare two directories:
'/path/to/data/testA' and '/path/to/data/testB' during test time.
"""
def __init__(self, opt):
"""Initialize this dataset class.
Parameters:
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
"""
BaseDataset.__init__(self, opt)
self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A') # create a path '/path/to/data/trainA'
self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B') # create a path '/path/to/data/trainB'
if opt.phase == "test" and not os.path.exists(self.dir_A) \
and os.path.exists(os.path.join(opt.dataroot, "valA")):
self.dir_A = os.path.join(opt.dataroot, "valA")
self.dir_B = os.path.join(opt.dataroot, "valB")
self.A_paths = sorted(make_dataset(self.dir_A, opt.max_dataset_size)) # load images from '/path/to/data/trainA'
self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size)) # load images from '/path/to/data/trainB'
self.A_size = len(self.A_paths) # get the size of dataset A
self.B_size = len(self.B_paths) # get the size of dataset B
def __getitem__(self, index):
"""Return a data point and its metadata information.
Parameters:
index (int) -- a random integer for data indexing
Returns a dictionary that contains A, B, A_paths and B_paths
A (tensor) -- an image in the input domain
B (tensor) -- its corresponding image in the target domain
A_paths (str) -- image paths
B_paths (str) -- image paths
"""
A_path = self.A_paths[index % self.A_size] # make sure index is within then range
if self.opt.serial_batches: # make sure index is within then range
index_B = index % self.B_size
else: # randomize the index for domain B to avoid fixed pairs.
index_B = random.randint(0, self.B_size - 1)
B_path = self.B_paths[index_B]
A_img = Image.open(A_path).convert('RGB')
B_img = Image.open(B_path).convert('RGB')
# Apply image transformation
# For FastCUT mode, if in finetuning phase (learning rate is decaying),
# do not perform resize-crop data augmentation of CycleGAN.
# print('current_epoch', self.current_epoch)
is_finetuning = self.opt.isTrain and self.current_epoch > self.opt.n_epochs
modified_opt = util.copyconf(self.opt, load_size=self.opt.crop_size if is_finetuning else self.opt.load_size)
transform = get_transform(modified_opt)
A = transform(A_img)
B = transform(B_img)
return {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path}
def __len__(self):
"""Return the total number of images in the dataset.
As we have two datasets with potentially different number of images,
we take a maximum of
"""
return max(self.A_size, self.B_size)

View File

@ -0,0 +1,100 @@
import os.path
from data.base_dataset import BaseDataset, get_transform
from data.image_folder import make_dataset
from PIL import Image
import random
import util.util as util
import torchvision.transforms.functional as TF
import random
from torchvision.transforms import transforms as tfs
class UnalignedDoubleDataset(BaseDataset):
"""
This dataset class can load unaligned/unpaired datasets.
It requires two directories to host training images from domain A '/path/to/data/trainA'
and from domain B '/path/to/data/trainB' respectively.
You can train the model with the dataset flag '--dataroot /path/to/data'.
Similarly, you need to prepare two directories:
'/path/to/data/testA' and '/path/to/data/testB' during test time.
"""
def __init__(self, opt):
"""Initialize this dataset class.
Parameters:
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
"""
# self.use_resize_crop = opt.use_resize_crop
BaseDataset.__init__(self, opt)
self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A') # create a path '/path/to/data/trainA'
self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B') # create a path '/path/to/data/trainB'
self.opt = opt
if opt.phase == "test" and not os.path.exists(self.dir_A) \
and os.path.exists(os.path.join(opt.dataroot, "valA")):
self.dir_A = os.path.join(opt.dataroot, "valA")
self.dir_B = os.path.join(opt.dataroot, "valB")
self.A_paths = sorted(make_dataset(self.dir_A, opt.max_dataset_size)) # load images from '/path/to/data/trainA'
self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size)) # load images from '/path/to/data/trainB'
self.A_size = len(self.A_paths) # get the size of dataset A
self.B_size = len(self.B_paths) # get the size of dataset B
def __getitem__(self, index):
"""Return a data point and its metadata information.
Parameters:
index (int) -- a random integer for data indexing
Returns a dictionary that contains A, B, A_paths and B_paths
A (tensor) -- an image in the input domain
B (tensor) -- its corresponding image in the target domain
A_paths (str) -- image paths
B_paths (str) -- image paths
"""
A_path = self.A_paths[index % self.A_size] # make sure index is within then range
if self.opt.serial_batches: # make sure index is within then range
index_B = index % self.B_size
else: # randomize the index for domain B to avoid fixed pairs.
index_B = random.randint(0, self.B_size - 1)
B_path = self.B_paths[index_B]
A_img = Image.open(A_path).convert('RGB')
A0 = A_img.crop((0,0,256,256))
A1 = A_img.crop((256,0,512,256))
B_img = Image.open(B_path).convert('RGB')
B0 = B_img.crop((0,0,256,256))
B1 = B_img.crop((256,0,512,256))
# Apply image transformation
# For FastCUT mode, if in finetuning phase (learning rate is decaying),
# do not perform resize-crop data augmentation of CycleGAN.
# print('current_epoch', self.current_epoch)
is_finetuning = self.opt.isTrain and self.current_epoch > self.opt.n_epochs
modified_opt = util.copyconf(self.opt, load_size=self.opt.crop_size if is_finetuning else self.opt.load_size)
resize = tfs.Resize(size=(self.opt.load_size, self.opt.load_size))
imgA = resize(A0)
param = dict()
i, j, h, w = tfs.RandomCrop.get_params(
imgA, output_size=(self.opt.crop_size, self.opt.crop_size))
param['crop_pos'] = (i, j)
transform = get_transform(modified_opt, param)
# print(transform)
# sys.exit(0)
# A = transform(A_img)
# B = transform(B_img)
A0 = transform(A0)
B0 = transform(B0)
A1 = transform(A1)
B1 = transform(B1)
return {'A0': A0, 'A1': A1, 'B0': B0, 'B1': B1, 'A_paths': A_path, 'B_paths': B_path}
def __len__(self):
"""Return the total number of images in the dataset.
As we have two datasets with potentially different number of images,
we take a maximum of
"""
return max(self.A_size, self.B_size)

View File

@ -0,0 +1,6 @@
@inproceedings{Cordts2016Cityscapes,
title={The Cityscapes Dataset for Semantic Urban Scene Understanding},
author={Cordts, Marius and Omran, Mohamed and Ramos, Sebastian and Rehfeld, Timo and Enzweiler, Markus and Benenson, Rodrigo and Franke, Uwe and Roth, Stefan and Schiele, Bernt},
booktitle={Proc. of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
year={2016}
}

View File

@ -0,0 +1,7 @@
@INPROCEEDINGS{Tylecek13,
author = {Radim Tyle{\v c}ek, Radim {\v S}{\' a}ra},
title = {Spatial Pattern Templates for Recognition of Objects with Regular Structure},
booktitle = {Proc. GCPR},
year = {2013},
address = {Saarbrucken, Germany},
}

View File

@ -0,0 +1,13 @@
@inproceedings{zhu2016generative,
title={Generative Visual Manipulation on the Natural Image Manifold},
author={Zhu, Jun-Yan and Kr{\"a}henb{\"u}hl, Philipp and Shechtman, Eli and Efros, Alexei A.},
booktitle={Proceedings of European Conference on Computer Vision (ECCV)},
year={2016}
}
@InProceedings{xie15hed,
author = {"Xie, Saining and Tu, Zhuowen"},
Title = {Holistically-Nested Edge Detection},
Booktitle = "Proceedings of IEEE International Conference on Computer Vision",
Year = {2015},
}

14
datasets/bibtex/shoes.tex Normal file
View File

@ -0,0 +1,14 @@
@InProceedings{fine-grained,
author = {A. Yu and K. Grauman},
title = {{F}ine-{G}rained {V}isual {C}omparisons with {L}ocal {L}earning},
booktitle = {Computer Vision and Pattern Recognition (CVPR)},
month = {June},
year = {2014}
}
@InProceedings{xie15hed,
author = {"Xie, Saining and Tu, Zhuowen"},
Title = {Holistically-Nested Edge Detection},
Booktitle = "Proceedings of IEEE International Conference on Computer Vision",
Year = {2015},
}

View File

@ -0,0 +1,8 @@
@article {Laffont14,
title = {Transient Attributes for High-Level Understanding and Editing of Outdoor Scenes},
author = {Pierre-Yves Laffont and Zhile Ren and Xiaofeng Tao and Chao Qian and James Hays},
journal = {ACM Transactions on Graphics (proceedings of SIGGRAPH)},
volume = {33},
number = {4},
year = {2014}
}

View File

@ -0,0 +1,48 @@
import os
import numpy as np
import cv2
import argparse
parser = argparse.ArgumentParser('create image pairs')
parser.add_argument('--fold_A', dest='fold_A', help='input directory for image A', type=str, default='../dataset/50kshoes_edges')
parser.add_argument('--fold_B', dest='fold_B', help='input directory for image B', type=str, default='../dataset/50kshoes_jpg')
parser.add_argument('--fold_AB', dest='fold_AB', help='output directory', type=str, default='../dataset/test_AB')
parser.add_argument('--num_imgs', dest='num_imgs', help='number of images', type=int, default=1000000)
parser.add_argument('--use_AB', dest='use_AB', help='if true: (0001_A, 0001_B) to (0001_AB)', action='store_true')
args = parser.parse_args()
for arg in vars(args):
print('[%s] = ' % arg, getattr(args, arg))
splits = os.listdir(args.fold_A)
for sp in splits:
img_fold_A = os.path.join(args.fold_A, sp)
img_fold_B = os.path.join(args.fold_B, sp)
img_list = os.listdir(img_fold_A)
if args.use_AB:
img_list = [img_path for img_path in img_list if '_A.' in img_path]
num_imgs = min(args.num_imgs, len(img_list))
print('split = %s, use %d/%d images' % (sp, num_imgs, len(img_list)))
img_fold_AB = os.path.join(args.fold_AB, sp)
if not os.path.isdir(img_fold_AB):
os.makedirs(img_fold_AB)
print('split = %s, number of images = %d' % (sp, num_imgs))
for n in range(num_imgs):
name_A = img_list[n]
path_A = os.path.join(img_fold_A, name_A)
if args.use_AB:
name_B = name_A.replace('_A.', '_B.')
else:
name_B = name_A
path_B = os.path.join(img_fold_B, name_B)
if os.path.isfile(path_A) and os.path.isfile(path_B):
name_AB = name_A
if args.use_AB:
name_AB = name_AB.replace('_A.', '.') # remove _A
path_AB = os.path.join(img_fold_AB, name_AB)
im_A = cv2.imread(path_A, 1) # python2: cv2.CV_LOAD_IMAGE_COLOR; python3: cv2.IMREAD_COLOR
im_B = cv2.imread(path_B, 1) # python2: cv2.CV_LOAD_IMAGE_COLOR; python3: cv2.IMREAD_COLOR
im_AB = np.concatenate([im_A, im_B], 1)
cv2.imwrite(path_AB, im_AB)

View File

@ -0,0 +1,64 @@
import cv2
import os
import glob
import argparse
def get_file_paths(folder):
image_file_paths = []
for root, dirs, filenames in os.walk(folder):
filenames = sorted(filenames)
for filename in filenames:
input_path = os.path.abspath(root)
file_path = os.path.join(input_path, filename)
if filename.endswith('.png') or filename.endswith('.jpg'):
image_file_paths.append(file_path)
break # prevent descending into subfolders
return image_file_paths
SF = 1.05
N = 3
def detect_cat(img_path, cat_cascade, output_dir, ratio=0.05, border_ratio=0.25):
print('processing {}'.format(img_path))
output_width = 286
img = cv2.imread(img_path)
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
H, W = img.shape[0], img.shape[1]
minH = int(H * ratio)
minW = int(W * ratio)
cats = cat_cascade.detectMultiScale(gray, scaleFactor=SF, minNeighbors=N, minSize=(minH, minW))
for cat_id, (x, y, w, h) in enumerate(cats):
x1 = max(0, x - w * border_ratio)
x2 = min(W, x + w * (1 + border_ratio))
y1 = max(0, y - h * border_ratio)
y2 = min(H, y + h * (1 + border_ratio))
img_crop = img[int(y1):int(y2), int(x1):int(x2)]
img_name = os.path.basename(img_path)
out_path = os.path.join(output_dir, img_name.replace('.jpg', '_cat%d.jpg' % cat_id))
print('write', out_path)
img_crop = cv2.resize(img_crop, (output_width, output_width), interpolation=cv2.INTER_CUBIC)
cv2.imwrite(out_path, img_crop, [int(cv2.IMWRITE_JPEG_QUALITY), 100])
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='detecting cat faces using opencv detector')
parser.add_argument('--input_dir', type=str, help='input image directory')
parser.add_argument('--output_dir', type=str, help='wihch directory to store cropped cat faces')
parser.add_argument('--use_ext', action='store_true', help='if use haarcascade_frontalcatface_extended or not')
args = parser.parse_args()
if args.use_ext:
cat_cascade = cv2.CascadeClassifier('haarcascade_frontalcatface.xml')
else:
cat_cascade = cv2.CascadeClassifier('haarcascade_frontalcatface_extended.xml')
img_paths = get_file_paths(args.input_dir)
print('total number of images {} from {}'.format(len(img_paths), args.input_dir))
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
for img_path in img_paths:
detect_cat(img_path, cat_cascade, args.output_dir)

View File

@ -0,0 +1,23 @@
set -ex
FILE=$1
if [[ $FILE != "ae_photos" && $FILE != "apple2orange" && $FILE != "summer2winter_yosemite" && $FILE != "horse2zebra" && $FILE != "monet2photo" && $FILE != "cezanne2photo" && $FILE != "ukiyoe2photo" && $FILE != "vangogh2photo" && $FILE != "maps" && $FILE != "cityscapes" && $FILE != "facades" && $FILE != "iphone2dslr_flower" && $FILE != "mini" && $FILE != "mini_pix2pix" && $FILE != "mini_colorization" && $FILE != "grumpifycat" ]]; then
echo "Available datasets are: apple2orange, summer2winter_yosemite, horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, vangogh2photo, maps, cityscapes, facades, iphone2dslr_flower, ae_photos, grumpifycat"
exit 1
fi
if [[ $FILE == "cityscapes" ]]; then
echo "Due to license issue, we cannot provide the Cityscapes dataset from our repository. Please download the Cityscapes dataset from https://cityscapes-dataset.com, and use the script ./datasets/prepare_cityscapes_dataset.py."
echo "You need to download gtFine_trainvaltest.zip and leftImg8bit_trainvaltest.zip. For further instruction, please read ./datasets/prepare_cityscapes_dataset.py"
exit 1
fi
echo "Specified [$FILE]"
URL=https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/$FILE.zip
ZIP_FILE=./datasets/$FILE.zip
TARGET_DIR=./datasets/$FILE/
wget --no-check-certificate -N $URL -O $ZIP_FILE
mkdir $TARGET_DIR
unzip $ZIP_FILE -d ./datasets/
rm $ZIP_FILE

View File

@ -0,0 +1,24 @@
set -ex
FILE=$1
if [[ $FILE != "cityscapes" && $FILE != "night2day" && $FILE != "edges2handbags" && $FILE != "edges2shoes" && $FILE != "facades" && $FILE != "maps" ]]; then
echo "Available datasets are cityscapes, night2day, edges2handbags, edges2shoes, facades, maps"
exit 1
fi
if [[ $FILE == "cityscapes" ]]; then
echo "Due to license issue, we cannot provide the Cityscapes dataset from our repository. Please download the Cityscapes dataset from https://cityscapes-dataset.com, and use the script ./datasets/prepare_cityscapes_dataset.py."
echo "You need to download gtFine_trainvaltest.zip and leftImg8bit_trainvaltest.zip. For further instruction, please read ./datasets/prepare_cityscapes_dataset.py"
exit 1
fi
echo "Specified [$FILE]"
URL=http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/$FILE.tar.gz
TAR_FILE=./datasets/$FILE.tar.gz
TARGET_DIR=./datasets/$FILE/
wget -N $URL -O $TAR_FILE
mkdir -p $TARGET_DIR
tar -zxvf $TAR_FILE -C ./datasets/
rm $TAR_FILE

View File

@ -0,0 +1,63 @@
import os
from PIL import Image
def get_file_paths(folder):
image_file_paths = []
for root, dirs, filenames in os.walk(folder):
filenames = sorted(filenames)
for filename in filenames:
input_path = os.path.abspath(root)
file_path = os.path.join(input_path, filename)
if filename.endswith('.png') or filename.endswith('.jpg'):
image_file_paths.append(file_path)
break # prevent descending into subfolders
return image_file_paths
def align_images(a_file_paths, b_file_paths, target_path):
if not os.path.exists(target_path):
os.makedirs(target_path)
for i in range(len(a_file_paths)):
img_a = Image.open(a_file_paths[i])
img_b = Image.open(b_file_paths[i])
assert(img_a.size == img_b.size)
aligned_image = Image.new("RGB", (img_a.size[0] * 2, img_a.size[1]))
aligned_image.paste(img_a, (0, 0))
aligned_image.paste(img_b, (img_a.size[0], 0))
aligned_image.save(os.path.join(target_path, '{:04d}.jpg'.format(i)))
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
'--dataset-path',
dest='dataset_path',
help='Which folder to process (it should have subfolders testA, testB, trainA and trainB'
)
args = parser.parse_args()
dataset_folder = args.dataset_path
print(dataset_folder)
test_a_path = os.path.join(dataset_folder, 'testA')
test_b_path = os.path.join(dataset_folder, 'testB')
test_a_file_paths = get_file_paths(test_a_path)
test_b_file_paths = get_file_paths(test_b_path)
assert(len(test_a_file_paths) == len(test_b_file_paths))
test_path = os.path.join(dataset_folder, 'test')
train_a_path = os.path.join(dataset_folder, 'trainA')
train_b_path = os.path.join(dataset_folder, 'trainB')
train_a_file_paths = get_file_paths(train_a_path)
train_b_file_paths = get_file_paths(train_b_path)
assert(len(train_a_file_paths) == len(train_b_file_paths))
train_path = os.path.join(dataset_folder, 'train')
align_images(test_a_file_paths, test_b_file_paths, test_path)
align_images(train_a_file_paths, train_b_file_paths, train_path)

View File

@ -0,0 +1,90 @@
import os
import glob
from PIL import Image
help_msg = """
The dataset can be downloaded from https://cityscapes-dataset.com.
Please download the datasets [gtFine_trainvaltest.zip] and [leftImg8bit_trainvaltest.zip] and unzip them.
gtFine contains the semantics segmentations. Use --gtFine_dir to specify the path to the unzipped gtFine_trainvaltest directory.
leftImg8bit contains the dashcam photographs. Use --leftImg8bit_dir to specify the path to the unzipped leftImg8bit_trainvaltest directory.
The processed images will be placed at --output_dir.
Example usage:
python prepare_cityscapes_dataset.py --gitFine_dir ./gtFine/ --leftImg8bit_dir ./leftImg8bit --output_dir ./datasets/cityscapes/
"""
def load_resized_img(path):
return Image.open(path).convert('RGB').resize((256, 256))
def check_matching_pair(segmap_path, photo_path):
segmap_identifier = os.path.basename(segmap_path).replace('_gtFine_color', '')
photo_identifier = os.path.basename(photo_path).replace('_leftImg8bit', '')
assert segmap_identifier == photo_identifier, \
"[%s] and [%s] don't seem to be matching. Aborting." % (segmap_path, photo_path)
def process_cityscapes(gtFine_dir, leftImg8bit_dir, output_dir, phase):
save_phase = 'test' if phase == 'val' else 'train'
savedir = os.path.join(output_dir, save_phase)
os.makedirs(savedir, exist_ok=True)
os.makedirs(savedir + 'A', exist_ok=True)
os.makedirs(savedir + 'B', exist_ok=True)
print("Directory structure prepared at %s" % output_dir)
segmap_expr = os.path.join(gtFine_dir, phase) + "/*/*_color.png"
segmap_paths = glob.glob(segmap_expr)
segmap_paths = sorted(segmap_paths)
photo_expr = os.path.join(leftImg8bit_dir, phase) + "/*/*_leftImg8bit.png"
photo_paths = glob.glob(photo_expr)
photo_paths = sorted(photo_paths)
assert len(segmap_paths) == len(photo_paths), \
"%d images that match [%s], and %d images that match [%s]. Aborting." % (len(segmap_paths), segmap_expr, len(photo_paths), photo_expr)
for i, (segmap_path, photo_path) in enumerate(zip(segmap_paths, photo_paths)):
check_matching_pair(segmap_path, photo_path)
segmap = load_resized_img(segmap_path)
photo = load_resized_img(photo_path)
# data for pix2pix where the two images are placed side-by-side
sidebyside = Image.new('RGB', (512, 256))
sidebyside.paste(segmap, (256, 0))
sidebyside.paste(photo, (0, 0))
savepath = os.path.join(savedir, "%d.jpg" % i)
sidebyside.save(savepath, format='JPEG', subsampling=0, quality=100)
# data for cyclegan where the two images are stored at two distinct directories
savepath = os.path.join(savedir + 'A', "%d_A.jpg" % i)
photo.save(savepath, format='JPEG', subsampling=0, quality=100)
savepath = os.path.join(savedir + 'B', "%d_B.jpg" % i)
segmap.save(savepath, format='JPEG', subsampling=0, quality=100)
if i % (len(segmap_paths) // 10) == 0:
print("%d / %d: last image saved at %s, " % (i, len(segmap_paths), savepath))
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--gtFine_dir', type=str, required=True,
help='Path to the Cityscapes gtFine directory.')
parser.add_argument('--leftImg8bit_dir', type=str, required=True,
help='Path to the Cityscapes leftImg8bit_trainvaltest directory.')
parser.add_argument('--output_dir', type=str, required=True,
default='./datasets/cityscapes',
help='Directory the output images will be written to.')
opt = parser.parse_args()
print(help_msg)
print('Preparing Cityscapes Dataset for val phase')
process_cityscapes(opt.gtFine_dir, opt.leftImg8bit_dir, opt.output_dir, "val")
print('Preparing Cityscapes Dataset for train phase')
process_cityscapes(opt.gtFine_dir, opt.leftImg8bit_dir, opt.output_dir, "train")
print('Done')

Binary file not shown.

After

Width:  |  Height:  |  Size: 289 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 606 KiB

BIN
images/method_final.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 278 KiB

67
models/__init__.py Normal file
View File

@ -0,0 +1,67 @@
"""This package contains modules related to objective functions, optimizations, and network architectures.
To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
You need to implement the following five functions:
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
-- <set_input>: unpack data from dataset and apply preprocessing.
-- <forward>: produce intermediate results.
-- <optimize_parameters>: calculate loss, gradients, and update network weights.
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
In the function <__init__>, you need to define four lists:
-- self.loss_names (str list): specify the training losses that you want to plot and save.
-- self.model_names (str list): define networks used in our training.
-- self.visual_names (str list): specify the images that you want to display and save.
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
Now you can use the model class by specifying flag '--model dummy'.
See our template model class 'template_model.py' for more details.
"""
import importlib
from models.base_model import BaseModel
def find_model_using_name(model_name):
"""Import the module "models/[model_name]_model.py".
In the file, the class called DatasetNameModel() will
be instantiated. It has to be a subclass of BaseModel,
and it is case-insensitive.
"""
model_filename = "models." + model_name + "_model"
modellib = importlib.import_module(model_filename)
model = None
target_model_name = model_name.replace('_', '') + 'model'
for name, cls in modellib.__dict__.items():
if name.lower() == target_model_name.lower() \
and issubclass(cls, BaseModel):
model = cls
if model is None:
print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
exit(0)
return model
def get_option_setter(model_name):
"""Return the static method <modify_commandline_options> of the model class."""
model_class = find_model_using_name(model_name)
return model_class.modify_commandline_options
def create_model(opt):
"""Create a model given the option.
This function warps the class CustomDatasetDataLoader.
This is the main interface between this package and 'train.py'/'test.py'
Example:
>>> from models import create_model
>>> model = create_model(opt)
"""
model = find_model_using_name(opt.model)
instance = model(opt)
print("model [%s] was created" % type(instance).__name__)
return instance

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

258
models/base_model.py Normal file
View File

@ -0,0 +1,258 @@
import os
import torch
from collections import OrderedDict
from abc import ABC, abstractmethod
from . import networks
class BaseModel(ABC):
"""This class is an abstract base class (ABC) for models.
To create a subclass, you need to implement the following five functions:
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
-- <set_input>: unpack data from dataset and apply preprocessing.
-- <forward>: produce intermediate results.
-- <optimize_parameters>: calculate losses, gradients, and update network weights.
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
"""
def __init__(self, opt):
"""Initialize the BaseModel class.
Parameters:
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
When creating your custom class, you need to implement your own initialization.
In this fucntion, you should first call <BaseModel.__init__(self, opt)>
Then, you need to define four lists:
-- self.loss_names (str list): specify the training losses that you want to plot and save.
-- self.model_names (str list): specify the images that you want to display and save.
-- self.visual_names (str list): define networks used in our training.
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
"""
self.opt = opt
self.gpu_ids = opt.gpu_ids
self.isTrain = opt.isTrain
self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
torch.backends.cudnn.benchmark = True
self.loss_names = []
self.model_names = []
self.visual_names = []
self.optimizers = []
self.image_paths = []
self.metric = 0 # used for learning rate policy 'plateau'
@staticmethod
def dict_grad_hook_factory(add_func=lambda x: x):
saved_dict = dict()
def hook_gen(name):
def grad_hook(grad):
saved_vals = add_func(grad)
saved_dict[name] = saved_vals
return grad_hook
return hook_gen, saved_dict
@staticmethod
def modify_commandline_options(parser, is_train):
"""Add new model-specific options, and rewrite default values for existing options.
Parameters:
parser -- original option parser
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
Returns:
the modified parser.
"""
return parser
@abstractmethod
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
input (dict): includes the data itself and its metadata information.
"""
pass
@abstractmethod
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
pass
@abstractmethod
def optimize_parameters(self):
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
pass
def setup(self, opt):
"""Load and print networks; create schedulers
Parameters:
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
"""
if self.isTrain:
self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
if not self.isTrain or opt.continue_train:
load_suffix = opt.epoch
self.load_networks(load_suffix)
self.print_networks(opt.verbose)
def parallelize(self):
for name in self.model_names:
if isinstance(name, str):
net = getattr(self, 'net' + name)
setattr(self, 'net' + name, torch.nn.DataParallel(net, self.opt.gpu_ids))
def data_dependent_initialize(self, data):
pass
def eval(self):
"""Make models eval mode during test time"""
for name in self.model_names:
if isinstance(name, str):
net = getattr(self, 'net' + name)
net.eval()
def test(self):
"""Forward function used in test time.
This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
It also calls <compute_visuals> to produce additional visualization results
"""
with torch.no_grad():
self.forward()
self.compute_visuals()
def compute_visuals(self):
"""Calculate additional output images for visdom and HTML visualization"""
pass
def get_image_paths(self):
""" Return image paths that are used to load current data"""
return self.image_paths
def update_learning_rate(self):
"""Update learning rates for all the networks; called at the end of every epoch"""
for scheduler in self.schedulers:
if self.opt.lr_policy == 'plateau':
scheduler.step(self.metric)
else:
scheduler.step()
lr = self.optimizers[0].param_groups[0]['lr']
print('learning rate = %.7f' % lr)
def get_current_visuals(self):
"""Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
visual_ret = OrderedDict()
for name in self.visual_names:
if isinstance(name, str):
visual_ret[name] = getattr(self, name)
return visual_ret
def get_current_losses(self):
"""Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
errors_ret = OrderedDict()
for name in self.loss_names:
if isinstance(name, str):
errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number
return errors_ret
def save_networks(self, epoch):
"""Save all the networks to the disk.
Parameters:
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
"""
for name in self.model_names:
if isinstance(name, str):
save_filename = '%s_net_%s.pth' % (epoch, name)
save_path = os.path.join(self.save_dir, save_filename)
net = getattr(self, 'net' + name)
if len(self.gpu_ids) > 0 and torch.cuda.is_available():
torch.save(net.module.cpu().state_dict(), save_path)
net.cuda(self.gpu_ids[0])
else:
torch.save(net.cpu().state_dict(), save_path)
def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
"""Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
key = keys[i]
if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
if module.__class__.__name__.startswith('InstanceNorm') and \
(key == 'running_mean' or key == 'running_var'):
if getattr(module, key) is None:
state_dict.pop('.'.join(keys))
if module.__class__.__name__.startswith('InstanceNorm') and \
(key == 'num_batches_tracked'):
state_dict.pop('.'.join(keys))
else:
self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
def load_networks(self, epoch):
"""Load all the networks from the disk.
Parameters:
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
"""
for name in self.model_names:
if isinstance(name, str):
load_filename = '%s_net_%s.pth' % (epoch, name)
if self.opt.isTrain and self.opt.pretrained_name is not None:
load_dir = os.path.join(self.opt.checkpoints_dir, self.opt.pretrained_name)
else:
load_dir = self.save_dir
load_path = os.path.join(load_dir, load_filename)
net = getattr(self, 'net' + name)
if isinstance(net, torch.nn.DataParallel):
net = net.module
print('loading the model from %s' % load_path)
# if you are using PyTorch newer than 0.4 (e.g., built from
# GitHub source), you can remove str() on self.device
state_dict = torch.load(load_path, map_location=str(self.device))
if hasattr(state_dict, '_metadata'):
del state_dict._metadata
# patch InstanceNorm checkpoints prior to 0.4
# for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
# self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
net.load_state_dict(state_dict)
def print_networks(self, verbose):
"""Print the total number of parameters in the network and (if verbose) network architecture
Parameters:
verbose (bool) -- if verbose: print the network architecture
"""
print('---------- Networks initialized -------------')
for name in self.model_names:
if isinstance(name, str):
net = getattr(self, 'net' + name)
num_params = 0
for param in net.parameters():
num_params += param.numel()
if verbose:
print(net)
print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
print('-----------------------------------------------')
def set_requires_grad(self, nets, requires_grad=False):
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
Parameters:
nets (network list) -- a list of networks
requires_grad (bool) -- whether the networks require gradients or not
"""
if not isinstance(nets, list):
nets = [nets]
for net in nets:
if net is not None:
for param in net.parameters():
param.requires_grad = requires_grad
def generate_visuals_for_evaluation(self, data, mode):
return {}

214
models/cut_model.py Normal file
View File

@ -0,0 +1,214 @@
import numpy as np
import torch
from .base_model import BaseModel
from . import networks
from .patchnce import PatchNCELoss
import util.util as util
class CUTModel(BaseModel):
""" This class implements CUT and FastCUT model, described in the paper
Contrastive Learning for Unpaired Image-to-Image Translation
Taesung Park, Alexei A. Efros, Richard Zhang, Jun-Yan Zhu
ECCV, 2020
The code borrows heavily from the PyTorch implementation of CycleGAN
https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
"""
@staticmethod
def modify_commandline_options(parser, is_train=True):
""" Configures options specific for CUT model
"""
parser.add_argument('--CUT_mode', type=str, default="CUT", choices='(CUT, cut, FastCUT, fastcut)')
parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN lossGAN(G(X))')
parser.add_argument('--lambda_NCE', type=float, default=1.0, help='weight for NCE loss: NCE(G(X), X)')
parser.add_argument('--nce_idt', type=util.str2bool, nargs='?', const=True, default=False, help='use NCE loss for identity mapping: NCE(G(Y), Y))')
parser.add_argument('--nce_layers', type=str, default='0,4,8,12,16', help='compute NCE loss on which layers')
parser.add_argument('--nce_includes_all_negatives_from_minibatch',
type=util.str2bool, nargs='?', const=True, default=False,
help='(used for single image translation) If True, include the negatives from the other samples of the minibatch when computing the contrastive loss. Please see models/patchnce.py for more details.')
parser.add_argument('--netF', type=str, default='mlp_sample', choices=['sample', 'reshape', 'mlp_sample'], help='how to downsample the feature map')
parser.add_argument('--netF_nc', type=int, default=256)
parser.add_argument('--nce_T', type=float, default=0.07, help='temperature for NCE loss')
parser.add_argument('--num_patches', type=int, default=256, help='number of patches per layer')
parser.add_argument('--flip_equivariance',
type=util.str2bool, nargs='?', const=True, default=False,
help="Enforce flip-equivariance as additional regularization. It's used by FastCUT, but not CUT")
parser.set_defaults(pool_size=0) # no image pooling
opt, _ = parser.parse_known_args()
# Set default parameters for CUT and FastCUT
if opt.CUT_mode.lower() == "cut":
parser.set_defaults(nce_idt=True, lambda_NCE=1.0)
elif opt.CUT_mode.lower() == "fastcut":
parser.set_defaults(
nce_idt=False, lambda_NCE=10.0, flip_equivariance=True,
n_epochs=150, n_epochs_decay=50
)
else:
raise ValueError(opt.CUT_mode)
return parser
def __init__(self, opt):
BaseModel.__init__(self, opt)
# specify the training losses you want to print out.
# The training/test scripts will call <BaseModel.get_current_losses>
self.loss_names = ['G_GAN', 'D_real', 'D_fake', 'G', 'NCE']
self.visual_names = ['real_A', 'fake_B', 'real_B']
self.nce_layers = [int(i) for i in self.opt.nce_layers.split(',')]
if opt.nce_idt and self.isTrain:
self.loss_names += ['NCE_Y']
self.visual_names += ['idt_B']
if self.isTrain:
self.model_names = ['G', 'F', 'D']
else: # during test time, only load G
self.model_names = ['G']
# define networks (both generator and discriminator)
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt)
self.netF = networks.define_F(opt.input_nc, opt.netF, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt)
if self.isTrain:
self.netD = networks.define_D(opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.normD, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt)
# define loss functions
self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
self.criterionNCE = []
for nce_layer in self.nce_layers:
self.criterionNCE.append(PatchNCELoss(opt).to(self.device))
self.criterionIdt = torch.nn.L1Loss().to(self.device)
self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
self.optimizers.append(self.optimizer_G)
self.optimizers.append(self.optimizer_D)
def data_dependent_initialize(self, data):
"""
The feature network netF is defined in terms of the shape of the intermediate, extracted
features of the encoder portion of netG. Because of this, the weights of netF are
initialized at the first feedforward pass with some input images.
Please also see PatchSampleF.create_mlp(), which is called at the first forward() call.
"""
self.set_input(data)
bs_per_gpu = self.real_A.size(0) // max(len(self.opt.gpu_ids), 1)
self.real_A = self.real_A[:bs_per_gpu]
self.real_B = self.real_B[:bs_per_gpu]
self.forward() # compute fake images: G(A)
if self.opt.isTrain:
self.compute_D_loss().backward() # calculate gradients for D
self.compute_G_loss().backward() # calculate graidents for G
if self.opt.lambda_NCE > 0.0:
self.optimizer_F = torch.optim.Adam(self.netF.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, self.opt.beta2))
self.optimizers.append(self.optimizer_F)
def optimize_parameters(self):
# forward
self.forward()
# update D
self.set_requires_grad(self.netD, True)
self.optimizer_D.zero_grad()
self.loss_D = self.compute_D_loss()
self.loss_D.backward()
self.optimizer_D.step()
# update G
self.set_requires_grad(self.netD, False)
self.optimizer_G.zero_grad()
if self.opt.netF == 'mlp_sample':
self.optimizer_F.zero_grad()
self.loss_G = self.compute_G_loss()
self.loss_G.backward()
self.optimizer_G.step()
if self.opt.netF == 'mlp_sample':
self.optimizer_F.step()
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
input (dict): include the data itself and its metadata information.
The option 'direction' can be used to swap domain A and domain B.
"""
AtoB = self.opt.direction == 'AtoB'
self.real_A = input['A' if AtoB else 'B'].to(self.device)
self.real_B = input['B' if AtoB else 'A'].to(self.device)
self.image_paths = input['A_paths' if AtoB else 'B_paths']
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
self.real = torch.cat((self.real_A, self.real_B), dim=0) if self.opt.nce_idt and self.opt.isTrain else self.real_A
if self.opt.flip_equivariance:
self.flipped_for_equivariance = self.opt.isTrain and (np.random.random() < 0.5)
if self.flipped_for_equivariance:
self.real = torch.flip(self.real, [3])
self.fake = self.netG(self.real)
self.fake_B = self.fake[:self.real_A.size(0)]
if self.opt.nce_idt:
self.idt_B = self.fake[self.real_A.size(0):]
def compute_D_loss(self):
"""Calculate GAN loss for the discriminator"""
fake = self.fake_B.detach()
# Fake; stop backprop to the generator by detaching fake_B
pred_fake = self.netD(fake)
self.loss_D_fake = self.criterionGAN(pred_fake, False).mean()
# Real
self.pred_real = self.netD(self.real_B)
loss_D_real = self.criterionGAN(self.pred_real, True)
self.loss_D_real = loss_D_real.mean()
# combine loss and calculate gradients
self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
return self.loss_D
def compute_G_loss(self):
"""Calculate GAN and NCE loss for the generator"""
fake = self.fake_B
# First, G(A) should fake the discriminator
if self.opt.lambda_GAN > 0.0:
pred_fake = self.netD(fake)
self.loss_G_GAN = self.criterionGAN(pred_fake, True).mean() * self.opt.lambda_GAN
else:
self.loss_G_GAN = 0.0
if self.opt.lambda_NCE > 0.0:
self.loss_NCE = self.calculate_NCE_loss(self.real_A, self.fake_B)
else:
self.loss_NCE, self.loss_NCE_bd = 0.0, 0.0
if self.opt.nce_idt and self.opt.lambda_NCE > 0.0:
self.loss_NCE_Y = self.calculate_NCE_loss(self.real_B, self.idt_B)
loss_NCE_both = (self.loss_NCE + self.loss_NCE_Y) * 0.5
else:
loss_NCE_both = self.loss_NCE
self.loss_G = self.loss_G_GAN + loss_NCE_both
return self.loss_G
def calculate_NCE_loss(self, src, tgt):
n_layers = len(self.nce_layers)
feat_q = self.netG(tgt, self.nce_layers, encode_only=True)
if self.opt.flip_equivariance and self.flipped_for_equivariance:
feat_q = [torch.flip(fq, [3]) for fq in feat_q]
feat_k = self.netG(src, self.nce_layers, encode_only=True)
feat_k_pool, sample_ids = self.netF(feat_k, self.opt.num_patches, None)
feat_q_pool, _ = self.netF(feat_q, self.opt.num_patches, sample_ids)
total_nce_loss = 0.0
for f_q, f_k, crit, nce_layer in zip(feat_q_pool, feat_k_pool, self.criterionNCE, self.nce_layers):
loss = crit(f_q, f_k) * self.opt.lambda_NCE
total_nce_loss += loss.mean()
return total_nce_loss / n_layers

222
models/cycle_gan_model.py Normal file
View File

@ -0,0 +1,222 @@
import torch
import itertools
from util.image_pool import ImagePool
from .base_model import BaseModel
from . import networks
try:
from apex import amp
except ImportError as error:
print(error)
class CycleGANModel(BaseModel):
"""
This class implements the CycleGAN model, for learning image-to-image translation without paired data.
The model training requires '--dataset_mode unaligned' dataset.
By default, it uses a '--netG resnet_9blocks' ResNet generator,
a '--netD basic' discriminator (PatchGAN introduced by pix2pix),
and a least-square GANs objective ('--gan_mode lsgan').
CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf
"""
@staticmethod
def modify_commandline_options(parser, is_train=True):
"""Add new dataset-specific options, and rewrite default values for existing options.
Parameters:
parser -- original option parser
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
Returns:
the modified parser.
For CycleGAN, in addition to GAN losses, we introduce lambda_A, lambda_B, and lambda_identity for the following losses.
A (source domain), B (target domain).
Generators: G_A: A -> B; G_B: B -> A.
Discriminators: D_A: G_A(A) vs. B; D_B: G_B(B) vs. A.
Forward cycle loss: lambda_A * ||G_B(G_A(A)) - A|| (Eqn. (2) in the paper)
Backward cycle loss: lambda_B * ||G_A(G_B(B)) - B|| (Eqn. (2) in the paper)
Identity loss (optional): lambda_identity * (||G_A(B) - B|| * lambda_B + ||G_B(A) - A|| * lambda_A) (Sec 5.2 "Photo generation from paintings" in the paper)
Dropout is not used in the original CycleGAN paper.
"""
# parser.set_defaults(no_dropout=True, no_antialias=True, no_antialias_up=True) # default CycleGAN did not use dropout
# parser.set_defaults(no_dropout=True)
if is_train:
parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)')
parser.add_argument('--lambda_B', type=float, default=10.0, help='weight for cycle loss (B -> A -> B)')
parser.add_argument('--lambda_identity', type=float, default=0.5, help='use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1')
return parser
def __init__(self, opt):
"""Initialize the CycleGAN class.
Parameters:
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
"""
BaseModel.__init__(self, opt)
# specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B']
# specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
visual_names_A = ['real_A', 'fake_B', 'rec_A']
visual_names_B = ['real_B', 'fake_A', 'rec_B']
if self.isTrain and self.opt.lambda_identity > 0.0: # if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B)
visual_names_A.append('idt_B')
visual_names_B.append('idt_A')
self.visual_names = visual_names_A + visual_names_B # combine visualizations for A and B
# specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>.
if self.isTrain:
self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']
else: # during test time, only load Gs
self.model_names = ['G_A', 'G_B']
# define networks (both Generators and discriminators)
# The naming is different from those used in the paper.
# Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG,
not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt=opt)
self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.normG,
not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt=opt)
if self.isTrain: # define discriminators
self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD,
opt.n_layers_D, opt.normD, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt=opt)
self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD,
opt.n_layers_D, opt.normD, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt=opt)
if self.isTrain:
if opt.lambda_identity > 0.0: # only works when input and output images have the same number of channels
assert(opt.input_nc == opt.output_nc)
self.fake_A_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images
self.fake_B_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images
# define loss functions
self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) # define GAN loss.
self.criterionCycle = torch.nn.L1Loss()
self.criterionIdt = torch.nn.L1Loss()
# initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizers.append(self.optimizer_G)
self.optimizers.append(self.optimizer_D)
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
input (dict): include the data itself and its metadata information.
The option 'direction' can be used to swap domain A and domain B.
"""
AtoB = self.opt.direction == 'AtoB'
self.real_A = input['A' if AtoB else 'B'].to(self.device)
self.real_B = input['B' if AtoB else 'A'].to(self.device)
self.image_paths = input['A_paths' if AtoB else 'B_paths']
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
self.fake_B = self.netG_A(self.real_A) # G_A(A)
self.rec_A = self.netG_B(self.fake_B) # G_B(G_A(A))
self.fake_A = self.netG_B(self.real_B) # G_B(B)
self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B))
def backward_D_basic(self, netD, real, fake):
"""Calculate GAN loss for the discriminator
Parameters:
netD (network) -- the discriminator D
real (tensor array) -- real images
fake (tensor array) -- images generated by a generator
Return the discriminator loss.
We also call loss_D.backward() to calculate the gradients.
"""
# Real
pred_real = netD(real)
loss_D_real = self.criterionGAN(pred_real, True)
# Fake
pred_fake = netD(fake.detach())
loss_D_fake = self.criterionGAN(pred_fake, False)
# Combined loss and calculate gradients
loss_D = (loss_D_real + loss_D_fake) * 0.5
if self.opt.amp:
with amp.scale_loss(loss_D, self.optimizer_D) as scaled_loss:
scaled_loss.backward()
else:
loss_D.backward()
return loss_D
def backward_D_A(self):
"""Calculate GAN loss for discriminator D_A"""
fake_B = self.fake_B_pool.query(self.fake_B)
self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)
def backward_D_B(self):
"""Calculate GAN loss for discriminator D_B"""
fake_A = self.fake_A_pool.query(self.fake_A)
self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
def backward_G(self):
"""Calculate the loss for generators G_A and G_B"""
lambda_idt = self.opt.lambda_identity
lambda_A = self.opt.lambda_A
lambda_B = self.opt.lambda_B
# Identity loss
if lambda_idt > 0:
# G_A should be identity if real_B is fed: ||G_A(B) - B||
self.idt_A = self.netG_A(self.real_B)
self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
# G_B should be identity if real_A is fed: ||G_B(A) - A||
self.idt_B = self.netG_B(self.real_A)
self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
else:
self.loss_idt_A = 0
self.loss_idt_B = 0
# GAN loss D_A(G_A(A))
self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)
# GAN loss D_B(G_B(B))
self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
# Forward cycle loss || G_B(G_A(A)) - A||
self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
# Backward cycle loss || G_A(G_B(B)) - B||
self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
# combined loss and calculate gradients
self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
if self.opt.amp:
with amp.scale_loss(self.loss_G, self.optimizer_G) as scaled_loss:
scaled_loss.backward()
else:
self.loss_G.backward()
def data_dependent_initialize(self):
return
def generate_visuals_for_evaluation(self, data, mode):
with torch.no_grad():
visuals = {}
AtoB = self.opt.direction == "AtoB"
G = self.netG_A
source = data["A" if AtoB else "B"].to(self.device)
if mode == "forward":
visuals["fake_B"] = G(source)
else:
raise ValueError("mode %s is not recognized" % mode)
return visuals
def optimize_parameters(self):
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
# forward
self.forward() # compute fake images and reconstruction images.
# G_A and G_B
self.set_requires_grad([self.netD_A, self.netD_B], False) # Ds require no gradients when optimizing Gs
self.optimizer_G.zero_grad() # set G_A and G_B's gradients to zero
self.backward_G() # calculate gradients for G_A and G_B
self.optimizer_G.step() # update G_A and G_B's weights
# D_A and D_B
self.set_requires_grad([self.netD_A, self.netD_B], True)
self.optimizer_D.zero_grad() # set D_A and D_B's gradients to zero
self.backward_D_A() # calculate gradients for D_A
self.backward_D_B() # calculate graidents for D_B
self.optimizer_D.step() # update D_A and D_B's weights

1530
models/networks.py Normal file

File diff suppressed because it is too large Load Diff

55
models/patchnce.py Normal file
View File

@ -0,0 +1,55 @@
from packaging import version
import torch
from torch import nn
class PatchNCELoss(nn.Module):
def __init__(self, opt):
super().__init__()
self.opt = opt
self.cross_entropy_loss = torch.nn.CrossEntropyLoss(reduction='none')
self.mask_dtype = torch.uint8 if version.parse(torch.__version__) < version.parse('1.2.0') else torch.bool
def forward(self, feat_q, feat_k):
num_patches = feat_q.shape[0]
dim = feat_q.shape[1]
feat_k = feat_k.detach()
# pos logit
l_pos = torch.bmm(
feat_q.view(num_patches, 1, -1), feat_k.view(num_patches, -1, 1))
l_pos = l_pos.view(num_patches, 1)
# neg logit
# Should the negatives from the other samples of a minibatch be utilized?
# In CUT and FastCUT, we found that it's best to only include negatives
# from the same image. Therefore, we set
# --nce_includes_all_negatives_from_minibatch as False
# However, for single-image translation, the minibatch consists of
# crops from the "same" high-resolution image.
# Therefore, we will include the negatives from the entire minibatch.
if self.opt.nce_includes_all_negatives_from_minibatch:
# reshape features as if they are all negatives of minibatch of size 1.
batch_dim_for_bmm = 1
else:
batch_dim_for_bmm = self.opt.batch_size
# reshape features to batch size
feat_q = feat_q.view(batch_dim_for_bmm, -1, dim)
feat_k = feat_k.view(batch_dim_for_bmm, -1, dim)
npatches = feat_q.size(1)
l_neg_curbatch = torch.bmm(feat_q, feat_k.transpose(2, 1))
# diagonal entries are similarity between same features, and hence meaningless.
# just fill the diagonal with very small number, which is exp(-10) and almost zero
diagonal = torch.eye(npatches, device=feat_q.device, dtype=self.mask_dtype)[None, :, :]
l_neg_curbatch.masked_fill_(diagonal, -10.0)
l_neg = l_neg_curbatch.view(-1, npatches)
out = torch.cat((l_pos, l_neg), dim=1) / self.opt.nce_T
loss = self.cross_entropy_loss(out, torch.zeros(out.size(0), dtype=torch.long,
device=feat_q.device))
return loss

363
models/roma_model.py Normal file
View File

@ -0,0 +1,363 @@
import numpy as np
import torch
from .base_model import BaseModel
from . import networks
from .patchnce import PatchNCELoss
import util.util as util
import timm
import time
import torch.nn.functional as F
import sys
from functools import partial
import torch.nn as nn
import math
from torchvision.transforms import transforms as tfs
class ROMAModel(BaseModel):
@staticmethod
def modify_commandline_options(parser, is_train=True):
""" Configures options specific for CUT model
"""
parser.add_argument('--adj_size_list', type=list, default=[2, 4, 6, 8, 12], help='different scales of perception field')
parser.add_argument('--lambda_mlp', type=float, default=1.0, help='weight of lr for discriminator')
parser.add_argument('--lambda_motion', type=float, default=1.0, help='weight for Temporal Consistency')
parser.add_argument('--lambda_D_ViT', type=float, default=1.0, help='weight for discriminator')
parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss: GAN(G(X))')
parser.add_argument('--lambda_global', type=float, default=1.0, help='weight for Global Structural Consistency')
parser.add_argument('--lambda_spatial', type=float, default=1.0, help='weight for Local Structural Consistency')
parser.add_argument('--atten_layers', type=str, default='1,3,5', help='compute Cross-Similarity on which layers')
parser.add_argument('--local_nums', type=int, default=256)
parser.add_argument('--which_D_layer', type=int, default=-1)
parser.add_argument('--side_length', type=int, default=7)
parser.set_defaults(pool_size=0)
opt, _ = parser.parse_known_args()
return parser
def __init__(self, opt):
BaseModel.__init__(self, opt)
self.loss_names = ['G_GAN_ViT', 'D_real_ViT', 'D_fake_ViT', 'global', 'spatial', 'motion']
self.visual_names = ['real_A0', 'real_A1', 'fake_B0', 'fake_B1', 'real_B0', 'real_B1']
self.atten_layers = [int(i) for i in self.opt.atten_layers.split(',')]
if self.isTrain:
self.model_names = ['G', 'D_ViT']
else: # during test time, only load G
self.model_names = ['G']
# define networks (both generator and discriminator)
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt)
if self.isTrain:
self.netD_ViT = networks.MLPDiscriminator().to(self.device)
self.netPreViT = timm.create_model("vit_base_patch16_384",pretrained=True).to(self.device)
self.norm = F.softmax
self.resize = tfs.Resize(size=(384,384))
self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
self.criterionNCE = []
for atten_layer in self.atten_layers:
self.criterionNCE.append(PatchNCELoss(opt).to(self.device))
self.criterionL1 = torch.nn.L1Loss().to(self.device)
self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
self.optimizer_D_ViT = torch.optim.Adam(self.netD_ViT.parameters(), lr=opt.lr * opt.lambda_mlp, betas=(opt.beta1, opt.beta2))
self.optimizers.append(self.optimizer_G)
self.optimizers.append(self.optimizer_D_ViT)
def data_dependent_initialize(self, data):
"""
The feature network netF is defined in terms of the shape of the intermediate, extracted
features of the encoder portion of netG. Because of this, the weights of netF are
initialized at the first feedforward pass with some input images.
Please also see PatchSampleF.create_mlp(), which is called at the first forward() call.
"""
pass
def optimize_parameters(self):
# forward
self.forward()
# update D
self.set_requires_grad(self.netD_ViT, True)
self.optimizer_D_ViT.zero_grad()
self.loss_D = self.compute_D_loss()
self.loss_D.backward()
self.optimizer_D_ViT.step()
# update G
self.set_requires_grad(self.netD_ViT, False)
self.optimizer_G.zero_grad()
self.loss_G = self.compute_G_loss()
self.loss_G.backward()
self.optimizer_G.step()
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
input (dict): include the data itself and its metadata information.
The option 'direction' can be used to swap domain A and domain B.
"""
AtoB = self.opt.direction == 'AtoB'
self.real_A0 = input['A0' if AtoB else 'B0'].to(self.device)
self.real_A1 = input['A1' if AtoB else 'B1'].to(self.device)
self.real_B0 = input['B0' if AtoB else 'A0'].to(self.device)
self.real_B1 = input['B1' if AtoB else 'A1'].to(self.device)
self.image_paths = input['A_paths' if AtoB else 'B_paths']
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
# ============ 第一步:对 real_A / real_A2 进行多步随机生成过程 ============
tau = self.opt.tau
T = self.opt.num_timesteps
incs = np.array([0] + [1/(i+1) for i in range(T-1)])
times = np.cumsum(incs)
times = times / times[-1]
times = 0.5 * times[-1] + 0.5 * times #[0.5,1]
times = np.concatenate([np.zeros(1), times])
times = torch.tensor(times).float().cuda()
self.times = times
bs = self.mutil_real_A0_tokens.size(0)
time_idx = (torch.randint(T, size=[1]).cuda() * torch.ones(size=[1]).cuda()).long()
self.time_idx = time_idx
with torch.no_grad():
self.netG.eval()
# ============ 第二步:对 real_A / real_A2 进行多步随机生成过程 ============
for t in range(self.time_idx.int().item() + 1):
# 计算增量 delta 与 inter/scale用于每个时间步的插值等
if t > 0:
delta = times[t] - times[t - 1]
denom = times[-1] - times[t - 1]
inter = (delta / denom).reshape(-1, 1, 1, 1)
scale = (delta * (1 - delta / denom)).reshape(-1, 1, 1, 1)
# 对 Xt、Xt2 进行随机噪声更新
Xt = self.mutil_real_A0_tokens if (t == 0) else (1 - inter) * Xt + inter * Xt_1.detach() + \
(scale * tau).sqrt() * torch.randn_like(Xt).to(self.mutil_real_A0_tokens.device)
time_idx = (t * torch.ones(size=[self.mutil_real_A0_tokens.shape[0]]).to(self.mutil_real_A0_tokens.device)).long()
z = torch.randn(size=[self.mutil_real_A0_tokens.shape[0], 4 * self.opt.ngf]).to(self.mutil_real_A0_tokens.device)
self.time = times[time_idx]
Xt_1 = self.netG(Xt, self.time, z)
Xt2 = self.mutil_real_A1_tokens if (t == 0) else (1 - inter) * Xt2 + inter * Xt_12.detach() + \
(scale * tau).sqrt() * torch.randn_like(Xt2).to(self.mutil_real_A1_tokens.device)
time_idx = (t * torch.ones(size=[self.mutil_real_A1_tokens.shape[0]]).to(self.mutil_real_A1_tokens.device)).long()
z = torch.randn(size=[self.mutil_real_A1_tokens.shape[0], 4 * self.opt.ngf]).to(self.mutil_real_A1_tokens.device)
Xt_12 = self.netG(Xt2, self.time, z)
# 保存去噪后的中间结果 (real_A_noisy 等),供下一步做拼接
self.real_A_noisy = Xt.detach()
self.real_A_noisy2 = Xt2.detach()
# 保存noisy_map
self.noisy_map = self.real_A_noisy - self.real_A
# ============ 第三步:拼接输入并执行网络推理 =============
bs = self.mutil_real_A0_tokens.size(0)
z_in = torch.randn(size=[2 * bs, 4 * self.opt.ngf]).to(self.mutil_real_A0_tokens.device)
z_in2 = torch.randn(size=[bs, 4 * self.opt.ngf]).to(self.mutil_real_A1_tokens.device)
# 将 real_A, real_B 拼接 (如 nce_idt=True),并同样处理 real_A_noisy 与 XtB
self.real = self.mutil_real_A0_tokens
self.realt = self.real_A_noisy
if self.opt.flip_equivariance:
self.flipped_for_equivariance = self.opt.isTrain and (np.random.random() < 0.5)
if self.flipped_for_equivariance:
self.real = torch.flip(self.real, [3])
self.realt = torch.flip(self.realt, [3])
self.fake_B0 = self.netG(self.real_A0)
self.fake_B1 = self.netG(self.real_A1)
if self.opt.isTrain:
real_A0 = self.real_A0
real_A1 = self.real_A1
real_B0 = self.real_B0
real_B1 = self.real_B1
fake_B0 = self.fake_B0
fake_B1 = self.fake_B1
self.real_A0_resize = self.resize(real_A0)
self.real_A1_resize = self.resize(real_A1)
real_B0 = self.resize(real_B0)
real_B1 = self.resize(real_B1)
self.fake_B0_resize = self.resize(fake_B0)
self.fake_B1_resize = self.resize(fake_B1)
self.mutil_real_A0_tokens = self.netPreViT(self.real_A0_resize, self.atten_layers, get_tokens=True)
self.mutil_real_A1_tokens = self.netPreViT(self.real_A1_resize, self.atten_layers, get_tokens=True)
self.mutil_real_B0_tokens = self.netPreViT(real_B0, self.atten_layers, get_tokens=True)
self.mutil_real_B1_tokens = self.netPreViT(real_B1, self.atten_layers, get_tokens=True)
self.mutil_fake_B0_tokens = self.netPreViT(self.fake_B0_resize, self.atten_layers, get_tokens=True)
self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True)
def tokens_concat(self, origin_tokens, adjacent_size):
adj_size = adjacent_size
B, token_num, C = origin_tokens.shape[0], origin_tokens.shape[1], origin_tokens.shape[2]
S = int(math.sqrt(token_num))
if S * S != token_num:
print('Error! Not a square!')
token_map = origin_tokens.clone().reshape(B,S,S,C)
cut_patch_list = []
for i in range(0, S, adj_size):
for j in range(0, S, adj_size):
i_left = i
i_right = i + adj_size + 1 if i + adj_size <= S else S + 1
j_left = j
j_right = j + adj_size if j + adj_size <= S else S + 1
cut_patch = token_map[:, i_left:i_right, j_left: j_right, :]
cut_patch= cut_patch.reshape(B,-1,C)
cut_patch = torch.mean(cut_patch, dim=1, keepdim=True)
cut_patch_list.append(cut_patch)
result = torch.cat(cut_patch_list,dim=1)
return result
def cat_results(self, origin_tokens, adj_size_list):
res_list = [origin_tokens]
for ad_s in adj_size_list:
cat_result = self.tokens_concat(origin_tokens, ad_s)
res_list.append(cat_result)
result = torch.cat(res_list, dim=1)
return result
def compute_D_loss(self):
"""Calculate GAN loss for the discriminator"""
lambda_D_ViT = self.opt.lambda_D_ViT
fake_B0_tokens = self.mutil_fake_B0_tokens[self.opt.which_D_layer].detach()
fake_B1_tokens = self.mutil_fake_B1_tokens[self.opt.which_D_layer].detach()
real_B0_tokens = self.mutil_real_B0_tokens[self.opt.which_D_layer]
real_B1_tokens = self.mutil_real_B1_tokens[self.opt.which_D_layer]
fake_B0_tokens = self.cat_results(fake_B0_tokens, self.opt.adj_size_list)
fake_B1_tokens = self.cat_results(fake_B1_tokens, self.opt.adj_size_list)
real_B0_tokens = self.cat_results(real_B0_tokens, self.opt.adj_size_list)
real_B1_tokens = self.cat_results(real_B1_tokens, self.opt.adj_size_list)
pre_fake0_ViT = self.netD_ViT(fake_B0_tokens)
pre_fake1_ViT = self.netD_ViT(fake_B1_tokens)
self.loss_D_fake_ViT = (self.criterionGAN(pre_fake0_ViT, False).mean() + self.criterionGAN(pre_fake1_ViT, False).mean()) * 0.5 * lambda_D_ViT
pred_real0_ViT = self.netD_ViT(real_B0_tokens)
pred_real1_ViT = self.netD_ViT(real_B1_tokens)
self.loss_D_real_ViT = (self.criterionGAN(pred_real0_ViT, True).mean() + self.criterionGAN(pred_real1_ViT, True).mean()) * 0.5 * lambda_D_ViT
self.loss_D_ViT = (self.loss_D_fake_ViT + self.loss_D_real_ViT) * 0.5
return self.loss_D_ViT
def compute_G_loss(self):
if self.opt.lambda_GAN > 0.0:
fake_B0_tokens = self.mutil_fake_B0_tokens[self.opt.which_D_layer]
fake_B1_tokens = self.mutil_fake_B1_tokens[self.opt.which_D_layer]
fake_B0_tokens = self.cat_results(fake_B0_tokens, self.opt.adj_size_list)
fake_B1_tokens = self.cat_results(fake_B1_tokens, self.opt.adj_size_list)
pred_fake0_ViT = self.netD_ViT(fake_B0_tokens)
pred_fake1_ViT = self.netD_ViT(fake_B1_tokens)
self.loss_G_GAN_ViT = (self.criterionGAN(pred_fake0_ViT, True) + self.criterionGAN(pred_fake1_ViT, True)) * 0.5 * self.opt.lambda_GAN
else:
self.loss_G_GAN_ViT = 0.0
if self.opt.lambda_global > 0.0 or self.opt.lambda_spatial > 0.0:
self.loss_global, self.loss_spatial = self.calculate_attention_loss()
else:
self.loss_global, self.loss_spatial = 0.0, 0.0
if self.opt.lambda_motion > 0.0:
self.loss_motion = 0.0
for real_A0_tokens, real_A1_tokens, fake_B0_tokens, fake_B1_tokens in zip(self.mutil_real_A0_tokens, self.mutil_real_A1_tokens, self.mutil_fake_B0_tokens, self.mutil_fake_B1_tokens):
A0_B1 = real_A0_tokens.bmm(fake_B1_tokens.permute(0,2,1))
B0_A1 = fake_B0_tokens.bmm(real_A1_tokens.permute(0,2,1))
cos_dis_global = F.cosine_similarity(A0_B1, B0_A1, dim=-1)
self.loss_motion += self.criterionL1(torch.ones_like(cos_dis_global), cos_dis_global).mean()
else:
self.loss_motion = 0.0
self.loss_G = self.loss_G_GAN_ViT + self.loss_global + self.loss_spatial + self.loss_motion
return self.loss_G
def calculate_attention_loss(self):
n_layers = len(self.atten_layers)
mutil_real_A0_tokens = self.mutil_real_A0_tokens
mutil_real_A1_tokens = self.mutil_real_A1_tokens
mutil_fake_B0_tokens = self.mutil_fake_B0_tokens
mutil_fake_B1_tokens = self.mutil_fake_B1_tokens
if self.opt.lambda_global > 0.0:
loss_global = self.calculate_similarity(mutil_real_A0_tokens, mutil_fake_B0_tokens) + self.calculate_similarity(mutil_real_A1_tokens, mutil_fake_B1_tokens)
loss_global *= 0.5
else:
loss_global = 0.0
if self.opt.lambda_spatial > 0.0:
loss_spatial = 0.0
local_nums = self.opt.local_nums
tokens_cnt = 576
local_id = np.random.permutation(tokens_cnt)
local_id = local_id[:int(min(local_nums, tokens_cnt))]
mutil_real_A0_local_tokens = self.netPreViT(self.real_A0_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
mutil_real_A1_local_tokens = self.netPreViT(self.real_A1_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
mutil_fake_B0_local_tokens = self.netPreViT(self.fake_B0_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
mutil_fake_B1_local_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
loss_spatial = self.calculate_similarity(mutil_real_A0_local_tokens, mutil_fake_B0_local_tokens) + self.calculate_similarity(mutil_real_A1_local_tokens, mutil_fake_B1_local_tokens)
loss_spatial *= 0.5
else:
loss_spatial = 0.0
return loss_global * self.opt.lambda_global, loss_spatial * self.opt.lambda_spatial
def calculate_similarity(self, mutil_src_tokens, mutil_tgt_tokens):
loss = 0.0
n_layers = len(self.atten_layers)
for src_tokens, tgt_tokens in zip(mutil_src_tokens, mutil_tgt_tokens):
src_tgt = src_tokens.bmm(tgt_tokens.permute(0,2,1))
tgt_src = tgt_tokens.bmm(src_tokens.permute(0,2,1))
cos_dis_global = F.cosine_similarity(src_tgt, tgt_src, dim=-1)
loss += self.criterionL1(torch.ones_like(cos_dis_global), cos_dis_global).mean()
loss = loss / n_layers
return loss

272
models/roma_single_model.py Normal file
View File

@ -0,0 +1,272 @@
import numpy as np
import torch
from .base_model import BaseModel
from . import networks
from .patchnce import PatchNCELoss
import util.util as util
import timm
import time
import torch.nn.functional as F
import sys
from functools import partial
import torch.nn as nn
import math
from torchvision.transforms import transforms as tfs
class ROMASingleModel(BaseModel):
@staticmethod
def modify_commandline_options(parser, is_train=True):
""" Configures options specific for CUT model
"""
parser.add_argument('--adj_size_list', type=list, default=[2, 4, 6, 8, 12], help='different scales of perception field')
parser.add_argument('--lambda_mlp', type=float, default=1.0, help='weight of lr for discriminator')
parser.add_argument('--lambda_motion', type=float, default=1.0, help='weight for Temporal Consistency')
parser.add_argument('--lambda_D_ViT', type=float, default=1.0, help='weight for discriminator')
parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss: GAN(G(X))')
parser.add_argument('--lambda_global', type=float, default=1.0, help='weight for Global Structural Consistency')
parser.add_argument('--lambda_spatial', type=float, default=1.0, help='weight for Local Structural Consistency')
parser.add_argument('--atten_layers', type=str, default='1,3,5', help='compute Cross-Similarity on which layers')
parser.add_argument('--local_nums', type=int, default=256)
parser.add_argument('--which_D_layer', type=int, default=-1)
parser.add_argument('--side_length', type=int, default=7)
parser.set_defaults(pool_size=0)
opt, _ = parser.parse_known_args()
return parser
def __init__(self, opt):
BaseModel.__init__(self, opt)
self.loss_names = ['G_GAN_ViT', 'D_real_ViT', 'D_fake_ViT', 'global', 'spatial']
self.visual_names = ['real_A', 'fake_B', 'real_B']
self.atten_layers = [int(i) for i in self.opt.atten_layers.split(',')]
if self.isTrain:
self.model_names = ['G', 'D_ViT']
else: # during test time, only load G
self.model_names = ['G']
# define networks (both generator and discriminator)
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt)
if self.isTrain:
self.netD_ViT = networks.MLPDiscriminator().to(self.device)
# self.netPreViT = timm.create_model("vit_base_patch32_384",pretrained=True).to(self.device)
self.netPreViT = timm.create_model("vit_base_patch16_384",pretrained=True).to(self.device)
self.norm = F.softmax
self.resize = tfs.Resize(size=(384,384))
# self.resize = tfs.Resize(size=(224, 224))
# define loss functions
self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
self.criterionNCE = []
for atten_layer in self.atten_layers:
self.criterionNCE.append(PatchNCELoss(opt).to(self.device))
self.criterionL1 = torch.nn.L1Loss().to(self.device)
self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
self.optimizer_D_ViT = torch.optim.Adam(self.netD_ViT.parameters(), lr=opt.lr * opt.lambda_mlp, betas=(opt.beta1, opt.beta2))
self.optimizers.append(self.optimizer_G)
self.optimizers.append(self.optimizer_D_ViT)
def data_dependent_initialize(self, data):
"""
The feature network netF is defined in terms of the shape of the intermediate, extracted
features of the encoder portion of netG. Because of this, the weights of netF are
initialized at the first feedforward pass with some input images.
Please also see PatchSampleF.create_mlp(), which is called at the first forward() call.
"""
pass
def optimize_parameters(self):
# forward
self.forward()
# update D
self.set_requires_grad(self.netD_ViT, True)
self.optimizer_D_ViT.zero_grad()
self.loss_D = self.compute_D_loss()
self.loss_D.backward()
self.optimizer_D_ViT.step()
# update G
self.set_requires_grad(self.netD_ViT, False)
self.optimizer_G.zero_grad()
self.loss_G = self.compute_G_loss()
self.loss_G.backward()
self.optimizer_G.step()
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
input (dict): include the data itself and its metadata information.
The option 'direction' can be used to swap domain A and domain B.
"""
AtoB = self.opt.direction == 'AtoB'
self.real_A = input['A' if AtoB else 'B'].to(self.device)
self.real_B = input['B' if AtoB else 'A'].to(self.device)
self.image_paths = input['A_paths' if AtoB else 'B_paths']
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
self.fake_B = self.netG(self.real_A)
if self.opt.isTrain:
real_A = self.real_A
real_B = self.real_B
fake_B = self.fake_B
self.real_A_resize = self.resize(real_A)
real_B = self.resize(real_B)
self.fake_B_resize = self.resize(fake_B)
self.mutil_real_A_tokens = self.netPreViT(self.real_A_resize, self.atten_layers, get_tokens=True)
self.mutil_real_B_tokens = self.netPreViT(real_B, self.atten_layers, get_tokens=True)
self.mutil_fake_B_tokens = self.netPreViT(self.fake_B_resize, self.atten_layers, get_tokens=True)
def tokens_concat(self, origin_tokens, adjacent_size):
adj_size = adjacent_size
B, token_num, C = origin_tokens.shape[0], origin_tokens.shape[1], origin_tokens.shape[2]
S = int(math.sqrt(token_num))
if S * S != token_num:
print('Error! Not a square!')
token_map = origin_tokens.clone().reshape(B,S,S,C)
cut_patch_list = []
for i in range(0, S, adj_size):
for j in range(0, S, adj_size):
i_left = i
i_right = i + adj_size + 1 if i + adj_size <= S else S + 1
j_left = j
j_right = j + adj_size if j + adj_size <= S else S + 1
cut_patch = token_map[:, i_left:i_right, j_left: j_right, :]
cut_patch= cut_patch.reshape(B,-1,C)
cut_patch = torch.mean(cut_patch, dim=1, keepdim=True)
cut_patch_list.append(cut_patch)
result = torch.cat(cut_patch_list,dim=1)
return result
def cat_results(self, origin_tokens, adj_size_list):
res_list = [origin_tokens]
for ad_s in adj_size_list:
cat_result = self.tokens_concat(origin_tokens, ad_s)
res_list.append(cat_result)
result = torch.cat(res_list, dim=1)
return result
def compute_D_loss(self):
"""Calculate GAN loss for the discriminator"""
lambda_D_ViT = self.opt.lambda_D_ViT
fake_B_tokens = self.mutil_fake_B_tokens[self.opt.which_D_layer].detach()
real_B_tokens = self.mutil_real_B_tokens[self.opt.which_D_layer]
fake_B_tokens = self.cat_results(fake_B_tokens, self.opt.adj_size_list)
real_B_tokens = self.cat_results(real_B_tokens, self.opt.adj_size_list)
pre_fake_ViT = self.netD_ViT(fake_B_tokens)
self.loss_D_fake_ViT = self.criterionGAN(pre_fake_ViT, False).mean() * lambda_D_ViT
pred_real_ViT = self.netD_ViT(real_B_tokens)
self.loss_D_real_ViT = self.criterionGAN(pred_real_ViT, True).mean() * lambda_D_ViT
self.loss_D_ViT = (self.loss_D_fake_ViT + self.loss_D_real_ViT) * 0.5
return self.loss_D_ViT
def compute_G_loss(self):
if self.opt.lambda_GAN > 0.0:
fake_B_tokens = self.mutil_fake_B_tokens[self.opt.which_D_layer]
fake_B_tokens = self.cat_results(fake_B_tokens, self.opt.adj_size_list)
pred_fake_ViT = self.netD_ViT(fake_B_tokens)
self.loss_G_GAN_ViT = self.criterionGAN(pred_fake_ViT, True) * self.opt.lambda_GAN
else:
self.loss_G_GAN_ViT = 0.0
if self.opt.lambda_global > 0.0 or self.opt.lambda_spatial > 0.0:
self.loss_global, self.loss_spatial = self.calculate_attention_loss()
else:
self.loss_global, self.loss_spatial = 0.0, 0.0
self.loss_G = self.loss_G_GAN_ViT + self.loss_global + self.loss_spatial
return self.loss_G
def calculate_attention_loss(self):
n_layers = len(self.atten_layers)
mutil_real_A_tokens = self.mutil_real_A_tokens
mutil_fake_B_tokens = self.mutil_fake_B_tokens
if self.opt.lambda_global > 0.0:
loss_global = self.calculate_similarity(mutil_real_A_tokens, mutil_fake_B_tokens)
else:
loss_global = 0.0
if self.opt.lambda_spatial > 0.0:
loss_spatial = 0.0
local_nums = self.opt.local_nums
tokens_cnt = 576
local_id = np.random.permutation(tokens_cnt)
local_id = local_id[:int(min(local_nums, tokens_cnt))]
mutil_real_A_local_tokens = self.netPreViT(self.real_A_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
mutil_fake_B_local_tokens = self.netPreViT(self.fake_B_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
loss_spatial = self.calculate_similarity(mutil_real_A_local_tokens, mutil_fake_B_local_tokens)
else:
loss_spatial = 0.0
return loss_global * self.opt.lambda_global, loss_spatial * self.opt.lambda_spatial
def calculate_similarity(self, mutil_src_tokens, mutil_tgt_tokens):
loss = 0.0
n_layers = len(self.atten_layers)
for src_tokens, tgt_tokens in zip(mutil_src_tokens, mutil_tgt_tokens):
src_tgt = src_tokens.bmm(tgt_tokens.permute(0,2,1))
tgt_src = tgt_tokens.bmm(src_tokens.permute(0,2,1))
cos_dis_global = F.cosine_similarity(src_tgt, tgt_src, dim=-1)
loss += self.criterionL1(torch.ones_like(cos_dis_global), cos_dis_global).mean()
loss = loss / n_layers
return loss

655
models/self_build.py Normal file
View File

@ -0,0 +1,655 @@
import numpy as np
import math
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import GaussianBlur
from .base_model import BaseModel
from . import networks
from .patchnce import PatchNCELoss
import util.util as util
from torchvision.transforms import transforms as tfs
def warp(image, flow): #warp操作
"""
基于光流的图像变形函数
Args:
image: [B, C, H, W] 输入图像
flow: [B, 2, H, W] 光流场x/y方向位移
Returns:
warped: [B, C, H, W] 变形后的图像
"""
B, C, H, W = image.shape
# 生成网格坐标
grid_x, grid_y = torch.meshgrid(torch.arange(W), torch.arange(H))
grid = torch.stack((grid_x, grid_y), dim=0).float().to(image.device) # [2,H,W]
grid = grid.unsqueeze(0).repeat(B,1,1,1) # [B,2,H,W]
# 应用光流位移(归一化到[-1,1]
new_grid = grid + flow
new_grid[:,0,:,:] = 2.0 * new_grid[:,0,:,:] / (W-1) - 1.0 # x方向
new_grid[:,1,:,:] = 2.0 * new_grid[:,1,:,:] / (H-1) - 1.0 # y方向
new_grid = new_grid.permute(0,2,3,1) # [B,H,W,2]
# 双线性插值
return F.grid_sample(image, new_grid, align_corners=True)
# 时序归一化损失计算
def compute_ctn_loss(G, x, F_content): #公式10
"""
计算内容感知时序归一化损失
Args:
G: 生成器
x: 输入红外图像 [B,C,H,W]
F_content: 生成的光流场 [B,2,H,W]
"""
# 生成可见光图像
y_fake = G(x) # [B,3,H,W]
# 对生成结果应用光流变形
warped_fake = warp(y_fake, F_content) # [B,3,H,W]
# 对输入应用相同光流后生成图像
warped_x = warp(x, F_content) # [B,C,H,W]
y_fake_warped = G(warped_x) # [B,3,H,W]
# 计算L2损失
loss = F.mse_loss(warped_fake, y_fake_warped)
return loss
class ContentAwareOptimization(nn.Module):
def __init__(self, lambda_inc=2.0, eta_ratio=0.4):
super().__init__()
self.lambda_inc = lambda_inc # 权重增强系数
self.eta_ratio = eta_ratio # 选择内容区域的比例
def compute_cosine_similarity(self, gradients):
"""
计算每个patch梯度与平均梯度的余弦相似度
Args:
gradients: [B, N, D] 判别器输出的每个patch的梯度N=w*h
Returns:
cosine_sim: [B, N] 每个patch的余弦相似度
"""
mean_grad = torch.mean(gradients, dim=1, keepdim=True) # [B, 1, D]
# 计算余弦相似度
cosine_sim = F.cosine_similarity(gradients, mean_grad, dim=2) # [B, N]
return cosine_sim
def generate_weight_map(self, gradients_real, gradients_fake):
"""
生成内容感知权重图
Args:
gradients_real: [B, N, D] 真实图像判别器梯度
gradients_fake: [B, N, D] 生成图像判别器梯度
Returns:
weight_real: [B, N] 真实图像权重图
weight_fake: [B, N] 生成图像权重图
"""
# 计算真实图像块的余弦相似度
cosine_real = self.compute_cosine_similarity(gradients_real) # [B, N] 公式5
# 计算生成图像块的余弦相似度
cosine_fake = self.compute_cosine_similarity(gradients_fake) # [B, N]
# 选择内容丰富的区域余弦相似度最低的eta_ratio比例
k = int(self.eta_ratio * cosine_real.shape[1])
# 对真实图像生成权重图
_, real_indices = torch.topk(-cosine_real, k, dim=1) # 选择最不相似的区域
weight_real = torch.ones_like(cosine_real)
for b in range(cosine_real.shape[0]):
weight_real[b, real_indices[b]] = self.lambda_inc / (1e-6 + torch.abs(cosine_real[b, real_indices[b]])) #公式6
# 对生成图像生成权重图(同理)
_, fake_indices = torch.topk(-cosine_fake, k, dim=1)
weight_fake = torch.ones_like(cosine_fake)
for b in range(cosine_fake.shape[0]):
weight_fake[b, fake_indices[b]] = self.lambda_inc / (1e-6 + torch.abs(cosine_fake[b, fake_indices[b]]))
return weight_real, weight_fake
def forward(self, D_real, D_fake, real_scores, fake_scores):
"""
计算内容感知对抗损失
Args:
D_real: 判别器对真实图像的特征输出 [B, C, H, W]
D_fake: 判别器对生成图像的特征输出 [B, C, H, W]
real_scores: 真实图像的判别器预测 [B, N] (N=H*W)
fake_scores: 生成图像的判别器预测 [B, N]
Returns:
loss_co_adv: 内容感知对抗损失
"""
B, C, H, W = D_real.shape
N = H * W
# 注册钩子获取梯度
gradients_real = []
gradients_fake = []
def hook_real(grad):
gradients_real.append(grad.detach().view(B, N, -1))
def hook_fake(grad):
gradients_fake.append(grad.detach().view(B, N, -1))
D_real.register_hook(hook_real)
D_fake.register_hook(hook_fake)
# 计算原始对抗损失以触发梯度计算
loss_real = torch.mean(torch.log(real_scores + 1e-8))
loss_fake = torch.mean(torch.log(1 - fake_scores + 1e-8))
# 添加与 D_real、D_fake 相关的 dummy 项,确保梯度传递
loss_dummy = 1e-8 * (D_real.sum() + D_fake.sum())
total_loss = loss_real + loss_fake + loss_dummy
total_loss.backward(retain_graph=True)
# 获取梯度数据
gradients_real = gradients_real[0] # [B, N, D]
gradients_fake = gradients_fake[0] # [B, N, D]
# 生成权重图
self.weight_real, self.weight_fake = self.generate_weight_map(gradients_real, gradients_fake)
# 应用权重到对抗损失
loss_co_real = torch.mean(self.weight_real * torch.log(real_scores + 1e-8))
loss_co_fake = torch.mean(self.weight_fake * torch.log(1 - fake_scores + 1e-8))
# 计算并返回最终内容感知对抗损失
loss_co_adv = -(loss_co_real + loss_co_fake)
return loss_co_adv
class ContentAwareTemporalNorm(nn.Module):
def __init__(self, gamma_stride=0.1, kernel_size=21, sigma=5.0):
super().__init__()
self.gamma_stride = gamma_stride # 控制整体运动幅度
self.smoother = GaussianBlur(kernel_size, sigma=sigma) # 高斯平滑层
def forward(self, weight_map):
"""
生成内容感知光流
Args:
weight_map: [B, 1, H, W] 权重图来自内容感知优化模块
Returns:
F_content: [B, 2, H, W] 生成的光流场x/y方向位移
"""
B, _, H, W = weight_map.shape
# 1. 归一化权重图
# 保持区域相对强度,同时限制数值范围
weight_norm = F.normalize(weight_map, p=1, dim=(2,3)) # L1归一化 [B,1,H,W]
# 2. 生成高斯噪声(与光流场同尺寸)
z = torch.randn(B, 2, H, W, device=weight_map.device) # [B,2,H,W]
# 3. 合成基础光流
# 将权重图扩展为2通道x/y方向共享权重
weight_expanded = weight_norm.expand(-1, 2, -1, -1) # [B,2,H,W]
F_raw = self.gamma_stride * weight_expanded * z # [B,2,H,W] #公式9
# 4. 平滑处理(保持结构连续性)
# 对每个通道独立进行高斯模糊
F_smooth = self.smoother(F_raw) # [B,2,H,W]
# 5. 动态范围调整(可选)
# 限制光流幅值,避免极端位移
F_content = torch.tanh(F_smooth) # 缩放到[-1,1]范围
return F_content
class CTNxModel(BaseModel):
@staticmethod
def modify_commandline_options(parser, is_train=True):
"""配置 CTNx 模型的特定选项"""
parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN lossGAN(G(X))')
parser.add_argument('--lambda_NCE', type=float, default=1.0, help='weight for NCE loss: NCE(G(X), X)')
parser.add_argument('--lambda_SB', type=float, default=0.1, help='weight for SB loss')
parser.add_argument('--lambda_ctn', type=float, default=1.0, help='weight for content-aware temporal norm')
parser.add_argument('--nce_idt', type=util.str2bool, nargs='?', const=True, default=False, help='use NCE loss for identity mapping: NCE(G(Y), Y))')
parser.add_argument('--nce_layers', type=str, default='0,4,8,12,16', help='compute NCE loss on which layers')
parser.add_argument('--nce_includes_all_negatives_from_minibatch',
type=util.str2bool, nargs='?', const=True, default=False,
help='(used for single image translation) If True, include the negatives from the other samples of the minibatch when computing the contrastive loss. Please see models/patchnce.py for more details.')
parser.add_argument('--netF', type=str, default='mlp_sample', choices=['sample', 'reshape', 'mlp_sample'], help='how to downsample the feature map')
parser.add_argument('--netF_nc', type=int, default=256)
parser.add_argument('--nce_T', type=float, default=0.07, help='temperature for NCE loss')
parser.add_argument('--lmda_1', type=float, default=0.1)
parser.add_argument('--num_patches', type=int, default=256, help='number of patches per layer')
parser.add_argument('--flip_equivariance',
type=util.str2bool, nargs='?', const=True, default=False,
help="Enforce flip-equivariance as additional regularization. It's used by FastCUT, but not CUT")
parser.add_argument('--lambda_inc', type=float, default=1.0, help='incremental weight for content-aware optimization')
parser.add_argument('--eta_ratio', type=float, default=0.1, help='ratio of content-rich regions')
parser.set_defaults(pool_size=0) # no image pooling
opt, _ = parser.parse_known_args()
# 直接设置为 sb 模式
parser.set_defaults(nce_idt=True, lambda_NCE=1.0)
return parser
def __init__(self, opt):
"""初始化 CTNx 模型"""
BaseModel.__init__(self, opt)
# 指定需要打印的训练损失
self.loss_names = ['G_GAN_1', 'D_real_1', 'D_fake_1', 'G_1', 'NCE_1', 'SB_1',
'G_2']
self.visual_names = ['real_A', 'real_A_noisy', 'fake_B', 'real_B']
self.atten_layers = [int(i) for i in self.opt.atten_layers.split(',')]
if self.opt.phase == 'test':
self.visual_names = ['real']
for NFE in range(self.opt.num_timesteps):
fake_name = 'fake_' + str(NFE+1)
self.visual_names.append(fake_name)
self.nce_layers = [int(i) for i in self.opt.nce_layers.split(',')]
if opt.nce_idt and self.isTrain:
self.loss_names += ['NCE_Y']
self.visual_names += ['idt_B']
if self.isTrain:
self.model_names = ['G1', 'F1', 'D1', 'E1',
'G2']
else:
self.model_names = ['G1']
# 创建网络
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt)
if self.isTrain:
self.netD = networks.define_D(opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.normD, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt)
self.netE = networks.define_D(opt.output_nc*4, opt.ndf, opt.netD, opt.n_layers_D, opt.normD, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt)
self.resize = tfs.Resize(size=(384,384))
# 加入预训练VIT
self.netPreViT = timm.create_model("vit_base_patch16_384", pretrained=True).to(self.device)
# 定义损失函数
self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
self.criterionNCE = []
for nce_layer in self.nce_layers:
self.criterionNCE.append(PatchNCELoss(opt).to(self.device))
self.criterionIdt = torch.nn.L1Loss().to(self.device)
self.optimizer_G1 = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
self.optimizer_D1 = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
self.optimizer_E1 = torch.optim.Adam(self.netE.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
self.optimizers = [self.optimizer_G1, self.optimizer_D1, self.optimizer_E1]
self.cao = ContentAwareOptimization(opt.lambda_inc, opt.eta_ratio) #损失函数
self.ctn = ContentAwareTemporalNorm() #生成的伪光流
def data_dependent_initialize(self, data):
"""
The feature network netF is defined in terms of the shape of the intermediate, extracted
features of the encoder portion of netG. Because of this, the weights of netF are
initialized at the first feedforward pass with some input images.
Please also see PatchSampleF.create_mlp(), which is called at the first forward() call.
"""
#bs_per_gpu = data["A"].size(0) // max(len(self.opt.gpu_ids), 1)
#self.set_input(data)
#self.real_A = self.real_A[:bs_per_gpu]
#self.real_B = self.real_B[:bs_per_gpu]
#self.forward() # compute fake images: G(A)
#if self.opt.isTrain:
#
# self.compute_G_loss().backward()
# self.compute_D_loss().backward()
# self.compute_E_loss().backward()
# if self.opt.lambda_NCE > 0.0:
# self.optimizer_F = torch.optim.Adam(self.netF.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, self.opt.beta2))
# self.optimizers.append(self.optimizer_F)
pass
def optimize_parameters(self):
# forward
self.forward()
self.netG.train()
self.netE.train()
self.netD.train()
# update D
self.set_requires_grad(self.netD, True)
self.optimizer_D.zero_grad()
self.loss_D = self.compute_D_loss()
self.loss_D.backward()
self.optimizer_D.step()
self.set_requires_grad(self.netE, True)
self.optimizer_E.zero_grad()
self.loss_E = self.compute_E_loss()
self.loss_E.backward()
self.optimizer_E.step()
# update G
self.set_requires_grad(self.netD, False)
self.set_requires_grad(self.netE, False)
self.optimizer_G.zero_grad()
self.loss_G = self.compute_G_loss()
self.loss_G.backward()
self.optimizer_G.step()
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
input (dict): include the data itself and its metadata information.
The option 'direction' can be used to swap domain A and domain B.
"""
AtoB = self.opt.direction == 'AtoB'
self.real_A0 = input['A0' if AtoB else 'B0'].to(self.device)
self.real_A1 = input['A1' if AtoB else 'B1'].to(self.device)
self.real_B0 = input['B0' if AtoB else 'A0'].to(self.device)
self.real_B1 = input['B1' if AtoB else 'A1'].to(self.device)
self.image_paths = input['A_paths' if AtoB else 'B_paths']
def tokens_concat(self, origin_tokens, adjacent_size):
adj_size = adjacent_size
B, token_num, C = origin_tokens.shape[0], origin_tokens.shape[1], origin_tokens.shape[2]
S = int(math.sqrt(token_num))
if S * S != token_num:
print('Error! Not a square!')
token_map = origin_tokens.clone().reshape(B,S,S,C)
cut_patch_list = []
for i in range(0, S, adj_size):
for j in range(0, S, adj_size):
i_left = i
i_right = i + adj_size + 1 if i + adj_size <= S else S + 1
j_left = j
j_right = j + adj_size if j + adj_size <= S else S + 1
cut_patch = token_map[:, i_left:i_right, j_left: j_right, :]
cut_patch= cut_patch.reshape(B,-1,C)
cut_patch = torch.mean(cut_patch, dim=1, keepdim=True)
cut_patch_list.append(cut_patch)
result = torch.cat(cut_patch_list,dim=1)
return result
def cat_results(self, origin_tokens, adj_size_list):
res_list = [origin_tokens]
for ad_s in adj_size_list:
cat_result = self.tokens_concat(origin_tokens, ad_s)
res_list.append(cat_result)
result = torch.cat(res_list, dim=1)
return result
def forward(self):
"""执行前向传递以生成输出图像"""
if self.opt.isTrain:
real_A0 = self.resize(self.real_A0)
real_A1 = self.resize(self.real_A1)
real_B0 = self.resize(self.real_B0)
real_B1 = self.resize(self.real_B1)
# 使用VIT
self.mutil_real_A0_tokens = self.netPreViT(real_A0, self.atten_layers, get_tokens=True)
self.mutil_real_A1_tokens = self.netPreViT(real_A1, self.atten_layers, get_tokens=True)
# 执行一次SB模块
# ============ 第一步:初始化时间步与时间索引 ============
# 计算 times并确定当前 time_idx随机选取用来表示当前时间步
tau = self.opt.tau
T = self.opt.num_timesteps
incs = np.array([0] + [1/(i+1) for i in range(T-1)])
times = np.cumsum(incs)
times = times / times[-1]
times = 0.5 * times[-1] + 0.5 * times #[0.5,1]
times = np.concatenate([np.zeros(1), times])
times = torch.tensor(times).float().cuda()
self.times = times
bs = self.mutil_real_A0_tokens.size(0)
time_idx = (torch.randint(T, size=[1]).cuda() * torch.ones(size=[1]).cuda()).long()
self.time_idx = time_idx
with torch.no_grad():
self.netG.eval()
# ============ 第二步:对 real_A / real_A2 进行多步随机生成过程 ============
for t in range(self.time_idx.int().item() + 1):
# 计算增量 delta 与 inter/scale用于每个时间步的插值等
if t > 0:
delta = times[t] - times[t - 1]
denom = times[-1] - times[t - 1]
inter = (delta / denom).reshape(-1, 1, 1, 1)
scale = (delta * (1 - delta / denom)).reshape(-1, 1, 1, 1)
# 对 Xt、Xt2 进行随机噪声更新
Xt = self.mutil_real_A0_tokens if (t == 0) else (1 - inter) * Xt + inter * Xt_1.detach() + \
(scale * tau).sqrt() * torch.randn_like(Xt).to(self.mutil_real_A0_tokens.device)
time_idx = (t * torch.ones(size=[self.mutil_real_A0_tokens.shape[0]]).to(self.mutil_real_A0_tokens.device)).long()
z = torch.randn(size=[self.mutil_real_A0_tokens.shape[0], 4 * self.opt.ngf]).to(self.mutil_real_A0_tokens.device)
self.time = times[time_idx]
Xt_1 = self.netG(Xt, self.time, z)
Xt2 = self.mutil_real_A1_tokens if (t == 0) else (1 - inter) * Xt2 + inter * Xt_12.detach() + \
(scale * tau).sqrt() * torch.randn_like(Xt2).to(self.mutil_real_A1_tokens.device)
time_idx = (t * torch.ones(size=[self.mutil_real_A1_tokens.shape[0]]).to(self.mutil_real_A1_tokens.device)).long()
z = torch.randn(size=[self.mutil_real_A1_tokens.shape[0], 4 * self.opt.ngf]).to(self.mutil_real_A1_tokens.device)
Xt_12 = self.netG(Xt2, self.time, z)
# 保存去噪后的中间结果 (real_A_noisy 等),供下一步做拼接
self.real_A_noisy = Xt.detach()
self.real_A_noisy2 = Xt2.detach()
# 保存noisy_map
self.noisy_map = self.real_A_noisy - self.real_A
# ============ 第三步:拼接输入并执行网络推理 =============
bs = self.mutil_real_A0_tokens.size(0)
z_in = torch.randn(size=[2 * bs, 4 * self.opt.ngf]).to(self.mutil_real_A0_tokens.device)
z_in2 = torch.randn(size=[bs, 4 * self.opt.ngf]).to(self.mutil_real_A1_tokens.device)
# 将 real_A, real_B 拼接 (如 nce_idt=True),并同样处理 real_A_noisy 与 XtB
self.real = self.mutil_real_A0_tokens
self.realt = self.real_A_noisy
if self.opt.flip_equivariance:
self.flipped_for_equivariance = self.opt.isTrain and (np.random.random() < 0.5)
if self.flipped_for_equivariance:
self.real = torch.flip(self.real, [3])
self.realt = torch.flip(self.realt, [3])
# 使用 netG 生成最终的 fake, fake_B2 等结果
self.fake_B = self.netG(self.realt, self.time, z_in)
self.fake_B2 = self.netG(self.real, self.time, z_in2)
self.fake_B = self.resize(self.fake_B)
self.fake_B2 = self.resize(self.fake_B2)
self.fake_B0 = self.fake_B
self.fake_B1 = self.fake_B2
# 使用VIT
self.mutil_fake_B0_tokens = self.netPreViT(self.fake_B, self.atten_layers, get_tokens=True)
self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B2, self.atten_layers, get_tokens=True)
# ============ 第四步:推理模式下的多次采样 ============
if self.opt.phase == 'test':
tau = self.opt.tau
T = self.opt.num_timesteps
incs = np.array([0] + [1/(i+1) for i in range(T-1)])
times = np.cumsum(incs)
times = times / times[-1]
times = 0.5 * times[-1] + 0.5 * times
times = np.concatenate([np.zeros(1),times])
times = torch.tensor(times).float().cuda()
self.times = times
bs = self.real.size(0)
time_idx = (torch.randint(T, size=[1]).cuda() * torch.ones(size=[1]).cuda()).long()
self.time_idx = time_idx
visuals = []
with torch.no_grad():
self.netG.eval()
for t in range(self.opt.num_timesteps):
if t > 0:
delta = times[t] - times[t-1]
denom = times[-1] - times[t-1]
inter = (delta / denom).reshape(-1,1,1,1)
scale = (delta * (1 - delta / denom)).reshape(-1,1,1,1)
Xt = self.mutil_real_A0_tokens if (t == 0) else (1-inter) * Xt + inter * Xt_1.detach() + (scale * tau).sqrt() * torch.randn_like(Xt).to(self.mutil_real_A0_tokens.device)
time_idx = (t * torch.ones(size=[self.mutil_real_A0_tokens.shape[0]]).to(self.mutil_real_A0_tokens.device)).long()
time = times[time_idx]
z = torch.randn(size=[self.mutil_real_A0_tokens.shape[0], 4 * self.opt.ngf]).to(self.mutil_real_A0_tokens.device)
Xt_1 = self.netG(Xt, time_idx, z)
setattr(self, "fake_"+str(t+1), Xt_1)
if self.opt.phase == 'train':
# 真实图像的梯度
real_gradient = torch.autograd.grad(self.real_B.sum(), self.real_B, create_graph=True)[0]
# 生成图像的梯度
fake_gradient = torch.autograd.grad(self.fake_B.sum(), self.fake_B, create_graph=True)[0]
# 梯度图
self.weight_real, self.weight_fake = self.cao.generate_weight_map(real_gradient, fake_gradient)
# 生成图像的CTN光流图
self.f_content = self.ctn(self.weight_fake)
# 把前面生成后的图片再加上noisy_map
self.fake_B_2 = self.fake_B + self.noisy_map
# 变换后的图片
wapped_fake_B = warp(self.fake_B, self.f_content)
# 经过第二次生成器
self.fake_B_2 = self.netG(wapped_fake_B, self.time, z_in)
def compute_D_loss(self):
"""计算判别器的 GAN 损失"""
fake = self.cat_results(self.fake_B.detach())
pred_fake = self.netD(fake, self.time)
self.loss_D_fake = self.criterionGAN(pred_fake, False).mean()
self.pred_real = self.netD(self.real_B0, self.time)
loss_D_real = self.criterionGAN(self.pred_real, True)
self.loss_D_real = loss_D_real.mean()
self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
return self.loss_D
def compute_E_loss(self):
"""计算判别器 E 的损失"""
XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B.detach()], dim=1)
XtXt_2 = torch.cat([self.real_A_noisy2, self.fake_B2.detach()], dim=1)
temp = torch.logsumexp(self.netE(XtXt_1, self.time, XtXt_2).reshape(-1), dim=0).mean()
self.loss_E = -self.netE(XtXt_1, self.time, XtXt_1).mean() + temp + temp**2
return self.loss_E
def compute_G_loss(self):
"""计算生成器的 GAN 损失"""
bs = self.mutil_real_A0_tokens.size(0)
tau = self.opt.tau
fake = self.fake_B
std = torch.rand(size=[1]).item() * self.opt.std
if self.opt.lambda_GAN > 0.0:
pred_fake = self.netD(fake, self.time)
self.loss_G_GAN = self.criterionGAN(pred_fake, True).mean() * self.opt.lambda_GAN
else:
self.loss_G_GAN = 0.0
self.loss_SB = 0
if self.opt.lambda_SB > 0.0:
XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B], dim=1)
XtXt_2 = torch.cat([self.real_A_noisy2, self.fake_B2], dim=1)
bs = self.opt.batch_size
# eq.9
ET_XY = self.netE(XtXt_1, self.time, XtXt_1).mean() - torch.logsumexp(self.netE(XtXt_1, self.time, XtXt_2).reshape(-1), dim=0)
self.loss_SB = -(self.opt.num_timesteps - self.time[0]) / self.opt.num_timesteps * self.opt.tau * ET_XY
self.loss_SB += self.opt.tau * torch.mean((self.real_A_noisy - self.fake_B) ** 2)
if self.opt.lambda_global > 0.0:
loss_global = self.calculate_similarity(self.mutil_real_A0_tokens, self.mutil_fake_B0_tokens) + self.calculate_similarity(self.mutil_real_A1_tokens, self.mutil_fake_B1_tokens)
loss_global *= 0.5
else:
loss_global = 0.0
if self.opt.lambda_ctn > 0.0:
wapped_fake_B = warp(self.fake_B, self.f_content) # use updated self.f_content
self.l2_loss = F.mse_loss(self.fake_B_2, wapped_fake_B) # complete the loss calculation
self.loss_G = self.loss_G_GAN + self.opt.lambda_SB * self.loss_SB + self.opt.lambda_ctn * self.l2_loss + loss_global * self.opt.lambda_global
return self.loss_G
def calculate_attention_loss(self):
n_layers = len(self.atten_layers)
mutil_real_A0_tokens = self.mutil_real_A0_tokens
mutil_real_A1_tokens = self.mutil_real_A1_tokens
mutil_fake_B0_tokens = self.mutil_fake_B0_tokens
mutil_fake_B1_tokens = self.mutil_fake_B1_tokens
if self.opt.lambda_global > 0.0:
loss_global = self.calculate_similarity(mutil_real_A0_tokens, mutil_fake_B0_tokens) + self.calculate_similarity(mutil_real_A1_tokens, mutil_fake_B1_tokens)
loss_global *= 0.5
else:
loss_global = 0.0
if self.opt.lambda_spatial > 0.0:
loss_spatial = 0.0
local_nums = self.opt.local_nums
tokens_cnt = 576
local_id = np.random.permutation(tokens_cnt)
local_id = local_id[:int(min(local_nums, tokens_cnt))]
mutil_real_A0_local_tokens = self.netPreViT(self.resize(self.real_A0), self.atten_layers, get_tokens=True, local_id=local_id, side_length=self.opt.side_length)
mutil_real_A1_local_tokens = self.netPreViT(self.resize(self.real_A1), self.atten_layers, get_tokens=True, local_id=local_id, side_length=self.opt.side_length)
mutil_fake_B0_local_tokens = self.netPreViT(self.resize(self.fake_B0), self.atten_layers, get_tokens=True, local_id=local_id, side_length=self.opt.side_length)
mutil_fake_B1_local_tokens = self.netPreViT(self.resize(self.fake_B1), self.atten_layers, get_tokens=True, local_id=local_id, side_length=self.opt.side_length)
loss_spatial = self.calculate_similarity(mutil_real_A0_local_tokens, mutil_fake_B0_local_tokens) + self.calculate_similarity(mutil_real_A1_local_tokens, mutil_fake_B1_local_tokens)
loss_spatial *= 0.5
else:
loss_spatial = 0.0
return loss_global * self.opt.lambda_global, loss_spatial * self.opt.lambda_spatial
def calculate_similarity(self, mutil_src_tokens, mutil_tgt_tokens):
loss = 0.0
n_layers = len(self.atten_layers)
for src_tokens, tgt_tokens in zip(mutil_src_tokens, mutil_tgt_tokens):
src_tgt = src_tokens.bmm(tgt_tokens.permute(0,2,1))
tgt_src = tgt_tokens.bmm(src_tokens.permute(0,2,1))
cos_dis_global = F.cosine_similarity(src_tgt, tgt_src, dim=-1)
loss += self.criterionL1(torch.ones_like(cos_dis_global), cos_dis_global).mean()
loss = loss / n_layers
return loss

914
models/stylegan_networks.py Normal file
View File

@ -0,0 +1,914 @@
"""
The network architectures is based on PyTorch implemenation of StyleGAN2Encoder.
Original PyTorch repo: https://github.com/rosinality/style-based-gan-pytorch
Origianl StyelGAN2 paper: https://github.com/NVlabs/stylegan2
We use the network architeture for our single-image traning setting.
"""
import math
import numpy as np
import random
import torch
from torch import nn
from torch.nn import functional as F
def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
return F.leaky_relu(input + bias, negative_slope) * scale
class FusedLeakyReLU(nn.Module):
def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
super().__init__()
self.bias = nn.Parameter(torch.zeros(1, channel, 1, 1))
self.negative_slope = negative_slope
self.scale = scale
def forward(self, input):
# print("FusedLeakyReLU: ", input.abs().mean())
out = fused_leaky_relu(input, self.bias,
self.negative_slope,
self.scale)
# print("FusedLeakyReLU: ", out.abs().mean())
return out
def upfirdn2d_native(
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
):
_, minor, in_h, in_w = input.shape
kernel_h, kernel_w = kernel.shape
out = input.view(-1, minor, in_h, 1, in_w, 1)
out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0])
out = out.view(-1, minor, in_h * up_y, in_w * up_x)
out = F.pad(
out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
)
out = out[
:,
:,
max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0),
max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0),
]
# out = out.permute(0, 3, 1, 2)
out = out.reshape(
[-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
)
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
out = F.conv2d(out, w)
out = out.reshape(
-1,
minor,
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
)
# out = out.permute(0, 2, 3, 1)
return out[:, :, ::down_y, ::down_x]
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
class PixelNorm(nn.Module):
def __init__(self):
super().__init__()
def forward(self, input):
return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
def make_kernel(k):
k = torch.tensor(k, dtype=torch.float32)
if len(k.shape) == 1:
k = k[None, :] * k[:, None]
k /= k.sum()
return k
class Upsample(nn.Module):
def __init__(self, kernel, factor=2):
super().__init__()
self.factor = factor
kernel = make_kernel(kernel) * (factor ** 2)
self.register_buffer('kernel', kernel)
p = kernel.shape[0] - factor
pad0 = (p + 1) // 2 + factor - 1
pad1 = p // 2
self.pad = (pad0, pad1)
def forward(self, input):
out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
return out
class Downsample(nn.Module):
def __init__(self, kernel, factor=2):
super().__init__()
self.factor = factor
kernel = make_kernel(kernel)
self.register_buffer('kernel', kernel)
p = kernel.shape[0] - factor
pad0 = (p + 1) // 2
pad1 = p // 2
self.pad = (pad0, pad1)
def forward(self, input):
out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
return out
class Blur(nn.Module):
def __init__(self, kernel, pad, upsample_factor=1):
super().__init__()
kernel = make_kernel(kernel)
if upsample_factor > 1:
kernel = kernel * (upsample_factor ** 2)
self.register_buffer('kernel', kernel)
self.pad = pad
def forward(self, input):
out = upfirdn2d(input, self.kernel, pad=self.pad)
return out
class EqualConv2d(nn.Module):
def __init__(
self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
):
super().__init__()
self.weight = nn.Parameter(
torch.randn(out_channel, in_channel, kernel_size, kernel_size)
)
self.scale = math.sqrt(1) / math.sqrt(in_channel * (kernel_size ** 2))
self.stride = stride
self.padding = padding
if bias:
self.bias = nn.Parameter(torch.zeros(out_channel))
else:
self.bias = None
def forward(self, input):
# print("Before EqualConv2d: ", input.abs().mean())
out = F.conv2d(
input,
self.weight * self.scale,
bias=self.bias,
stride=self.stride,
padding=self.padding,
)
# print("After EqualConv2d: ", out.abs().mean(), (self.weight * self.scale).abs().mean())
return out
def __repr__(self):
return (
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
)
class EqualLinear(nn.Module):
def __init__(
self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
):
super().__init__()
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
if bias:
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
else:
self.bias = None
self.activation = activation
self.scale = (math.sqrt(1) / math.sqrt(in_dim)) * lr_mul
self.lr_mul = lr_mul
def forward(self, input):
if self.activation:
out = F.linear(input, self.weight * self.scale)
out = fused_leaky_relu(out, self.bias * self.lr_mul)
else:
out = F.linear(
input, self.weight * self.scale, bias=self.bias * self.lr_mul
)
return out
def __repr__(self):
return (
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
)
class ScaledLeakyReLU(nn.Module):
def __init__(self, negative_slope=0.2):
super().__init__()
self.negative_slope = negative_slope
def forward(self, input):
out = F.leaky_relu(input, negative_slope=self.negative_slope)
return out * math.sqrt(2)
class ModulatedConv2d(nn.Module):
def __init__(
self,
in_channel,
out_channel,
kernel_size,
style_dim,
demodulate=True,
upsample=False,
downsample=False,
blur_kernel=[1, 3, 3, 1],
):
super().__init__()
self.eps = 1e-8
self.kernel_size = kernel_size
self.in_channel = in_channel
self.out_channel = out_channel
self.upsample = upsample
self.downsample = downsample
if upsample:
factor = 2
p = (len(blur_kernel) - factor) - (kernel_size - 1)
pad0 = (p + 1) // 2 + factor - 1
pad1 = p // 2 + 1
self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
if downsample:
factor = 2
p = (len(blur_kernel) - factor) + (kernel_size - 1)
pad0 = (p + 1) // 2
pad1 = p // 2
self.blur = Blur(blur_kernel, pad=(pad0, pad1))
fan_in = in_channel * kernel_size ** 2
self.scale = math.sqrt(1) / math.sqrt(fan_in)
self.padding = kernel_size // 2
self.weight = nn.Parameter(
torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
)
if style_dim is not None and style_dim > 0:
self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
self.demodulate = demodulate
def __repr__(self):
return (
f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '
f'upsample={self.upsample}, downsample={self.downsample})'
)
def forward(self, input, style):
batch, in_channel, height, width = input.shape
if style is not None:
style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
else:
style = torch.ones(batch, 1, in_channel, 1, 1).cuda()
weight = self.scale * self.weight * style
if self.demodulate:
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
weight = weight.view(
batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
)
if self.upsample:
input = input.view(1, batch * in_channel, height, width)
weight = weight.view(
batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
)
weight = weight.transpose(1, 2).reshape(
batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
)
out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
_, _, height, width = out.shape
out = out.view(batch, self.out_channel, height, width)
out = self.blur(out)
elif self.downsample:
input = self.blur(input)
_, _, height, width = input.shape
input = input.view(1, batch * in_channel, height, width)
out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
_, _, height, width = out.shape
out = out.view(batch, self.out_channel, height, width)
else:
input = input.view(1, batch * in_channel, height, width)
out = F.conv2d(input, weight, padding=self.padding, groups=batch)
_, _, height, width = out.shape
out = out.view(batch, self.out_channel, height, width)
return out
class NoiseInjection(nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.zeros(1))
def forward(self, image, noise=None):
if noise is None:
batch, _, height, width = image.shape
noise = image.new_empty(batch, 1, height, width).normal_()
return image + self.weight * noise
class ConstantInput(nn.Module):
def __init__(self, channel, size=4):
super().__init__()
self.input = nn.Parameter(torch.randn(1, channel, size, size))
def forward(self, input):
batch = input.shape[0]
out = self.input.repeat(batch, 1, 1, 1)
return out
class StyledConv(nn.Module):
def __init__(
self,
in_channel,
out_channel,
kernel_size,
style_dim=None,
upsample=False,
blur_kernel=[1, 3, 3, 1],
demodulate=True,
inject_noise=True,
):
super().__init__()
self.inject_noise = inject_noise
self.conv = ModulatedConv2d(
in_channel,
out_channel,
kernel_size,
style_dim,
upsample=upsample,
blur_kernel=blur_kernel,
demodulate=demodulate,
)
self.noise = NoiseInjection()
# self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
# self.activate = ScaledLeakyReLU(0.2)
self.activate = FusedLeakyReLU(out_channel)
def forward(self, input, style=None, noise=None):
out = self.conv(input, style)
if self.inject_noise:
out = self.noise(out, noise=noise)
# out = out + self.bias
out = self.activate(out)
return out
class ToRGB(nn.Module):
def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
super().__init__()
if upsample:
self.upsample = Upsample(blur_kernel)
self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
def forward(self, input, style, skip=None):
out = self.conv(input, style)
out = out + self.bias
if skip is not None:
skip = self.upsample(skip)
out = out + skip
return out
class Generator(nn.Module):
def __init__(
self,
size,
style_dim,
n_mlp,
channel_multiplier=2,
blur_kernel=[1, 3, 3, 1],
lr_mlp=0.01,
):
super().__init__()
self.size = size
self.style_dim = style_dim
layers = [PixelNorm()]
for i in range(n_mlp):
layers.append(
EqualLinear(
style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu'
)
)
self.style = nn.Sequential(*layers)
self.channels = {
4: 512,
8: 512,
16: 512,
32: 512,
64: 256 * channel_multiplier,
128: 128 * channel_multiplier,
256: 64 * channel_multiplier,
512: 32 * channel_multiplier,
1024: 16 * channel_multiplier,
}
self.input = ConstantInput(self.channels[4])
self.conv1 = StyledConv(
self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
)
self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
self.log_size = int(math.log(size, 2))
self.num_layers = (self.log_size - 2) * 2 + 1
self.convs = nn.ModuleList()
self.upsamples = nn.ModuleList()
self.to_rgbs = nn.ModuleList()
self.noises = nn.Module()
in_channel = self.channels[4]
for layer_idx in range(self.num_layers):
res = (layer_idx + 5) // 2
shape = [1, 1, 2 ** res, 2 ** res]
self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape))
for i in range(3, self.log_size + 1):
out_channel = self.channels[2 ** i]
self.convs.append(
StyledConv(
in_channel,
out_channel,
3,
style_dim,
upsample=True,
blur_kernel=blur_kernel,
)
)
self.convs.append(
StyledConv(
out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
)
)
self.to_rgbs.append(ToRGB(out_channel, style_dim))
in_channel = out_channel
self.n_latent = self.log_size * 2 - 2
def make_noise(self):
device = self.input.input.device
noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
for i in range(3, self.log_size + 1):
for _ in range(2):
noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
return noises
def mean_latent(self, n_latent):
latent_in = torch.randn(
n_latent, self.style_dim, device=self.input.input.device
)
latent = self.style(latent_in).mean(0, keepdim=True)
return latent
def get_latent(self, input):
return self.style(input)
def forward(
self,
styles,
return_latents=False,
inject_index=None,
truncation=1,
truncation_latent=None,
input_is_latent=False,
noise=None,
randomize_noise=True,
):
if not input_is_latent:
styles = [self.style(s) for s in styles]
if noise is None:
if randomize_noise:
noise = [None] * self.num_layers
else:
noise = [
getattr(self.noises, f'noise_{i}') for i in range(self.num_layers)
]
if truncation < 1:
style_t = []
for style in styles:
style_t.append(
truncation_latent + truncation * (style - truncation_latent)
)
styles = style_t
if len(styles) < 2:
inject_index = self.n_latent
if len(styles[0].shape) < 3:
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
else:
latent = styles[0]
else:
if inject_index is None:
inject_index = random.randint(1, self.n_latent - 1)
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
latent = torch.cat([latent, latent2], 1)
out = self.input(latent)
out = self.conv1(out, latent[:, 0], noise=noise[0])
skip = self.to_rgb1(out, latent[:, 1])
i = 1
for conv1, conv2, noise1, noise2, to_rgb in zip(
self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
):
out = conv1(out, latent[:, i], noise=noise1)
out = conv2(out, latent[:, i + 1], noise=noise2)
skip = to_rgb(out, latent[:, i + 2], skip)
i += 2
image = skip
if return_latents:
return image, latent
else:
return image, None
class ConvLayer(nn.Sequential):
def __init__(
self,
in_channel,
out_channel,
kernel_size,
downsample=False,
blur_kernel=[1, 3, 3, 1],
bias=True,
activate=True,
):
layers = []
if downsample:
factor = 2
p = (len(blur_kernel) - factor) + (kernel_size - 1)
pad0 = (p + 1) // 2
pad1 = p // 2
layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
stride = 2
self.padding = 0
else:
stride = 1
self.padding = kernel_size // 2
layers.append(
EqualConv2d(
in_channel,
out_channel,
kernel_size,
padding=self.padding,
stride=stride,
bias=bias and not activate,
)
)
if activate:
if bias:
layers.append(FusedLeakyReLU(out_channel))
else:
layers.append(ScaledLeakyReLU(0.2))
super().__init__(*layers)
class ResBlock(nn.Module):
def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1], downsample=True, skip_gain=1.0):
super().__init__()
self.skip_gain = skip_gain
self.conv1 = ConvLayer(in_channel, in_channel, 3)
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=downsample, blur_kernel=blur_kernel)
if in_channel != out_channel or downsample:
self.skip = ConvLayer(
in_channel, out_channel, 1, downsample=downsample, activate=False, bias=False
)
else:
self.skip = nn.Identity()
def forward(self, input):
out = self.conv1(input)
out = self.conv2(out)
skip = self.skip(input)
out = (out * self.skip_gain + skip) / math.sqrt(self.skip_gain ** 2 + 1.0)
return out
class StyleGAN2Discriminator(nn.Module):
def __init__(self, input_nc, ndf=64, n_layers=3, no_antialias=False, size=None, opt=None):
super().__init__()
self.opt = opt
self.stddev_group = 16
if size is None:
size = 2 ** int((np.rint(np.log2(min(opt.load_size, opt.crop_size)))))
if "patch" in self.opt.netD and self.opt.D_patch_size is not None:
size = 2 ** int(np.log2(self.opt.D_patch_size))
blur_kernel = [1, 3, 3, 1]
channel_multiplier = ndf / 64
channels = {
4: min(384, int(4096 * channel_multiplier)),
8: min(384, int(2048 * channel_multiplier)),
16: min(384, int(1024 * channel_multiplier)),
32: min(384, int(512 * channel_multiplier)),
64: int(256 * channel_multiplier),
128: int(128 * channel_multiplier),
256: int(64 * channel_multiplier),
512: int(32 * channel_multiplier),
1024: int(16 * channel_multiplier),
}
convs = [ConvLayer(3, channels[size], 1)]
log_size = int(math.log(size, 2))
in_channel = channels[size]
if "smallpatch" in self.opt.netD:
final_res_log2 = 4
elif "patch" in self.opt.netD:
final_res_log2 = 3
else:
final_res_log2 = 2
for i in range(log_size, final_res_log2, -1):
out_channel = channels[2 ** (i - 1)]
convs.append(ResBlock(in_channel, out_channel, blur_kernel))
in_channel = out_channel
self.convs = nn.Sequential(*convs)
if False and "tile" in self.opt.netD:
in_channel += 1
self.final_conv = ConvLayer(in_channel, channels[4], 3)
if "patch" in self.opt.netD:
self.final_linear = ConvLayer(channels[4], 1, 3, bias=False, activate=False)
else:
self.final_linear = nn.Sequential(
EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'),
EqualLinear(channels[4], 1),
)
def forward(self, input, get_minibatch_features=False):
if "patch" in self.opt.netD and self.opt.D_patch_size is not None:
h, w = input.size(2), input.size(3)
y = torch.randint(h - self.opt.D_patch_size, ())
x = torch.randint(w - self.opt.D_patch_size, ())
input = input[:, :, y:y + self.opt.D_patch_size, x:x + self.opt.D_patch_size]
out = input
for i, conv in enumerate(self.convs):
out = conv(out)
# print(i, out.abs().mean())
# out = self.convs(input)
batch, channel, height, width = out.shape
if False and "tile" in self.opt.netD:
group = min(batch, self.stddev_group)
stddev = out.view(
group, -1, 1, channel // 1, height, width
)
stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
stddev = stddev.mean([2, 3, 4], keepdim=True).squeeze(2)
stddev = stddev.repeat(group, 1, height, width)
out = torch.cat([out, stddev], 1)
out = self.final_conv(out)
# print(out.abs().mean())
if "patch" not in self.opt.netD:
out = out.view(batch, -1)
out = self.final_linear(out)
return out
class TileStyleGAN2Discriminator(StyleGAN2Discriminator):
def forward(self, input):
B, C, H, W = input.size(0), input.size(1), input.size(2), input.size(3)
size = self.opt.D_patch_size
Y = H // size
X = W // size
input = input.view(B, C, Y, size, X, size)
input = input.permute(0, 2, 4, 1, 3, 5).contiguous().view(B * Y * X, C, size, size)
return super().forward(input)
class StyleGAN2Encoder(nn.Module):
def __init__(self, input_nc, output_nc, ngf=64, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False, opt=None):
super().__init__()
assert opt is not None
self.opt = opt
channel_multiplier = ngf / 32
channels = {
4: min(512, int(round(4096 * channel_multiplier))),
8: min(512, int(round(2048 * channel_multiplier))),
16: min(512, int(round(1024 * channel_multiplier))),
32: min(512, int(round(512 * channel_multiplier))),
64: int(round(256 * channel_multiplier)),
128: int(round(128 * channel_multiplier)),
256: int(round(64 * channel_multiplier)),
512: int(round(32 * channel_multiplier)),
1024: int(round(16 * channel_multiplier)),
}
blur_kernel = [1, 3, 3, 1]
cur_res = 2 ** int((np.rint(np.log2(min(opt.load_size, opt.crop_size)))))
convs = [nn.Identity(),
ConvLayer(3, channels[cur_res], 1)]
num_downsampling = self.opt.stylegan2_G_num_downsampling
for i in range(num_downsampling):
in_channel = channels[cur_res]
out_channel = channels[cur_res // 2]
convs.append(ResBlock(in_channel, out_channel, blur_kernel, downsample=True))
cur_res = cur_res // 2
for i in range(n_blocks // 2):
n_channel = channels[cur_res]
convs.append(ResBlock(n_channel, n_channel, downsample=False))
self.convs = nn.Sequential(*convs)
def forward(self, input, layers=[], get_features=False):
feat = input
feats = []
if -1 in layers:
layers.append(len(self.convs) - 1)
for layer_id, layer in enumerate(self.convs):
feat = layer(feat)
# print(layer_id, " features ", feat.abs().mean())
if layer_id in layers:
feats.append(feat)
if get_features:
return feat, feats
else:
return feat
class StyleGAN2Decoder(nn.Module):
def __init__(self, input_nc, output_nc, ngf=64, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False, opt=None):
super().__init__()
assert opt is not None
self.opt = opt
blur_kernel = [1, 3, 3, 1]
channel_multiplier = ngf / 32
channels = {
4: min(512, int(round(4096 * channel_multiplier))),
8: min(512, int(round(2048 * channel_multiplier))),
16: min(512, int(round(1024 * channel_multiplier))),
32: min(512, int(round(512 * channel_multiplier))),
64: int(round(256 * channel_multiplier)),
128: int(round(128 * channel_multiplier)),
256: int(round(64 * channel_multiplier)),
512: int(round(32 * channel_multiplier)),
1024: int(round(16 * channel_multiplier)),
}
num_downsampling = self.opt.stylegan2_G_num_downsampling
cur_res = 2 ** int((np.rint(np.log2(min(opt.load_size, opt.crop_size))))) // (2 ** num_downsampling)
convs = []
for i in range(n_blocks // 2):
n_channel = channels[cur_res]
convs.append(ResBlock(n_channel, n_channel, downsample=False))
for i in range(num_downsampling):
in_channel = channels[cur_res]
out_channel = channels[cur_res * 2]
inject_noise = "small" not in self.opt.netG
convs.append(
StyledConv(in_channel, out_channel, 3, upsample=True, blur_kernel=blur_kernel, inject_noise=inject_noise)
)
cur_res = cur_res * 2
convs.append(ConvLayer(channels[cur_res], 3, 1))
self.convs = nn.Sequential(*convs)
def forward(self, input):
return self.convs(input)
class StyleGAN2Generator(nn.Module):
def __init__(self, input_nc, output_nc, ngf=64, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False, opt=None):
super().__init__()
self.opt = opt
self.encoder = StyleGAN2Encoder(input_nc, output_nc, ngf, use_dropout, n_blocks, padding_type, no_antialias, opt)
self.decoder = StyleGAN2Decoder(input_nc, output_nc, ngf, use_dropout, n_blocks, padding_type, no_antialias, opt)
def forward(self, input, layers=[], encode_only=False):
feat, feats = self.encoder(input, layers, True)
if encode_only:
return feats
else:
fake = self.decoder(feat)
if len(layers) > 0:
return fake, feats
else:
return fake

99
models/template_model.py Normal file
View File

@ -0,0 +1,99 @@
"""Model class template
This module provides a template for users to implement custom models.
You can specify '--model template' to use this model.
The class name should be consistent with both the filename and its model option.
The filename should be <model>_dataset.py
The class name should be <Model>Dataset.py
It implements a simple image-to-image translation baseline based on regression loss.
Given input-output pairs (data_A, data_B), it learns a network netG that can minimize the following L1 loss:
min_<netG> ||netG(data_A) - data_B||_1
You need to implement the following functions:
<modify_commandline_options>: Add model-specific options and rewrite default values for existing options.
<__init__>: Initialize this model class.
<set_input>: Unpack input data and perform data pre-processing.
<forward>: Run forward pass. This will be called by both <optimize_parameters> and <test>.
<optimize_parameters>: Update network weights; it will be called in every training iteration.
"""
import torch
from .base_model import BaseModel
from . import networks
class TemplateModel(BaseModel):
@staticmethod
def modify_commandline_options(parser, is_train=True):
"""Add new model-specific options and rewrite default values for existing options.
Parameters:
parser -- the option parser
is_train -- if it is training phase or test phase. You can use this flag to add training-specific or test-specific options.
Returns:
the modified parser.
"""
parser.set_defaults(dataset_mode='aligned') # You can rewrite default values for this model. For example, this model usually uses aligned dataset as its dataset.
if is_train:
parser.add_argument('--lambda_regression', type=float, default=1.0, help='weight for the regression loss') # You can define new arguments for this model.
return parser
def __init__(self, opt):
"""Initialize this model class.
Parameters:
opt -- training/test options
A few things can be done here.
- (required) call the initialization function of BaseModel
- define loss function, visualization images, model names, and optimizers
"""
BaseModel.__init__(self, opt) # call the initialization method of BaseModel
# specify the training losses you want to print out. The program will call base_model.get_current_losses to plot the losses to the console and save them to the disk.
self.loss_names = ['loss_G']
# specify the images you want to save and display. The program will call base_model.get_current_visuals to save and display these images.
self.visual_names = ['data_A', 'data_B', 'output']
# specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks to save and load networks.
# you can use opt.isTrain to specify different behaviors for training and test. For example, some networks will not be used during test, and you don't need to load them.
self.model_names = ['G']
# define networks; you can use opt.isTrain to specify different behaviors for training and test.
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, gpu_ids=self.gpu_ids)
if self.isTrain: # only defined during training time
# define your loss functions. You can use losses provided by torch.nn such as torch.nn.L1Loss.
# We also provide a GANLoss class "networks.GANLoss". self.criterionGAN = networks.GANLoss().to(self.device)
self.criterionLoss = torch.nn.L1Loss()
# define and initialize optimizers. You can define one optimizer for each network.
# If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
self.optimizer = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizers = [self.optimizer]
# Our program will automatically call <model.setup> to define schedulers, load networks, and print networks
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
input: a dictionary that contains the data itself and its metadata information.
"""
AtoB = self.opt.direction == 'AtoB' # use <direction> to swap data_A and data_B
self.data_A = input['A' if AtoB else 'B'].to(self.device) # get image data A
self.data_B = input['B' if AtoB else 'A'].to(self.device) # get image data B
self.image_paths = input['A_paths' if AtoB else 'B_paths'] # get image paths
def forward(self):
"""Run forward pass. This will be called by both functions <optimize_parameters> and <test>."""
self.output = self.netG(self.data_A) # generate output image given the input data_A
def backward(self):
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
# caculate the intermediate results if necessary; here self.output has been computed during function <forward>
# calculate loss given the input and intermediate results
self.loss_G = self.criterionLoss(self.output, self.data_B) * self.opt.lambda_regression
self.loss_G.backward() # calculate gradients of network G w.r.t. loss_G
def optimize_parameters(self):
"""Update network weights; it will be called in every training iteration."""
self.forward() # first call forward to calculate intermediate results
self.optimizer.zero_grad() # clear network G's existing gradients
self.backward() # calculate gradients for network G
self.optimizer.step() # update gradients for network G

Binary file not shown.

42
models/util/crop.py Normal file
View File

@ -0,0 +1,42 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
from torchvision import transforms
from torchvision.transforms import functional as F
class RandomResizedCrop(transforms.RandomResizedCrop):
"""
RandomResizedCrop for matching TF/TPU implementation: no for-loop is used.
This may lead to results different with torchvision's version.
Following BYOL's TF code:
https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206
"""
@staticmethod
def get_params(img, scale, ratio):
width, height = F._get_image_size(img)
area = height * width
target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
log_ratio = torch.log(torch.tensor(ratio))
aspect_ratio = torch.exp(
torch.empty(1).uniform_(log_ratio[0], log_ratio[1])
).item()
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
w = min(w, width)
h = min(h, height)
i = torch.randint(0, height - h + 1, size=(1,)).item()
j = torch.randint(0, width - w + 1, size=(1,)).item()
return i, j, h, w

65
models/util/datasets.py Normal file
View File

@ -0,0 +1,65 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# DeiT: https://github.com/facebookresearch/deit
# --------------------------------------------------------
import os
import PIL
from torchvision import datasets, transforms
from timm.data import create_transform
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
def build_dataset(is_train, args):
transform = build_transform(is_train, args)
root = os.path.join(args.data_path, 'train' if is_train else 'val')
dataset = datasets.ImageFolder(root, transform=transform)
print(dataset)
return dataset
def build_transform(is_train, args):
mean = IMAGENET_DEFAULT_MEAN
std = IMAGENET_DEFAULT_STD
# train transform
if is_train:
# this should always dispatch to transforms_imagenet_train
transform = create_transform(
input_size=args.input_size,
is_training=True,
color_jitter=args.color_jitter,
auto_augment=args.aa,
interpolation='bicubic',
re_prob=args.reprob,
re_mode=args.remode,
re_count=args.recount,
mean=mean,
std=std,
)
return transform
# eval transform
t = []
if args.input_size <= 224:
crop_pct = 224 / 256
else:
crop_pct = 1.0
size = int(args.input_size / crop_pct)
t.append(
transforms.Resize(size, interpolation=PIL.Image.BICUBIC), # to maintain same ratio w.r.t. 224 images
)
t.append(transforms.CenterCrop(args.input_size))
t.append(transforms.ToTensor())
t.append(transforms.Normalize(mean, std))
return transforms.Compose(t)

47
models/util/lars.py Normal file
View File

@ -0,0 +1,47 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# LARS optimizer, implementation from MoCo v3:
# https://github.com/facebookresearch/moco-v3
# --------------------------------------------------------
import torch
class LARS(torch.optim.Optimizer):
"""
LARS optimizer, no rate scaling or weight decay for parameters <= 1D.
"""
def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001):
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient)
super().__init__(params, defaults)
@torch.no_grad()
def step(self):
for g in self.param_groups:
for p in g['params']:
dp = p.grad
if dp is None:
continue
if p.ndim > 1: # if not normalization gamma/beta or bias
dp = dp.add(p, alpha=g['weight_decay'])
param_norm = torch.norm(p)
update_norm = torch.norm(dp)
one = torch.ones_like(param_norm)
q = torch.where(param_norm > 0.,
torch.where(update_norm > 0,
(g['trust_coefficient'] * param_norm / update_norm), one),
one)
dp = dp.mul(q)
param_state = self.state[p]
if 'mu' not in param_state:
param_state['mu'] = torch.zeros_like(p)
mu = param_state['mu']
mu.mul_(g['momentum']).add_(dp)
p.add_(mu, alpha=-g['lr'])

76
models/util/lr_decay.py Normal file
View File

@ -0,0 +1,76 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# ELECTRA https://github.com/google-research/electra
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
# --------------------------------------------------------
import json
def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75):
"""
Parameter groups for layer-wise lr decay
Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
"""
param_group_names = {}
param_groups = {}
num_layers = len(model.blocks) + 1
layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1))
for n, p in model.named_parameters():
if not p.requires_grad:
continue
# no decay: all 1D parameters and model specific ones
if p.ndim == 1 or n in no_weight_decay_list:
g_decay = "no_decay"
this_decay = 0.
else:
g_decay = "decay"
this_decay = weight_decay
layer_id = get_layer_id_for_vit(n, num_layers)
group_name = "layer_%d_%s" % (layer_id, g_decay)
if group_name not in param_group_names:
this_scale = layer_scales[layer_id]
param_group_names[group_name] = {
"lr_scale": this_scale,
"weight_decay": this_decay,
"params": [],
}
param_groups[group_name] = {
"lr_scale": this_scale,
"weight_decay": this_decay,
"params": [],
}
param_group_names[group_name]["params"].append(n)
param_groups[group_name]["params"].append(p)
# print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
return list(param_groups.values())
def get_layer_id_for_vit(name, num_layers):
"""
Assign a parameter with its layer id
Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
"""
if name in ['cls_token', 'pos_embed']:
return 0
elif name.startswith('patch_embed'):
return 0
elif name.startswith('blocks'):
return int(name.split('.')[1]) + 1
else:
return num_layers

21
models/util/lr_sched.py Normal file
View File

@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
def adjust_learning_rate(optimizer, epoch, args):
"""Decay the learning rate with half-cycle cosine after warmup"""
if epoch < args.warmup_epochs:
lr = args.lr * epoch / args.warmup_epochs
else:
lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \
(1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
for param_group in optimizer.param_groups:
if "lr_scale" in param_group:
param_group["lr"] = lr * param_group["lr_scale"]
else:
param_group["lr"] = lr
return lr

340
models/util/misc.py Normal file
View File

@ -0,0 +1,340 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# DeiT: https://github.com/facebookresearch/deit
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
# --------------------------------------------------------
import builtins
import datetime
import os
import time
from collections import defaultdict, deque
from pathlib import Path
import torch
import torch.distributed as dist
from torch._six import inf
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20, fmt=None):
if fmt is None:
fmt = "{median:.4f} ({global_avg:.4f})"
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
if not is_dist_avail_and_initialized():
return
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
dist.barrier()
dist.all_reduce(t)
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
median=self.median,
avg=self.avg,
global_avg=self.global_avg,
max=self.max,
value=self.value)
class MetricLogger(object):
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if v is None:
continue
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, attr))
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append(
"{}: {}".format(name, str(meter))
)
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def add_meter(self, name, meter):
self.meters[name] = meter
def log_every(self, iterable, print_freq, header=None):
i = 0
if not header:
header = ''
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt='{avg:.4f}')
data_time = SmoothedValue(fmt='{avg:.4f}')
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
log_msg = [
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}'
]
if torch.cuda.is_available():
log_msg.append('max mem: {memory:.0f}')
log_msg = self.delimiter.join(log_msg)
MB = 1024.0 * 1024.0
for obj in iterable:
data_time.update(time.time() - end)
yield obj
iter_time.update(time.time() - end)
if i % print_freq == 0 or i == len(iterable) - 1:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB))
else:
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time)))
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('{} Total time: {} ({:.4f} s / it)'.format(
header, total_time_str, total_time / len(iterable)))
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
builtin_print = builtins.print
def print(*args, **kwargs):
force = kwargs.pop('force', False)
force = force or (get_world_size() > 8)
if is_master or force:
now = datetime.datetime.now().time()
builtin_print('[{}] '.format(now), end='') # print with time stamp
builtin_print(*args, **kwargs)
builtins.print = print
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def save_on_master(*args, **kwargs):
if is_main_process():
torch.save(*args, **kwargs)
def init_distributed_mode(args):
if args.dist_on_itp:
args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
os.environ['LOCAL_RANK'] = str(args.gpu)
os.environ['RANK'] = str(args.rank)
os.environ['WORLD_SIZE'] = str(args.world_size)
# ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK'])
elif 'SLURM_PROCID' in os.environ:
args.rank = int(os.environ['SLURM_PROCID'])
args.gpu = args.rank % torch.cuda.device_count()
else:
print('Not using distributed mode')
setup_for_distributed(is_master=True) # hack
args.distributed = False
return
args.distributed = True
torch.cuda.set_device(args.gpu)
args.dist_backend = 'nccl'
print('| distributed init (rank {}): {}, gpu {}'.format(
args.rank, args.dist_url, args.gpu), flush=True)
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
torch.distributed.barrier()
setup_for_distributed(args.rank == 0)
class NativeScalerWithGradNormCount:
state_dict_key = "amp_scaler"
def __init__(self):
self._scaler = torch.cuda.amp.GradScaler()
def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
self._scaler.scale(loss).backward(create_graph=create_graph)
if update_grad:
if clip_grad is not None:
assert parameters is not None
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
else:
self._scaler.unscale_(optimizer)
norm = get_grad_norm_(parameters)
self._scaler.step(optimizer)
self._scaler.update()
else:
norm = None
return norm
def state_dict(self):
return self._scaler.state_dict()
def load_state_dict(self, state_dict):
self._scaler.load_state_dict(state_dict)
def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = [p for p in parameters if p.grad is not None]
norm_type = float(norm_type)
if len(parameters) == 0:
return torch.tensor(0.)
device = parameters[0].grad.device
if norm_type == inf:
total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
else:
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
return total_norm
def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler):
output_dir = Path(args.output_dir)
epoch_name = str(epoch)
if loss_scaler is not None:
checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]
for checkpoint_path in checkpoint_paths:
to_save = {
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'epoch': epoch,
'scaler': loss_scaler.state_dict(),
'args': args,
}
save_on_master(to_save, checkpoint_path)
else:
client_state = {'epoch': epoch}
model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state)
def load_model(args, model_without_ddp, optimizer, loss_scaler):
if args.resume:
if args.resume.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
args.resume, map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(args.resume, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'])
print("Resume checkpoint %s" % args.resume)
if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval):
optimizer.load_state_dict(checkpoint['optimizer'])
args.start_epoch = checkpoint['epoch'] + 1
if 'scaler' in checkpoint:
loss_scaler.load_state_dict(checkpoint['scaler'])
print("With optim & sched!")
def all_reduce_mean(x):
world_size = get_world_size()
if world_size > 1:
x_reduce = torch.tensor(x).cuda()
dist.all_reduce(x_reduce)
x_reduce /= world_size
return x_reduce.item()
else:
return x

96
models/util/pos_embed.py Normal file
View File

@ -0,0 +1,96 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# Position embedding utils
# --------------------------------------------------------
import numpy as np
import torch
# --------------------------------------------------------
# 2D sine-cosine position embedding
# References:
# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
# MoCo v3: https://github.com/facebookresearch/moco-v3
# --------------------------------------------------------
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token:
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float)
omega /= embed_dim / 2.
omega = 1. / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
# --------------------------------------------------------
# Interpolate position embeddings for high-resolution
# References:
# DeiT: https://github.com/facebookresearch/deit
# --------------------------------------------------------
def interpolate_pos_embed(model, checkpoint_model):
if 'pos_embed' in checkpoint_model:
pos_embed_checkpoint = checkpoint_model['pos_embed']
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = model.patch_embed.num_patches
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
# height (== width) for the checkpoint position embedding
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
# height (== width) for the new position embedding
new_size = int(num_patches ** 0.5)
# class_token and dist_token are kept unchanged
if orig_size != new_size:
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
checkpoint_model['pos_embed'] = new_pos_embed

1
options/__init__.py Normal file
View File

@ -0,0 +1 @@
"""This package options includes option modules: training options, test options, and basic options (used in both training and test)."""

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

167
options/base_options.py Normal file
View File

@ -0,0 +1,167 @@
import argparse
import os
from util import util
import torch
import models
import data
class BaseOptions():
"""This class defines options used during both training and test time.
It also implements several helper functions such as parsing, printing, and saving the options.
It also gathers additional options defined in <modify_commandline_options> functions in both dataset class and model class.
"""
def __init__(self, cmd_line=None):
"""Reset the class; indicates the class hasn't been initailized"""
self.initialized = False
self.cmd_line = None
if cmd_line is not None:
self.cmd_line = cmd_line.split()
def initialize(self, parser):
"""Define the common options that are used in both training and test."""
# basic parameters
parser.add_argument('--dataroot', default='placeholder', help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')
parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')
parser.add_argument('--easy_label', type=str, default='experiment_name', help='Interpretable name')
parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
parser.add_argument('--use_idt', action='store_true', help='use_idt')
parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
# model parameters
parser.add_argument('--model', type=str, default='cut', help='chooses which model to use.')
parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')
parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')
parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')
parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')
parser.add_argument('--netD', type=str, default='basic', choices=['basic', 'n_layers', 'pixel', 'patch', 'tilestylegan2', 'stylegan2'], help='specify discriminator architecture. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')
parser.add_argument('--netG', type=str, default='resnet_9blocks', choices=['resnet_9blocks','resnet_9blocks_mask', 'resnet_6blocks', 'unet_256', 'unet_128', 'stylegan2', 'smallstylegan2', 'resnet_cat'], help='specify generator architecture')
parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')
parser.add_argument('--normG', type=str, default='instance', choices=['instance', 'batch', 'none'], help='instance normalization or batch normalization for G')
parser.add_argument('--normD', type=str, default='instance', choices=['instance', 'batch', 'none'], help='instance normalization or batch normalization for D')
parser.add_argument('--init_type', type=str, default='xavier', choices=['normal', 'xavier', 'kaiming', 'orthogonal'], help='network initialization')
parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')
parser.add_argument('--no_dropout', type=util.str2bool, nargs='?', const=True, default=True,
help='no dropout for the generator')
parser.add_argument('--no_antialias', action='store_true', help='if specified, use stride=2 convs instead of antialiased-downsampling (sad)')
parser.add_argument('--no_antialias_up', action='store_true', help='if specified, use [upconv(learned filter)] instead of [upconv(hard-coded [1,3,3,1] filter), conv]')
# dataset parameters
parser.add_argument('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')
parser.add_argument('--direction', type=str, default='AtoB', help='AtoB or BtoA')
parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data')
parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
parser.add_argument('--load_size', type=int, default=286, help='scale images to this size')
parser.add_argument('--crop_size', type=int, default=256, help='then crop to this size')
parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
parser.add_argument('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')
parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')
parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')
parser.add_argument('--random_scale_max', type=float, default=3.0,
help='(used for single image translation) Randomly scale the image by the specified factor as data augmentation.')
# additional parameters
parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')
# parameters related to StyleGAN2-based networks
parser.add_argument('--stylegan2_G_num_downsampling',
default=1, type=int,
help='Number of downsampling layers used by StyleGAN2Generator')
self.initialized = True
return parser
def gather_options(self):
"""Initialize our parser with basic options(only once).
Add additional model-specific and dataset-specific options.
These options are defined in the <modify_commandline_options> function
in model and dataset classes.
"""
if not self.initialized: # check if it has been initialized
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser = self.initialize(parser)
# get the basic options
if self.cmd_line is None:
opt, _ = parser.parse_known_args()
else:
opt, _ = parser.parse_known_args(self.cmd_line)
# modify model-related parser options
model_name = opt.model
model_option_setter = models.get_option_setter(model_name)
parser = model_option_setter(parser, self.isTrain)
if self.cmd_line is None:
print(parser)
opt, _ = parser.parse_known_args() # parse again with new defaults
else:
opt, _ = parser.parse_known_args(self.cmd_line) # parse again with new defaults
# modify dataset-related parser options
dataset_name = opt.dataset_mode
dataset_option_setter = data.get_option_setter(dataset_name)
parser = dataset_option_setter(parser, self.isTrain)
# save and return the parser
self.parser = parser
if self.cmd_line is None:
return parser.parse_args()
else:
return parser.parse_args(self.cmd_line)
def print_options(self, opt):
"""Print and save options
It will print both current options and default values(if different).
It will save options into a text file / [checkpoints_dir] / opt.txt
"""
message = ''
message += '----------------- Options ---------------\n'
for k, v in sorted(vars(opt).items()):
comment = ''
default = self.parser.get_default(k)
if v != default:
comment = '\t[default: %s]' % str(default)
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
message += '----------------- End -------------------'
print(message)
# save to the disk
expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
util.mkdirs(expr_dir)
file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase))
try:
with open(file_name, 'wt') as opt_file:
opt_file.write(message)
opt_file.write('\n')
except PermissionError as error:
print("permission error {}".format(error))
pass
def parse(self):
"""Parse our options, create checkpoints directory suffix, and set up gpu device."""
opt = self.gather_options()
opt.isTrain = self.isTrain # train or test
# process opt.suffix
if opt.suffix:
suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
opt.name = opt.name + suffix
self.print_options(opt)
# set gpu ids
str_ids = opt.gpu_ids.split(',')
opt.gpu_ids = []
for str_id in str_ids:
id = int(str_id)
if id >= 0:
opt.gpu_ids.append(id)
if len(opt.gpu_ids) > 0:
torch.cuda.set_device(opt.gpu_ids[0])
self.opt = opt
return self.opt

21
options/test_options.py Normal file
View File

@ -0,0 +1,21 @@
from .base_options import BaseOptions
class TestOptions(BaseOptions):
"""This class includes test options.
It also includes shared options defined in BaseOptions.
"""
def initialize(self, parser):
parser = BaseOptions.initialize(self, parser) # define shared options
parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
# Dropout and Batchnorm has different behavioir during training and test.
parser.add_argument('--eval', action='store_true', help='use eval mode during test time.')
parser.add_argument('--num_test', type=int, default=50, help='how many test images to run')
# To avoid cropping, the load_size should be the same as crop_size
parser.set_defaults(load_size=parser.get_default('crop_size'))
self.isTrain = False
return parser

47
options/train_options.py Normal file
View File

@ -0,0 +1,47 @@
from .base_options import BaseOptions
class TrainOptions(BaseOptions):
"""This class includes training options.
It also includes shared options defined in BaseOptions.
"""
def initialize(self, parser):
parser = BaseOptions.initialize(self, parser)
# visdom and HTML visualization parameters
parser.add_argument('--display_freq', type=int, default=50, help='frequency of showing training results on screen')
parser.add_argument('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.')
parser.add_argument('--display_id', type=int, default=None, help='window id of the web display. Default is random window id')
parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display')
parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")')
parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')
parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html')
parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
# network saving and loading parameters
parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')
parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')
parser.add_argument('--evaluation_freq', type=int, default=5000, help='evaluation freq')
parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration')
parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
# parser.add_argument('--use_mlp', action='store_true', help='use_mlp')
# parser.add_argument('--use_tgt_style_src', action='store_true', help='use_tgt_style_src')
parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>, ...')
parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
parser.add_argument('--pretrained_name', type=str, default=None, help='resume training from another checkpoint')
# training parameters
parser.add_argument('--n_epochs', type=int, default=100, help='number of epochs with the initial learning rate')
parser.add_argument('--n_epochs_decay', type=int, default=100, help='number of epochs to linearly decay learning rate to zero')
parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
parser.add_argument('--beta2', type=float, default=0.999, help='momentum term of adam')
parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
parser.add_argument('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.')
parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images')
parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]')
parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations')
self.isTrain = True
return parser

1
scripts/test.sh Normal file
View File

@ -0,0 +1 @@
CUDA_VISIBLE_DEVICES=0 python test.py --dataroot /path/of/test_dataset --checkpoints_dir ./checkpoints --name train1 --model roma_single --num_test 10000 --epoch latest

5
scripts/train.sh Normal file
View File

@ -0,0 +1,5 @@
# Train for video mode
CUDA_VISIBLE_DEVICES=0 python train.py --dataroot /path --name ROMA_name --dataset_mode unaligned_double --no_flip --local_nums 64 --display_env ROMA_env --model roma --side_length 7 --lambda_spatial 5.0 --lambda_global 5.0 --lambda_motion 1.0 --atten_layers 1,3,5 --lr 0.00001
# Train for image mode
CUDA_VISIBLE_DEVICES=0 python train.py --dataroot /path --name ROMA_name --dataset_mode unaligned --local_nums 64 --display_env ROMA_env --model roma --side_length 7 --lambda_spatial 5.0 --lambda_global 5.0 --atten_layers 1,3,5 --lr 0.00001

70
test.py Normal file
View File

@ -0,0 +1,70 @@
"""General-purpose test script for image-to-image translation.
Once you have trained your model with train.py, you can use this script to test the model.
It will load a saved model from --checkpoints_dir and save the results to --results_dir.
It first creates model and dataset given the option. It will hard-code some parameters.
It then runs inference for --num_test images and save results to an HTML file.
Example (You need to train models first or download pre-trained models from our website):
Test a CycleGAN model (both sides):
python test.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan
Test a CycleGAN model (one side only):
python test.py --dataroot datasets/horse2zebra/testA --name horse2zebra_pretrained --model test --no_dropout
The option '--model test' is used for generating CycleGAN results only for one side.
This option will automatically set '--dataset_mode single', which only loads the images from one set.
On the contrary, using '--model cycle_gan' requires loading and generating results in both directions,
which is sometimes unnecessary. The results will be saved at ./results/.
Use '--results_dir <directory_path_to_save_result>' to specify the results directory.
Test a pix2pix model:
python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA
See options/base_options.py and options/test_options.py for more test options.
See training and test tips at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/tips.md
See frequently asked questions at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/qa.md
"""
import os
from options.test_options import TestOptions
from data import create_dataset
from models import create_model
from util.visualizer import save_images
from util import html
import util.util as util
if __name__ == '__main__':
opt = TestOptions().parse() # get test options
# hard-code some parameters for test
opt.num_threads = 0 # test code only supports num_threads = 1
opt.batch_size = 1 # test code only supports batch_size = 1
opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed.
opt.no_flip = True # no flip; comment this line if results on flipped images are needed.
opt.display_id = -1 # no visdom display; the test code saves the results to a HTML file.
dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options
# train_dataset = create_dataset(util.copyconf(opt, phase="train"))
model = create_model(opt) # create a model given opt.model and other options
# create a webpage for viewing the results
web_dir = os.path.join(opt.results_dir, opt.name, '{}_{}'.format(opt.phase, opt.epoch)) # define the website directory
print('creating web directory', web_dir)
webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.epoch))
for i, data in enumerate(dataset):
if i == 0:
model.data_dependent_initialize(data)
model.setup(opt) # regular setup: load and print networks; create schedulers
model.parallelize()
if opt.eval:
model.eval()
if i >= opt.num_test: # only apply our model to opt.num_test images.
break
model.set_input(data) # unpack data from data loader
model.test() # run inference
visuals = model.get_current_visuals() # get image results
img_path = model.get_image_paths() # get image paths
if i % 5 == 0: # save images to an HTML file
print('processing (%04d)-th image... %s' % (i, img_path))
save_images(webpage, visuals, img_path, width=opt.display_winsize)
webpage.save() # save the HTML

4
timm/__init__.py Normal file
View File

@ -0,0 +1,4 @@
from .version import __version__
from .models import create_model, list_models, is_model, list_modules, model_entrypoint, \
is_scriptable, is_exportable, set_scriptable, set_exportable, has_model_default_key, is_model_default_key, \
get_model_default_value, is_model_pretrained

Binary file not shown.

Binary file not shown.

12
timm/data/__init__.py Normal file
View File

@ -0,0 +1,12 @@
from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\
rand_augment_transform, auto_augment_transform
from .config import resolve_data_config
from .constants import *
from .dataset import ImageDataset, IterableImageDataset, AugMixDataset
from .dataset_factory import create_dataset
from .loader import create_loader
from .mixup import Mixup, FastCollateMixup
from .parsers import create_parser
from .real_labels import RealLabelsImagenet
from .transforms import *
from .transforms_factory import create_transform

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

865
timm/data/auto_augment.py Normal file
View File

@ -0,0 +1,865 @@
""" AutoAugment, RandAugment, and AugMix for PyTorch
This code implements the searched ImageNet policies with various tweaks and improvements and
does not include any of the search code.
AA and RA Implementation adapted from:
https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
AugMix adapted from:
https://github.com/google-research/augmix
Papers:
AutoAugment: Learning Augmentation Policies from Data - https://arxiv.org/abs/1805.09501
Learning Data Augmentation Strategies for Object Detection - https://arxiv.org/abs/1906.11172
RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719
AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - https://arxiv.org/abs/1912.02781
Hacked together by / Copyright 2020 Ross Wightman
"""
import random
import math
import re
from PIL import Image, ImageOps, ImageEnhance, ImageChops
import PIL
import numpy as np
_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]])
_FILL = (128, 128, 128)
_LEVEL_DENOM = 10. # denominator for conversion from 'Mx' magnitude scale to fractional aug level for op arguments
_HPARAMS_DEFAULT = dict(
translate_const=250,
img_mean=_FILL,
)
_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
def _interpolation(kwargs):
interpolation = kwargs.pop('resample', Image.BILINEAR)
if isinstance(interpolation, (list, tuple)):
return random.choice(interpolation)
else:
return interpolation
def _check_args_tf(kwargs):
if 'fillcolor' in kwargs and _PIL_VER < (5, 0):
kwargs.pop('fillcolor')
kwargs['resample'] = _interpolation(kwargs)
def shear_x(img, factor, **kwargs):
_check_args_tf(kwargs)
return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs)
def shear_y(img, factor, **kwargs):
_check_args_tf(kwargs)
return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs)
def translate_x_rel(img, pct, **kwargs):
pixels = pct * img.size[0]
_check_args_tf(kwargs)
return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
def translate_y_rel(img, pct, **kwargs):
pixels = pct * img.size[1]
_check_args_tf(kwargs)
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
def translate_x_abs(img, pixels, **kwargs):
_check_args_tf(kwargs)
return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
def translate_y_abs(img, pixels, **kwargs):
_check_args_tf(kwargs)
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
def rotate(img, degrees, **kwargs):
_check_args_tf(kwargs)
if _PIL_VER >= (5, 2):
return img.rotate(degrees, **kwargs)
elif _PIL_VER >= (5, 0):
w, h = img.size
post_trans = (0, 0)
rotn_center = (w / 2.0, h / 2.0)
angle = -math.radians(degrees)
matrix = [
round(math.cos(angle), 15),
round(math.sin(angle), 15),
0.0,
round(-math.sin(angle), 15),
round(math.cos(angle), 15),
0.0,
]
def transform(x, y, matrix):
(a, b, c, d, e, f) = matrix
return a * x + b * y + c, d * x + e * y + f
matrix[2], matrix[5] = transform(
-rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix
)
matrix[2] += rotn_center[0]
matrix[5] += rotn_center[1]
return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
else:
return img.rotate(degrees, resample=kwargs['resample'])
def auto_contrast(img, **__):
return ImageOps.autocontrast(img)
def invert(img, **__):
return ImageOps.invert(img)
def equalize(img, **__):
return ImageOps.equalize(img)
def solarize(img, thresh, **__):
return ImageOps.solarize(img, thresh)
def solarize_add(img, add, thresh=128, **__):
lut = []
for i in range(256):
if i < thresh:
lut.append(min(255, i + add))
else:
lut.append(i)
if img.mode in ("L", "RGB"):
if img.mode == "RGB" and len(lut) == 256:
lut = lut + lut + lut
return img.point(lut)
else:
return img
def posterize(img, bits_to_keep, **__):
if bits_to_keep >= 8:
return img
return ImageOps.posterize(img, bits_to_keep)
def contrast(img, factor, **__):
return ImageEnhance.Contrast(img).enhance(factor)
def color(img, factor, **__):
return ImageEnhance.Color(img).enhance(factor)
def brightness(img, factor, **__):
return ImageEnhance.Brightness(img).enhance(factor)
def sharpness(img, factor, **__):
return ImageEnhance.Sharpness(img).enhance(factor)
def _randomly_negate(v):
"""With 50% prob, negate the value"""
return -v if random.random() > 0.5 else v
def _rotate_level_to_arg(level, _hparams):
# range [-30, 30]
level = (level / _LEVEL_DENOM) * 30.
level = _randomly_negate(level)
return level,
def _enhance_level_to_arg(level, _hparams):
# range [0.1, 1.9]
return (level / _LEVEL_DENOM) * 1.8 + 0.1,
def _enhance_increasing_level_to_arg(level, _hparams):
# the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend
# range [0.1, 1.9] if level <= _LEVEL_DENOM
level = (level / _LEVEL_DENOM) * .9
level = max(0.1, 1.0 + _randomly_negate(level)) # keep it >= 0.1
return level,
def _shear_level_to_arg(level, _hparams):
# range [-0.3, 0.3]
level = (level / _LEVEL_DENOM) * 0.3
level = _randomly_negate(level)
return level,
def _translate_abs_level_to_arg(level, hparams):
translate_const = hparams['translate_const']
level = (level / _LEVEL_DENOM) * float(translate_const)
level = _randomly_negate(level)
return level,
def _translate_rel_level_to_arg(level, hparams):
# default range [-0.45, 0.45]
translate_pct = hparams.get('translate_pct', 0.45)
level = (level / _LEVEL_DENOM) * translate_pct
level = _randomly_negate(level)
return level,
def _posterize_level_to_arg(level, _hparams):
# As per Tensorflow TPU EfficientNet impl
# range [0, 4], 'keep 0 up to 4 MSB of original image'
# intensity/severity of augmentation decreases with level
return int((level / _LEVEL_DENOM) * 4),
def _posterize_increasing_level_to_arg(level, hparams):
# As per Tensorflow models research and UDA impl
# range [4, 0], 'keep 4 down to 0 MSB of original image',
# intensity/severity of augmentation increases with level
return 4 - _posterize_level_to_arg(level, hparams)[0],
def _posterize_original_level_to_arg(level, _hparams):
# As per original AutoAugment paper description
# range [4, 8], 'keep 4 up to 8 MSB of image'
# intensity/severity of augmentation decreases with level
return int((level / _LEVEL_DENOM) * 4) + 4,
def _solarize_level_to_arg(level, _hparams):
# range [0, 256]
# intensity/severity of augmentation decreases with level
return int((level / _LEVEL_DENOM) * 256),
def _solarize_increasing_level_to_arg(level, _hparams):
# range [0, 256]
# intensity/severity of augmentation increases with level
return 256 - _solarize_level_to_arg(level, _hparams)[0],
def _solarize_add_level_to_arg(level, _hparams):
# range [0, 110]
return int((level / _LEVEL_DENOM) * 110),
LEVEL_TO_ARG = {
'AutoContrast': None,
'Equalize': None,
'Invert': None,
'Rotate': _rotate_level_to_arg,
# There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers
'Posterize': _posterize_level_to_arg,
'PosterizeIncreasing': _posterize_increasing_level_to_arg,
'PosterizeOriginal': _posterize_original_level_to_arg,
'Solarize': _solarize_level_to_arg,
'SolarizeIncreasing': _solarize_increasing_level_to_arg,
'SolarizeAdd': _solarize_add_level_to_arg,
'Color': _enhance_level_to_arg,
'ColorIncreasing': _enhance_increasing_level_to_arg,
'Contrast': _enhance_level_to_arg,
'ContrastIncreasing': _enhance_increasing_level_to_arg,
'Brightness': _enhance_level_to_arg,
'BrightnessIncreasing': _enhance_increasing_level_to_arg,
'Sharpness': _enhance_level_to_arg,
'SharpnessIncreasing': _enhance_increasing_level_to_arg,
'ShearX': _shear_level_to_arg,
'ShearY': _shear_level_to_arg,
'TranslateX': _translate_abs_level_to_arg,
'TranslateY': _translate_abs_level_to_arg,
'TranslateXRel': _translate_rel_level_to_arg,
'TranslateYRel': _translate_rel_level_to_arg,
}
NAME_TO_OP = {
'AutoContrast': auto_contrast,
'Equalize': equalize,
'Invert': invert,
'Rotate': rotate,
'Posterize': posterize,
'PosterizeIncreasing': posterize,
'PosterizeOriginal': posterize,
'Solarize': solarize,
'SolarizeIncreasing': solarize,
'SolarizeAdd': solarize_add,
'Color': color,
'ColorIncreasing': color,
'Contrast': contrast,
'ContrastIncreasing': contrast,
'Brightness': brightness,
'BrightnessIncreasing': brightness,
'Sharpness': sharpness,
'SharpnessIncreasing': sharpness,
'ShearX': shear_x,
'ShearY': shear_y,
'TranslateX': translate_x_abs,
'TranslateY': translate_y_abs,
'TranslateXRel': translate_x_rel,
'TranslateYRel': translate_y_rel,
}
class AugmentOp:
def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
hparams = hparams or _HPARAMS_DEFAULT
self.name = name
self.aug_fn = NAME_TO_OP[name]
self.level_fn = LEVEL_TO_ARG[name]
self.prob = prob
self.magnitude = magnitude
self.hparams = hparams.copy()
self.kwargs = dict(
fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL,
resample=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION,
)
# If magnitude_std is > 0, we introduce some randomness
# in the usually fixed policy and sample magnitude from a normal distribution
# with mean `magnitude` and std-dev of `magnitude_std`.
# NOTE This is my own hack, being tested, not in papers or reference impls.
# If magnitude_std is inf, we sample magnitude from a uniform distribution
self.magnitude_std = self.hparams.get('magnitude_std', 0)
self.magnitude_max = self.hparams.get('magnitude_max', None)
def __call__(self, img):
if self.prob < 1.0 and random.random() > self.prob:
return img
magnitude = self.magnitude
if self.magnitude_std > 0:
# magnitude randomization enabled
if self.magnitude_std == float('inf'):
magnitude = random.uniform(0, magnitude)
elif self.magnitude_std > 0:
magnitude = random.gauss(magnitude, self.magnitude_std)
# default upper_bound for the timm RA impl is _LEVEL_DENOM (10)
# setting magnitude_max overrides this to allow M > 10 (behaviour closer to Google TF RA impl)
upper_bound = self.magnitude_max or _LEVEL_DENOM
magnitude = max(0., min(magnitude, upper_bound))
level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple()
return self.aug_fn(img, *level_args, **self.kwargs)
def __repr__(self):
fs = self.__class__.__name__ + f'(name={self.name}, p={self.prob}'
fs += f', m={self.magnitude}, mstd={self.magnitude_std}'
if self.magnitude_max is not None:
fs += f', mmax={self.magnitude_max}'
fs += ')'
return fs
def auto_augment_policy_v0(hparams):
# ImageNet v0 policy from TPU EfficientNet impl, cannot find a paper reference.
policy = [
[('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
[('Color', 0.4, 9), ('Equalize', 0.6, 3)],
[('Color', 0.4, 1), ('Rotate', 0.6, 8)],
[('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
[('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
[('Color', 0.2, 0), ('Equalize', 0.8, 8)],
[('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
[('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
[('Color', 0.6, 1), ('Equalize', 1.0, 2)],
[('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
[('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
[('Color', 0.4, 7), ('Equalize', 0.6, 0)],
[('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)],
[('Solarize', 0.6, 8), ('Color', 0.6, 9)],
[('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
[('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
[('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
[('ShearY', 0.8, 0), ('Color', 0.6, 4)],
[('Color', 1.0, 0), ('Rotate', 0.6, 2)],
[('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
[('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
[('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
[('Posterize', 0.8, 2), ('Solarize', 0.6, 10)], # This results in black image with Tpu posterize
[('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
[('Color', 0.8, 6), ('Rotate', 0.4, 5)],
]
pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
return pc
def auto_augment_policy_v0r(hparams):
# ImageNet v0 policy from TPU EfficientNet impl, with variation of Posterize used
# in Google research implementation (number of bits discarded increases with magnitude)
policy = [
[('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
[('Color', 0.4, 9), ('Equalize', 0.6, 3)],
[('Color', 0.4, 1), ('Rotate', 0.6, 8)],
[('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
[('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
[('Color', 0.2, 0), ('Equalize', 0.8, 8)],
[('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
[('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
[('Color', 0.6, 1), ('Equalize', 1.0, 2)],
[('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
[('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
[('Color', 0.4, 7), ('Equalize', 0.6, 0)],
[('PosterizeIncreasing', 0.4, 6), ('AutoContrast', 0.4, 7)],
[('Solarize', 0.6, 8), ('Color', 0.6, 9)],
[('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
[('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
[('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
[('ShearY', 0.8, 0), ('Color', 0.6, 4)],
[('Color', 1.0, 0), ('Rotate', 0.6, 2)],
[('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
[('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
[('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
[('PosterizeIncreasing', 0.8, 2), ('Solarize', 0.6, 10)],
[('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
[('Color', 0.8, 6), ('Rotate', 0.4, 5)],
]
pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
return pc
def auto_augment_policy_original(hparams):
# ImageNet policy from https://arxiv.org/abs/1805.09501
policy = [
[('PosterizeOriginal', 0.4, 8), ('Rotate', 0.6, 9)],
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
[('PosterizeOriginal', 0.6, 7), ('PosterizeOriginal', 0.6, 6)],
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
[('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
[('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
[('PosterizeOriginal', 0.8, 5), ('Equalize', 1.0, 2)],
[('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
[('Equalize', 0.6, 8), ('PosterizeOriginal', 0.4, 6)],
[('Rotate', 0.8, 8), ('Color', 0.4, 0)],
[('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
[('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
[('Rotate', 0.8, 8), ('Color', 1.0, 2)],
[('Color', 0.8, 8), ('Solarize', 0.8, 7)],
[('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
[('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
[('Color', 0.4, 0), ('Equalize', 0.6, 3)],
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
]
pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
return pc
def auto_augment_policy_originalr(hparams):
# ImageNet policy from https://arxiv.org/abs/1805.09501 with research posterize variation
policy = [
[('PosterizeIncreasing', 0.4, 8), ('Rotate', 0.6, 9)],
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
[('PosterizeIncreasing', 0.6, 7), ('PosterizeIncreasing', 0.6, 6)],
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
[('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
[('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
[('PosterizeIncreasing', 0.8, 5), ('Equalize', 1.0, 2)],
[('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
[('Equalize', 0.6, 8), ('PosterizeIncreasing', 0.4, 6)],
[('Rotate', 0.8, 8), ('Color', 0.4, 0)],
[('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
[('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
[('Rotate', 0.8, 8), ('Color', 1.0, 2)],
[('Color', 0.8, 8), ('Solarize', 0.8, 7)],
[('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
[('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
[('Color', 0.4, 0), ('Equalize', 0.6, 3)],
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
]
pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
return pc
def auto_augment_policy(name='v0', hparams=None):
hparams = hparams or _HPARAMS_DEFAULT
if name == 'original':
return auto_augment_policy_original(hparams)
elif name == 'originalr':
return auto_augment_policy_originalr(hparams)
elif name == 'v0':
return auto_augment_policy_v0(hparams)
elif name == 'v0r':
return auto_augment_policy_v0r(hparams)
else:
assert False, 'Unknown AA policy (%s)' % name
class AutoAugment:
def __init__(self, policy):
self.policy = policy
def __call__(self, img):
sub_policy = random.choice(self.policy)
for op in sub_policy:
img = op(img)
return img
def __repr__(self):
fs = self.__class__.__name__ + f'(policy='
for p in self.policy:
fs += '\n\t['
fs += ', '.join([str(op) for op in p])
fs += ']'
fs += ')'
return fs
def auto_augment_transform(config_str, hparams):
"""
Create a AutoAugment transform
:param config_str: String defining configuration of auto augmentation. Consists of multiple sections separated by
dashes ('-'). The first section defines the AutoAugment policy (one of 'v0', 'v0r', 'original', 'originalr').
The remaining sections, not order sepecific determine
'mstd' - float std deviation of magnitude noise applied
Ex 'original-mstd0.5' results in AutoAugment with original policy, magnitude_std 0.5
:param hparams: Other hparams (kwargs) for the AutoAugmentation scheme
:return: A PyTorch compatible Transform
"""
config = config_str.split('-')
policy_name = config[0]
config = config[1:]
for c in config:
cs = re.split(r'(\d.*)', c)
if len(cs) < 2:
continue
key, val = cs[:2]
if key == 'mstd':
# noise param injected via hparams for now
hparams.setdefault('magnitude_std', float(val))
else:
assert False, 'Unknown AutoAugment config section'
aa_policy = auto_augment_policy(policy_name, hparams=hparams)
return AutoAugment(aa_policy)
_RAND_TRANSFORMS = [
'AutoContrast',
'Equalize',
'Invert',
'Rotate',
'Posterize',
'Solarize',
'SolarizeAdd',
'Color',
'Contrast',
'Brightness',
'Sharpness',
'ShearX',
'ShearY',
'TranslateXRel',
'TranslateYRel',
#'Cutout' # NOTE I've implement this as random erasing separately
]
_RAND_INCREASING_TRANSFORMS = [
'AutoContrast',
'Equalize',
'Invert',
'Rotate',
'PosterizeIncreasing',
'SolarizeIncreasing',
'SolarizeAdd',
'ColorIncreasing',
'ContrastIncreasing',
'BrightnessIncreasing',
'SharpnessIncreasing',
'ShearX',
'ShearY',
'TranslateXRel',
'TranslateYRel',
#'Cutout' # NOTE I've implement this as random erasing separately
]
# These experimental weights are based loosely on the relative improvements mentioned in paper.
# They may not result in increased performance, but could likely be tuned to so.
_RAND_CHOICE_WEIGHTS_0 = {
'Rotate': 0.3,
'ShearX': 0.2,
'ShearY': 0.2,
'TranslateXRel': 0.1,
'TranslateYRel': 0.1,
'Color': .025,
'Sharpness': 0.025,
'AutoContrast': 0.025,
'Solarize': .005,
'SolarizeAdd': .005,
'Contrast': .005,
'Brightness': .005,
'Equalize': .005,
'Posterize': 0,
'Invert': 0,
}
def _select_rand_weights(weight_idx=0, transforms=None):
transforms = transforms or _RAND_TRANSFORMS
assert weight_idx == 0 # only one set of weights currently
rand_weights = _RAND_CHOICE_WEIGHTS_0
probs = [rand_weights[k] for k in transforms]
probs /= np.sum(probs)
return probs
def rand_augment_ops(magnitude=10, hparams=None, transforms=None):
hparams = hparams or _HPARAMS_DEFAULT
transforms = transforms or _RAND_TRANSFORMS
return [AugmentOp(
name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms]
class RandAugment:
def __init__(self, ops, num_layers=2, choice_weights=None):
self.ops = ops
self.num_layers = num_layers
self.choice_weights = choice_weights
def __call__(self, img):
# no replacement when using weighted choice
ops = np.random.choice(
self.ops, self.num_layers, replace=self.choice_weights is None, p=self.choice_weights)
for op in ops:
img = op(img)
return img
def __repr__(self):
fs = self.__class__.__name__ + f'(n={self.num_layers}, ops='
for op in self.ops:
fs += f'\n\t{op}'
fs += ')'
return fs
def rand_augment_transform(config_str, hparams):
"""
Create a RandAugment transform
:param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
sections, not order sepecific determine
'm' - integer magnitude of rand augment
'n' - integer num layers (number of transform ops selected per image)
'w' - integer probabiliy weight index (index of a set of weights to influence choice of op)
'mstd' - float std deviation of magnitude noise applied, or uniform sampling if infinity (or > 100)
'mmax' - set upper bound for magnitude to something other than default of _LEVEL_DENOM (10)
'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0)
Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2
:param hparams: Other hparams (kwargs) for the RandAugmentation scheme
:return: A PyTorch compatible Transform
"""
magnitude = _LEVEL_DENOM # default to _LEVEL_DENOM for magnitude (currently 10)
num_layers = 2 # default to 2 ops per image
weight_idx = None # default to no probability weights for op choice
transforms = _RAND_TRANSFORMS
config = config_str.split('-')
assert config[0] == 'rand'
config = config[1:]
for c in config:
cs = re.split(r'(\d.*)', c)
if len(cs) < 2:
continue
key, val = cs[:2]
if key == 'mstd':
# noise param / randomization of magnitude values
mstd = float(val)
if mstd > 100:
# use uniform sampling in 0 to magnitude if mstd is > 100
mstd = float('inf')
hparams.setdefault('magnitude_std', mstd)
elif key == 'mmax':
# clip magnitude between [0, mmax] instead of default [0, _LEVEL_DENOM]
hparams.setdefault('magnitude_max', int(val))
elif key == 'inc':
if bool(val):
transforms = _RAND_INCREASING_TRANSFORMS
elif key == 'm':
magnitude = int(val)
elif key == 'n':
num_layers = int(val)
elif key == 'w':
weight_idx = int(val)
else:
assert False, 'Unknown RandAugment config section'
ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams, transforms=transforms)
choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx)
return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)
_AUGMIX_TRANSFORMS = [
'AutoContrast',
'ColorIncreasing', # not in paper
'ContrastIncreasing', # not in paper
'BrightnessIncreasing', # not in paper
'SharpnessIncreasing', # not in paper
'Equalize',
'Rotate',
'PosterizeIncreasing',
'SolarizeIncreasing',
'ShearX',
'ShearY',
'TranslateXRel',
'TranslateYRel',
]
def augmix_ops(magnitude=10, hparams=None, transforms=None):
hparams = hparams or _HPARAMS_DEFAULT
transforms = transforms or _AUGMIX_TRANSFORMS
return [AugmentOp(
name, prob=1.0, magnitude=magnitude, hparams=hparams) for name in transforms]
class AugMixAugment:
""" AugMix Transform
Adapted and improved from impl here: https://github.com/google-research/augmix/blob/master/imagenet.py
From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty -
https://arxiv.org/abs/1912.02781
"""
def __init__(self, ops, alpha=1., width=3, depth=-1, blended=False):
self.ops = ops
self.alpha = alpha
self.width = width
self.depth = depth
self.blended = blended # blended mode is faster but not well tested
def _calc_blended_weights(self, ws, m):
ws = ws * m
cump = 1.
rws = []
for w in ws[::-1]:
alpha = w / cump
cump *= (1 - alpha)
rws.append(alpha)
return np.array(rws[::-1], dtype=np.float32)
def _apply_blended(self, img, mixing_weights, m):
# This is my first crack and implementing a slightly faster mixed augmentation. Instead
# of accumulating the mix for each chain in a Numpy array and then blending with original,
# it recomputes the blending coefficients and applies one PIL image blend per chain.
# TODO the results appear in the right ballpark but they differ by more than rounding.
img_orig = img.copy()
ws = self._calc_blended_weights(mixing_weights, m)
for w in ws:
depth = self.depth if self.depth > 0 else np.random.randint(1, 4)
ops = np.random.choice(self.ops, depth, replace=True)
img_aug = img_orig # no ops are in-place, deep copy not necessary
for op in ops:
img_aug = op(img_aug)
img = Image.blend(img, img_aug, w)
return img
def _apply_basic(self, img, mixing_weights, m):
# This is a literal adaptation of the paper/official implementation without normalizations and
# PIL <-> Numpy conversions between every op. It is still quite CPU compute heavy compared to the
# typical augmentation transforms, could use a GPU / Kornia implementation.
img_shape = img.size[0], img.size[1], len(img.getbands())
mixed = np.zeros(img_shape, dtype=np.float32)
for mw in mixing_weights:
depth = self.depth if self.depth > 0 else np.random.randint(1, 4)
ops = np.random.choice(self.ops, depth, replace=True)
img_aug = img # no ops are in-place, deep copy not necessary
for op in ops:
img_aug = op(img_aug)
mixed += mw * np.asarray(img_aug, dtype=np.float32)
np.clip(mixed, 0, 255., out=mixed)
mixed = Image.fromarray(mixed.astype(np.uint8))
return Image.blend(img, mixed, m)
def __call__(self, img):
mixing_weights = np.float32(np.random.dirichlet([self.alpha] * self.width))
m = np.float32(np.random.beta(self.alpha, self.alpha))
if self.blended:
mixed = self._apply_blended(img, mixing_weights, m)
else:
mixed = self._apply_basic(img, mixing_weights, m)
return mixed
def __repr__(self):
fs = self.__class__.__name__ + f'(alpha={self.alpha}, width={self.width}, depth={self.depth}, ops='
for op in self.ops:
fs += f'\n\t{op}'
fs += ')'
return fs
def augment_and_mix_transform(config_str, hparams):
""" Create AugMix PyTorch transform
:param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
sections, not order sepecific determine
'm' - integer magnitude (severity) of augmentation mix (default: 3)
'w' - integer width of augmentation chain (default: 3)
'd' - integer depth of augmentation chain (-1 is random [1, 3], default: -1)
'b' - integer (bool), blend each branch of chain into end result without a final blend, less CPU (default: 0)
'mstd' - float std deviation of magnitude noise applied (default: 0)
Ex 'augmix-m5-w4-d2' results in AugMix with severity 5, chain width 4, chain depth 2
:param hparams: Other hparams (kwargs) for the Augmentation transforms
:return: A PyTorch compatible Transform
"""
magnitude = 3
width = 3
depth = -1
alpha = 1.
blended = False
config = config_str.split('-')
assert config[0] == 'augmix'
config = config[1:]
for c in config:
cs = re.split(r'(\d.*)', c)
if len(cs) < 2:
continue
key, val = cs[:2]
if key == 'mstd':
# noise param injected via hparams for now
hparams.setdefault('magnitude_std', float(val))
elif key == 'm':
magnitude = int(val)
elif key == 'w':
width = int(val)
elif key == 'd':
depth = int(val)
elif key == 'a':
alpha = float(val)
elif key == 'b':
blended = bool(val)
else:
assert False, 'Unknown AugMix config section'
hparams.setdefault('magnitude_std', float('inf')) # default to uniform sampling (if not set via mstd arg)
ops = augmix_ops(magnitude=magnitude, hparams=hparams)
return AugMixAugment(ops, alpha=alpha, width=width, depth=depth, blended=blended)

78
timm/data/config.py Normal file
View File

@ -0,0 +1,78 @@
import logging
from .constants import *
_logger = logging.getLogger(__name__)
def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, verbose=False):
new_config = {}
default_cfg = default_cfg
if not default_cfg and model is not None and hasattr(model, 'default_cfg'):
default_cfg = model.default_cfg
# Resolve input/image size
in_chans = 3
if 'chans' in args and args['chans'] is not None:
in_chans = args['chans']
input_size = (in_chans, 224, 224)
if 'input_size' in args and args['input_size'] is not None:
assert isinstance(args['input_size'], (tuple, list))
assert len(args['input_size']) == 3
input_size = tuple(args['input_size'])
in_chans = input_size[0] # input_size overrides in_chans
elif 'img_size' in args and args['img_size'] is not None:
assert isinstance(args['img_size'], int)
input_size = (in_chans, args['img_size'], args['img_size'])
else:
if use_test_size and 'test_input_size' in default_cfg:
input_size = default_cfg['test_input_size']
elif 'input_size' in default_cfg:
input_size = default_cfg['input_size']
new_config['input_size'] = input_size
# resolve interpolation method
new_config['interpolation'] = 'bicubic'
if 'interpolation' in args and args['interpolation']:
new_config['interpolation'] = args['interpolation']
elif 'interpolation' in default_cfg:
new_config['interpolation'] = default_cfg['interpolation']
# resolve dataset + model mean for normalization
new_config['mean'] = IMAGENET_DEFAULT_MEAN
if 'mean' in args and args['mean'] is not None:
mean = tuple(args['mean'])
if len(mean) == 1:
mean = tuple(list(mean) * in_chans)
else:
assert len(mean) == in_chans
new_config['mean'] = mean
elif 'mean' in default_cfg:
new_config['mean'] = default_cfg['mean']
# resolve dataset + model std deviation for normalization
new_config['std'] = IMAGENET_DEFAULT_STD
if 'std' in args and args['std'] is not None:
std = tuple(args['std'])
if len(std) == 1:
std = tuple(list(std) * in_chans)
else:
assert len(std) == in_chans
new_config['std'] = std
elif 'std' in default_cfg:
new_config['std'] = default_cfg['std']
# resolve default crop percentage
new_config['crop_pct'] = DEFAULT_CROP_PCT
if 'crop_pct' in args and args['crop_pct'] is not None:
new_config['crop_pct'] = args['crop_pct']
elif 'crop_pct' in default_cfg:
new_config['crop_pct'] = default_cfg['crop_pct']
if verbose:
_logger.info('Data processing configuration for current model + dataset:')
for n, v in new_config.items():
_logger.info('\t%s: %s' % (n, str(v)))
return new_config

7
timm/data/constants.py Normal file
View File

@ -0,0 +1,7 @@
DEFAULT_CROP_PCT = 0.875
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255)
IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3)

152
timm/data/dataset.py Normal file
View File

@ -0,0 +1,152 @@
""" Quick n Simple Image Folder, Tarfile based DataSet
Hacked together by / Copyright 2020 Ross Wightman
"""
import torch.utils.data as data
import os
import torch
import logging
from PIL import Image
from .parsers import create_parser
_logger = logging.getLogger(__name__)
_ERROR_RETRY = 50
class ImageDataset(data.Dataset):
def __init__(
self,
root,
parser=None,
class_map=None,
load_bytes=False,
transform=None,
target_transform=None,
):
if parser is None or isinstance(parser, str):
parser = create_parser(parser or '', root=root, class_map=class_map)
self.parser = parser
self.load_bytes = load_bytes
self.transform = transform
self.target_transform = target_transform
self._consecutive_errors = 0
def __getitem__(self, index):
img, target = self.parser[index]
try:
img = img.read() if self.load_bytes else Image.open(img).convert('RGB')
except Exception as e:
_logger.warning(f'Skipped sample (index {index}, file {self.parser.filename(index)}). {str(e)}')
self._consecutive_errors += 1
if self._consecutive_errors < _ERROR_RETRY:
return self.__getitem__((index + 1) % len(self.parser))
else:
raise e
self._consecutive_errors = 0
if self.transform is not None:
img = self.transform(img)
if target is None:
target = -1
elif self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.parser)
def filename(self, index, basename=False, absolute=False):
return self.parser.filename(index, basename, absolute)
def filenames(self, basename=False, absolute=False):
return self.parser.filenames(basename, absolute)
class IterableImageDataset(data.IterableDataset):
def __init__(
self,
root,
parser=None,
split='train',
is_training=False,
batch_size=None,
repeats=0,
download=False,
transform=None,
target_transform=None,
):
assert parser is not None
if isinstance(parser, str):
self.parser = create_parser(
parser, root=root, split=split, is_training=is_training,
batch_size=batch_size, repeats=repeats, download=download)
else:
self.parser = parser
self.transform = transform
self.target_transform = target_transform
self._consecutive_errors = 0
def __iter__(self):
for img, target in self.parser:
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
yield img, target
def __len__(self):
if hasattr(self.parser, '__len__'):
return len(self.parser)
else:
return 0
def filename(self, index, basename=False, absolute=False):
assert False, 'Filename lookup by index not supported, use filenames().'
def filenames(self, basename=False, absolute=False):
return self.parser.filenames(basename, absolute)
class AugMixDataset(torch.utils.data.Dataset):
"""Dataset wrapper to perform AugMix or other clean/augmentation mixes"""
def __init__(self, dataset, num_splits=2):
self.augmentation = None
self.normalize = None
self.dataset = dataset
if self.dataset.transform is not None:
self._set_transforms(self.dataset.transform)
self.num_splits = num_splits
def _set_transforms(self, x):
assert isinstance(x, (list, tuple)) and len(x) == 3, 'Expecting a tuple/list of 3 transforms'
self.dataset.transform = x[0]
self.augmentation = x[1]
self.normalize = x[2]
@property
def transform(self):
return self.dataset.transform
@transform.setter
def transform(self, x):
self._set_transforms(x)
def _normalize(self, x):
return x if self.normalize is None else self.normalize(x)
def __getitem__(self, i):
x, y = self.dataset[i] # all splits share the same dataset base transform
x_list = [self._normalize(x)] # first split only normalizes (this is the 'clean' split)
# run the full augmentation on the remaining splits
for _ in range(self.num_splits - 1):
x_list.append(self._normalize(self.augmentation(x)))
return tuple(x_list), y
def __len__(self):
return len(self.dataset)

Some files were not shown because too many files have changed in this diff Show More