first commit
This commit is contained in:
commit
5cb1f58852
142
README.md
Normal file
142
README.md
Normal file
@ -0,0 +1,142 @@
|
||||
# ROMA
|
||||
This repository is the official Pytorch implementation for ACM MM'22 paper
|
||||
"ROMA: Cross-Domain Region Similarity Matching for Unpaired Nighttime Infrared to Daytime Visible Video Translation".[[Arxiv]](https://arxiv.org/abs/2204.12367)
|
||||
|
||||
**Examples of Object Detection:**
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
**Examples of Video Fusion**
|
||||
|
||||

|
||||
|
||||
More experimental results can be obtained by contacting us.
|
||||
|
||||
# Introduction
|
||||
|
||||
## Method
|
||||

|
||||
|
||||
- The domain gaps between unpaired nighttime infrared and daytime visible videos are even huger than paired ones that captured at the same time, establishing an effective translation mapping will greatly contribute to various fields.
|
||||
- Our proposed cross-similarity, which are calculated across domains, could make the generative process focus on learning the content of structural correspondence between real and synthesized frames, getting rid of the negative effects of different styles.
|
||||
|
||||
|
||||
|
||||
## Training
|
||||
The following is the required structure of dataset. For the video mode, the input of a single data is the result of concatenating **two adjacent frames**; for the image mode, the input of a single data is **a single image**.
|
||||
```
|
||||
Video/Image mode:
|
||||
trainA: \Path\of\trainA
|
||||
trainB: \Path\of\trainB
|
||||
|
||||
```
|
||||
Concrete examples of the training and testing are shown in the script files `./scripts/train.sh` and `./scripts/test.sh`, respectively.
|
||||
|
||||
|
||||
|
||||
|
||||
## InfraredCity and InfraredCity-Lite Dataset
|
||||
|
||||
|
||||
<table class="tg">
|
||||
<thead>
|
||||
<tr>
|
||||
<th class="tg-uzvj" colspan="2">InfraredCity</th>
|
||||
<th class="tg-uzvj" colspan="4">Total Frame</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr>
|
||||
<td class="tg-9wq8" colspan="2">Nighttime Infrared</td>
|
||||
<td class="tg-9wq8" colspan="4">201,856</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="tg-9wq8" colspan="2">Nighttime Visible</td>
|
||||
<td class="tg-9wq8" colspan="4">178,698</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="tg-9wq8" colspan="2">Daytime Visible</td>
|
||||
<td class="tg-9wq8" colspan="4">199,430</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="tg-9wq8" colspan="6"></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="tg-uzvj" colspan="2">InfraredCity-Lite</td>
|
||||
<td class="tg-uzvj">Infrared<br>Train</td>
|
||||
<td class="tg-uzvj">Infrared<br>Test</td>
|
||||
<td class="tg-uzvj">Visible<br>Train</td>
|
||||
<td class="tg-uzvj">Total</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="tg-9wq8" rowspan="2">City</td>
|
||||
<td class="tg-9wq8">clearday</td>
|
||||
<td class="tg-9wq8">5,538</td>
|
||||
<td class="tg-9wq8">1,000</td>
|
||||
<td class="tg-9wq8" rowspan="2">5360</td>
|
||||
<td class="tg-9wq8" rowspan="2">15,180</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="tg-9wq8">overcast</td>
|
||||
<td class="tg-9wq8">2,282</td>
|
||||
<td class="tg-9wq8">1,000</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="tg-9wq8" rowspan="2">Highway</td>
|
||||
<td class="tg-9wq8">clearday</td>
|
||||
<td class="tg-9wq8">4,412</td>
|
||||
<td class="tg-9wq8">1,000</td>
|
||||
<td class="tg-9wq8" rowspan="2">6,463</td>
|
||||
<td class="tg-9wq8" rowspan="2">15,853</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="tg-9wq8">overcast</td>
|
||||
<td class="tg-9wq8">2,978</td>
|
||||
<td class="tg-9wq8">1,000</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="tg-9wq8" colspan="2">Monitor</td>
|
||||
<td class="tg-9wq8">5,612</td>
|
||||
<td class="tg-9wq8">500</td>
|
||||
<td class="tg-9wq8">4,194</td>
|
||||
<td class="tg-9wq8">10,306</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
The datasets and their more details are available in [InfiRay](http://openai.raytrontek.com/apply/Infrared_city.html/).
|
||||
|
||||
|
||||
### Citation
|
||||
If you find our work useful in your research or publication, please cite our work:
|
||||
```
|
||||
@inproceedings{ROMA2022,
|
||||
title = {ROMA: Cross-Domain Region Similarity Matching for Unpaired Nighttime Infrared to Daytime Visible Video Translation},
|
||||
author = {Zhenjie Yu and Kai Chen and Shuang Li and Bingfeng Han and Chi Harold Liu and Shuigen Wang},
|
||||
booktitle = {ACM MM},
|
||||
pages = {5294--5302},
|
||||
year = {2022}
|
||||
}
|
||||
```
|
||||
|
||||
#### Acknowledgements
|
||||
This code borrows heavily from the PyTorch implementation of [Cycle-GAN and Pix2Pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) and [CUT](https://github.com/taesungp/contrastive-unpaired-translation).
|
||||
A huge thanks to them!
|
||||
```
|
||||
@inproceedings{CycleGAN2017,
|
||||
title = {Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networkss},
|
||||
author = {Zhu, Jun-Yan and Park, Taesung and Isola, Phillip and Efros, Alexei A},
|
||||
booktitle = {ICCV},
|
||||
year = {2017}
|
||||
}
|
||||
|
||||
@inproceedings{CUT2020,
|
||||
author = {Taesung Park and Alexei A. Efros and Richard Zhang and Jun{-}Yan Zhu},
|
||||
title = {Contrastive Learning for Unpaired Image-to-Image Translation},
|
||||
booktitle = {ECCV},
|
||||
pages = {319--345},
|
||||
year = {2020},
|
||||
}
|
||||
```
|
||||
98
data/__init__.py
Normal file
98
data/__init__.py
Normal file
@ -0,0 +1,98 @@
|
||||
"""This package includes all the modules related to data loading and preprocessing
|
||||
|
||||
To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
|
||||
You need to implement four functions:
|
||||
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
|
||||
-- <__len__>: return the size of dataset.
|
||||
-- <__getitem__>: get a data point from data loader.
|
||||
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
|
||||
|
||||
Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
|
||||
See our template dataset class 'template_dataset.py' for more details.
|
||||
"""
|
||||
import importlib
|
||||
import torch.utils.data
|
||||
from data.base_dataset import BaseDataset
|
||||
|
||||
|
||||
def find_dataset_using_name(dataset_name):
|
||||
"""Import the module "data/[dataset_name]_dataset.py".
|
||||
|
||||
In the file, the class called DatasetNameDataset() will
|
||||
be instantiated. It has to be a subclass of BaseDataset,
|
||||
and it is case-insensitive.
|
||||
"""
|
||||
dataset_filename = "data." + dataset_name + "_dataset"
|
||||
datasetlib = importlib.import_module(dataset_filename)
|
||||
|
||||
dataset = None
|
||||
target_dataset_name = dataset_name.replace('_', '') + 'dataset'
|
||||
for name, cls in datasetlib.__dict__.items():
|
||||
if name.lower() == target_dataset_name.lower() \
|
||||
and issubclass(cls, BaseDataset):
|
||||
dataset = cls
|
||||
|
||||
if dataset is None:
|
||||
raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def get_option_setter(dataset_name):
|
||||
"""Return the static method <modify_commandline_options> of the dataset class."""
|
||||
dataset_class = find_dataset_using_name(dataset_name)
|
||||
return dataset_class.modify_commandline_options
|
||||
|
||||
|
||||
def create_dataset(opt):
|
||||
"""Create a dataset given the option.
|
||||
|
||||
This function wraps the class CustomDatasetDataLoader.
|
||||
This is the main interface between this package and 'train.py'/'test.py'
|
||||
|
||||
Example:
|
||||
>>> from data import create_dataset
|
||||
>>> dataset = create_dataset(opt)
|
||||
"""
|
||||
data_loader = CustomDatasetDataLoader(opt)
|
||||
dataset = data_loader.load_data()
|
||||
return dataset
|
||||
|
||||
|
||||
class CustomDatasetDataLoader():
|
||||
"""Wrapper class of Dataset class that performs multi-threaded data loading"""
|
||||
|
||||
def __init__(self, opt):
|
||||
"""Initialize this class
|
||||
|
||||
Step 1: create a dataset instance given the name [dataset_mode]
|
||||
Step 2: create a multi-threaded data loader.
|
||||
"""
|
||||
self.opt = opt
|
||||
dataset_class = find_dataset_using_name(opt.dataset_mode)
|
||||
self.dataset = dataset_class(opt)
|
||||
print("dataset [%s] was created" % type(self.dataset).__name__)
|
||||
self.dataloader = torch.utils.data.DataLoader(
|
||||
self.dataset,
|
||||
batch_size=opt.batch_size,
|
||||
shuffle=not opt.serial_batches,
|
||||
num_workers=int(opt.num_threads),
|
||||
drop_last=True if opt.isTrain else False,
|
||||
)
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.dataset.current_epoch = epoch
|
||||
|
||||
def load_data(self):
|
||||
return self
|
||||
|
||||
def __len__(self):
|
||||
"""Return the number of data in the dataset"""
|
||||
return min(len(self.dataset), self.opt.max_dataset_size)
|
||||
|
||||
def __iter__(self):
|
||||
"""Return a batch of data"""
|
||||
for i, data in enumerate(self.dataloader):
|
||||
if i * self.opt.batch_size >= self.opt.max_dataset_size:
|
||||
break
|
||||
yield data
|
||||
BIN
data/__pycache__/__init__.cpython-36.pyc
Normal file
BIN
data/__pycache__/__init__.cpython-36.pyc
Normal file
Binary file not shown.
BIN
data/__pycache__/base_dataset.cpython-36.pyc
Normal file
BIN
data/__pycache__/base_dataset.cpython-36.pyc
Normal file
Binary file not shown.
BIN
data/__pycache__/image_folder.cpython-36.pyc
Normal file
BIN
data/__pycache__/image_folder.cpython-36.pyc
Normal file
Binary file not shown.
BIN
data/__pycache__/unaligned_dataset.cpython-36.pyc
Normal file
BIN
data/__pycache__/unaligned_dataset.cpython-36.pyc
Normal file
Binary file not shown.
BIN
data/__pycache__/unaligned_double_dataset.cpython-36.pyc
Normal file
BIN
data/__pycache__/unaligned_double_dataset.cpython-36.pyc
Normal file
Binary file not shown.
230
data/base_dataset.py
Normal file
230
data/base_dataset.py
Normal file
@ -0,0 +1,230 @@
|
||||
"""This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
|
||||
|
||||
It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
|
||||
"""
|
||||
import random
|
||||
import numpy as np
|
||||
import torch.utils.data as data
|
||||
from PIL import Image
|
||||
import torchvision.transforms as transforms
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseDataset(data.Dataset, ABC):
|
||||
"""This class is an abstract base class (ABC) for datasets.
|
||||
|
||||
To create a subclass, you need to implement the following four functions:
|
||||
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
|
||||
-- <__len__>: return the size of dataset.
|
||||
-- <__getitem__>: get a data point.
|
||||
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
"""Initialize the class; save the options in the class
|
||||
|
||||
Parameters:
|
||||
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
|
||||
"""
|
||||
self.opt = opt
|
||||
self.root = opt.dataroot
|
||||
self.current_epoch = 0
|
||||
|
||||
@staticmethod
|
||||
def modify_commandline_options(parser, is_train):
|
||||
"""Add new dataset-specific options, and rewrite default values for existing options.
|
||||
|
||||
Parameters:
|
||||
parser -- original option parser
|
||||
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
|
||||
|
||||
Returns:
|
||||
the modified parser.
|
||||
"""
|
||||
return parser
|
||||
|
||||
@abstractmethod
|
||||
def __len__(self):
|
||||
"""Return the total number of images in the dataset."""
|
||||
return 0
|
||||
|
||||
@abstractmethod
|
||||
def __getitem__(self, index):
|
||||
"""Return a data point and its metadata information.
|
||||
|
||||
Parameters:
|
||||
index - - a random integer for data indexing
|
||||
|
||||
Returns:
|
||||
a dictionary of data with their names. It ususally contains the data itself and its metadata information.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def get_params(opt, size):
|
||||
w, h = size
|
||||
new_h = h
|
||||
new_w = w
|
||||
if opt.preprocess == 'resize_and_crop':
|
||||
new_h = new_w = opt.load_size
|
||||
elif opt.preprocess == 'scale_width_and_crop':
|
||||
new_w = opt.load_size
|
||||
new_h = opt.load_size * h // w
|
||||
|
||||
x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
|
||||
y = random.randint(0, np.maximum(0, new_h - opt.crop_size))
|
||||
|
||||
flip = random.random() > 0.5
|
||||
|
||||
return {'crop_pos': (x, y), 'flip': flip}
|
||||
|
||||
|
||||
def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True):
|
||||
transform_list = []
|
||||
if grayscale:
|
||||
transform_list.append(transforms.Grayscale(1))
|
||||
if 'fixsize' in opt.preprocess:
|
||||
transform_list.append(transforms.Resize(params["size"], method))
|
||||
if 'resize' in opt.preprocess:
|
||||
osize = [opt.load_size, opt.load_size]
|
||||
if "gta2cityscapes" in opt.dataroot:
|
||||
osize[0] = opt.load_size // 2
|
||||
transform_list.append(transforms.Resize(osize, method))
|
||||
elif 'scale_width' in opt.preprocess:
|
||||
transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, opt.crop_size, method)))
|
||||
elif 'scale_shortside' in opt.preprocess:
|
||||
transform_list.append(transforms.Lambda(lambda img: __scale_shortside(img, opt.load_size, opt.crop_size, method)))
|
||||
|
||||
if 'zoom' in opt.preprocess:
|
||||
if params is None:
|
||||
transform_list.append(transforms.Lambda(lambda img: __random_zoom(img, opt.load_size, opt.crop_size, method)))
|
||||
else:
|
||||
transform_list.append(transforms.Lambda(lambda img: __random_zoom(img, opt.load_size, opt.crop_size, method, factor=params["scale_factor"])))
|
||||
|
||||
if 'crop' in opt.preprocess:
|
||||
if params is None or 'crop_pos' not in params:
|
||||
transform_list.append(transforms.RandomCrop(opt.crop_size))
|
||||
else:
|
||||
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))
|
||||
|
||||
if 'patch' in opt.preprocess:
|
||||
transform_list.append(transforms.Lambda(lambda img: __patch(img, params['patch_index'], opt.crop_size)))
|
||||
|
||||
if 'trim' in opt.preprocess:
|
||||
transform_list.append(transforms.Lambda(lambda img: __trim(img, opt.crop_size)))
|
||||
|
||||
# if opt.preprocess == 'none':
|
||||
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))
|
||||
|
||||
if not opt.no_flip:
|
||||
if params is None or 'flip' not in params:
|
||||
transform_list.append(transforms.RandomHorizontalFlip())
|
||||
elif 'flip' in params:
|
||||
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
|
||||
|
||||
if convert:
|
||||
transform_list += [transforms.ToTensor()]
|
||||
if grayscale:
|
||||
transform_list += [transforms.Normalize((0.5,), (0.5,))]
|
||||
else:
|
||||
transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
|
||||
return transforms.Compose(transform_list)
|
||||
|
||||
|
||||
def __make_power_2(img, base, method=Image.BICUBIC):
|
||||
ow, oh = img.size
|
||||
h = int(round(oh / base) * base)
|
||||
w = int(round(ow / base) * base)
|
||||
if h == oh and w == ow:
|
||||
return img
|
||||
|
||||
return img.resize((w, h), method)
|
||||
|
||||
|
||||
def __random_zoom(img, target_width, crop_width, method=Image.BICUBIC, factor=None):
|
||||
if factor is None:
|
||||
zoom_level = np.random.uniform(0.8, 1.0, size=[2])
|
||||
else:
|
||||
zoom_level = (factor[0], factor[1])
|
||||
iw, ih = img.size
|
||||
zoomw = max(crop_width, iw * zoom_level[0])
|
||||
zoomh = max(crop_width, ih * zoom_level[1])
|
||||
img = img.resize((int(round(zoomw)), int(round(zoomh))), method)
|
||||
return img
|
||||
|
||||
|
||||
def __scale_shortside(img, target_width, crop_width, method=Image.BICUBIC):
|
||||
ow, oh = img.size
|
||||
shortside = min(ow, oh)
|
||||
if shortside >= target_width:
|
||||
return img
|
||||
else:
|
||||
scale = target_width / shortside
|
||||
return img.resize((round(ow * scale), round(oh * scale)), method)
|
||||
|
||||
|
||||
def __trim(img, trim_width):
|
||||
ow, oh = img.size
|
||||
if ow > trim_width:
|
||||
xstart = np.random.randint(ow - trim_width)
|
||||
xend = xstart + trim_width
|
||||
else:
|
||||
xstart = 0
|
||||
xend = ow
|
||||
if oh > trim_width:
|
||||
ystart = np.random.randint(oh - trim_width)
|
||||
yend = ystart + trim_width
|
||||
else:
|
||||
ystart = 0
|
||||
yend = oh
|
||||
return img.crop((xstart, ystart, xend, yend))
|
||||
|
||||
|
||||
def __scale_width(img, target_width, crop_width, method=Image.BICUBIC):
|
||||
ow, oh = img.size
|
||||
if ow == target_width and oh >= crop_width:
|
||||
return img
|
||||
w = target_width
|
||||
h = int(max(target_width * oh / ow, crop_width))
|
||||
return img.resize((w, h), method)
|
||||
|
||||
|
||||
def __crop(img, pos, size):
|
||||
ow, oh = img.size
|
||||
x1, y1 = pos
|
||||
tw = th = size
|
||||
if (ow > tw or oh > th):
|
||||
return img.crop((x1, y1, x1 + tw, y1 + th))
|
||||
return img
|
||||
|
||||
|
||||
def __patch(img, index, size):
|
||||
ow, oh = img.size
|
||||
nw, nh = ow // size, oh // size
|
||||
roomx = ow - nw * size
|
||||
roomy = oh - nh * size
|
||||
startx = np.random.randint(int(roomx) + 1)
|
||||
starty = np.random.randint(int(roomy) + 1)
|
||||
|
||||
index = index % (nw * nh)
|
||||
ix = index // nh
|
||||
iy = index % nh
|
||||
gridx = startx + ix * size
|
||||
gridy = starty + iy * size
|
||||
return img.crop((gridx, gridy, gridx + size, gridy + size))
|
||||
|
||||
|
||||
def __flip(img, flip):
|
||||
if flip:
|
||||
return img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
return img
|
||||
|
||||
|
||||
def __print_size_warning(ow, oh, w, h):
|
||||
"""Print warning information about image size(only print once)"""
|
||||
if not hasattr(__print_size_warning, 'has_printed'):
|
||||
print("The image size needs to be a multiple of 4. "
|
||||
"The loaded image size was (%d, %d), so it was adjusted to "
|
||||
"(%d, %d). This adjustment will be done to all images "
|
||||
"whose sizes are not multiples of 4" % (ow, oh, w, h))
|
||||
__print_size_warning.has_printed = True
|
||||
66
data/image_folder.py
Normal file
66
data/image_folder.py
Normal file
@ -0,0 +1,66 @@
|
||||
"""A modified image folder class
|
||||
|
||||
We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
|
||||
so that this class can load images from both current directory and its subdirectories.
|
||||
"""
|
||||
|
||||
import torch.utils.data as data
|
||||
|
||||
from PIL import Image
|
||||
import os
|
||||
import os.path
|
||||
|
||||
IMG_EXTENSIONS = [
|
||||
'.jpg', '.JPG', '.jpeg', '.JPEG',
|
||||
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
|
||||
'.tif', '.TIF', '.tiff', '.TIFF',
|
||||
]
|
||||
|
||||
|
||||
def is_image_file(filename):
|
||||
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
||||
|
||||
|
||||
def make_dataset(dir, max_dataset_size=float("inf")):
|
||||
images = []
|
||||
assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir
|
||||
|
||||
for root, _, fnames in sorted(os.walk(dir, followlinks=True)):
|
||||
for fname in fnames:
|
||||
if is_image_file(fname):
|
||||
path = os.path.join(root, fname)
|
||||
images.append(path)
|
||||
return images[:min(max_dataset_size, len(images))]
|
||||
|
||||
|
||||
def default_loader(path):
|
||||
return Image.open(path).convert('RGB')
|
||||
|
||||
|
||||
class ImageFolder(data.Dataset):
|
||||
|
||||
def __init__(self, root, transform=None, return_paths=False,
|
||||
loader=default_loader):
|
||||
imgs = make_dataset(root)
|
||||
if len(imgs) == 0:
|
||||
raise(RuntimeError("Found 0 images in: " + root + "\n"
|
||||
"Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
|
||||
|
||||
self.root = root
|
||||
self.imgs = imgs
|
||||
self.transform = transform
|
||||
self.return_paths = return_paths
|
||||
self.loader = loader
|
||||
|
||||
def __getitem__(self, index):
|
||||
path = self.imgs[index]
|
||||
img = self.loader(path)
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
if self.return_paths:
|
||||
return img, path
|
||||
else:
|
||||
return img
|
||||
|
||||
def __len__(self):
|
||||
return len(self.imgs)
|
||||
40
data/single_dataset.py
Normal file
40
data/single_dataset.py
Normal file
@ -0,0 +1,40 @@
|
||||
from data.base_dataset import BaseDataset, get_transform
|
||||
from data.image_folder import make_dataset
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class SingleDataset(BaseDataset):
|
||||
"""This dataset class can load a set of images specified by the path --dataroot /path/to/data.
|
||||
|
||||
It can be used for generating CycleGAN results only for one side with the model option '-model test'.
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
"""Initialize this dataset class.
|
||||
|
||||
Parameters:
|
||||
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
|
||||
"""
|
||||
BaseDataset.__init__(self, opt)
|
||||
self.A_paths = sorted(make_dataset(opt.dataroot, opt.max_dataset_size))
|
||||
input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc
|
||||
self.transform = get_transform(opt, grayscale=(input_nc == 1))
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""Return a data point and its metadata information.
|
||||
|
||||
Parameters:
|
||||
index - - a random integer for data indexing
|
||||
|
||||
Returns a dictionary that contains A and A_paths
|
||||
A(tensor) - - an image in one domain
|
||||
A_paths(str) - - the path of the image
|
||||
"""
|
||||
A_path = self.A_paths[index]
|
||||
A_img = Image.open(A_path).convert('RGB')
|
||||
A = self.transform(A_img)
|
||||
return {'A': A, 'A_paths': A_path}
|
||||
|
||||
def __len__(self):
|
||||
"""Return the total number of images in the dataset."""
|
||||
return len(self.A_paths)
|
||||
108
data/singleimage_dataset.py
Normal file
108
data/singleimage_dataset.py
Normal file
@ -0,0 +1,108 @@
|
||||
import numpy as np
|
||||
import os.path
|
||||
from data.base_dataset import BaseDataset, get_transform
|
||||
from data.image_folder import make_dataset
|
||||
from PIL import Image
|
||||
import random
|
||||
import util.util as util
|
||||
|
||||
|
||||
class SingleImageDataset(BaseDataset):
|
||||
"""
|
||||
This dataset class can load unaligned/unpaired datasets.
|
||||
|
||||
It requires two directories to host training images from domain A '/path/to/data/trainA'
|
||||
and from domain B '/path/to/data/trainB' respectively.
|
||||
You can train the model with the dataset flag '--dataroot /path/to/data'.
|
||||
Similarly, you need to prepare two directories:
|
||||
'/path/to/data/testA' and '/path/to/data/testB' during test time.
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
"""Initialize this dataset class.
|
||||
|
||||
Parameters:
|
||||
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
|
||||
"""
|
||||
BaseDataset.__init__(self, opt)
|
||||
|
||||
self.dir_A = os.path.join(opt.dataroot, 'trainA') # create a path '/path/to/data/trainA'
|
||||
self.dir_B = os.path.join(opt.dataroot, 'trainB') # create a path '/path/to/data/trainB'
|
||||
|
||||
if os.path.exists(self.dir_A) and os.path.exists(self.dir_B):
|
||||
self.A_paths = sorted(make_dataset(self.dir_A, opt.max_dataset_size)) # load images from '/path/to/data/trainA'
|
||||
self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size)) # load images from '/path/to/data/trainB'
|
||||
self.A_size = len(self.A_paths) # get the size of dataset A
|
||||
self.B_size = len(self.B_paths) # get the size of dataset B
|
||||
|
||||
assert len(self.A_paths) == 1 and len(self.B_paths) == 1,\
|
||||
"SingleImageDataset class should be used with one image in each domain"
|
||||
A_img = Image.open(self.A_paths[0]).convert('RGB')
|
||||
B_img = Image.open(self.B_paths[0]).convert('RGB')
|
||||
print("Image sizes %s and %s" % (str(A_img.size), str(B_img.size)))
|
||||
|
||||
self.A_img = A_img
|
||||
self.B_img = B_img
|
||||
|
||||
# In single-image translation, we augment the data loader by applying
|
||||
# random scaling. Still, we design the data loader such that the
|
||||
# amount of scaling is the same within a minibatch. To do this,
|
||||
# we precompute the random scaling values, and repeat them by |batch_size|.
|
||||
A_zoom = 1 / self.opt.random_scale_max
|
||||
zoom_levels_A = np.random.uniform(A_zoom, 1.0, size=(len(self) // opt.batch_size + 1, 1, 2))
|
||||
self.zoom_levels_A = np.reshape(np.tile(zoom_levels_A, (1, opt.batch_size, 1)), [-1, 2])
|
||||
|
||||
B_zoom = 1 / self.opt.random_scale_max
|
||||
zoom_levels_B = np.random.uniform(B_zoom, 1.0, size=(len(self) // opt.batch_size + 1, 1, 2))
|
||||
self.zoom_levels_B = np.reshape(np.tile(zoom_levels_B, (1, opt.batch_size, 1)), [-1, 2])
|
||||
|
||||
# While the crop locations are randomized, the negative samples should
|
||||
# not come from the same location. To do this, we precompute the
|
||||
# crop locations with no repetition.
|
||||
self.patch_indices_A = list(range(len(self)))
|
||||
random.shuffle(self.patch_indices_A)
|
||||
self.patch_indices_B = list(range(len(self)))
|
||||
random.shuffle(self.patch_indices_B)
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""Return a data point and its metadata information.
|
||||
|
||||
Parameters:
|
||||
index (int) -- a random integer for data indexing
|
||||
|
||||
Returns a dictionary that contains A, B, A_paths and B_paths
|
||||
A (tensor) -- an image in the input domain
|
||||
B (tensor) -- its corresponding image in the target domain
|
||||
A_paths (str) -- image paths
|
||||
B_paths (str) -- image paths
|
||||
"""
|
||||
A_path = self.A_paths[0]
|
||||
B_path = self.B_paths[0]
|
||||
A_img = self.A_img
|
||||
B_img = self.B_img
|
||||
|
||||
# apply image transformation
|
||||
if self.opt.phase == "train":
|
||||
param = {'scale_factor': self.zoom_levels_A[index],
|
||||
'patch_index': self.patch_indices_A[index],
|
||||
'flip': random.random() > 0.5}
|
||||
|
||||
transform_A = get_transform(self.opt, params=param, method=Image.BILINEAR)
|
||||
A = transform_A(A_img)
|
||||
|
||||
param = {'scale_factor': self.zoom_levels_B[index],
|
||||
'patch_index': self.patch_indices_B[index],
|
||||
'flip': random.random() > 0.5}
|
||||
transform_B = get_transform(self.opt, params=param, method=Image.BILINEAR)
|
||||
B = transform_B(B_img)
|
||||
else:
|
||||
transform = get_transform(self.opt, method=Image.BILINEAR)
|
||||
A = transform(A_img)
|
||||
B = transform(B_img)
|
||||
|
||||
return {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path}
|
||||
|
||||
def __len__(self):
|
||||
""" Let's pretend the single image contains 100,000 crops for convenience.
|
||||
"""
|
||||
return 100000
|
||||
75
data/template_dataset.py
Normal file
75
data/template_dataset.py
Normal file
@ -0,0 +1,75 @@
|
||||
"""Dataset class template
|
||||
|
||||
This module provides a template for users to implement custom datasets.
|
||||
You can specify '--dataset_mode template' to use this dataset.
|
||||
The class name should be consistent with both the filename and its dataset_mode option.
|
||||
The filename should be <dataset_mode>_dataset.py
|
||||
The class name should be <Dataset_mode>Dataset.py
|
||||
You need to implement the following functions:
|
||||
-- <modify_commandline_options>: Add dataset-specific options and rewrite default values for existing options.
|
||||
-- <__init__>: Initialize this dataset class.
|
||||
-- <__getitem__>: Return a data point and its metadata information.
|
||||
-- <__len__>: Return the number of images.
|
||||
"""
|
||||
from data.base_dataset import BaseDataset, get_transform
|
||||
# from data.image_folder import make_dataset
|
||||
# from PIL import Image
|
||||
|
||||
|
||||
class TemplateDataset(BaseDataset):
|
||||
"""A template dataset class for you to implement custom datasets."""
|
||||
@staticmethod
|
||||
def modify_commandline_options(parser, is_train):
|
||||
"""Add new dataset-specific options, and rewrite default values for existing options.
|
||||
|
||||
Parameters:
|
||||
parser -- original option parser
|
||||
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
|
||||
|
||||
Returns:
|
||||
the modified parser.
|
||||
"""
|
||||
parser.add_argument('--new_dataset_option', type=float, default=1.0, help='new dataset option')
|
||||
parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0) # specify dataset-specific default values
|
||||
return parser
|
||||
|
||||
def __init__(self, opt):
|
||||
"""Initialize this dataset class.
|
||||
|
||||
Parameters:
|
||||
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
|
||||
|
||||
A few things can be done here.
|
||||
- save the options (have been done in BaseDataset)
|
||||
- get image paths and meta information of the dataset.
|
||||
- define the image transformation.
|
||||
"""
|
||||
# save the option and dataset root
|
||||
BaseDataset.__init__(self, opt)
|
||||
# get the image paths of your dataset;
|
||||
self.image_paths = [] # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root
|
||||
# define the default transform function. You can use <base_dataset.get_transform>; You can also define your custom transform function
|
||||
self.transform = get_transform(opt)
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""Return a data point and its metadata information.
|
||||
|
||||
Parameters:
|
||||
index -- a random integer for data indexing
|
||||
|
||||
Returns:
|
||||
a dictionary of data with their names. It usually contains the data itself and its metadata information.
|
||||
|
||||
Step 1: get a random image path: e.g., path = self.image_paths[index]
|
||||
Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB').
|
||||
Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image)
|
||||
Step 4: return a data point as a dictionary.
|
||||
"""
|
||||
path = 'temp' # needs to be a string
|
||||
data_A = None # needs to be a tensor
|
||||
data_B = None # needs to be a tensor
|
||||
return {'data_A': data_A, 'data_B': data_B, 'path': path}
|
||||
|
||||
def __len__(self):
|
||||
"""Return the total number of images."""
|
||||
return len(self.image_paths)
|
||||
79
data/unaligned_dataset.py
Normal file
79
data/unaligned_dataset.py
Normal file
@ -0,0 +1,79 @@
|
||||
import os.path
|
||||
from data.base_dataset import BaseDataset, get_transform
|
||||
from data.image_folder import make_dataset
|
||||
from PIL import Image
|
||||
import random
|
||||
import util.util as util
|
||||
|
||||
|
||||
class UnalignedDataset(BaseDataset):
|
||||
"""
|
||||
This dataset class can load unaligned/unpaired datasets.
|
||||
|
||||
It requires two directories to host training images from domain A '/path/to/data/trainA'
|
||||
and from domain B '/path/to/data/trainB' respectively.
|
||||
You can train the model with the dataset flag '--dataroot /path/to/data'.
|
||||
Similarly, you need to prepare two directories:
|
||||
'/path/to/data/testA' and '/path/to/data/testB' during test time.
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
"""Initialize this dataset class.
|
||||
|
||||
Parameters:
|
||||
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
|
||||
"""
|
||||
BaseDataset.__init__(self, opt)
|
||||
self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A') # create a path '/path/to/data/trainA'
|
||||
self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B') # create a path '/path/to/data/trainB'
|
||||
|
||||
if opt.phase == "test" and not os.path.exists(self.dir_A) \
|
||||
and os.path.exists(os.path.join(opt.dataroot, "valA")):
|
||||
self.dir_A = os.path.join(opt.dataroot, "valA")
|
||||
self.dir_B = os.path.join(opt.dataroot, "valB")
|
||||
|
||||
self.A_paths = sorted(make_dataset(self.dir_A, opt.max_dataset_size)) # load images from '/path/to/data/trainA'
|
||||
self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size)) # load images from '/path/to/data/trainB'
|
||||
self.A_size = len(self.A_paths) # get the size of dataset A
|
||||
self.B_size = len(self.B_paths) # get the size of dataset B
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""Return a data point and its metadata information.
|
||||
|
||||
Parameters:
|
||||
index (int) -- a random integer for data indexing
|
||||
|
||||
Returns a dictionary that contains A, B, A_paths and B_paths
|
||||
A (tensor) -- an image in the input domain
|
||||
B (tensor) -- its corresponding image in the target domain
|
||||
A_paths (str) -- image paths
|
||||
B_paths (str) -- image paths
|
||||
"""
|
||||
A_path = self.A_paths[index % self.A_size] # make sure index is within then range
|
||||
if self.opt.serial_batches: # make sure index is within then range
|
||||
index_B = index % self.B_size
|
||||
else: # randomize the index for domain B to avoid fixed pairs.
|
||||
index_B = random.randint(0, self.B_size - 1)
|
||||
B_path = self.B_paths[index_B]
|
||||
A_img = Image.open(A_path).convert('RGB')
|
||||
B_img = Image.open(B_path).convert('RGB')
|
||||
|
||||
# Apply image transformation
|
||||
# For FastCUT mode, if in finetuning phase (learning rate is decaying),
|
||||
# do not perform resize-crop data augmentation of CycleGAN.
|
||||
# print('current_epoch', self.current_epoch)
|
||||
is_finetuning = self.opt.isTrain and self.current_epoch > self.opt.n_epochs
|
||||
modified_opt = util.copyconf(self.opt, load_size=self.opt.crop_size if is_finetuning else self.opt.load_size)
|
||||
transform = get_transform(modified_opt)
|
||||
A = transform(A_img)
|
||||
B = transform(B_img)
|
||||
|
||||
return {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path}
|
||||
|
||||
def __len__(self):
|
||||
"""Return the total number of images in the dataset.
|
||||
|
||||
As we have two datasets with potentially different number of images,
|
||||
we take a maximum of
|
||||
"""
|
||||
return max(self.A_size, self.B_size)
|
||||
100
data/unaligned_double_dataset.py
Normal file
100
data/unaligned_double_dataset.py
Normal file
@ -0,0 +1,100 @@
|
||||
import os.path
|
||||
from data.base_dataset import BaseDataset, get_transform
|
||||
from data.image_folder import make_dataset
|
||||
from PIL import Image
|
||||
import random
|
||||
import util.util as util
|
||||
import torchvision.transforms.functional as TF
|
||||
import random
|
||||
from torchvision.transforms import transforms as tfs
|
||||
|
||||
class UnalignedDoubleDataset(BaseDataset):
|
||||
"""
|
||||
This dataset class can load unaligned/unpaired datasets.
|
||||
|
||||
It requires two directories to host training images from domain A '/path/to/data/trainA'
|
||||
and from domain B '/path/to/data/trainB' respectively.
|
||||
You can train the model with the dataset flag '--dataroot /path/to/data'.
|
||||
Similarly, you need to prepare two directories:
|
||||
'/path/to/data/testA' and '/path/to/data/testB' during test time.
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
"""Initialize this dataset class.
|
||||
|
||||
Parameters:
|
||||
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
|
||||
"""
|
||||
# self.use_resize_crop = opt.use_resize_crop
|
||||
BaseDataset.__init__(self, opt)
|
||||
self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A') # create a path '/path/to/data/trainA'
|
||||
self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B') # create a path '/path/to/data/trainB'
|
||||
self.opt = opt
|
||||
if opt.phase == "test" and not os.path.exists(self.dir_A) \
|
||||
and os.path.exists(os.path.join(opt.dataroot, "valA")):
|
||||
self.dir_A = os.path.join(opt.dataroot, "valA")
|
||||
self.dir_B = os.path.join(opt.dataroot, "valB")
|
||||
|
||||
self.A_paths = sorted(make_dataset(self.dir_A, opt.max_dataset_size)) # load images from '/path/to/data/trainA'
|
||||
self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size)) # load images from '/path/to/data/trainB'
|
||||
self.A_size = len(self.A_paths) # get the size of dataset A
|
||||
self.B_size = len(self.B_paths) # get the size of dataset B
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""Return a data point and its metadata information.
|
||||
|
||||
Parameters:
|
||||
index (int) -- a random integer for data indexing
|
||||
|
||||
Returns a dictionary that contains A, B, A_paths and B_paths
|
||||
A (tensor) -- an image in the input domain
|
||||
B (tensor) -- its corresponding image in the target domain
|
||||
A_paths (str) -- image paths
|
||||
B_paths (str) -- image paths
|
||||
"""
|
||||
A_path = self.A_paths[index % self.A_size] # make sure index is within then range
|
||||
if self.opt.serial_batches: # make sure index is within then range
|
||||
index_B = index % self.B_size
|
||||
else: # randomize the index for domain B to avoid fixed pairs.
|
||||
index_B = random.randint(0, self.B_size - 1)
|
||||
B_path = self.B_paths[index_B]
|
||||
A_img = Image.open(A_path).convert('RGB')
|
||||
A0 = A_img.crop((0,0,256,256))
|
||||
A1 = A_img.crop((256,0,512,256))
|
||||
B_img = Image.open(B_path).convert('RGB')
|
||||
B0 = B_img.crop((0,0,256,256))
|
||||
B1 = B_img.crop((256,0,512,256))
|
||||
|
||||
# Apply image transformation
|
||||
# For FastCUT mode, if in finetuning phase (learning rate is decaying),
|
||||
# do not perform resize-crop data augmentation of CycleGAN.
|
||||
# print('current_epoch', self.current_epoch)
|
||||
is_finetuning = self.opt.isTrain and self.current_epoch > self.opt.n_epochs
|
||||
modified_opt = util.copyconf(self.opt, load_size=self.opt.crop_size if is_finetuning else self.opt.load_size)
|
||||
|
||||
resize = tfs.Resize(size=(self.opt.load_size, self.opt.load_size))
|
||||
imgA = resize(A0)
|
||||
param = dict()
|
||||
i, j, h, w = tfs.RandomCrop.get_params(
|
||||
imgA, output_size=(self.opt.crop_size, self.opt.crop_size))
|
||||
param['crop_pos'] = (i, j)
|
||||
transform = get_transform(modified_opt, param)
|
||||
# print(transform)
|
||||
# sys.exit(0)
|
||||
# A = transform(A_img)
|
||||
# B = transform(B_img)
|
||||
|
||||
A0 = transform(A0)
|
||||
B0 = transform(B0)
|
||||
A1 = transform(A1)
|
||||
B1 = transform(B1)
|
||||
|
||||
return {'A0': A0, 'A1': A1, 'B0': B0, 'B1': B1, 'A_paths': A_path, 'B_paths': B_path}
|
||||
|
||||
def __len__(self):
|
||||
"""Return the total number of images in the dataset.
|
||||
|
||||
As we have two datasets with potentially different number of images,
|
||||
we take a maximum of
|
||||
"""
|
||||
return max(self.A_size, self.B_size)
|
||||
6
datasets/bibtex/cityscapes.tex
Normal file
6
datasets/bibtex/cityscapes.tex
Normal file
@ -0,0 +1,6 @@
|
||||
@inproceedings{Cordts2016Cityscapes,
|
||||
title={The Cityscapes Dataset for Semantic Urban Scene Understanding},
|
||||
author={Cordts, Marius and Omran, Mohamed and Ramos, Sebastian and Rehfeld, Timo and Enzweiler, Markus and Benenson, Rodrigo and Franke, Uwe and Roth, Stefan and Schiele, Bernt},
|
||||
booktitle={Proc. of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
|
||||
year={2016}
|
||||
}
|
||||
7
datasets/bibtex/facades.tex
Normal file
7
datasets/bibtex/facades.tex
Normal file
@ -0,0 +1,7 @@
|
||||
@INPROCEEDINGS{Tylecek13,
|
||||
author = {Radim Tyle{\v c}ek, Radim {\v S}{\' a}ra},
|
||||
title = {Spatial Pattern Templates for Recognition of Objects with Regular Structure},
|
||||
booktitle = {Proc. GCPR},
|
||||
year = {2013},
|
||||
address = {Saarbrucken, Germany},
|
||||
}
|
||||
13
datasets/bibtex/handbags.tex
Normal file
13
datasets/bibtex/handbags.tex
Normal file
@ -0,0 +1,13 @@
|
||||
@inproceedings{zhu2016generative,
|
||||
title={Generative Visual Manipulation on the Natural Image Manifold},
|
||||
author={Zhu, Jun-Yan and Kr{\"a}henb{\"u}hl, Philipp and Shechtman, Eli and Efros, Alexei A.},
|
||||
booktitle={Proceedings of European Conference on Computer Vision (ECCV)},
|
||||
year={2016}
|
||||
}
|
||||
|
||||
@InProceedings{xie15hed,
|
||||
author = {"Xie, Saining and Tu, Zhuowen"},
|
||||
Title = {Holistically-Nested Edge Detection},
|
||||
Booktitle = "Proceedings of IEEE International Conference on Computer Vision",
|
||||
Year = {2015},
|
||||
}
|
||||
14
datasets/bibtex/shoes.tex
Normal file
14
datasets/bibtex/shoes.tex
Normal file
@ -0,0 +1,14 @@
|
||||
@InProceedings{fine-grained,
|
||||
author = {A. Yu and K. Grauman},
|
||||
title = {{F}ine-{G}rained {V}isual {C}omparisons with {L}ocal {L}earning},
|
||||
booktitle = {Computer Vision and Pattern Recognition (CVPR)},
|
||||
month = {June},
|
||||
year = {2014}
|
||||
}
|
||||
|
||||
@InProceedings{xie15hed,
|
||||
author = {"Xie, Saining and Tu, Zhuowen"},
|
||||
Title = {Holistically-Nested Edge Detection},
|
||||
Booktitle = "Proceedings of IEEE International Conference on Computer Vision",
|
||||
Year = {2015},
|
||||
}
|
||||
8
datasets/bibtex/transattr.tex
Normal file
8
datasets/bibtex/transattr.tex
Normal file
@ -0,0 +1,8 @@
|
||||
@article {Laffont14,
|
||||
title = {Transient Attributes for High-Level Understanding and Editing of Outdoor Scenes},
|
||||
author = {Pierre-Yves Laffont and Zhile Ren and Xiaofeng Tao and Chao Qian and James Hays},
|
||||
journal = {ACM Transactions on Graphics (proceedings of SIGGRAPH)},
|
||||
volume = {33},
|
||||
number = {4},
|
||||
year = {2014}
|
||||
}
|
||||
48
datasets/combine_A_and_B.py
Normal file
48
datasets/combine_A_and_B.py
Normal file
@ -0,0 +1,48 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import cv2
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser('create image pairs')
|
||||
parser.add_argument('--fold_A', dest='fold_A', help='input directory for image A', type=str, default='../dataset/50kshoes_edges')
|
||||
parser.add_argument('--fold_B', dest='fold_B', help='input directory for image B', type=str, default='../dataset/50kshoes_jpg')
|
||||
parser.add_argument('--fold_AB', dest='fold_AB', help='output directory', type=str, default='../dataset/test_AB')
|
||||
parser.add_argument('--num_imgs', dest='num_imgs', help='number of images', type=int, default=1000000)
|
||||
parser.add_argument('--use_AB', dest='use_AB', help='if true: (0001_A, 0001_B) to (0001_AB)', action='store_true')
|
||||
args = parser.parse_args()
|
||||
|
||||
for arg in vars(args):
|
||||
print('[%s] = ' % arg, getattr(args, arg))
|
||||
|
||||
splits = os.listdir(args.fold_A)
|
||||
|
||||
for sp in splits:
|
||||
img_fold_A = os.path.join(args.fold_A, sp)
|
||||
img_fold_B = os.path.join(args.fold_B, sp)
|
||||
img_list = os.listdir(img_fold_A)
|
||||
if args.use_AB:
|
||||
img_list = [img_path for img_path in img_list if '_A.' in img_path]
|
||||
|
||||
num_imgs = min(args.num_imgs, len(img_list))
|
||||
print('split = %s, use %d/%d images' % (sp, num_imgs, len(img_list)))
|
||||
img_fold_AB = os.path.join(args.fold_AB, sp)
|
||||
if not os.path.isdir(img_fold_AB):
|
||||
os.makedirs(img_fold_AB)
|
||||
print('split = %s, number of images = %d' % (sp, num_imgs))
|
||||
for n in range(num_imgs):
|
||||
name_A = img_list[n]
|
||||
path_A = os.path.join(img_fold_A, name_A)
|
||||
if args.use_AB:
|
||||
name_B = name_A.replace('_A.', '_B.')
|
||||
else:
|
||||
name_B = name_A
|
||||
path_B = os.path.join(img_fold_B, name_B)
|
||||
if os.path.isfile(path_A) and os.path.isfile(path_B):
|
||||
name_AB = name_A
|
||||
if args.use_AB:
|
||||
name_AB = name_AB.replace('_A.', '.') # remove _A
|
||||
path_AB = os.path.join(img_fold_AB, name_AB)
|
||||
im_A = cv2.imread(path_A, 1) # python2: cv2.CV_LOAD_IMAGE_COLOR; python3: cv2.IMREAD_COLOR
|
||||
im_B = cv2.imread(path_B, 1) # python2: cv2.CV_LOAD_IMAGE_COLOR; python3: cv2.IMREAD_COLOR
|
||||
im_AB = np.concatenate([im_A, im_B], 1)
|
||||
cv2.imwrite(path_AB, im_AB)
|
||||
64
datasets/detect_cat_face.py
Normal file
64
datasets/detect_cat_face.py
Normal file
@ -0,0 +1,64 @@
|
||||
import cv2
|
||||
import os
|
||||
import glob
|
||||
import argparse
|
||||
|
||||
|
||||
def get_file_paths(folder):
|
||||
image_file_paths = []
|
||||
for root, dirs, filenames in os.walk(folder):
|
||||
filenames = sorted(filenames)
|
||||
for filename in filenames:
|
||||
input_path = os.path.abspath(root)
|
||||
file_path = os.path.join(input_path, filename)
|
||||
if filename.endswith('.png') or filename.endswith('.jpg'):
|
||||
image_file_paths.append(file_path)
|
||||
|
||||
break # prevent descending into subfolders
|
||||
return image_file_paths
|
||||
|
||||
|
||||
SF = 1.05
|
||||
N = 3
|
||||
|
||||
|
||||
def detect_cat(img_path, cat_cascade, output_dir, ratio=0.05, border_ratio=0.25):
|
||||
print('processing {}'.format(img_path))
|
||||
output_width = 286
|
||||
img = cv2.imread(img_path)
|
||||
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
H, W = img.shape[0], img.shape[1]
|
||||
minH = int(H * ratio)
|
||||
minW = int(W * ratio)
|
||||
cats = cat_cascade.detectMultiScale(gray, scaleFactor=SF, minNeighbors=N, minSize=(minH, minW))
|
||||
|
||||
for cat_id, (x, y, w, h) in enumerate(cats):
|
||||
x1 = max(0, x - w * border_ratio)
|
||||
x2 = min(W, x + w * (1 + border_ratio))
|
||||
y1 = max(0, y - h * border_ratio)
|
||||
y2 = min(H, y + h * (1 + border_ratio))
|
||||
img_crop = img[int(y1):int(y2), int(x1):int(x2)]
|
||||
img_name = os.path.basename(img_path)
|
||||
out_path = os.path.join(output_dir, img_name.replace('.jpg', '_cat%d.jpg' % cat_id))
|
||||
print('write', out_path)
|
||||
img_crop = cv2.resize(img_crop, (output_width, output_width), interpolation=cv2.INTER_CUBIC)
|
||||
cv2.imwrite(out_path, img_crop, [int(cv2.IMWRITE_JPEG_QUALITY), 100])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='detecting cat faces using opencv detector')
|
||||
parser.add_argument('--input_dir', type=str, help='input image directory')
|
||||
parser.add_argument('--output_dir', type=str, help='wihch directory to store cropped cat faces')
|
||||
parser.add_argument('--use_ext', action='store_true', help='if use haarcascade_frontalcatface_extended or not')
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.use_ext:
|
||||
cat_cascade = cv2.CascadeClassifier('haarcascade_frontalcatface.xml')
|
||||
else:
|
||||
cat_cascade = cv2.CascadeClassifier('haarcascade_frontalcatface_extended.xml')
|
||||
img_paths = get_file_paths(args.input_dir)
|
||||
print('total number of images {} from {}'.format(len(img_paths), args.input_dir))
|
||||
if not os.path.exists(args.output_dir):
|
||||
os.makedirs(args.output_dir)
|
||||
for img_path in img_paths:
|
||||
detect_cat(img_path, cat_cascade, args.output_dir)
|
||||
23
datasets/download_cut_dataset.sh
Normal file
23
datasets/download_cut_dataset.sh
Normal file
@ -0,0 +1,23 @@
|
||||
set -ex
|
||||
|
||||
FILE=$1
|
||||
|
||||
if [[ $FILE != "ae_photos" && $FILE != "apple2orange" && $FILE != "summer2winter_yosemite" && $FILE != "horse2zebra" && $FILE != "monet2photo" && $FILE != "cezanne2photo" && $FILE != "ukiyoe2photo" && $FILE != "vangogh2photo" && $FILE != "maps" && $FILE != "cityscapes" && $FILE != "facades" && $FILE != "iphone2dslr_flower" && $FILE != "mini" && $FILE != "mini_pix2pix" && $FILE != "mini_colorization" && $FILE != "grumpifycat" ]]; then
|
||||
echo "Available datasets are: apple2orange, summer2winter_yosemite, horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, vangogh2photo, maps, cityscapes, facades, iphone2dslr_flower, ae_photos, grumpifycat"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ $FILE == "cityscapes" ]]; then
|
||||
echo "Due to license issue, we cannot provide the Cityscapes dataset from our repository. Please download the Cityscapes dataset from https://cityscapes-dataset.com, and use the script ./datasets/prepare_cityscapes_dataset.py."
|
||||
echo "You need to download gtFine_trainvaltest.zip and leftImg8bit_trainvaltest.zip. For further instruction, please read ./datasets/prepare_cityscapes_dataset.py"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Specified [$FILE]"
|
||||
URL=https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/$FILE.zip
|
||||
ZIP_FILE=./datasets/$FILE.zip
|
||||
TARGET_DIR=./datasets/$FILE/
|
||||
wget --no-check-certificate -N $URL -O $ZIP_FILE
|
||||
mkdir $TARGET_DIR
|
||||
unzip $ZIP_FILE -d ./datasets/
|
||||
rm $ZIP_FILE
|
||||
24
datasets/download_pix2pix_dataset.sh
Normal file
24
datasets/download_pix2pix_dataset.sh
Normal file
@ -0,0 +1,24 @@
|
||||
set -ex
|
||||
|
||||
FILE=$1
|
||||
|
||||
if [[ $FILE != "cityscapes" && $FILE != "night2day" && $FILE != "edges2handbags" && $FILE != "edges2shoes" && $FILE != "facades" && $FILE != "maps" ]]; then
|
||||
echo "Available datasets are cityscapes, night2day, edges2handbags, edges2shoes, facades, maps"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ $FILE == "cityscapes" ]]; then
|
||||
echo "Due to license issue, we cannot provide the Cityscapes dataset from our repository. Please download the Cityscapes dataset from https://cityscapes-dataset.com, and use the script ./datasets/prepare_cityscapes_dataset.py."
|
||||
echo "You need to download gtFine_trainvaltest.zip and leftImg8bit_trainvaltest.zip. For further instruction, please read ./datasets/prepare_cityscapes_dataset.py"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Specified [$FILE]"
|
||||
|
||||
URL=http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/$FILE.tar.gz
|
||||
TAR_FILE=./datasets/$FILE.tar.gz
|
||||
TARGET_DIR=./datasets/$FILE/
|
||||
wget -N $URL -O $TAR_FILE
|
||||
mkdir -p $TARGET_DIR
|
||||
tar -zxvf $TAR_FILE -C ./datasets/
|
||||
rm $TAR_FILE
|
||||
63
datasets/make_dataset_aligned.py
Normal file
63
datasets/make_dataset_aligned.py
Normal file
@ -0,0 +1,63 @@
|
||||
import os
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def get_file_paths(folder):
|
||||
image_file_paths = []
|
||||
for root, dirs, filenames in os.walk(folder):
|
||||
filenames = sorted(filenames)
|
||||
for filename in filenames:
|
||||
input_path = os.path.abspath(root)
|
||||
file_path = os.path.join(input_path, filename)
|
||||
if filename.endswith('.png') or filename.endswith('.jpg'):
|
||||
image_file_paths.append(file_path)
|
||||
|
||||
break # prevent descending into subfolders
|
||||
return image_file_paths
|
||||
|
||||
|
||||
def align_images(a_file_paths, b_file_paths, target_path):
|
||||
if not os.path.exists(target_path):
|
||||
os.makedirs(target_path)
|
||||
|
||||
for i in range(len(a_file_paths)):
|
||||
img_a = Image.open(a_file_paths[i])
|
||||
img_b = Image.open(b_file_paths[i])
|
||||
assert(img_a.size == img_b.size)
|
||||
|
||||
aligned_image = Image.new("RGB", (img_a.size[0] * 2, img_a.size[1]))
|
||||
aligned_image.paste(img_a, (0, 0))
|
||||
aligned_image.paste(img_b, (img_a.size[0], 0))
|
||||
aligned_image.save(os.path.join(target_path, '{:04d}.jpg'.format(i)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--dataset-path',
|
||||
dest='dataset_path',
|
||||
help='Which folder to process (it should have subfolders testA, testB, trainA and trainB'
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
dataset_folder = args.dataset_path
|
||||
print(dataset_folder)
|
||||
|
||||
test_a_path = os.path.join(dataset_folder, 'testA')
|
||||
test_b_path = os.path.join(dataset_folder, 'testB')
|
||||
test_a_file_paths = get_file_paths(test_a_path)
|
||||
test_b_file_paths = get_file_paths(test_b_path)
|
||||
assert(len(test_a_file_paths) == len(test_b_file_paths))
|
||||
test_path = os.path.join(dataset_folder, 'test')
|
||||
|
||||
train_a_path = os.path.join(dataset_folder, 'trainA')
|
||||
train_b_path = os.path.join(dataset_folder, 'trainB')
|
||||
train_a_file_paths = get_file_paths(train_a_path)
|
||||
train_b_file_paths = get_file_paths(train_b_path)
|
||||
assert(len(train_a_file_paths) == len(train_b_file_paths))
|
||||
train_path = os.path.join(dataset_folder, 'train')
|
||||
|
||||
align_images(test_a_file_paths, test_b_file_paths, test_path)
|
||||
align_images(train_a_file_paths, train_b_file_paths, train_path)
|
||||
90
datasets/prepare_cityscapes_dataset.py
Normal file
90
datasets/prepare_cityscapes_dataset.py
Normal file
@ -0,0 +1,90 @@
|
||||
import os
|
||||
import glob
|
||||
from PIL import Image
|
||||
|
||||
help_msg = """
|
||||
The dataset can be downloaded from https://cityscapes-dataset.com.
|
||||
Please download the datasets [gtFine_trainvaltest.zip] and [leftImg8bit_trainvaltest.zip] and unzip them.
|
||||
gtFine contains the semantics segmentations. Use --gtFine_dir to specify the path to the unzipped gtFine_trainvaltest directory.
|
||||
leftImg8bit contains the dashcam photographs. Use --leftImg8bit_dir to specify the path to the unzipped leftImg8bit_trainvaltest directory.
|
||||
The processed images will be placed at --output_dir.
|
||||
|
||||
Example usage:
|
||||
|
||||
python prepare_cityscapes_dataset.py --gitFine_dir ./gtFine/ --leftImg8bit_dir ./leftImg8bit --output_dir ./datasets/cityscapes/
|
||||
"""
|
||||
|
||||
|
||||
def load_resized_img(path):
|
||||
return Image.open(path).convert('RGB').resize((256, 256))
|
||||
|
||||
|
||||
def check_matching_pair(segmap_path, photo_path):
|
||||
segmap_identifier = os.path.basename(segmap_path).replace('_gtFine_color', '')
|
||||
photo_identifier = os.path.basename(photo_path).replace('_leftImg8bit', '')
|
||||
|
||||
assert segmap_identifier == photo_identifier, \
|
||||
"[%s] and [%s] don't seem to be matching. Aborting." % (segmap_path, photo_path)
|
||||
|
||||
|
||||
def process_cityscapes(gtFine_dir, leftImg8bit_dir, output_dir, phase):
|
||||
save_phase = 'test' if phase == 'val' else 'train'
|
||||
savedir = os.path.join(output_dir, save_phase)
|
||||
os.makedirs(savedir, exist_ok=True)
|
||||
os.makedirs(savedir + 'A', exist_ok=True)
|
||||
os.makedirs(savedir + 'B', exist_ok=True)
|
||||
print("Directory structure prepared at %s" % output_dir)
|
||||
|
||||
segmap_expr = os.path.join(gtFine_dir, phase) + "/*/*_color.png"
|
||||
segmap_paths = glob.glob(segmap_expr)
|
||||
segmap_paths = sorted(segmap_paths)
|
||||
|
||||
photo_expr = os.path.join(leftImg8bit_dir, phase) + "/*/*_leftImg8bit.png"
|
||||
photo_paths = glob.glob(photo_expr)
|
||||
photo_paths = sorted(photo_paths)
|
||||
|
||||
assert len(segmap_paths) == len(photo_paths), \
|
||||
"%d images that match [%s], and %d images that match [%s]. Aborting." % (len(segmap_paths), segmap_expr, len(photo_paths), photo_expr)
|
||||
|
||||
for i, (segmap_path, photo_path) in enumerate(zip(segmap_paths, photo_paths)):
|
||||
check_matching_pair(segmap_path, photo_path)
|
||||
segmap = load_resized_img(segmap_path)
|
||||
photo = load_resized_img(photo_path)
|
||||
|
||||
# data for pix2pix where the two images are placed side-by-side
|
||||
sidebyside = Image.new('RGB', (512, 256))
|
||||
sidebyside.paste(segmap, (256, 0))
|
||||
sidebyside.paste(photo, (0, 0))
|
||||
savepath = os.path.join(savedir, "%d.jpg" % i)
|
||||
sidebyside.save(savepath, format='JPEG', subsampling=0, quality=100)
|
||||
|
||||
# data for cyclegan where the two images are stored at two distinct directories
|
||||
savepath = os.path.join(savedir + 'A', "%d_A.jpg" % i)
|
||||
photo.save(savepath, format='JPEG', subsampling=0, quality=100)
|
||||
savepath = os.path.join(savedir + 'B', "%d_B.jpg" % i)
|
||||
segmap.save(savepath, format='JPEG', subsampling=0, quality=100)
|
||||
|
||||
if i % (len(segmap_paths) // 10) == 0:
|
||||
print("%d / %d: last image saved at %s, " % (i, len(segmap_paths), savepath))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--gtFine_dir', type=str, required=True,
|
||||
help='Path to the Cityscapes gtFine directory.')
|
||||
parser.add_argument('--leftImg8bit_dir', type=str, required=True,
|
||||
help='Path to the Cityscapes leftImg8bit_trainvaltest directory.')
|
||||
parser.add_argument('--output_dir', type=str, required=True,
|
||||
default='./datasets/cityscapes',
|
||||
help='Directory the output images will be written to.')
|
||||
opt = parser.parse_args()
|
||||
|
||||
print(help_msg)
|
||||
|
||||
print('Preparing Cityscapes Dataset for val phase')
|
||||
process_cityscapes(opt.gtFine_dir, opt.leftImg8bit_dir, opt.output_dir, "val")
|
||||
print('Preparing Cityscapes Dataset for train phase')
|
||||
process_cityscapes(opt.gtFine_dir, opt.leftImg8bit_dir, opt.output_dir, "train")
|
||||
|
||||
print('Done')
|
||||
BIN
datasets/single_image_monet_etretat/trainA/monet.jpg
Normal file
BIN
datasets/single_image_monet_etretat/trainA/monet.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 289 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 606 KiB |
BIN
images/method_final.jpg
Normal file
BIN
images/method_final.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 278 KiB |
67
models/__init__.py
Normal file
67
models/__init__.py
Normal file
@ -0,0 +1,67 @@
|
||||
"""This package contains modules related to objective functions, optimizations, and network architectures.
|
||||
|
||||
To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
|
||||
You need to implement the following five functions:
|
||||
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
|
||||
-- <set_input>: unpack data from dataset and apply preprocessing.
|
||||
-- <forward>: produce intermediate results.
|
||||
-- <optimize_parameters>: calculate loss, gradients, and update network weights.
|
||||
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
|
||||
|
||||
In the function <__init__>, you need to define four lists:
|
||||
-- self.loss_names (str list): specify the training losses that you want to plot and save.
|
||||
-- self.model_names (str list): define networks used in our training.
|
||||
-- self.visual_names (str list): specify the images that you want to display and save.
|
||||
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
|
||||
|
||||
Now you can use the model class by specifying flag '--model dummy'.
|
||||
See our template model class 'template_model.py' for more details.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
from models.base_model import BaseModel
|
||||
|
||||
|
||||
def find_model_using_name(model_name):
|
||||
"""Import the module "models/[model_name]_model.py".
|
||||
|
||||
In the file, the class called DatasetNameModel() will
|
||||
be instantiated. It has to be a subclass of BaseModel,
|
||||
and it is case-insensitive.
|
||||
"""
|
||||
model_filename = "models." + model_name + "_model"
|
||||
modellib = importlib.import_module(model_filename)
|
||||
model = None
|
||||
target_model_name = model_name.replace('_', '') + 'model'
|
||||
for name, cls in modellib.__dict__.items():
|
||||
if name.lower() == target_model_name.lower() \
|
||||
and issubclass(cls, BaseModel):
|
||||
model = cls
|
||||
|
||||
if model is None:
|
||||
print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
|
||||
exit(0)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def get_option_setter(model_name):
|
||||
"""Return the static method <modify_commandline_options> of the model class."""
|
||||
model_class = find_model_using_name(model_name)
|
||||
return model_class.modify_commandline_options
|
||||
|
||||
|
||||
def create_model(opt):
|
||||
"""Create a model given the option.
|
||||
|
||||
This function warps the class CustomDatasetDataLoader.
|
||||
This is the main interface between this package and 'train.py'/'test.py'
|
||||
|
||||
Example:
|
||||
>>> from models import create_model
|
||||
>>> model = create_model(opt)
|
||||
"""
|
||||
model = find_model_using_name(opt.model)
|
||||
instance = model(opt)
|
||||
print("model [%s] was created" % type(instance).__name__)
|
||||
return instance
|
||||
BIN
models/__pycache__/__init__.cpython-36.pyc
Normal file
BIN
models/__pycache__/__init__.cpython-36.pyc
Normal file
Binary file not shown.
BIN
models/__pycache__/base_model.cpython-36.pyc
Normal file
BIN
models/__pycache__/base_model.cpython-36.pyc
Normal file
Binary file not shown.
BIN
models/__pycache__/cut_model.cpython-36.pyc
Normal file
BIN
models/__pycache__/cut_model.cpython-36.pyc
Normal file
Binary file not shown.
BIN
models/__pycache__/mae.cpython-36.pyc
Normal file
BIN
models/__pycache__/mae.cpython-36.pyc
Normal file
Binary file not shown.
BIN
models/__pycache__/models_mae.cpython-36.pyc
Normal file
BIN
models/__pycache__/models_mae.cpython-36.pyc
Normal file
Binary file not shown.
BIN
models/__pycache__/mutilvitgloballocal_model.cpython-36.pyc
Normal file
BIN
models/__pycache__/mutilvitgloballocal_model.cpython-36.pyc
Normal file
Binary file not shown.
BIN
models/__pycache__/networks.cpython-36.pyc
Normal file
BIN
models/__pycache__/networks.cpython-36.pyc
Normal file
Binary file not shown.
BIN
models/__pycache__/patchnce.cpython-36.pyc
Normal file
BIN
models/__pycache__/patchnce.cpython-36.pyc
Normal file
Binary file not shown.
BIN
models/__pycache__/region0_model.cpython-36.pyc
Normal file
BIN
models/__pycache__/region0_model.cpython-36.pyc
Normal file
Binary file not shown.
BIN
models/__pycache__/region_model.cpython-36.pyc
Normal file
BIN
models/__pycache__/region_model.cpython-36.pyc
Normal file
Binary file not shown.
BIN
models/__pycache__/stylegan_networks.cpython-36.pyc
Normal file
BIN
models/__pycache__/stylegan_networks.cpython-36.pyc
Normal file
Binary file not shown.
BIN
models/__pycache__/vit2Gmask_model.cpython-36.pyc
Normal file
BIN
models/__pycache__/vit2Gmask_model.cpython-36.pyc
Normal file
Binary file not shown.
BIN
models/__pycache__/vit2_model.cpython-36.pyc
Normal file
BIN
models/__pycache__/vit2_model.cpython-36.pyc
Normal file
Binary file not shown.
BIN
models/__pycache__/vit2patchmask_model.cpython-36.pyc
Normal file
BIN
models/__pycache__/vit2patchmask_model.cpython-36.pyc
Normal file
Binary file not shown.
BIN
models/__pycache__/vit2tokenmask_model.cpython-36.pyc
Normal file
BIN
models/__pycache__/vit2tokenmask_model.cpython-36.pyc
Normal file
Binary file not shown.
BIN
models/__pycache__/vitD_model.cpython-36.pyc
Normal file
BIN
models/__pycache__/vitD_model.cpython-36.pyc
Normal file
Binary file not shown.
BIN
models/__pycache__/vit_model.cpython-36.pyc
Normal file
BIN
models/__pycache__/vit_model.cpython-36.pyc
Normal file
Binary file not shown.
BIN
models/__pycache__/vitdonly2_model.cpython-36.pyc
Normal file
BIN
models/__pycache__/vitdonly2_model.cpython-36.pyc
Normal file
Binary file not shown.
BIN
models/__pycache__/vitdonly_model.cpython-36.pyc
Normal file
BIN
models/__pycache__/vitdonly_model.cpython-36.pyc
Normal file
Binary file not shown.
BIN
models/__pycache__/vitgloballocal_model.cpython-36.pyc
Normal file
BIN
models/__pycache__/vitgloballocal_model.cpython-36.pyc
Normal file
Binary file not shown.
BIN
models/__pycache__/vitlocalgloballocal_model.cpython-36.pyc
Normal file
BIN
models/__pycache__/vitlocalgloballocal_model.cpython-36.pyc
Normal file
Binary file not shown.
258
models/base_model.py
Normal file
258
models/base_model.py
Normal file
@ -0,0 +1,258 @@
|
||||
import os
|
||||
import torch
|
||||
from collections import OrderedDict
|
||||
from abc import ABC, abstractmethod
|
||||
from . import networks
|
||||
|
||||
|
||||
class BaseModel(ABC):
|
||||
"""This class is an abstract base class (ABC) for models.
|
||||
To create a subclass, you need to implement the following five functions:
|
||||
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
|
||||
-- <set_input>: unpack data from dataset and apply preprocessing.
|
||||
-- <forward>: produce intermediate results.
|
||||
-- <optimize_parameters>: calculate losses, gradients, and update network weights.
|
||||
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
"""Initialize the BaseModel class.
|
||||
|
||||
Parameters:
|
||||
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
|
||||
|
||||
When creating your custom class, you need to implement your own initialization.
|
||||
In this fucntion, you should first call <BaseModel.__init__(self, opt)>
|
||||
Then, you need to define four lists:
|
||||
-- self.loss_names (str list): specify the training losses that you want to plot and save.
|
||||
-- self.model_names (str list): specify the images that you want to display and save.
|
||||
-- self.visual_names (str list): define networks used in our training.
|
||||
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
|
||||
"""
|
||||
self.opt = opt
|
||||
self.gpu_ids = opt.gpu_ids
|
||||
self.isTrain = opt.isTrain
|
||||
self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
|
||||
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
|
||||
if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
|
||||
torch.backends.cudnn.benchmark = True
|
||||
self.loss_names = []
|
||||
self.model_names = []
|
||||
self.visual_names = []
|
||||
self.optimizers = []
|
||||
self.image_paths = []
|
||||
self.metric = 0 # used for learning rate policy 'plateau'
|
||||
|
||||
@staticmethod
|
||||
def dict_grad_hook_factory(add_func=lambda x: x):
|
||||
saved_dict = dict()
|
||||
|
||||
def hook_gen(name):
|
||||
def grad_hook(grad):
|
||||
saved_vals = add_func(grad)
|
||||
saved_dict[name] = saved_vals
|
||||
return grad_hook
|
||||
return hook_gen, saved_dict
|
||||
|
||||
@staticmethod
|
||||
def modify_commandline_options(parser, is_train):
|
||||
"""Add new model-specific options, and rewrite default values for existing options.
|
||||
|
||||
Parameters:
|
||||
parser -- original option parser
|
||||
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
|
||||
|
||||
Returns:
|
||||
the modified parser.
|
||||
"""
|
||||
return parser
|
||||
|
||||
@abstractmethod
|
||||
def set_input(self, input):
|
||||
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
|
||||
|
||||
Parameters:
|
||||
input (dict): includes the data itself and its metadata information.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def forward(self):
|
||||
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def optimize_parameters(self):
|
||||
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
|
||||
pass
|
||||
|
||||
def setup(self, opt):
|
||||
"""Load and print networks; create schedulers
|
||||
|
||||
Parameters:
|
||||
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
|
||||
"""
|
||||
if self.isTrain:
|
||||
self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
|
||||
if not self.isTrain or opt.continue_train:
|
||||
load_suffix = opt.epoch
|
||||
self.load_networks(load_suffix)
|
||||
|
||||
self.print_networks(opt.verbose)
|
||||
|
||||
def parallelize(self):
|
||||
for name in self.model_names:
|
||||
if isinstance(name, str):
|
||||
net = getattr(self, 'net' + name)
|
||||
setattr(self, 'net' + name, torch.nn.DataParallel(net, self.opt.gpu_ids))
|
||||
|
||||
def data_dependent_initialize(self, data):
|
||||
pass
|
||||
|
||||
def eval(self):
|
||||
"""Make models eval mode during test time"""
|
||||
for name in self.model_names:
|
||||
if isinstance(name, str):
|
||||
net = getattr(self, 'net' + name)
|
||||
net.eval()
|
||||
|
||||
def test(self):
|
||||
"""Forward function used in test time.
|
||||
|
||||
This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
|
||||
It also calls <compute_visuals> to produce additional visualization results
|
||||
"""
|
||||
with torch.no_grad():
|
||||
self.forward()
|
||||
self.compute_visuals()
|
||||
|
||||
def compute_visuals(self):
|
||||
"""Calculate additional output images for visdom and HTML visualization"""
|
||||
pass
|
||||
|
||||
def get_image_paths(self):
|
||||
""" Return image paths that are used to load current data"""
|
||||
return self.image_paths
|
||||
|
||||
def update_learning_rate(self):
|
||||
"""Update learning rates for all the networks; called at the end of every epoch"""
|
||||
for scheduler in self.schedulers:
|
||||
if self.opt.lr_policy == 'plateau':
|
||||
scheduler.step(self.metric)
|
||||
else:
|
||||
scheduler.step()
|
||||
|
||||
lr = self.optimizers[0].param_groups[0]['lr']
|
||||
print('learning rate = %.7f' % lr)
|
||||
|
||||
def get_current_visuals(self):
|
||||
"""Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
|
||||
visual_ret = OrderedDict()
|
||||
for name in self.visual_names:
|
||||
if isinstance(name, str):
|
||||
visual_ret[name] = getattr(self, name)
|
||||
return visual_ret
|
||||
|
||||
def get_current_losses(self):
|
||||
"""Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
|
||||
errors_ret = OrderedDict()
|
||||
for name in self.loss_names:
|
||||
if isinstance(name, str):
|
||||
errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number
|
||||
return errors_ret
|
||||
|
||||
def save_networks(self, epoch):
|
||||
"""Save all the networks to the disk.
|
||||
|
||||
Parameters:
|
||||
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
|
||||
"""
|
||||
for name in self.model_names:
|
||||
if isinstance(name, str):
|
||||
save_filename = '%s_net_%s.pth' % (epoch, name)
|
||||
save_path = os.path.join(self.save_dir, save_filename)
|
||||
net = getattr(self, 'net' + name)
|
||||
|
||||
if len(self.gpu_ids) > 0 and torch.cuda.is_available():
|
||||
torch.save(net.module.cpu().state_dict(), save_path)
|
||||
net.cuda(self.gpu_ids[0])
|
||||
else:
|
||||
torch.save(net.cpu().state_dict(), save_path)
|
||||
|
||||
def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
|
||||
"""Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
|
||||
key = keys[i]
|
||||
if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
|
||||
if module.__class__.__name__.startswith('InstanceNorm') and \
|
||||
(key == 'running_mean' or key == 'running_var'):
|
||||
if getattr(module, key) is None:
|
||||
state_dict.pop('.'.join(keys))
|
||||
if module.__class__.__name__.startswith('InstanceNorm') and \
|
||||
(key == 'num_batches_tracked'):
|
||||
state_dict.pop('.'.join(keys))
|
||||
else:
|
||||
self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
|
||||
|
||||
def load_networks(self, epoch):
|
||||
"""Load all the networks from the disk.
|
||||
|
||||
Parameters:
|
||||
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
|
||||
"""
|
||||
for name in self.model_names:
|
||||
if isinstance(name, str):
|
||||
load_filename = '%s_net_%s.pth' % (epoch, name)
|
||||
if self.opt.isTrain and self.opt.pretrained_name is not None:
|
||||
load_dir = os.path.join(self.opt.checkpoints_dir, self.opt.pretrained_name)
|
||||
else:
|
||||
load_dir = self.save_dir
|
||||
|
||||
load_path = os.path.join(load_dir, load_filename)
|
||||
net = getattr(self, 'net' + name)
|
||||
if isinstance(net, torch.nn.DataParallel):
|
||||
net = net.module
|
||||
print('loading the model from %s' % load_path)
|
||||
# if you are using PyTorch newer than 0.4 (e.g., built from
|
||||
# GitHub source), you can remove str() on self.device
|
||||
state_dict = torch.load(load_path, map_location=str(self.device))
|
||||
if hasattr(state_dict, '_metadata'):
|
||||
del state_dict._metadata
|
||||
|
||||
# patch InstanceNorm checkpoints prior to 0.4
|
||||
# for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
|
||||
# self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
|
||||
net.load_state_dict(state_dict)
|
||||
|
||||
def print_networks(self, verbose):
|
||||
"""Print the total number of parameters in the network and (if verbose) network architecture
|
||||
|
||||
Parameters:
|
||||
verbose (bool) -- if verbose: print the network architecture
|
||||
"""
|
||||
print('---------- Networks initialized -------------')
|
||||
for name in self.model_names:
|
||||
if isinstance(name, str):
|
||||
net = getattr(self, 'net' + name)
|
||||
num_params = 0
|
||||
for param in net.parameters():
|
||||
num_params += param.numel()
|
||||
if verbose:
|
||||
print(net)
|
||||
print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
|
||||
print('-----------------------------------------------')
|
||||
|
||||
def set_requires_grad(self, nets, requires_grad=False):
|
||||
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
|
||||
Parameters:
|
||||
nets (network list) -- a list of networks
|
||||
requires_grad (bool) -- whether the networks require gradients or not
|
||||
"""
|
||||
if not isinstance(nets, list):
|
||||
nets = [nets]
|
||||
for net in nets:
|
||||
if net is not None:
|
||||
for param in net.parameters():
|
||||
param.requires_grad = requires_grad
|
||||
|
||||
def generate_visuals_for_evaluation(self, data, mode):
|
||||
return {}
|
||||
214
models/cut_model.py
Normal file
214
models/cut_model.py
Normal file
@ -0,0 +1,214 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from .base_model import BaseModel
|
||||
from . import networks
|
||||
from .patchnce import PatchNCELoss
|
||||
import util.util as util
|
||||
|
||||
|
||||
class CUTModel(BaseModel):
|
||||
""" This class implements CUT and FastCUT model, described in the paper
|
||||
Contrastive Learning for Unpaired Image-to-Image Translation
|
||||
Taesung Park, Alexei A. Efros, Richard Zhang, Jun-Yan Zhu
|
||||
ECCV, 2020
|
||||
|
||||
The code borrows heavily from the PyTorch implementation of CycleGAN
|
||||
https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
|
||||
"""
|
||||
@staticmethod
|
||||
def modify_commandline_options(parser, is_train=True):
|
||||
""" Configures options specific for CUT model
|
||||
"""
|
||||
parser.add_argument('--CUT_mode', type=str, default="CUT", choices='(CUT, cut, FastCUT, fastcut)')
|
||||
|
||||
parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss:GAN(G(X))')
|
||||
parser.add_argument('--lambda_NCE', type=float, default=1.0, help='weight for NCE loss: NCE(G(X), X)')
|
||||
parser.add_argument('--nce_idt', type=util.str2bool, nargs='?', const=True, default=False, help='use NCE loss for identity mapping: NCE(G(Y), Y))')
|
||||
parser.add_argument('--nce_layers', type=str, default='0,4,8,12,16', help='compute NCE loss on which layers')
|
||||
parser.add_argument('--nce_includes_all_negatives_from_minibatch',
|
||||
type=util.str2bool, nargs='?', const=True, default=False,
|
||||
help='(used for single image translation) If True, include the negatives from the other samples of the minibatch when computing the contrastive loss. Please see models/patchnce.py for more details.')
|
||||
parser.add_argument('--netF', type=str, default='mlp_sample', choices=['sample', 'reshape', 'mlp_sample'], help='how to downsample the feature map')
|
||||
parser.add_argument('--netF_nc', type=int, default=256)
|
||||
parser.add_argument('--nce_T', type=float, default=0.07, help='temperature for NCE loss')
|
||||
parser.add_argument('--num_patches', type=int, default=256, help='number of patches per layer')
|
||||
parser.add_argument('--flip_equivariance',
|
||||
type=util.str2bool, nargs='?', const=True, default=False,
|
||||
help="Enforce flip-equivariance as additional regularization. It's used by FastCUT, but not CUT")
|
||||
|
||||
parser.set_defaults(pool_size=0) # no image pooling
|
||||
|
||||
opt, _ = parser.parse_known_args()
|
||||
|
||||
# Set default parameters for CUT and FastCUT
|
||||
if opt.CUT_mode.lower() == "cut":
|
||||
parser.set_defaults(nce_idt=True, lambda_NCE=1.0)
|
||||
elif opt.CUT_mode.lower() == "fastcut":
|
||||
parser.set_defaults(
|
||||
nce_idt=False, lambda_NCE=10.0, flip_equivariance=True,
|
||||
n_epochs=150, n_epochs_decay=50
|
||||
)
|
||||
else:
|
||||
raise ValueError(opt.CUT_mode)
|
||||
|
||||
return parser
|
||||
|
||||
def __init__(self, opt):
|
||||
BaseModel.__init__(self, opt)
|
||||
|
||||
# specify the training losses you want to print out.
|
||||
# The training/test scripts will call <BaseModel.get_current_losses>
|
||||
self.loss_names = ['G_GAN', 'D_real', 'D_fake', 'G', 'NCE']
|
||||
self.visual_names = ['real_A', 'fake_B', 'real_B']
|
||||
self.nce_layers = [int(i) for i in self.opt.nce_layers.split(',')]
|
||||
|
||||
if opt.nce_idt and self.isTrain:
|
||||
self.loss_names += ['NCE_Y']
|
||||
self.visual_names += ['idt_B']
|
||||
|
||||
if self.isTrain:
|
||||
self.model_names = ['G', 'F', 'D']
|
||||
else: # during test time, only load G
|
||||
self.model_names = ['G']
|
||||
|
||||
# define networks (both generator and discriminator)
|
||||
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt)
|
||||
self.netF = networks.define_F(opt.input_nc, opt.netF, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt)
|
||||
|
||||
if self.isTrain:
|
||||
self.netD = networks.define_D(opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.normD, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt)
|
||||
|
||||
# define loss functions
|
||||
self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
|
||||
self.criterionNCE = []
|
||||
|
||||
for nce_layer in self.nce_layers:
|
||||
self.criterionNCE.append(PatchNCELoss(opt).to(self.device))
|
||||
|
||||
self.criterionIdt = torch.nn.L1Loss().to(self.device)
|
||||
self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
|
||||
self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
|
||||
self.optimizers.append(self.optimizer_G)
|
||||
self.optimizers.append(self.optimizer_D)
|
||||
|
||||
def data_dependent_initialize(self, data):
|
||||
"""
|
||||
The feature network netF is defined in terms of the shape of the intermediate, extracted
|
||||
features of the encoder portion of netG. Because of this, the weights of netF are
|
||||
initialized at the first feedforward pass with some input images.
|
||||
Please also see PatchSampleF.create_mlp(), which is called at the first forward() call.
|
||||
"""
|
||||
self.set_input(data)
|
||||
bs_per_gpu = self.real_A.size(0) // max(len(self.opt.gpu_ids), 1)
|
||||
self.real_A = self.real_A[:bs_per_gpu]
|
||||
self.real_B = self.real_B[:bs_per_gpu]
|
||||
self.forward() # compute fake images: G(A)
|
||||
if self.opt.isTrain:
|
||||
self.compute_D_loss().backward() # calculate gradients for D
|
||||
self.compute_G_loss().backward() # calculate graidents for G
|
||||
if self.opt.lambda_NCE > 0.0:
|
||||
self.optimizer_F = torch.optim.Adam(self.netF.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, self.opt.beta2))
|
||||
self.optimizers.append(self.optimizer_F)
|
||||
|
||||
def optimize_parameters(self):
|
||||
# forward
|
||||
self.forward()
|
||||
|
||||
# update D
|
||||
self.set_requires_grad(self.netD, True)
|
||||
self.optimizer_D.zero_grad()
|
||||
self.loss_D = self.compute_D_loss()
|
||||
self.loss_D.backward()
|
||||
self.optimizer_D.step()
|
||||
|
||||
# update G
|
||||
self.set_requires_grad(self.netD, False)
|
||||
self.optimizer_G.zero_grad()
|
||||
if self.opt.netF == 'mlp_sample':
|
||||
self.optimizer_F.zero_grad()
|
||||
self.loss_G = self.compute_G_loss()
|
||||
self.loss_G.backward()
|
||||
self.optimizer_G.step()
|
||||
if self.opt.netF == 'mlp_sample':
|
||||
self.optimizer_F.step()
|
||||
|
||||
def set_input(self, input):
|
||||
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
|
||||
Parameters:
|
||||
input (dict): include the data itself and its metadata information.
|
||||
The option 'direction' can be used to swap domain A and domain B.
|
||||
"""
|
||||
AtoB = self.opt.direction == 'AtoB'
|
||||
self.real_A = input['A' if AtoB else 'B'].to(self.device)
|
||||
self.real_B = input['B' if AtoB else 'A'].to(self.device)
|
||||
self.image_paths = input['A_paths' if AtoB else 'B_paths']
|
||||
|
||||
def forward(self):
|
||||
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
|
||||
self.real = torch.cat((self.real_A, self.real_B), dim=0) if self.opt.nce_idt and self.opt.isTrain else self.real_A
|
||||
if self.opt.flip_equivariance:
|
||||
self.flipped_for_equivariance = self.opt.isTrain and (np.random.random() < 0.5)
|
||||
if self.flipped_for_equivariance:
|
||||
self.real = torch.flip(self.real, [3])
|
||||
|
||||
self.fake = self.netG(self.real)
|
||||
self.fake_B = self.fake[:self.real_A.size(0)]
|
||||
if self.opt.nce_idt:
|
||||
self.idt_B = self.fake[self.real_A.size(0):]
|
||||
|
||||
def compute_D_loss(self):
|
||||
"""Calculate GAN loss for the discriminator"""
|
||||
fake = self.fake_B.detach()
|
||||
# Fake; stop backprop to the generator by detaching fake_B
|
||||
pred_fake = self.netD(fake)
|
||||
self.loss_D_fake = self.criterionGAN(pred_fake, False).mean()
|
||||
# Real
|
||||
self.pred_real = self.netD(self.real_B)
|
||||
loss_D_real = self.criterionGAN(self.pred_real, True)
|
||||
self.loss_D_real = loss_D_real.mean()
|
||||
|
||||
# combine loss and calculate gradients
|
||||
self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
|
||||
return self.loss_D
|
||||
|
||||
def compute_G_loss(self):
|
||||
"""Calculate GAN and NCE loss for the generator"""
|
||||
fake = self.fake_B
|
||||
# First, G(A) should fake the discriminator
|
||||
if self.opt.lambda_GAN > 0.0:
|
||||
pred_fake = self.netD(fake)
|
||||
self.loss_G_GAN = self.criterionGAN(pred_fake, True).mean() * self.opt.lambda_GAN
|
||||
else:
|
||||
self.loss_G_GAN = 0.0
|
||||
|
||||
if self.opt.lambda_NCE > 0.0:
|
||||
self.loss_NCE = self.calculate_NCE_loss(self.real_A, self.fake_B)
|
||||
else:
|
||||
self.loss_NCE, self.loss_NCE_bd = 0.0, 0.0
|
||||
|
||||
if self.opt.nce_idt and self.opt.lambda_NCE > 0.0:
|
||||
self.loss_NCE_Y = self.calculate_NCE_loss(self.real_B, self.idt_B)
|
||||
loss_NCE_both = (self.loss_NCE + self.loss_NCE_Y) * 0.5
|
||||
else:
|
||||
loss_NCE_both = self.loss_NCE
|
||||
|
||||
self.loss_G = self.loss_G_GAN + loss_NCE_both
|
||||
return self.loss_G
|
||||
|
||||
def calculate_NCE_loss(self, src, tgt):
|
||||
n_layers = len(self.nce_layers)
|
||||
feat_q = self.netG(tgt, self.nce_layers, encode_only=True)
|
||||
|
||||
if self.opt.flip_equivariance and self.flipped_for_equivariance:
|
||||
feat_q = [torch.flip(fq, [3]) for fq in feat_q]
|
||||
|
||||
feat_k = self.netG(src, self.nce_layers, encode_only=True)
|
||||
feat_k_pool, sample_ids = self.netF(feat_k, self.opt.num_patches, None)
|
||||
feat_q_pool, _ = self.netF(feat_q, self.opt.num_patches, sample_ids)
|
||||
|
||||
total_nce_loss = 0.0
|
||||
for f_q, f_k, crit, nce_layer in zip(feat_q_pool, feat_k_pool, self.criterionNCE, self.nce_layers):
|
||||
loss = crit(f_q, f_k) * self.opt.lambda_NCE
|
||||
total_nce_loss += loss.mean()
|
||||
|
||||
return total_nce_loss / n_layers
|
||||
222
models/cycle_gan_model.py
Normal file
222
models/cycle_gan_model.py
Normal file
@ -0,0 +1,222 @@
|
||||
import torch
|
||||
import itertools
|
||||
from util.image_pool import ImagePool
|
||||
from .base_model import BaseModel
|
||||
from . import networks
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError as error:
|
||||
print(error)
|
||||
|
||||
|
||||
class CycleGANModel(BaseModel):
|
||||
"""
|
||||
This class implements the CycleGAN model, for learning image-to-image translation without paired data.
|
||||
|
||||
The model training requires '--dataset_mode unaligned' dataset.
|
||||
By default, it uses a '--netG resnet_9blocks' ResNet generator,
|
||||
a '--netD basic' discriminator (PatchGAN introduced by pix2pix),
|
||||
and a least-square GANs objective ('--gan_mode lsgan').
|
||||
|
||||
CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf
|
||||
"""
|
||||
@staticmethod
|
||||
def modify_commandline_options(parser, is_train=True):
|
||||
"""Add new dataset-specific options, and rewrite default values for existing options.
|
||||
|
||||
Parameters:
|
||||
parser -- original option parser
|
||||
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
|
||||
|
||||
Returns:
|
||||
the modified parser.
|
||||
|
||||
For CycleGAN, in addition to GAN losses, we introduce lambda_A, lambda_B, and lambda_identity for the following losses.
|
||||
A (source domain), B (target domain).
|
||||
Generators: G_A: A -> B; G_B: B -> A.
|
||||
Discriminators: D_A: G_A(A) vs. B; D_B: G_B(B) vs. A.
|
||||
Forward cycle loss: lambda_A * ||G_B(G_A(A)) - A|| (Eqn. (2) in the paper)
|
||||
Backward cycle loss: lambda_B * ||G_A(G_B(B)) - B|| (Eqn. (2) in the paper)
|
||||
Identity loss (optional): lambda_identity * (||G_A(B) - B|| * lambda_B + ||G_B(A) - A|| * lambda_A) (Sec 5.2 "Photo generation from paintings" in the paper)
|
||||
Dropout is not used in the original CycleGAN paper.
|
||||
"""
|
||||
# parser.set_defaults(no_dropout=True, no_antialias=True, no_antialias_up=True) # default CycleGAN did not use dropout
|
||||
# parser.set_defaults(no_dropout=True)
|
||||
if is_train:
|
||||
parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)')
|
||||
parser.add_argument('--lambda_B', type=float, default=10.0, help='weight for cycle loss (B -> A -> B)')
|
||||
parser.add_argument('--lambda_identity', type=float, default=0.5, help='use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1')
|
||||
|
||||
return parser
|
||||
|
||||
def __init__(self, opt):
|
||||
"""Initialize the CycleGAN class.
|
||||
|
||||
Parameters:
|
||||
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
|
||||
"""
|
||||
BaseModel.__init__(self, opt)
|
||||
# specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
|
||||
self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B']
|
||||
# specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
|
||||
visual_names_A = ['real_A', 'fake_B', 'rec_A']
|
||||
visual_names_B = ['real_B', 'fake_A', 'rec_B']
|
||||
if self.isTrain and self.opt.lambda_identity > 0.0: # if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B)
|
||||
visual_names_A.append('idt_B')
|
||||
visual_names_B.append('idt_A')
|
||||
|
||||
self.visual_names = visual_names_A + visual_names_B # combine visualizations for A and B
|
||||
# specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>.
|
||||
if self.isTrain:
|
||||
self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']
|
||||
else: # during test time, only load Gs
|
||||
self.model_names = ['G_A', 'G_B']
|
||||
|
||||
# define networks (both Generators and discriminators)
|
||||
# The naming is different from those used in the paper.
|
||||
# Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
|
||||
self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG,
|
||||
not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt=opt)
|
||||
self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.normG,
|
||||
not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt=opt)
|
||||
|
||||
if self.isTrain: # define discriminators
|
||||
self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD,
|
||||
opt.n_layers_D, opt.normD, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt=opt)
|
||||
self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD,
|
||||
opt.n_layers_D, opt.normD, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt=opt)
|
||||
|
||||
if self.isTrain:
|
||||
if opt.lambda_identity > 0.0: # only works when input and output images have the same number of channels
|
||||
assert(opt.input_nc == opt.output_nc)
|
||||
self.fake_A_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images
|
||||
self.fake_B_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images
|
||||
# define loss functions
|
||||
self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) # define GAN loss.
|
||||
self.criterionCycle = torch.nn.L1Loss()
|
||||
self.criterionIdt = torch.nn.L1Loss()
|
||||
# initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
|
||||
self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))
|
||||
self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))
|
||||
self.optimizers.append(self.optimizer_G)
|
||||
self.optimizers.append(self.optimizer_D)
|
||||
|
||||
def set_input(self, input):
|
||||
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
|
||||
|
||||
Parameters:
|
||||
input (dict): include the data itself and its metadata information.
|
||||
|
||||
The option 'direction' can be used to swap domain A and domain B.
|
||||
"""
|
||||
AtoB = self.opt.direction == 'AtoB'
|
||||
self.real_A = input['A' if AtoB else 'B'].to(self.device)
|
||||
self.real_B = input['B' if AtoB else 'A'].to(self.device)
|
||||
self.image_paths = input['A_paths' if AtoB else 'B_paths']
|
||||
|
||||
def forward(self):
|
||||
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
|
||||
self.fake_B = self.netG_A(self.real_A) # G_A(A)
|
||||
self.rec_A = self.netG_B(self.fake_B) # G_B(G_A(A))
|
||||
self.fake_A = self.netG_B(self.real_B) # G_B(B)
|
||||
self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B))
|
||||
|
||||
def backward_D_basic(self, netD, real, fake):
|
||||
"""Calculate GAN loss for the discriminator
|
||||
|
||||
Parameters:
|
||||
netD (network) -- the discriminator D
|
||||
real (tensor array) -- real images
|
||||
fake (tensor array) -- images generated by a generator
|
||||
|
||||
Return the discriminator loss.
|
||||
We also call loss_D.backward() to calculate the gradients.
|
||||
"""
|
||||
# Real
|
||||
pred_real = netD(real)
|
||||
loss_D_real = self.criterionGAN(pred_real, True)
|
||||
# Fake
|
||||
pred_fake = netD(fake.detach())
|
||||
loss_D_fake = self.criterionGAN(pred_fake, False)
|
||||
# Combined loss and calculate gradients
|
||||
loss_D = (loss_D_real + loss_D_fake) * 0.5
|
||||
if self.opt.amp:
|
||||
with amp.scale_loss(loss_D, self.optimizer_D) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
loss_D.backward()
|
||||
return loss_D
|
||||
|
||||
def backward_D_A(self):
|
||||
"""Calculate GAN loss for discriminator D_A"""
|
||||
fake_B = self.fake_B_pool.query(self.fake_B)
|
||||
self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)
|
||||
|
||||
def backward_D_B(self):
|
||||
"""Calculate GAN loss for discriminator D_B"""
|
||||
fake_A = self.fake_A_pool.query(self.fake_A)
|
||||
self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
|
||||
|
||||
def backward_G(self):
|
||||
"""Calculate the loss for generators G_A and G_B"""
|
||||
lambda_idt = self.opt.lambda_identity
|
||||
lambda_A = self.opt.lambda_A
|
||||
lambda_B = self.opt.lambda_B
|
||||
# Identity loss
|
||||
if lambda_idt > 0:
|
||||
# G_A should be identity if real_B is fed: ||G_A(B) - B||
|
||||
self.idt_A = self.netG_A(self.real_B)
|
||||
self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
|
||||
# G_B should be identity if real_A is fed: ||G_B(A) - A||
|
||||
self.idt_B = self.netG_B(self.real_A)
|
||||
self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
|
||||
else:
|
||||
self.loss_idt_A = 0
|
||||
self.loss_idt_B = 0
|
||||
|
||||
# GAN loss D_A(G_A(A))
|
||||
self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)
|
||||
# GAN loss D_B(G_B(B))
|
||||
self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
|
||||
# Forward cycle loss || G_B(G_A(A)) - A||
|
||||
self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
|
||||
# Backward cycle loss || G_A(G_B(B)) - B||
|
||||
self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
|
||||
# combined loss and calculate gradients
|
||||
self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
|
||||
if self.opt.amp:
|
||||
with amp.scale_loss(self.loss_G, self.optimizer_G) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
self.loss_G.backward()
|
||||
|
||||
def data_dependent_initialize(self):
|
||||
return
|
||||
|
||||
def generate_visuals_for_evaluation(self, data, mode):
|
||||
with torch.no_grad():
|
||||
visuals = {}
|
||||
AtoB = self.opt.direction == "AtoB"
|
||||
G = self.netG_A
|
||||
source = data["A" if AtoB else "B"].to(self.device)
|
||||
if mode == "forward":
|
||||
visuals["fake_B"] = G(source)
|
||||
else:
|
||||
raise ValueError("mode %s is not recognized" % mode)
|
||||
return visuals
|
||||
|
||||
def optimize_parameters(self):
|
||||
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
|
||||
# forward
|
||||
self.forward() # compute fake images and reconstruction images.
|
||||
# G_A and G_B
|
||||
self.set_requires_grad([self.netD_A, self.netD_B], False) # Ds require no gradients when optimizing Gs
|
||||
self.optimizer_G.zero_grad() # set G_A and G_B's gradients to zero
|
||||
self.backward_G() # calculate gradients for G_A and G_B
|
||||
self.optimizer_G.step() # update G_A and G_B's weights
|
||||
# D_A and D_B
|
||||
self.set_requires_grad([self.netD_A, self.netD_B], True)
|
||||
self.optimizer_D.zero_grad() # set D_A and D_B's gradients to zero
|
||||
self.backward_D_A() # calculate gradients for D_A
|
||||
self.backward_D_B() # calculate graidents for D_B
|
||||
self.optimizer_D.step() # update D_A and D_B's weights
|
||||
1530
models/networks.py
Normal file
1530
models/networks.py
Normal file
File diff suppressed because it is too large
Load Diff
55
models/patchnce.py
Normal file
55
models/patchnce.py
Normal file
@ -0,0 +1,55 @@
|
||||
from packaging import version
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class PatchNCELoss(nn.Module):
|
||||
def __init__(self, opt):
|
||||
super().__init__()
|
||||
self.opt = opt
|
||||
self.cross_entropy_loss = torch.nn.CrossEntropyLoss(reduction='none')
|
||||
self.mask_dtype = torch.uint8 if version.parse(torch.__version__) < version.parse('1.2.0') else torch.bool
|
||||
|
||||
def forward(self, feat_q, feat_k):
|
||||
num_patches = feat_q.shape[0]
|
||||
dim = feat_q.shape[1]
|
||||
feat_k = feat_k.detach()
|
||||
|
||||
# pos logit
|
||||
l_pos = torch.bmm(
|
||||
feat_q.view(num_patches, 1, -1), feat_k.view(num_patches, -1, 1))
|
||||
l_pos = l_pos.view(num_patches, 1)
|
||||
|
||||
# neg logit
|
||||
|
||||
# Should the negatives from the other samples of a minibatch be utilized?
|
||||
# In CUT and FastCUT, we found that it's best to only include negatives
|
||||
# from the same image. Therefore, we set
|
||||
# --nce_includes_all_negatives_from_minibatch as False
|
||||
# However, for single-image translation, the minibatch consists of
|
||||
# crops from the "same" high-resolution image.
|
||||
# Therefore, we will include the negatives from the entire minibatch.
|
||||
if self.opt.nce_includes_all_negatives_from_minibatch:
|
||||
# reshape features as if they are all negatives of minibatch of size 1.
|
||||
batch_dim_for_bmm = 1
|
||||
else:
|
||||
batch_dim_for_bmm = self.opt.batch_size
|
||||
|
||||
# reshape features to batch size
|
||||
feat_q = feat_q.view(batch_dim_for_bmm, -1, dim)
|
||||
feat_k = feat_k.view(batch_dim_for_bmm, -1, dim)
|
||||
npatches = feat_q.size(1)
|
||||
l_neg_curbatch = torch.bmm(feat_q, feat_k.transpose(2, 1))
|
||||
|
||||
# diagonal entries are similarity between same features, and hence meaningless.
|
||||
# just fill the diagonal with very small number, which is exp(-10) and almost zero
|
||||
diagonal = torch.eye(npatches, device=feat_q.device, dtype=self.mask_dtype)[None, :, :]
|
||||
l_neg_curbatch.masked_fill_(diagonal, -10.0)
|
||||
l_neg = l_neg_curbatch.view(-1, npatches)
|
||||
|
||||
out = torch.cat((l_pos, l_neg), dim=1) / self.opt.nce_T
|
||||
|
||||
loss = self.cross_entropy_loss(out, torch.zeros(out.size(0), dtype=torch.long,
|
||||
device=feat_q.device))
|
||||
|
||||
return loss
|
||||
363
models/roma_model.py
Normal file
363
models/roma_model.py
Normal file
@ -0,0 +1,363 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from .base_model import BaseModel
|
||||
from . import networks
|
||||
from .patchnce import PatchNCELoss
|
||||
import util.util as util
|
||||
import timm
|
||||
import time
|
||||
import torch.nn.functional as F
|
||||
import sys
|
||||
from functools import partial
|
||||
import torch.nn as nn
|
||||
import math
|
||||
|
||||
from torchvision.transforms import transforms as tfs
|
||||
|
||||
class ROMAModel(BaseModel):
|
||||
|
||||
@staticmethod
|
||||
def modify_commandline_options(parser, is_train=True):
|
||||
""" Configures options specific for CUT model
|
||||
"""
|
||||
parser.add_argument('--adj_size_list', type=list, default=[2, 4, 6, 8, 12], help='different scales of perception field')
|
||||
parser.add_argument('--lambda_mlp', type=float, default=1.0, help='weight of lr for discriminator')
|
||||
parser.add_argument('--lambda_motion', type=float, default=1.0, help='weight for Temporal Consistency')
|
||||
parser.add_argument('--lambda_D_ViT', type=float, default=1.0, help='weight for discriminator')
|
||||
parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss: GAN(G(X))')
|
||||
parser.add_argument('--lambda_global', type=float, default=1.0, help='weight for Global Structural Consistency')
|
||||
parser.add_argument('--lambda_spatial', type=float, default=1.0, help='weight for Local Structural Consistency')
|
||||
parser.add_argument('--atten_layers', type=str, default='1,3,5', help='compute Cross-Similarity on which layers')
|
||||
parser.add_argument('--local_nums', type=int, default=256)
|
||||
parser.add_argument('--which_D_layer', type=int, default=-1)
|
||||
parser.add_argument('--side_length', type=int, default=7)
|
||||
|
||||
parser.set_defaults(pool_size=0)
|
||||
|
||||
opt, _ = parser.parse_known_args()
|
||||
|
||||
return parser
|
||||
|
||||
def __init__(self, opt):
|
||||
BaseModel.__init__(self, opt)
|
||||
|
||||
|
||||
self.loss_names = ['G_GAN_ViT', 'D_real_ViT', 'D_fake_ViT', 'global', 'spatial', 'motion']
|
||||
self.visual_names = ['real_A0', 'real_A1', 'fake_B0', 'fake_B1', 'real_B0', 'real_B1']
|
||||
self.atten_layers = [int(i) for i in self.opt.atten_layers.split(',')]
|
||||
|
||||
|
||||
if self.isTrain:
|
||||
self.model_names = ['G', 'D_ViT']
|
||||
else: # during test time, only load G
|
||||
self.model_names = ['G']
|
||||
|
||||
|
||||
# define networks (both generator and discriminator)
|
||||
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt)
|
||||
|
||||
|
||||
if self.isTrain:
|
||||
|
||||
self.netD_ViT = networks.MLPDiscriminator().to(self.device)
|
||||
self.netPreViT = timm.create_model("vit_base_patch16_384",pretrained=True).to(self.device)
|
||||
|
||||
|
||||
self.norm = F.softmax
|
||||
|
||||
self.resize = tfs.Resize(size=(384,384))
|
||||
|
||||
self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
|
||||
self.criterionNCE = []
|
||||
|
||||
for atten_layer in self.atten_layers:
|
||||
self.criterionNCE.append(PatchNCELoss(opt).to(self.device))
|
||||
|
||||
self.criterionL1 = torch.nn.L1Loss().to(self.device)
|
||||
self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
|
||||
self.optimizer_D_ViT = torch.optim.Adam(self.netD_ViT.parameters(), lr=opt.lr * opt.lambda_mlp, betas=(opt.beta1, opt.beta2))
|
||||
self.optimizers.append(self.optimizer_G)
|
||||
self.optimizers.append(self.optimizer_D_ViT)
|
||||
|
||||
def data_dependent_initialize(self, data):
|
||||
"""
|
||||
The feature network netF is defined in terms of the shape of the intermediate, extracted
|
||||
features of the encoder portion of netG. Because of this, the weights of netF are
|
||||
initialized at the first feedforward pass with some input images.
|
||||
Please also see PatchSampleF.create_mlp(), which is called at the first forward() call.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def optimize_parameters(self):
|
||||
# forward
|
||||
self.forward()
|
||||
|
||||
# update D
|
||||
self.set_requires_grad(self.netD_ViT, True)
|
||||
self.optimizer_D_ViT.zero_grad()
|
||||
self.loss_D = self.compute_D_loss()
|
||||
self.loss_D.backward()
|
||||
self.optimizer_D_ViT.step()
|
||||
|
||||
# update G
|
||||
self.set_requires_grad(self.netD_ViT, False)
|
||||
self.optimizer_G.zero_grad()
|
||||
self.loss_G = self.compute_G_loss()
|
||||
self.loss_G.backward()
|
||||
self.optimizer_G.step()
|
||||
|
||||
def set_input(self, input):
|
||||
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
|
||||
Parameters:
|
||||
input (dict): include the data itself and its metadata information.
|
||||
The option 'direction' can be used to swap domain A and domain B.
|
||||
"""
|
||||
AtoB = self.opt.direction == 'AtoB'
|
||||
self.real_A0 = input['A0' if AtoB else 'B0'].to(self.device)
|
||||
self.real_A1 = input['A1' if AtoB else 'B1'].to(self.device)
|
||||
self.real_B0 = input['B0' if AtoB else 'A0'].to(self.device)
|
||||
self.real_B1 = input['B1' if AtoB else 'A1'].to(self.device)
|
||||
self.image_paths = input['A_paths' if AtoB else 'B_paths']
|
||||
|
||||
def forward(self):
|
||||
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
|
||||
|
||||
# ============ 第一步:对 real_A / real_A2 进行多步随机生成过程 ============
|
||||
tau = self.opt.tau
|
||||
T = self.opt.num_timesteps
|
||||
incs = np.array([0] + [1/(i+1) for i in range(T-1)])
|
||||
times = np.cumsum(incs)
|
||||
times = times / times[-1]
|
||||
times = 0.5 * times[-1] + 0.5 * times #[0.5,1]
|
||||
times = np.concatenate([np.zeros(1), times])
|
||||
times = torch.tensor(times).float().cuda()
|
||||
self.times = times
|
||||
bs = self.mutil_real_A0_tokens.size(0)
|
||||
time_idx = (torch.randint(T, size=[1]).cuda() * torch.ones(size=[1]).cuda()).long()
|
||||
self.time_idx = time_idx
|
||||
|
||||
with torch.no_grad():
|
||||
self.netG.eval()
|
||||
# ============ 第二步:对 real_A / real_A2 进行多步随机生成过程 ============
|
||||
for t in range(self.time_idx.int().item() + 1):
|
||||
# 计算增量 delta 与 inter/scale,用于每个时间步的插值等
|
||||
if t > 0:
|
||||
delta = times[t] - times[t - 1]
|
||||
denom = times[-1] - times[t - 1]
|
||||
inter = (delta / denom).reshape(-1, 1, 1, 1)
|
||||
scale = (delta * (1 - delta / denom)).reshape(-1, 1, 1, 1)
|
||||
|
||||
# 对 Xt、Xt2 进行随机噪声更新
|
||||
Xt = self.mutil_real_A0_tokens if (t == 0) else (1 - inter) * Xt + inter * Xt_1.detach() + \
|
||||
(scale * tau).sqrt() * torch.randn_like(Xt).to(self.mutil_real_A0_tokens.device)
|
||||
time_idx = (t * torch.ones(size=[self.mutil_real_A0_tokens.shape[0]]).to(self.mutil_real_A0_tokens.device)).long()
|
||||
z = torch.randn(size=[self.mutil_real_A0_tokens.shape[0], 4 * self.opt.ngf]).to(self.mutil_real_A0_tokens.device)
|
||||
self.time = times[time_idx]
|
||||
Xt_1 = self.netG(Xt, self.time, z)
|
||||
|
||||
Xt2 = self.mutil_real_A1_tokens if (t == 0) else (1 - inter) * Xt2 + inter * Xt_12.detach() + \
|
||||
(scale * tau).sqrt() * torch.randn_like(Xt2).to(self.mutil_real_A1_tokens.device)
|
||||
time_idx = (t * torch.ones(size=[self.mutil_real_A1_tokens.shape[0]]).to(self.mutil_real_A1_tokens.device)).long()
|
||||
z = torch.randn(size=[self.mutil_real_A1_tokens.shape[0], 4 * self.opt.ngf]).to(self.mutil_real_A1_tokens.device)
|
||||
Xt_12 = self.netG(Xt2, self.time, z)
|
||||
|
||||
# 保存去噪后的中间结果 (real_A_noisy 等),供下一步做拼接
|
||||
self.real_A_noisy = Xt.detach()
|
||||
self.real_A_noisy2 = Xt2.detach()
|
||||
# 保存noisy_map
|
||||
self.noisy_map = self.real_A_noisy - self.real_A
|
||||
|
||||
# ============ 第三步:拼接输入并执行网络推理 =============
|
||||
bs = self.mutil_real_A0_tokens.size(0)
|
||||
z_in = torch.randn(size=[2 * bs, 4 * self.opt.ngf]).to(self.mutil_real_A0_tokens.device)
|
||||
z_in2 = torch.randn(size=[bs, 4 * self.opt.ngf]).to(self.mutil_real_A1_tokens.device)
|
||||
# 将 real_A, real_B 拼接 (如 nce_idt=True),并同样处理 real_A_noisy 与 XtB
|
||||
self.real = self.mutil_real_A0_tokens
|
||||
self.realt = self.real_A_noisy
|
||||
|
||||
if self.opt.flip_equivariance:
|
||||
self.flipped_for_equivariance = self.opt.isTrain and (np.random.random() < 0.5)
|
||||
if self.flipped_for_equivariance:
|
||||
self.real = torch.flip(self.real, [3])
|
||||
self.realt = torch.flip(self.realt, [3])
|
||||
|
||||
|
||||
self.fake_B0 = self.netG(self.real_A0)
|
||||
self.fake_B1 = self.netG(self.real_A1)
|
||||
|
||||
if self.opt.isTrain:
|
||||
real_A0 = self.real_A0
|
||||
real_A1 = self.real_A1
|
||||
real_B0 = self.real_B0
|
||||
real_B1 = self.real_B1
|
||||
fake_B0 = self.fake_B0
|
||||
fake_B1 = self.fake_B1
|
||||
self.real_A0_resize = self.resize(real_A0)
|
||||
self.real_A1_resize = self.resize(real_A1)
|
||||
real_B0 = self.resize(real_B0)
|
||||
real_B1 = self.resize(real_B1)
|
||||
self.fake_B0_resize = self.resize(fake_B0)
|
||||
self.fake_B1_resize = self.resize(fake_B1)
|
||||
self.mutil_real_A0_tokens = self.netPreViT(self.real_A0_resize, self.atten_layers, get_tokens=True)
|
||||
self.mutil_real_A1_tokens = self.netPreViT(self.real_A1_resize, self.atten_layers, get_tokens=True)
|
||||
self.mutil_real_B0_tokens = self.netPreViT(real_B0, self.atten_layers, get_tokens=True)
|
||||
self.mutil_real_B1_tokens = self.netPreViT(real_B1, self.atten_layers, get_tokens=True)
|
||||
self.mutil_fake_B0_tokens = self.netPreViT(self.fake_B0_resize, self.atten_layers, get_tokens=True)
|
||||
self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True)
|
||||
|
||||
def tokens_concat(self, origin_tokens, adjacent_size):
|
||||
adj_size = adjacent_size
|
||||
B, token_num, C = origin_tokens.shape[0], origin_tokens.shape[1], origin_tokens.shape[2]
|
||||
S = int(math.sqrt(token_num))
|
||||
if S * S != token_num:
|
||||
print('Error! Not a square!')
|
||||
token_map = origin_tokens.clone().reshape(B,S,S,C)
|
||||
cut_patch_list = []
|
||||
for i in range(0, S, adj_size):
|
||||
for j in range(0, S, adj_size):
|
||||
i_left = i
|
||||
i_right = i + adj_size + 1 if i + adj_size <= S else S + 1
|
||||
j_left = j
|
||||
j_right = j + adj_size if j + adj_size <= S else S + 1
|
||||
|
||||
cut_patch = token_map[:, i_left:i_right, j_left: j_right, :]
|
||||
cut_patch= cut_patch.reshape(B,-1,C)
|
||||
cut_patch = torch.mean(cut_patch, dim=1, keepdim=True)
|
||||
cut_patch_list.append(cut_patch)
|
||||
|
||||
|
||||
result = torch.cat(cut_patch_list,dim=1)
|
||||
return result
|
||||
|
||||
|
||||
def cat_results(self, origin_tokens, adj_size_list):
|
||||
res_list = [origin_tokens]
|
||||
for ad_s in adj_size_list:
|
||||
cat_result = self.tokens_concat(origin_tokens, ad_s)
|
||||
res_list.append(cat_result)
|
||||
|
||||
result = torch.cat(res_list, dim=1)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
|
||||
def compute_D_loss(self):
|
||||
"""Calculate GAN loss for the discriminator"""
|
||||
|
||||
|
||||
lambda_D_ViT = self.opt.lambda_D_ViT
|
||||
fake_B0_tokens = self.mutil_fake_B0_tokens[self.opt.which_D_layer].detach()
|
||||
fake_B1_tokens = self.mutil_fake_B1_tokens[self.opt.which_D_layer].detach()
|
||||
|
||||
real_B0_tokens = self.mutil_real_B0_tokens[self.opt.which_D_layer]
|
||||
real_B1_tokens = self.mutil_real_B1_tokens[self.opt.which_D_layer]
|
||||
|
||||
|
||||
fake_B0_tokens = self.cat_results(fake_B0_tokens, self.opt.adj_size_list)
|
||||
fake_B1_tokens = self.cat_results(fake_B1_tokens, self.opt.adj_size_list)
|
||||
|
||||
|
||||
|
||||
real_B0_tokens = self.cat_results(real_B0_tokens, self.opt.adj_size_list)
|
||||
real_B1_tokens = self.cat_results(real_B1_tokens, self.opt.adj_size_list)
|
||||
|
||||
pre_fake0_ViT = self.netD_ViT(fake_B0_tokens)
|
||||
pre_fake1_ViT = self.netD_ViT(fake_B1_tokens)
|
||||
|
||||
self.loss_D_fake_ViT = (self.criterionGAN(pre_fake0_ViT, False).mean() + self.criterionGAN(pre_fake1_ViT, False).mean()) * 0.5 * lambda_D_ViT
|
||||
|
||||
pred_real0_ViT = self.netD_ViT(real_B0_tokens)
|
||||
pred_real1_ViT = self.netD_ViT(real_B1_tokens)
|
||||
self.loss_D_real_ViT = (self.criterionGAN(pred_real0_ViT, True).mean() + self.criterionGAN(pred_real1_ViT, True).mean()) * 0.5 * lambda_D_ViT
|
||||
|
||||
self.loss_D_ViT = (self.loss_D_fake_ViT + self.loss_D_real_ViT) * 0.5
|
||||
|
||||
|
||||
return self.loss_D_ViT
|
||||
|
||||
def compute_G_loss(self):
|
||||
|
||||
if self.opt.lambda_GAN > 0.0:
|
||||
|
||||
fake_B0_tokens = self.mutil_fake_B0_tokens[self.opt.which_D_layer]
|
||||
fake_B1_tokens = self.mutil_fake_B1_tokens[self.opt.which_D_layer]
|
||||
fake_B0_tokens = self.cat_results(fake_B0_tokens, self.opt.adj_size_list)
|
||||
fake_B1_tokens = self.cat_results(fake_B1_tokens, self.opt.adj_size_list)
|
||||
pred_fake0_ViT = self.netD_ViT(fake_B0_tokens)
|
||||
pred_fake1_ViT = self.netD_ViT(fake_B1_tokens)
|
||||
self.loss_G_GAN_ViT = (self.criterionGAN(pred_fake0_ViT, True) + self.criterionGAN(pred_fake1_ViT, True)) * 0.5 * self.opt.lambda_GAN
|
||||
else:
|
||||
self.loss_G_GAN_ViT = 0.0
|
||||
|
||||
if self.opt.lambda_global > 0.0 or self.opt.lambda_spatial > 0.0:
|
||||
self.loss_global, self.loss_spatial = self.calculate_attention_loss()
|
||||
else:
|
||||
self.loss_global, self.loss_spatial = 0.0, 0.0
|
||||
|
||||
if self.opt.lambda_motion > 0.0:
|
||||
self.loss_motion = 0.0
|
||||
for real_A0_tokens, real_A1_tokens, fake_B0_tokens, fake_B1_tokens in zip(self.mutil_real_A0_tokens, self.mutil_real_A1_tokens, self.mutil_fake_B0_tokens, self.mutil_fake_B1_tokens):
|
||||
A0_B1 = real_A0_tokens.bmm(fake_B1_tokens.permute(0,2,1))
|
||||
B0_A1 = fake_B0_tokens.bmm(real_A1_tokens.permute(0,2,1))
|
||||
cos_dis_global = F.cosine_similarity(A0_B1, B0_A1, dim=-1)
|
||||
self.loss_motion += self.criterionL1(torch.ones_like(cos_dis_global), cos_dis_global).mean()
|
||||
else:
|
||||
self.loss_motion = 0.0
|
||||
|
||||
self.loss_G = self.loss_G_GAN_ViT + self.loss_global + self.loss_spatial + self.loss_motion
|
||||
return self.loss_G
|
||||
|
||||
def calculate_attention_loss(self):
|
||||
n_layers = len(self.atten_layers)
|
||||
mutil_real_A0_tokens = self.mutil_real_A0_tokens
|
||||
mutil_real_A1_tokens = self.mutil_real_A1_tokens
|
||||
mutil_fake_B0_tokens = self.mutil_fake_B0_tokens
|
||||
mutil_fake_B1_tokens = self.mutil_fake_B1_tokens
|
||||
|
||||
|
||||
if self.opt.lambda_global > 0.0:
|
||||
loss_global = self.calculate_similarity(mutil_real_A0_tokens, mutil_fake_B0_tokens) + self.calculate_similarity(mutil_real_A1_tokens, mutil_fake_B1_tokens)
|
||||
loss_global *= 0.5
|
||||
|
||||
else:
|
||||
loss_global = 0.0
|
||||
|
||||
if self.opt.lambda_spatial > 0.0:
|
||||
loss_spatial = 0.0
|
||||
local_nums = self.opt.local_nums
|
||||
tokens_cnt = 576
|
||||
local_id = np.random.permutation(tokens_cnt)
|
||||
local_id = local_id[:int(min(local_nums, tokens_cnt))]
|
||||
|
||||
mutil_real_A0_local_tokens = self.netPreViT(self.real_A0_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
|
||||
mutil_real_A1_local_tokens = self.netPreViT(self.real_A1_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
|
||||
|
||||
mutil_fake_B0_local_tokens = self.netPreViT(self.fake_B0_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
|
||||
mutil_fake_B1_local_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
|
||||
|
||||
loss_spatial = self.calculate_similarity(mutil_real_A0_local_tokens, mutil_fake_B0_local_tokens) + self.calculate_similarity(mutil_real_A1_local_tokens, mutil_fake_B1_local_tokens)
|
||||
loss_spatial *= 0.5
|
||||
|
||||
else:
|
||||
loss_spatial = 0.0
|
||||
|
||||
|
||||
|
||||
return loss_global * self.opt.lambda_global, loss_spatial * self.opt.lambda_spatial
|
||||
|
||||
def calculate_similarity(self, mutil_src_tokens, mutil_tgt_tokens):
|
||||
loss = 0.0
|
||||
n_layers = len(self.atten_layers)
|
||||
|
||||
for src_tokens, tgt_tokens in zip(mutil_src_tokens, mutil_tgt_tokens):
|
||||
|
||||
src_tgt = src_tokens.bmm(tgt_tokens.permute(0,2,1))
|
||||
tgt_src = tgt_tokens.bmm(src_tokens.permute(0,2,1))
|
||||
cos_dis_global = F.cosine_similarity(src_tgt, tgt_src, dim=-1)
|
||||
loss += self.criterionL1(torch.ones_like(cos_dis_global), cos_dis_global).mean()
|
||||
|
||||
loss = loss / n_layers
|
||||
return loss
|
||||
|
||||
272
models/roma_single_model.py
Normal file
272
models/roma_single_model.py
Normal file
@ -0,0 +1,272 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from .base_model import BaseModel
|
||||
from . import networks
|
||||
from .patchnce import PatchNCELoss
|
||||
import util.util as util
|
||||
import timm
|
||||
import time
|
||||
import torch.nn.functional as F
|
||||
import sys
|
||||
from functools import partial
|
||||
import torch.nn as nn
|
||||
import math
|
||||
|
||||
from torchvision.transforms import transforms as tfs
|
||||
|
||||
class ROMASingleModel(BaseModel):
|
||||
|
||||
@staticmethod
|
||||
def modify_commandline_options(parser, is_train=True):
|
||||
""" Configures options specific for CUT model
|
||||
"""
|
||||
parser.add_argument('--adj_size_list', type=list, default=[2, 4, 6, 8, 12], help='different scales of perception field')
|
||||
parser.add_argument('--lambda_mlp', type=float, default=1.0, help='weight of lr for discriminator')
|
||||
parser.add_argument('--lambda_motion', type=float, default=1.0, help='weight for Temporal Consistency')
|
||||
parser.add_argument('--lambda_D_ViT', type=float, default=1.0, help='weight for discriminator')
|
||||
parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss: GAN(G(X))')
|
||||
parser.add_argument('--lambda_global', type=float, default=1.0, help='weight for Global Structural Consistency')
|
||||
parser.add_argument('--lambda_spatial', type=float, default=1.0, help='weight for Local Structural Consistency')
|
||||
parser.add_argument('--atten_layers', type=str, default='1,3,5', help='compute Cross-Similarity on which layers')
|
||||
parser.add_argument('--local_nums', type=int, default=256)
|
||||
parser.add_argument('--which_D_layer', type=int, default=-1)
|
||||
parser.add_argument('--side_length', type=int, default=7)
|
||||
|
||||
parser.set_defaults(pool_size=0)
|
||||
|
||||
opt, _ = parser.parse_known_args()
|
||||
|
||||
return parser
|
||||
|
||||
def __init__(self, opt):
|
||||
BaseModel.__init__(self, opt)
|
||||
|
||||
|
||||
self.loss_names = ['G_GAN_ViT', 'D_real_ViT', 'D_fake_ViT', 'global', 'spatial']
|
||||
self.visual_names = ['real_A', 'fake_B', 'real_B']
|
||||
self.atten_layers = [int(i) for i in self.opt.atten_layers.split(',')]
|
||||
|
||||
|
||||
if self.isTrain:
|
||||
self.model_names = ['G', 'D_ViT']
|
||||
else: # during test time, only load G
|
||||
self.model_names = ['G']
|
||||
|
||||
|
||||
# define networks (both generator and discriminator)
|
||||
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt)
|
||||
|
||||
|
||||
if self.isTrain:
|
||||
|
||||
self.netD_ViT = networks.MLPDiscriminator().to(self.device)
|
||||
# self.netPreViT = timm.create_model("vit_base_patch32_384",pretrained=True).to(self.device)
|
||||
self.netPreViT = timm.create_model("vit_base_patch16_384",pretrained=True).to(self.device)
|
||||
|
||||
|
||||
self.norm = F.softmax
|
||||
|
||||
self.resize = tfs.Resize(size=(384,384))
|
||||
# self.resize = tfs.Resize(size=(224, 224))
|
||||
|
||||
# define loss functions
|
||||
self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
|
||||
self.criterionNCE = []
|
||||
|
||||
for atten_layer in self.atten_layers:
|
||||
self.criterionNCE.append(PatchNCELoss(opt).to(self.device))
|
||||
|
||||
self.criterionL1 = torch.nn.L1Loss().to(self.device)
|
||||
self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
|
||||
self.optimizer_D_ViT = torch.optim.Adam(self.netD_ViT.parameters(), lr=opt.lr * opt.lambda_mlp, betas=(opt.beta1, opt.beta2))
|
||||
self.optimizers.append(self.optimizer_G)
|
||||
self.optimizers.append(self.optimizer_D_ViT)
|
||||
|
||||
def data_dependent_initialize(self, data):
|
||||
"""
|
||||
The feature network netF is defined in terms of the shape of the intermediate, extracted
|
||||
features of the encoder portion of netG. Because of this, the weights of netF are
|
||||
initialized at the first feedforward pass with some input images.
|
||||
Please also see PatchSampleF.create_mlp(), which is called at the first forward() call.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def optimize_parameters(self):
|
||||
# forward
|
||||
self.forward()
|
||||
|
||||
# update D
|
||||
self.set_requires_grad(self.netD_ViT, True)
|
||||
self.optimizer_D_ViT.zero_grad()
|
||||
self.loss_D = self.compute_D_loss()
|
||||
self.loss_D.backward()
|
||||
self.optimizer_D_ViT.step()
|
||||
|
||||
# update G
|
||||
self.set_requires_grad(self.netD_ViT, False)
|
||||
self.optimizer_G.zero_grad()
|
||||
self.loss_G = self.compute_G_loss()
|
||||
self.loss_G.backward()
|
||||
self.optimizer_G.step()
|
||||
|
||||
def set_input(self, input):
|
||||
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
|
||||
Parameters:
|
||||
input (dict): include the data itself and its metadata information.
|
||||
The option 'direction' can be used to swap domain A and domain B.
|
||||
"""
|
||||
AtoB = self.opt.direction == 'AtoB'
|
||||
self.real_A = input['A' if AtoB else 'B'].to(self.device)
|
||||
self.real_B = input['B' if AtoB else 'A'].to(self.device)
|
||||
self.image_paths = input['A_paths' if AtoB else 'B_paths']
|
||||
|
||||
def forward(self):
|
||||
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
|
||||
self.fake_B = self.netG(self.real_A)
|
||||
|
||||
if self.opt.isTrain:
|
||||
real_A = self.real_A
|
||||
real_B = self.real_B
|
||||
fake_B = self.fake_B
|
||||
self.real_A_resize = self.resize(real_A)
|
||||
real_B = self.resize(real_B)
|
||||
self.fake_B_resize = self.resize(fake_B)
|
||||
self.mutil_real_A_tokens = self.netPreViT(self.real_A_resize, self.atten_layers, get_tokens=True)
|
||||
self.mutil_real_B_tokens = self.netPreViT(real_B, self.atten_layers, get_tokens=True)
|
||||
self.mutil_fake_B_tokens = self.netPreViT(self.fake_B_resize, self.atten_layers, get_tokens=True)
|
||||
|
||||
def tokens_concat(self, origin_tokens, adjacent_size):
|
||||
adj_size = adjacent_size
|
||||
B, token_num, C = origin_tokens.shape[0], origin_tokens.shape[1], origin_tokens.shape[2]
|
||||
S = int(math.sqrt(token_num))
|
||||
if S * S != token_num:
|
||||
print('Error! Not a square!')
|
||||
token_map = origin_tokens.clone().reshape(B,S,S,C)
|
||||
cut_patch_list = []
|
||||
for i in range(0, S, adj_size):
|
||||
for j in range(0, S, adj_size):
|
||||
i_left = i
|
||||
i_right = i + adj_size + 1 if i + adj_size <= S else S + 1
|
||||
j_left = j
|
||||
j_right = j + adj_size if j + adj_size <= S else S + 1
|
||||
|
||||
cut_patch = token_map[:, i_left:i_right, j_left: j_right, :]
|
||||
cut_patch= cut_patch.reshape(B,-1,C)
|
||||
cut_patch = torch.mean(cut_patch, dim=1, keepdim=True)
|
||||
cut_patch_list.append(cut_patch)
|
||||
|
||||
|
||||
result = torch.cat(cut_patch_list,dim=1)
|
||||
return result
|
||||
|
||||
|
||||
def cat_results(self, origin_tokens, adj_size_list):
|
||||
res_list = [origin_tokens]
|
||||
for ad_s in adj_size_list:
|
||||
cat_result = self.tokens_concat(origin_tokens, ad_s)
|
||||
res_list.append(cat_result)
|
||||
|
||||
result = torch.cat(res_list, dim=1)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
|
||||
def compute_D_loss(self):
|
||||
"""Calculate GAN loss for the discriminator"""
|
||||
|
||||
|
||||
lambda_D_ViT = self.opt.lambda_D_ViT
|
||||
fake_B_tokens = self.mutil_fake_B_tokens[self.opt.which_D_layer].detach()
|
||||
|
||||
real_B_tokens = self.mutil_real_B_tokens[self.opt.which_D_layer]
|
||||
|
||||
|
||||
fake_B_tokens = self.cat_results(fake_B_tokens, self.opt.adj_size_list)
|
||||
|
||||
real_B_tokens = self.cat_results(real_B_tokens, self.opt.adj_size_list)
|
||||
|
||||
pre_fake_ViT = self.netD_ViT(fake_B_tokens)
|
||||
|
||||
|
||||
self.loss_D_fake_ViT = self.criterionGAN(pre_fake_ViT, False).mean() * lambda_D_ViT
|
||||
|
||||
pred_real_ViT = self.netD_ViT(real_B_tokens)
|
||||
self.loss_D_real_ViT = self.criterionGAN(pred_real_ViT, True).mean() * lambda_D_ViT
|
||||
|
||||
self.loss_D_ViT = (self.loss_D_fake_ViT + self.loss_D_real_ViT) * 0.5
|
||||
|
||||
|
||||
return self.loss_D_ViT
|
||||
|
||||
def compute_G_loss(self):
|
||||
|
||||
if self.opt.lambda_GAN > 0.0:
|
||||
|
||||
fake_B_tokens = self.mutil_fake_B_tokens[self.opt.which_D_layer]
|
||||
fake_B_tokens = self.cat_results(fake_B_tokens, self.opt.adj_size_list)
|
||||
pred_fake_ViT = self.netD_ViT(fake_B_tokens)
|
||||
self.loss_G_GAN_ViT = self.criterionGAN(pred_fake_ViT, True) * self.opt.lambda_GAN
|
||||
else:
|
||||
self.loss_G_GAN_ViT = 0.0
|
||||
|
||||
if self.opt.lambda_global > 0.0 or self.opt.lambda_spatial > 0.0:
|
||||
self.loss_global, self.loss_spatial = self.calculate_attention_loss()
|
||||
else:
|
||||
self.loss_global, self.loss_spatial = 0.0, 0.0
|
||||
|
||||
|
||||
|
||||
self.loss_G = self.loss_G_GAN_ViT + self.loss_global + self.loss_spatial
|
||||
return self.loss_G
|
||||
|
||||
def calculate_attention_loss(self):
|
||||
n_layers = len(self.atten_layers)
|
||||
mutil_real_A_tokens = self.mutil_real_A_tokens
|
||||
mutil_fake_B_tokens = self.mutil_fake_B_tokens
|
||||
|
||||
|
||||
|
||||
if self.opt.lambda_global > 0.0:
|
||||
loss_global = self.calculate_similarity(mutil_real_A_tokens, mutil_fake_B_tokens)
|
||||
|
||||
|
||||
else:
|
||||
loss_global = 0.0
|
||||
|
||||
if self.opt.lambda_spatial > 0.0:
|
||||
loss_spatial = 0.0
|
||||
local_nums = self.opt.local_nums
|
||||
tokens_cnt = 576
|
||||
local_id = np.random.permutation(tokens_cnt)
|
||||
local_id = local_id[:int(min(local_nums, tokens_cnt))]
|
||||
|
||||
mutil_real_A_local_tokens = self.netPreViT(self.real_A_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
|
||||
|
||||
mutil_fake_B_local_tokens = self.netPreViT(self.fake_B_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
|
||||
|
||||
loss_spatial = self.calculate_similarity(mutil_real_A_local_tokens, mutil_fake_B_local_tokens)
|
||||
|
||||
|
||||
else:
|
||||
loss_spatial = 0.0
|
||||
|
||||
|
||||
|
||||
return loss_global * self.opt.lambda_global, loss_spatial * self.opt.lambda_spatial
|
||||
|
||||
def calculate_similarity(self, mutil_src_tokens, mutil_tgt_tokens):
|
||||
loss = 0.0
|
||||
n_layers = len(self.atten_layers)
|
||||
|
||||
for src_tokens, tgt_tokens in zip(mutil_src_tokens, mutil_tgt_tokens):
|
||||
|
||||
src_tgt = src_tokens.bmm(tgt_tokens.permute(0,2,1))
|
||||
tgt_src = tgt_tokens.bmm(src_tokens.permute(0,2,1))
|
||||
cos_dis_global = F.cosine_similarity(src_tgt, tgt_src, dim=-1)
|
||||
loss += self.criterionL1(torch.ones_like(cos_dis_global), cos_dis_global).mean()
|
||||
|
||||
loss = loss / n_layers
|
||||
return loss
|
||||
|
||||
655
models/self_build.py
Normal file
655
models/self_build.py
Normal file
@ -0,0 +1,655 @@
|
||||
import numpy as np
|
||||
import math
|
||||
import timm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torchvision.transforms import GaussianBlur
|
||||
from .base_model import BaseModel
|
||||
from . import networks
|
||||
from .patchnce import PatchNCELoss
|
||||
import util.util as util
|
||||
|
||||
from torchvision.transforms import transforms as tfs
|
||||
|
||||
def warp(image, flow): #warp操作
|
||||
"""
|
||||
基于光流的图像变形函数
|
||||
Args:
|
||||
image: [B, C, H, W] 输入图像
|
||||
flow: [B, 2, H, W] 光流场(x/y方向位移)
|
||||
Returns:
|
||||
warped: [B, C, H, W] 变形后的图像
|
||||
"""
|
||||
B, C, H, W = image.shape
|
||||
# 生成网格坐标
|
||||
grid_x, grid_y = torch.meshgrid(torch.arange(W), torch.arange(H))
|
||||
grid = torch.stack((grid_x, grid_y), dim=0).float().to(image.device) # [2,H,W]
|
||||
grid = grid.unsqueeze(0).repeat(B,1,1,1) # [B,2,H,W]
|
||||
|
||||
# 应用光流位移(归一化到[-1,1])
|
||||
new_grid = grid + flow
|
||||
new_grid[:,0,:,:] = 2.0 * new_grid[:,0,:,:] / (W-1) - 1.0 # x方向
|
||||
new_grid[:,1,:,:] = 2.0 * new_grid[:,1,:,:] / (H-1) - 1.0 # y方向
|
||||
new_grid = new_grid.permute(0,2,3,1) # [B,H,W,2]
|
||||
|
||||
# 双线性插值
|
||||
return F.grid_sample(image, new_grid, align_corners=True)
|
||||
|
||||
# 时序归一化损失计算
|
||||
def compute_ctn_loss(G, x, F_content): #公式10
|
||||
"""
|
||||
计算内容感知时序归一化损失
|
||||
Args:
|
||||
G: 生成器
|
||||
x: 输入红外图像 [B,C,H,W]
|
||||
F_content: 生成的光流场 [B,2,H,W]
|
||||
"""
|
||||
|
||||
# 生成可见光图像
|
||||
y_fake = G(x) # [B,3,H,W]
|
||||
|
||||
# 对生成结果应用光流变形
|
||||
warped_fake = warp(y_fake, F_content) # [B,3,H,W]
|
||||
|
||||
# 对输入应用相同光流后生成图像
|
||||
warped_x = warp(x, F_content) # [B,C,H,W]
|
||||
y_fake_warped = G(warped_x) # [B,3,H,W]
|
||||
|
||||
# 计算L2损失
|
||||
loss = F.mse_loss(warped_fake, y_fake_warped)
|
||||
return loss
|
||||
|
||||
class ContentAwareOptimization(nn.Module):
|
||||
def __init__(self, lambda_inc=2.0, eta_ratio=0.4):
|
||||
super().__init__()
|
||||
self.lambda_inc = lambda_inc # 权重增强系数
|
||||
self.eta_ratio = eta_ratio # 选择内容区域的比例
|
||||
|
||||
def compute_cosine_similarity(self, gradients):
|
||||
"""
|
||||
计算每个patch梯度与平均梯度的余弦相似度
|
||||
Args:
|
||||
gradients: [B, N, D] 判别器输出的每个patch的梯度(N=w*h)
|
||||
Returns:
|
||||
cosine_sim: [B, N] 每个patch的余弦相似度
|
||||
"""
|
||||
mean_grad = torch.mean(gradients, dim=1, keepdim=True) # [B, 1, D]
|
||||
# 计算余弦相似度
|
||||
cosine_sim = F.cosine_similarity(gradients, mean_grad, dim=2) # [B, N]
|
||||
return cosine_sim
|
||||
|
||||
def generate_weight_map(self, gradients_real, gradients_fake):
|
||||
"""
|
||||
生成内容感知权重图
|
||||
Args:
|
||||
gradients_real: [B, N, D] 真实图像判别器梯度
|
||||
gradients_fake: [B, N, D] 生成图像判别器梯度
|
||||
Returns:
|
||||
weight_real: [B, N] 真实图像权重图
|
||||
weight_fake: [B, N] 生成图像权重图
|
||||
"""
|
||||
# 计算真实图像块的余弦相似度
|
||||
cosine_real = self.compute_cosine_similarity(gradients_real) # [B, N] 公式5
|
||||
# 计算生成图像块的余弦相似度
|
||||
cosine_fake = self.compute_cosine_similarity(gradients_fake) # [B, N]
|
||||
|
||||
# 选择内容丰富的区域(余弦相似度最低的eta_ratio比例)
|
||||
k = int(self.eta_ratio * cosine_real.shape[1])
|
||||
|
||||
# 对真实图像生成权重图
|
||||
_, real_indices = torch.topk(-cosine_real, k, dim=1) # 选择最不相似的区域
|
||||
weight_real = torch.ones_like(cosine_real)
|
||||
for b in range(cosine_real.shape[0]):
|
||||
weight_real[b, real_indices[b]] = self.lambda_inc / (1e-6 + torch.abs(cosine_real[b, real_indices[b]])) #公式6
|
||||
|
||||
# 对生成图像生成权重图(同理)
|
||||
_, fake_indices = torch.topk(-cosine_fake, k, dim=1)
|
||||
weight_fake = torch.ones_like(cosine_fake)
|
||||
for b in range(cosine_fake.shape[0]):
|
||||
weight_fake[b, fake_indices[b]] = self.lambda_inc / (1e-6 + torch.abs(cosine_fake[b, fake_indices[b]]))
|
||||
|
||||
return weight_real, weight_fake
|
||||
|
||||
def forward(self, D_real, D_fake, real_scores, fake_scores):
|
||||
"""
|
||||
计算内容感知对抗损失
|
||||
Args:
|
||||
D_real: 判别器对真实图像的特征输出 [B, C, H, W]
|
||||
D_fake: 判别器对生成图像的特征输出 [B, C, H, W]
|
||||
real_scores: 真实图像的判别器预测 [B, N] (N=H*W)
|
||||
fake_scores: 生成图像的判别器预测 [B, N]
|
||||
Returns:
|
||||
loss_co_adv: 内容感知对抗损失
|
||||
"""
|
||||
B, C, H, W = D_real.shape
|
||||
N = H * W
|
||||
|
||||
# 注册钩子获取梯度
|
||||
gradients_real = []
|
||||
gradients_fake = []
|
||||
|
||||
def hook_real(grad):
|
||||
gradients_real.append(grad.detach().view(B, N, -1))
|
||||
|
||||
def hook_fake(grad):
|
||||
gradients_fake.append(grad.detach().view(B, N, -1))
|
||||
|
||||
D_real.register_hook(hook_real)
|
||||
D_fake.register_hook(hook_fake)
|
||||
|
||||
# 计算原始对抗损失以触发梯度计算
|
||||
loss_real = torch.mean(torch.log(real_scores + 1e-8))
|
||||
loss_fake = torch.mean(torch.log(1 - fake_scores + 1e-8))
|
||||
# 添加与 D_real、D_fake 相关的 dummy 项,确保梯度传递
|
||||
loss_dummy = 1e-8 * (D_real.sum() + D_fake.sum())
|
||||
total_loss = loss_real + loss_fake + loss_dummy
|
||||
total_loss.backward(retain_graph=True)
|
||||
|
||||
# 获取梯度数据
|
||||
gradients_real = gradients_real[0] # [B, N, D]
|
||||
gradients_fake = gradients_fake[0] # [B, N, D]
|
||||
|
||||
# 生成权重图
|
||||
self.weight_real, self.weight_fake = self.generate_weight_map(gradients_real, gradients_fake)
|
||||
|
||||
# 应用权重到对抗损失
|
||||
loss_co_real = torch.mean(self.weight_real * torch.log(real_scores + 1e-8))
|
||||
loss_co_fake = torch.mean(self.weight_fake * torch.log(1 - fake_scores + 1e-8))
|
||||
|
||||
# 计算并返回最终内容感知对抗损失
|
||||
loss_co_adv = -(loss_co_real + loss_co_fake)
|
||||
|
||||
return loss_co_adv
|
||||
|
||||
class ContentAwareTemporalNorm(nn.Module):
|
||||
def __init__(self, gamma_stride=0.1, kernel_size=21, sigma=5.0):
|
||||
super().__init__()
|
||||
self.gamma_stride = gamma_stride # 控制整体运动幅度
|
||||
self.smoother = GaussianBlur(kernel_size, sigma=sigma) # 高斯平滑层
|
||||
|
||||
def forward(self, weight_map):
|
||||
"""
|
||||
生成内容感知光流
|
||||
Args:
|
||||
weight_map: [B, 1, H, W] 权重图(来自内容感知优化模块)
|
||||
Returns:
|
||||
F_content: [B, 2, H, W] 生成的光流场(x/y方向位移)
|
||||
"""
|
||||
B, _, H, W = weight_map.shape
|
||||
|
||||
# 1. 归一化权重图
|
||||
# 保持区域相对强度,同时限制数值范围
|
||||
weight_norm = F.normalize(weight_map, p=1, dim=(2,3)) # L1归一化 [B,1,H,W]
|
||||
|
||||
# 2. 生成高斯噪声(与光流场同尺寸)
|
||||
z = torch.randn(B, 2, H, W, device=weight_map.device) # [B,2,H,W]
|
||||
|
||||
# 3. 合成基础光流
|
||||
# 将权重图扩展为2通道(x/y方向共享权重)
|
||||
weight_expanded = weight_norm.expand(-1, 2, -1, -1) # [B,2,H,W]
|
||||
F_raw = self.gamma_stride * weight_expanded * z # [B,2,H,W] #公式9
|
||||
|
||||
# 4. 平滑处理(保持结构连续性)
|
||||
# 对每个通道独立进行高斯模糊
|
||||
F_smooth = self.smoother(F_raw) # [B,2,H,W]
|
||||
|
||||
# 5. 动态范围调整(可选)
|
||||
# 限制光流幅值,避免极端位移
|
||||
F_content = torch.tanh(F_smooth) # 缩放到[-1,1]范围
|
||||
|
||||
return F_content
|
||||
|
||||
class CTNxModel(BaseModel):
|
||||
@staticmethod
|
||||
def modify_commandline_options(parser, is_train=True):
|
||||
"""配置 CTNx 模型的特定选项"""
|
||||
|
||||
parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss:GAN(G(X))')
|
||||
parser.add_argument('--lambda_NCE', type=float, default=1.0, help='weight for NCE loss: NCE(G(X), X)')
|
||||
parser.add_argument('--lambda_SB', type=float, default=0.1, help='weight for SB loss')
|
||||
parser.add_argument('--lambda_ctn', type=float, default=1.0, help='weight for content-aware temporal norm')
|
||||
|
||||
parser.add_argument('--nce_idt', type=util.str2bool, nargs='?', const=True, default=False, help='use NCE loss for identity mapping: NCE(G(Y), Y))')
|
||||
parser.add_argument('--nce_layers', type=str, default='0,4,8,12,16', help='compute NCE loss on which layers')
|
||||
parser.add_argument('--nce_includes_all_negatives_from_minibatch',
|
||||
type=util.str2bool, nargs='?', const=True, default=False,
|
||||
help='(used for single image translation) If True, include the negatives from the other samples of the minibatch when computing the contrastive loss. Please see models/patchnce.py for more details.')
|
||||
|
||||
parser.add_argument('--netF', type=str, default='mlp_sample', choices=['sample', 'reshape', 'mlp_sample'], help='how to downsample the feature map')
|
||||
parser.add_argument('--netF_nc', type=int, default=256)
|
||||
parser.add_argument('--nce_T', type=float, default=0.07, help='temperature for NCE loss')
|
||||
|
||||
parser.add_argument('--lmda_1', type=float, default=0.1)
|
||||
parser.add_argument('--num_patches', type=int, default=256, help='number of patches per layer')
|
||||
parser.add_argument('--flip_equivariance',
|
||||
type=util.str2bool, nargs='?', const=True, default=False,
|
||||
help="Enforce flip-equivariance as additional regularization. It's used by FastCUT, but not CUT")
|
||||
|
||||
parser.add_argument('--lambda_inc', type=float, default=1.0, help='incremental weight for content-aware optimization')
|
||||
parser.add_argument('--eta_ratio', type=float, default=0.1, help='ratio of content-rich regions')
|
||||
|
||||
|
||||
parser.set_defaults(pool_size=0) # no image pooling
|
||||
|
||||
opt, _ = parser.parse_known_args()
|
||||
|
||||
# 直接设置为 sb 模式
|
||||
parser.set_defaults(nce_idt=True, lambda_NCE=1.0)
|
||||
|
||||
return parser
|
||||
|
||||
def __init__(self, opt):
|
||||
"""初始化 CTNx 模型"""
|
||||
BaseModel.__init__(self, opt)
|
||||
|
||||
# 指定需要打印的训练损失
|
||||
self.loss_names = ['G_GAN_1', 'D_real_1', 'D_fake_1', 'G_1', 'NCE_1', 'SB_1',
|
||||
'G_2']
|
||||
self.visual_names = ['real_A', 'real_A_noisy', 'fake_B', 'real_B']
|
||||
self.atten_layers = [int(i) for i in self.opt.atten_layers.split(',')]
|
||||
|
||||
if self.opt.phase == 'test':
|
||||
self.visual_names = ['real']
|
||||
for NFE in range(self.opt.num_timesteps):
|
||||
fake_name = 'fake_' + str(NFE+1)
|
||||
self.visual_names.append(fake_name)
|
||||
self.nce_layers = [int(i) for i in self.opt.nce_layers.split(',')]
|
||||
|
||||
if opt.nce_idt and self.isTrain:
|
||||
self.loss_names += ['NCE_Y']
|
||||
self.visual_names += ['idt_B']
|
||||
|
||||
if self.isTrain:
|
||||
self.model_names = ['G1', 'F1', 'D1', 'E1',
|
||||
'G2']
|
||||
|
||||
|
||||
else:
|
||||
self.model_names = ['G1']
|
||||
|
||||
# 创建网络
|
||||
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt)
|
||||
|
||||
|
||||
if self.isTrain:
|
||||
self.netD = networks.define_D(opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.normD, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt)
|
||||
self.netE = networks.define_D(opt.output_nc*4, opt.ndf, opt.netD, opt.n_layers_D, opt.normD, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt)
|
||||
|
||||
self.resize = tfs.Resize(size=(384,384))
|
||||
|
||||
# 加入预训练VIT
|
||||
self.netPreViT = timm.create_model("vit_base_patch16_384", pretrained=True).to(self.device)
|
||||
|
||||
# 定义损失函数
|
||||
self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
|
||||
self.criterionNCE = []
|
||||
for nce_layer in self.nce_layers:
|
||||
self.criterionNCE.append(PatchNCELoss(opt).to(self.device))
|
||||
self.criterionIdt = torch.nn.L1Loss().to(self.device)
|
||||
self.optimizer_G1 = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
|
||||
self.optimizer_D1 = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
|
||||
self.optimizer_E1 = torch.optim.Adam(self.netE.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
|
||||
self.optimizers = [self.optimizer_G1, self.optimizer_D1, self.optimizer_E1]
|
||||
|
||||
self.cao = ContentAwareOptimization(opt.lambda_inc, opt.eta_ratio) #损失函数
|
||||
self.ctn = ContentAwareTemporalNorm() #生成的伪光流
|
||||
|
||||
def data_dependent_initialize(self, data):
|
||||
"""
|
||||
The feature network netF is defined in terms of the shape of the intermediate, extracted
|
||||
features of the encoder portion of netG. Because of this, the weights of netF are
|
||||
initialized at the first feedforward pass with some input images.
|
||||
Please also see PatchSampleF.create_mlp(), which is called at the first forward() call.
|
||||
"""
|
||||
#bs_per_gpu = data["A"].size(0) // max(len(self.opt.gpu_ids), 1)
|
||||
#self.set_input(data)
|
||||
#self.real_A = self.real_A[:bs_per_gpu]
|
||||
#self.real_B = self.real_B[:bs_per_gpu]
|
||||
#self.forward() # compute fake images: G(A)
|
||||
#if self.opt.isTrain:
|
||||
#
|
||||
# self.compute_G_loss().backward()
|
||||
# self.compute_D_loss().backward()
|
||||
# self.compute_E_loss().backward()
|
||||
# if self.opt.lambda_NCE > 0.0:
|
||||
# self.optimizer_F = torch.optim.Adam(self.netF.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, self.opt.beta2))
|
||||
# self.optimizers.append(self.optimizer_F)
|
||||
pass
|
||||
|
||||
def optimize_parameters(self):
|
||||
# forward
|
||||
self.forward()
|
||||
|
||||
self.netG.train()
|
||||
self.netE.train()
|
||||
self.netD.train()
|
||||
|
||||
# update D
|
||||
self.set_requires_grad(self.netD, True)
|
||||
self.optimizer_D.zero_grad()
|
||||
self.loss_D = self.compute_D_loss()
|
||||
self.loss_D.backward()
|
||||
self.optimizer_D.step()
|
||||
|
||||
self.set_requires_grad(self.netE, True)
|
||||
self.optimizer_E.zero_grad()
|
||||
self.loss_E = self.compute_E_loss()
|
||||
self.loss_E.backward()
|
||||
self.optimizer_E.step()
|
||||
|
||||
# update G
|
||||
self.set_requires_grad(self.netD, False)
|
||||
self.set_requires_grad(self.netE, False)
|
||||
|
||||
self.optimizer_G.zero_grad()
|
||||
|
||||
self.loss_G = self.compute_G_loss()
|
||||
self.loss_G.backward()
|
||||
self.optimizer_G.step()
|
||||
|
||||
|
||||
def set_input(self, input):
|
||||
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
|
||||
Parameters:
|
||||
input (dict): include the data itself and its metadata information.
|
||||
The option 'direction' can be used to swap domain A and domain B.
|
||||
"""
|
||||
AtoB = self.opt.direction == 'AtoB'
|
||||
self.real_A0 = input['A0' if AtoB else 'B0'].to(self.device)
|
||||
self.real_A1 = input['A1' if AtoB else 'B1'].to(self.device)
|
||||
self.real_B0 = input['B0' if AtoB else 'A0'].to(self.device)
|
||||
self.real_B1 = input['B1' if AtoB else 'A1'].to(self.device)
|
||||
self.image_paths = input['A_paths' if AtoB else 'B_paths']
|
||||
|
||||
|
||||
def tokens_concat(self, origin_tokens, adjacent_size):
|
||||
adj_size = adjacent_size
|
||||
B, token_num, C = origin_tokens.shape[0], origin_tokens.shape[1], origin_tokens.shape[2]
|
||||
S = int(math.sqrt(token_num))
|
||||
if S * S != token_num:
|
||||
print('Error! Not a square!')
|
||||
token_map = origin_tokens.clone().reshape(B,S,S,C)
|
||||
cut_patch_list = []
|
||||
for i in range(0, S, adj_size):
|
||||
for j in range(0, S, adj_size):
|
||||
i_left = i
|
||||
i_right = i + adj_size + 1 if i + adj_size <= S else S + 1
|
||||
j_left = j
|
||||
j_right = j + adj_size if j + adj_size <= S else S + 1
|
||||
|
||||
cut_patch = token_map[:, i_left:i_right, j_left: j_right, :]
|
||||
cut_patch= cut_patch.reshape(B,-1,C)
|
||||
cut_patch = torch.mean(cut_patch, dim=1, keepdim=True)
|
||||
cut_patch_list.append(cut_patch)
|
||||
|
||||
|
||||
result = torch.cat(cut_patch_list,dim=1)
|
||||
return result
|
||||
|
||||
def cat_results(self, origin_tokens, adj_size_list):
|
||||
res_list = [origin_tokens]
|
||||
for ad_s in adj_size_list:
|
||||
cat_result = self.tokens_concat(origin_tokens, ad_s)
|
||||
res_list.append(cat_result)
|
||||
|
||||
result = torch.cat(res_list, dim=1)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
|
||||
def forward(self):
|
||||
"""执行前向传递以生成输出图像"""
|
||||
|
||||
if self.opt.isTrain:
|
||||
real_A0 = self.resize(self.real_A0)
|
||||
real_A1 = self.resize(self.real_A1)
|
||||
real_B0 = self.resize(self.real_B0)
|
||||
real_B1 = self.resize(self.real_B1)
|
||||
# 使用VIT
|
||||
self.mutil_real_A0_tokens = self.netPreViT(real_A0, self.atten_layers, get_tokens=True)
|
||||
self.mutil_real_A1_tokens = self.netPreViT(real_A1, self.atten_layers, get_tokens=True)
|
||||
|
||||
# 执行一次SB模块
|
||||
|
||||
# ============ 第一步:初始化时间步与时间索引 ============
|
||||
# 计算 times,并确定当前 time_idx(随机选取用来表示当前时间步)
|
||||
tau = self.opt.tau
|
||||
T = self.opt.num_timesteps
|
||||
incs = np.array([0] + [1/(i+1) for i in range(T-1)])
|
||||
times = np.cumsum(incs)
|
||||
times = times / times[-1]
|
||||
times = 0.5 * times[-1] + 0.5 * times #[0.5,1]
|
||||
times = np.concatenate([np.zeros(1), times])
|
||||
times = torch.tensor(times).float().cuda()
|
||||
self.times = times
|
||||
bs = self.mutil_real_A0_tokens.size(0)
|
||||
time_idx = (torch.randint(T, size=[1]).cuda() * torch.ones(size=[1]).cuda()).long()
|
||||
self.time_idx = time_idx
|
||||
|
||||
with torch.no_grad():
|
||||
self.netG.eval()
|
||||
# ============ 第二步:对 real_A / real_A2 进行多步随机生成过程 ============
|
||||
for t in range(self.time_idx.int().item() + 1):
|
||||
# 计算增量 delta 与 inter/scale,用于每个时间步的插值等
|
||||
if t > 0:
|
||||
delta = times[t] - times[t - 1]
|
||||
denom = times[-1] - times[t - 1]
|
||||
inter = (delta / denom).reshape(-1, 1, 1, 1)
|
||||
scale = (delta * (1 - delta / denom)).reshape(-1, 1, 1, 1)
|
||||
|
||||
# 对 Xt、Xt2 进行随机噪声更新
|
||||
Xt = self.mutil_real_A0_tokens if (t == 0) else (1 - inter) * Xt + inter * Xt_1.detach() + \
|
||||
(scale * tau).sqrt() * torch.randn_like(Xt).to(self.mutil_real_A0_tokens.device)
|
||||
time_idx = (t * torch.ones(size=[self.mutil_real_A0_tokens.shape[0]]).to(self.mutil_real_A0_tokens.device)).long()
|
||||
z = torch.randn(size=[self.mutil_real_A0_tokens.shape[0], 4 * self.opt.ngf]).to(self.mutil_real_A0_tokens.device)
|
||||
self.time = times[time_idx]
|
||||
Xt_1 = self.netG(Xt, self.time, z)
|
||||
|
||||
Xt2 = self.mutil_real_A1_tokens if (t == 0) else (1 - inter) * Xt2 + inter * Xt_12.detach() + \
|
||||
(scale * tau).sqrt() * torch.randn_like(Xt2).to(self.mutil_real_A1_tokens.device)
|
||||
time_idx = (t * torch.ones(size=[self.mutil_real_A1_tokens.shape[0]]).to(self.mutil_real_A1_tokens.device)).long()
|
||||
z = torch.randn(size=[self.mutil_real_A1_tokens.shape[0], 4 * self.opt.ngf]).to(self.mutil_real_A1_tokens.device)
|
||||
Xt_12 = self.netG(Xt2, self.time, z)
|
||||
|
||||
# 保存去噪后的中间结果 (real_A_noisy 等),供下一步做拼接
|
||||
self.real_A_noisy = Xt.detach()
|
||||
self.real_A_noisy2 = Xt2.detach()
|
||||
# 保存noisy_map
|
||||
self.noisy_map = self.real_A_noisy - self.real_A
|
||||
|
||||
# ============ 第三步:拼接输入并执行网络推理 =============
|
||||
bs = self.mutil_real_A0_tokens.size(0)
|
||||
z_in = torch.randn(size=[2 * bs, 4 * self.opt.ngf]).to(self.mutil_real_A0_tokens.device)
|
||||
z_in2 = torch.randn(size=[bs, 4 * self.opt.ngf]).to(self.mutil_real_A1_tokens.device)
|
||||
# 将 real_A, real_B 拼接 (如 nce_idt=True),并同样处理 real_A_noisy 与 XtB
|
||||
self.real = self.mutil_real_A0_tokens
|
||||
self.realt = self.real_A_noisy
|
||||
|
||||
if self.opt.flip_equivariance:
|
||||
self.flipped_for_equivariance = self.opt.isTrain and (np.random.random() < 0.5)
|
||||
if self.flipped_for_equivariance:
|
||||
self.real = torch.flip(self.real, [3])
|
||||
self.realt = torch.flip(self.realt, [3])
|
||||
|
||||
# 使用 netG 生成最终的 fake, fake_B2 等结果
|
||||
self.fake_B = self.netG(self.realt, self.time, z_in)
|
||||
self.fake_B2 = self.netG(self.real, self.time, z_in2)
|
||||
|
||||
self.fake_B = self.resize(self.fake_B)
|
||||
self.fake_B2 = self.resize(self.fake_B2)
|
||||
|
||||
self.fake_B0 = self.fake_B
|
||||
self.fake_B1 = self.fake_B2
|
||||
|
||||
# 使用VIT
|
||||
self.mutil_fake_B0_tokens = self.netPreViT(self.fake_B, self.atten_layers, get_tokens=True)
|
||||
self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B2, self.atten_layers, get_tokens=True)
|
||||
|
||||
# ============ 第四步:推理模式下的多次采样 ============
|
||||
if self.opt.phase == 'test':
|
||||
tau = self.opt.tau
|
||||
T = self.opt.num_timesteps
|
||||
incs = np.array([0] + [1/(i+1) for i in range(T-1)])
|
||||
times = np.cumsum(incs)
|
||||
times = times / times[-1]
|
||||
times = 0.5 * times[-1] + 0.5 * times
|
||||
times = np.concatenate([np.zeros(1),times])
|
||||
times = torch.tensor(times).float().cuda()
|
||||
self.times = times
|
||||
bs = self.real.size(0)
|
||||
time_idx = (torch.randint(T, size=[1]).cuda() * torch.ones(size=[1]).cuda()).long()
|
||||
self.time_idx = time_idx
|
||||
visuals = []
|
||||
with torch.no_grad():
|
||||
self.netG.eval()
|
||||
for t in range(self.opt.num_timesteps):
|
||||
|
||||
if t > 0:
|
||||
delta = times[t] - times[t-1]
|
||||
denom = times[-1] - times[t-1]
|
||||
inter = (delta / denom).reshape(-1,1,1,1)
|
||||
scale = (delta * (1 - delta / denom)).reshape(-1,1,1,1)
|
||||
Xt = self.mutil_real_A0_tokens if (t == 0) else (1-inter) * Xt + inter * Xt_1.detach() + (scale * tau).sqrt() * torch.randn_like(Xt).to(self.mutil_real_A0_tokens.device)
|
||||
time_idx = (t * torch.ones(size=[self.mutil_real_A0_tokens.shape[0]]).to(self.mutil_real_A0_tokens.device)).long()
|
||||
time = times[time_idx]
|
||||
z = torch.randn(size=[self.mutil_real_A0_tokens.shape[0], 4 * self.opt.ngf]).to(self.mutil_real_A0_tokens.device)
|
||||
Xt_1 = self.netG(Xt, time_idx, z)
|
||||
|
||||
setattr(self, "fake_"+str(t+1), Xt_1)
|
||||
|
||||
if self.opt.phase == 'train':
|
||||
# 真实图像的梯度
|
||||
real_gradient = torch.autograd.grad(self.real_B.sum(), self.real_B, create_graph=True)[0]
|
||||
# 生成图像的梯度
|
||||
fake_gradient = torch.autograd.grad(self.fake_B.sum(), self.fake_B, create_graph=True)[0]
|
||||
# 梯度图
|
||||
self.weight_real, self.weight_fake = self.cao.generate_weight_map(real_gradient, fake_gradient)
|
||||
|
||||
# 生成图像的CTN光流图
|
||||
self.f_content = self.ctn(self.weight_fake)
|
||||
|
||||
# 把前面生成后的图片再加上noisy_map
|
||||
self.fake_B_2 = self.fake_B + self.noisy_map
|
||||
|
||||
# 变换后的图片
|
||||
wapped_fake_B = warp(self.fake_B, self.f_content)
|
||||
|
||||
# 经过第二次生成器
|
||||
self.fake_B_2 = self.netG(wapped_fake_B, self.time, z_in)
|
||||
|
||||
def compute_D_loss(self):
|
||||
"""计算判别器的 GAN 损失"""
|
||||
|
||||
fake = self.cat_results(self.fake_B.detach())
|
||||
pred_fake = self.netD(fake, self.time)
|
||||
self.loss_D_fake = self.criterionGAN(pred_fake, False).mean()
|
||||
|
||||
self.pred_real = self.netD(self.real_B0, self.time)
|
||||
loss_D_real = self.criterionGAN(self.pred_real, True)
|
||||
self.loss_D_real = loss_D_real.mean()
|
||||
|
||||
self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
|
||||
return self.loss_D
|
||||
|
||||
def compute_E_loss(self):
|
||||
"""计算判别器 E 的损失"""
|
||||
|
||||
XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B.detach()], dim=1)
|
||||
XtXt_2 = torch.cat([self.real_A_noisy2, self.fake_B2.detach()], dim=1)
|
||||
temp = torch.logsumexp(self.netE(XtXt_1, self.time, XtXt_2).reshape(-1), dim=0).mean()
|
||||
self.loss_E = -self.netE(XtXt_1, self.time, XtXt_1).mean() + temp + temp**2
|
||||
|
||||
return self.loss_E
|
||||
|
||||
def compute_G_loss(self):
|
||||
"""计算生成器的 GAN 损失"""
|
||||
|
||||
bs = self.mutil_real_A0_tokens.size(0)
|
||||
tau = self.opt.tau
|
||||
|
||||
fake = self.fake_B
|
||||
std = torch.rand(size=[1]).item() * self.opt.std
|
||||
|
||||
if self.opt.lambda_GAN > 0.0:
|
||||
pred_fake = self.netD(fake, self.time)
|
||||
self.loss_G_GAN = self.criterionGAN(pred_fake, True).mean() * self.opt.lambda_GAN
|
||||
else:
|
||||
self.loss_G_GAN = 0.0
|
||||
self.loss_SB = 0
|
||||
if self.opt.lambda_SB > 0.0:
|
||||
XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B], dim=1)
|
||||
XtXt_2 = torch.cat([self.real_A_noisy2, self.fake_B2], dim=1)
|
||||
|
||||
bs = self.opt.batch_size
|
||||
|
||||
# eq.9
|
||||
ET_XY = self.netE(XtXt_1, self.time, XtXt_1).mean() - torch.logsumexp(self.netE(XtXt_1, self.time, XtXt_2).reshape(-1), dim=0)
|
||||
self.loss_SB = -(self.opt.num_timesteps - self.time[0]) / self.opt.num_timesteps * self.opt.tau * ET_XY
|
||||
self.loss_SB += self.opt.tau * torch.mean((self.real_A_noisy - self.fake_B) ** 2)
|
||||
|
||||
if self.opt.lambda_global > 0.0:
|
||||
loss_global = self.calculate_similarity(self.mutil_real_A0_tokens, self.mutil_fake_B0_tokens) + self.calculate_similarity(self.mutil_real_A1_tokens, self.mutil_fake_B1_tokens)
|
||||
loss_global *= 0.5
|
||||
else:
|
||||
loss_global = 0.0
|
||||
|
||||
if self.opt.lambda_ctn > 0.0:
|
||||
wapped_fake_B = warp(self.fake_B, self.f_content) # use updated self.f_content
|
||||
self.l2_loss = F.mse_loss(self.fake_B_2, wapped_fake_B) # complete the loss calculation
|
||||
|
||||
self.loss_G = self.loss_G_GAN + self.opt.lambda_SB * self.loss_SB + self.opt.lambda_ctn * self.l2_loss + loss_global * self.opt.lambda_global
|
||||
return self.loss_G
|
||||
|
||||
def calculate_attention_loss(self):
|
||||
n_layers = len(self.atten_layers)
|
||||
mutil_real_A0_tokens = self.mutil_real_A0_tokens
|
||||
mutil_real_A1_tokens = self.mutil_real_A1_tokens
|
||||
mutil_fake_B0_tokens = self.mutil_fake_B0_tokens
|
||||
mutil_fake_B1_tokens = self.mutil_fake_B1_tokens
|
||||
|
||||
|
||||
if self.opt.lambda_global > 0.0:
|
||||
loss_global = self.calculate_similarity(mutil_real_A0_tokens, mutil_fake_B0_tokens) + self.calculate_similarity(mutil_real_A1_tokens, mutil_fake_B1_tokens)
|
||||
loss_global *= 0.5
|
||||
|
||||
else:
|
||||
loss_global = 0.0
|
||||
|
||||
if self.opt.lambda_spatial > 0.0:
|
||||
loss_spatial = 0.0
|
||||
local_nums = self.opt.local_nums
|
||||
tokens_cnt = 576
|
||||
local_id = np.random.permutation(tokens_cnt)
|
||||
local_id = local_id[:int(min(local_nums, tokens_cnt))]
|
||||
|
||||
mutil_real_A0_local_tokens = self.netPreViT(self.resize(self.real_A0), self.atten_layers, get_tokens=True, local_id=local_id, side_length=self.opt.side_length)
|
||||
mutil_real_A1_local_tokens = self.netPreViT(self.resize(self.real_A1), self.atten_layers, get_tokens=True, local_id=local_id, side_length=self.opt.side_length)
|
||||
|
||||
mutil_fake_B0_local_tokens = self.netPreViT(self.resize(self.fake_B0), self.atten_layers, get_tokens=True, local_id=local_id, side_length=self.opt.side_length)
|
||||
mutil_fake_B1_local_tokens = self.netPreViT(self.resize(self.fake_B1), self.atten_layers, get_tokens=True, local_id=local_id, side_length=self.opt.side_length)
|
||||
|
||||
loss_spatial = self.calculate_similarity(mutil_real_A0_local_tokens, mutil_fake_B0_local_tokens) + self.calculate_similarity(mutil_real_A1_local_tokens, mutil_fake_B1_local_tokens)
|
||||
loss_spatial *= 0.5
|
||||
|
||||
else:
|
||||
loss_spatial = 0.0
|
||||
|
||||
return loss_global * self.opt.lambda_global, loss_spatial * self.opt.lambda_spatial
|
||||
|
||||
def calculate_similarity(self, mutil_src_tokens, mutil_tgt_tokens):
|
||||
loss = 0.0
|
||||
n_layers = len(self.atten_layers)
|
||||
|
||||
for src_tokens, tgt_tokens in zip(mutil_src_tokens, mutil_tgt_tokens):
|
||||
src_tgt = src_tokens.bmm(tgt_tokens.permute(0,2,1))
|
||||
tgt_src = tgt_tokens.bmm(src_tokens.permute(0,2,1))
|
||||
cos_dis_global = F.cosine_similarity(src_tgt, tgt_src, dim=-1)
|
||||
loss += self.criterionL1(torch.ones_like(cos_dis_global), cos_dis_global).mean()
|
||||
|
||||
loss = loss / n_layers
|
||||
return loss
|
||||
|
||||
|
||||
|
||||
914
models/stylegan_networks.py
Normal file
914
models/stylegan_networks.py
Normal file
@ -0,0 +1,914 @@
|
||||
"""
|
||||
The network architectures is based on PyTorch implemenation of StyleGAN2Encoder.
|
||||
Original PyTorch repo: https://github.com/rosinality/style-based-gan-pytorch
|
||||
Origianl StyelGAN2 paper: https://github.com/NVlabs/stylegan2
|
||||
We use the network architeture for our single-image traning setting.
|
||||
"""
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
|
||||
return F.leaky_relu(input + bias, negative_slope) * scale
|
||||
|
||||
|
||||
class FusedLeakyReLU(nn.Module):
|
||||
def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
|
||||
super().__init__()
|
||||
self.bias = nn.Parameter(torch.zeros(1, channel, 1, 1))
|
||||
self.negative_slope = negative_slope
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, input):
|
||||
# print("FusedLeakyReLU: ", input.abs().mean())
|
||||
out = fused_leaky_relu(input, self.bias,
|
||||
self.negative_slope,
|
||||
self.scale)
|
||||
# print("FusedLeakyReLU: ", out.abs().mean())
|
||||
return out
|
||||
|
||||
|
||||
def upfirdn2d_native(
|
||||
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
|
||||
):
|
||||
_, minor, in_h, in_w = input.shape
|
||||
kernel_h, kernel_w = kernel.shape
|
||||
|
||||
out = input.view(-1, minor, in_h, 1, in_w, 1)
|
||||
out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0])
|
||||
out = out.view(-1, minor, in_h * up_y, in_w * up_x)
|
||||
|
||||
out = F.pad(
|
||||
out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
|
||||
)
|
||||
out = out[
|
||||
:,
|
||||
:,
|
||||
max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0),
|
||||
max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0),
|
||||
]
|
||||
|
||||
# out = out.permute(0, 3, 1, 2)
|
||||
out = out.reshape(
|
||||
[-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
|
||||
)
|
||||
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
|
||||
out = F.conv2d(out, w)
|
||||
out = out.reshape(
|
||||
-1,
|
||||
minor,
|
||||
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
|
||||
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
|
||||
)
|
||||
# out = out.permute(0, 2, 3, 1)
|
||||
|
||||
return out[:, :, ::down_y, ::down_x]
|
||||
|
||||
|
||||
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
|
||||
return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
|
||||
|
||||
|
||||
class PixelNorm(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, input):
|
||||
return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
|
||||
|
||||
|
||||
def make_kernel(k):
|
||||
k = torch.tensor(k, dtype=torch.float32)
|
||||
|
||||
if len(k.shape) == 1:
|
||||
k = k[None, :] * k[:, None]
|
||||
|
||||
k /= k.sum()
|
||||
|
||||
return k
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, kernel, factor=2):
|
||||
super().__init__()
|
||||
|
||||
self.factor = factor
|
||||
kernel = make_kernel(kernel) * (factor ** 2)
|
||||
self.register_buffer('kernel', kernel)
|
||||
|
||||
p = kernel.shape[0] - factor
|
||||
|
||||
pad0 = (p + 1) // 2 + factor - 1
|
||||
pad1 = p // 2
|
||||
|
||||
self.pad = (pad0, pad1)
|
||||
|
||||
def forward(self, input):
|
||||
out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, kernel, factor=2):
|
||||
super().__init__()
|
||||
|
||||
self.factor = factor
|
||||
kernel = make_kernel(kernel)
|
||||
self.register_buffer('kernel', kernel)
|
||||
|
||||
p = kernel.shape[0] - factor
|
||||
|
||||
pad0 = (p + 1) // 2
|
||||
pad1 = p // 2
|
||||
|
||||
self.pad = (pad0, pad1)
|
||||
|
||||
def forward(self, input):
|
||||
out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Blur(nn.Module):
|
||||
def __init__(self, kernel, pad, upsample_factor=1):
|
||||
super().__init__()
|
||||
|
||||
kernel = make_kernel(kernel)
|
||||
|
||||
if upsample_factor > 1:
|
||||
kernel = kernel * (upsample_factor ** 2)
|
||||
|
||||
self.register_buffer('kernel', kernel)
|
||||
|
||||
self.pad = pad
|
||||
|
||||
def forward(self, input):
|
||||
out = upfirdn2d(input, self.kernel, pad=self.pad)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class EqualConv2d(nn.Module):
|
||||
def __init__(
|
||||
self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.weight = nn.Parameter(
|
||||
torch.randn(out_channel, in_channel, kernel_size, kernel_size)
|
||||
)
|
||||
self.scale = math.sqrt(1) / math.sqrt(in_channel * (kernel_size ** 2))
|
||||
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.zeros(out_channel))
|
||||
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
def forward(self, input):
|
||||
# print("Before EqualConv2d: ", input.abs().mean())
|
||||
out = F.conv2d(
|
||||
input,
|
||||
self.weight * self.scale,
|
||||
bias=self.bias,
|
||||
stride=self.stride,
|
||||
padding=self.padding,
|
||||
)
|
||||
# print("After EqualConv2d: ", out.abs().mean(), (self.weight * self.scale).abs().mean())
|
||||
|
||||
return out
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
|
||||
f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
|
||||
)
|
||||
|
||||
|
||||
class EqualLinear(nn.Module):
|
||||
def __init__(
|
||||
self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
|
||||
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
|
||||
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self.activation = activation
|
||||
|
||||
self.scale = (math.sqrt(1) / math.sqrt(in_dim)) * lr_mul
|
||||
self.lr_mul = lr_mul
|
||||
|
||||
def forward(self, input):
|
||||
if self.activation:
|
||||
out = F.linear(input, self.weight * self.scale)
|
||||
out = fused_leaky_relu(out, self.bias * self.lr_mul)
|
||||
|
||||
else:
|
||||
out = F.linear(
|
||||
input, self.weight * self.scale, bias=self.bias * self.lr_mul
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
|
||||
)
|
||||
|
||||
|
||||
class ScaledLeakyReLU(nn.Module):
|
||||
def __init__(self, negative_slope=0.2):
|
||||
super().__init__()
|
||||
|
||||
self.negative_slope = negative_slope
|
||||
|
||||
def forward(self, input):
|
||||
out = F.leaky_relu(input, negative_slope=self.negative_slope)
|
||||
|
||||
return out * math.sqrt(2)
|
||||
|
||||
|
||||
class ModulatedConv2d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
style_dim,
|
||||
demodulate=True,
|
||||
upsample=False,
|
||||
downsample=False,
|
||||
blur_kernel=[1, 3, 3, 1],
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.eps = 1e-8
|
||||
self.kernel_size = kernel_size
|
||||
self.in_channel = in_channel
|
||||
self.out_channel = out_channel
|
||||
self.upsample = upsample
|
||||
self.downsample = downsample
|
||||
|
||||
if upsample:
|
||||
factor = 2
|
||||
p = (len(blur_kernel) - factor) - (kernel_size - 1)
|
||||
pad0 = (p + 1) // 2 + factor - 1
|
||||
pad1 = p // 2 + 1
|
||||
|
||||
self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
|
||||
|
||||
if downsample:
|
||||
factor = 2
|
||||
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
||||
pad0 = (p + 1) // 2
|
||||
pad1 = p // 2
|
||||
|
||||
self.blur = Blur(blur_kernel, pad=(pad0, pad1))
|
||||
|
||||
fan_in = in_channel * kernel_size ** 2
|
||||
self.scale = math.sqrt(1) / math.sqrt(fan_in)
|
||||
self.padding = kernel_size // 2
|
||||
|
||||
self.weight = nn.Parameter(
|
||||
torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
|
||||
)
|
||||
|
||||
if style_dim is not None and style_dim > 0:
|
||||
self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
|
||||
|
||||
self.demodulate = demodulate
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '
|
||||
f'upsample={self.upsample}, downsample={self.downsample})'
|
||||
)
|
||||
|
||||
def forward(self, input, style):
|
||||
batch, in_channel, height, width = input.shape
|
||||
|
||||
if style is not None:
|
||||
style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
|
||||
else:
|
||||
style = torch.ones(batch, 1, in_channel, 1, 1).cuda()
|
||||
weight = self.scale * self.weight * style
|
||||
|
||||
if self.demodulate:
|
||||
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
|
||||
weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
|
||||
|
||||
weight = weight.view(
|
||||
batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
|
||||
)
|
||||
|
||||
if self.upsample:
|
||||
input = input.view(1, batch * in_channel, height, width)
|
||||
weight = weight.view(
|
||||
batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
|
||||
)
|
||||
weight = weight.transpose(1, 2).reshape(
|
||||
batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
|
||||
)
|
||||
out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
|
||||
_, _, height, width = out.shape
|
||||
out = out.view(batch, self.out_channel, height, width)
|
||||
out = self.blur(out)
|
||||
|
||||
elif self.downsample:
|
||||
input = self.blur(input)
|
||||
_, _, height, width = input.shape
|
||||
input = input.view(1, batch * in_channel, height, width)
|
||||
out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
|
||||
_, _, height, width = out.shape
|
||||
out = out.view(batch, self.out_channel, height, width)
|
||||
|
||||
else:
|
||||
input = input.view(1, batch * in_channel, height, width)
|
||||
out = F.conv2d(input, weight, padding=self.padding, groups=batch)
|
||||
_, _, height, width = out.shape
|
||||
out = out.view(batch, self.out_channel, height, width)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class NoiseInjection(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.weight = nn.Parameter(torch.zeros(1))
|
||||
|
||||
def forward(self, image, noise=None):
|
||||
if noise is None:
|
||||
batch, _, height, width = image.shape
|
||||
noise = image.new_empty(batch, 1, height, width).normal_()
|
||||
|
||||
return image + self.weight * noise
|
||||
|
||||
|
||||
class ConstantInput(nn.Module):
|
||||
def __init__(self, channel, size=4):
|
||||
super().__init__()
|
||||
|
||||
self.input = nn.Parameter(torch.randn(1, channel, size, size))
|
||||
|
||||
def forward(self, input):
|
||||
batch = input.shape[0]
|
||||
out = self.input.repeat(batch, 1, 1, 1)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class StyledConv(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
style_dim=None,
|
||||
upsample=False,
|
||||
blur_kernel=[1, 3, 3, 1],
|
||||
demodulate=True,
|
||||
inject_noise=True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.inject_noise = inject_noise
|
||||
self.conv = ModulatedConv2d(
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
style_dim,
|
||||
upsample=upsample,
|
||||
blur_kernel=blur_kernel,
|
||||
demodulate=demodulate,
|
||||
)
|
||||
|
||||
self.noise = NoiseInjection()
|
||||
# self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
|
||||
# self.activate = ScaledLeakyReLU(0.2)
|
||||
self.activate = FusedLeakyReLU(out_channel)
|
||||
|
||||
def forward(self, input, style=None, noise=None):
|
||||
out = self.conv(input, style)
|
||||
if self.inject_noise:
|
||||
out = self.noise(out, noise=noise)
|
||||
# out = out + self.bias
|
||||
out = self.activate(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ToRGB(nn.Module):
|
||||
def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
|
||||
super().__init__()
|
||||
|
||||
if upsample:
|
||||
self.upsample = Upsample(blur_kernel)
|
||||
|
||||
self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
|
||||
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
|
||||
|
||||
def forward(self, input, style, skip=None):
|
||||
out = self.conv(input, style)
|
||||
out = out + self.bias
|
||||
|
||||
if skip is not None:
|
||||
skip = self.upsample(skip)
|
||||
|
||||
out = out + skip
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
style_dim,
|
||||
n_mlp,
|
||||
channel_multiplier=2,
|
||||
blur_kernel=[1, 3, 3, 1],
|
||||
lr_mlp=0.01,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.size = size
|
||||
|
||||
self.style_dim = style_dim
|
||||
|
||||
layers = [PixelNorm()]
|
||||
|
||||
for i in range(n_mlp):
|
||||
layers.append(
|
||||
EqualLinear(
|
||||
style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu'
|
||||
)
|
||||
)
|
||||
|
||||
self.style = nn.Sequential(*layers)
|
||||
|
||||
self.channels = {
|
||||
4: 512,
|
||||
8: 512,
|
||||
16: 512,
|
||||
32: 512,
|
||||
64: 256 * channel_multiplier,
|
||||
128: 128 * channel_multiplier,
|
||||
256: 64 * channel_multiplier,
|
||||
512: 32 * channel_multiplier,
|
||||
1024: 16 * channel_multiplier,
|
||||
}
|
||||
|
||||
self.input = ConstantInput(self.channels[4])
|
||||
self.conv1 = StyledConv(
|
||||
self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
|
||||
)
|
||||
self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
|
||||
|
||||
self.log_size = int(math.log(size, 2))
|
||||
self.num_layers = (self.log_size - 2) * 2 + 1
|
||||
|
||||
self.convs = nn.ModuleList()
|
||||
self.upsamples = nn.ModuleList()
|
||||
self.to_rgbs = nn.ModuleList()
|
||||
self.noises = nn.Module()
|
||||
|
||||
in_channel = self.channels[4]
|
||||
|
||||
for layer_idx in range(self.num_layers):
|
||||
res = (layer_idx + 5) // 2
|
||||
shape = [1, 1, 2 ** res, 2 ** res]
|
||||
self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape))
|
||||
|
||||
for i in range(3, self.log_size + 1):
|
||||
out_channel = self.channels[2 ** i]
|
||||
|
||||
self.convs.append(
|
||||
StyledConv(
|
||||
in_channel,
|
||||
out_channel,
|
||||
3,
|
||||
style_dim,
|
||||
upsample=True,
|
||||
blur_kernel=blur_kernel,
|
||||
)
|
||||
)
|
||||
|
||||
self.convs.append(
|
||||
StyledConv(
|
||||
out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
|
||||
)
|
||||
)
|
||||
|
||||
self.to_rgbs.append(ToRGB(out_channel, style_dim))
|
||||
|
||||
in_channel = out_channel
|
||||
|
||||
self.n_latent = self.log_size * 2 - 2
|
||||
|
||||
def make_noise(self):
|
||||
device = self.input.input.device
|
||||
|
||||
noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
|
||||
|
||||
for i in range(3, self.log_size + 1):
|
||||
for _ in range(2):
|
||||
noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
|
||||
|
||||
return noises
|
||||
|
||||
def mean_latent(self, n_latent):
|
||||
latent_in = torch.randn(
|
||||
n_latent, self.style_dim, device=self.input.input.device
|
||||
)
|
||||
latent = self.style(latent_in).mean(0, keepdim=True)
|
||||
|
||||
return latent
|
||||
|
||||
def get_latent(self, input):
|
||||
return self.style(input)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
styles,
|
||||
return_latents=False,
|
||||
inject_index=None,
|
||||
truncation=1,
|
||||
truncation_latent=None,
|
||||
input_is_latent=False,
|
||||
noise=None,
|
||||
randomize_noise=True,
|
||||
):
|
||||
if not input_is_latent:
|
||||
styles = [self.style(s) for s in styles]
|
||||
|
||||
if noise is None:
|
||||
if randomize_noise:
|
||||
noise = [None] * self.num_layers
|
||||
else:
|
||||
noise = [
|
||||
getattr(self.noises, f'noise_{i}') for i in range(self.num_layers)
|
||||
]
|
||||
|
||||
if truncation < 1:
|
||||
style_t = []
|
||||
|
||||
for style in styles:
|
||||
style_t.append(
|
||||
truncation_latent + truncation * (style - truncation_latent)
|
||||
)
|
||||
|
||||
styles = style_t
|
||||
|
||||
if len(styles) < 2:
|
||||
inject_index = self.n_latent
|
||||
|
||||
if len(styles[0].shape) < 3:
|
||||
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
||||
|
||||
else:
|
||||
latent = styles[0]
|
||||
|
||||
else:
|
||||
if inject_index is None:
|
||||
inject_index = random.randint(1, self.n_latent - 1)
|
||||
|
||||
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
||||
latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
|
||||
|
||||
latent = torch.cat([latent, latent2], 1)
|
||||
|
||||
out = self.input(latent)
|
||||
out = self.conv1(out, latent[:, 0], noise=noise[0])
|
||||
|
||||
skip = self.to_rgb1(out, latent[:, 1])
|
||||
|
||||
i = 1
|
||||
for conv1, conv2, noise1, noise2, to_rgb in zip(
|
||||
self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
|
||||
):
|
||||
out = conv1(out, latent[:, i], noise=noise1)
|
||||
out = conv2(out, latent[:, i + 1], noise=noise2)
|
||||
skip = to_rgb(out, latent[:, i + 2], skip)
|
||||
|
||||
i += 2
|
||||
|
||||
image = skip
|
||||
|
||||
if return_latents:
|
||||
return image, latent
|
||||
|
||||
else:
|
||||
return image, None
|
||||
|
||||
|
||||
class ConvLayer(nn.Sequential):
|
||||
def __init__(
|
||||
self,
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
downsample=False,
|
||||
blur_kernel=[1, 3, 3, 1],
|
||||
bias=True,
|
||||
activate=True,
|
||||
):
|
||||
layers = []
|
||||
|
||||
if downsample:
|
||||
factor = 2
|
||||
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
||||
pad0 = (p + 1) // 2
|
||||
pad1 = p // 2
|
||||
|
||||
layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
|
||||
|
||||
stride = 2
|
||||
self.padding = 0
|
||||
|
||||
else:
|
||||
stride = 1
|
||||
self.padding = kernel_size // 2
|
||||
|
||||
layers.append(
|
||||
EqualConv2d(
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
padding=self.padding,
|
||||
stride=stride,
|
||||
bias=bias and not activate,
|
||||
)
|
||||
)
|
||||
|
||||
if activate:
|
||||
if bias:
|
||||
layers.append(FusedLeakyReLU(out_channel))
|
||||
|
||||
else:
|
||||
layers.append(ScaledLeakyReLU(0.2))
|
||||
|
||||
super().__init__(*layers)
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1], downsample=True, skip_gain=1.0):
|
||||
super().__init__()
|
||||
|
||||
self.skip_gain = skip_gain
|
||||
self.conv1 = ConvLayer(in_channel, in_channel, 3)
|
||||
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=downsample, blur_kernel=blur_kernel)
|
||||
|
||||
if in_channel != out_channel or downsample:
|
||||
self.skip = ConvLayer(
|
||||
in_channel, out_channel, 1, downsample=downsample, activate=False, bias=False
|
||||
)
|
||||
else:
|
||||
self.skip = nn.Identity()
|
||||
|
||||
def forward(self, input):
|
||||
out = self.conv1(input)
|
||||
out = self.conv2(out)
|
||||
|
||||
skip = self.skip(input)
|
||||
out = (out * self.skip_gain + skip) / math.sqrt(self.skip_gain ** 2 + 1.0)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class StyleGAN2Discriminator(nn.Module):
|
||||
def __init__(self, input_nc, ndf=64, n_layers=3, no_antialias=False, size=None, opt=None):
|
||||
super().__init__()
|
||||
self.opt = opt
|
||||
self.stddev_group = 16
|
||||
if size is None:
|
||||
size = 2 ** int((np.rint(np.log2(min(opt.load_size, opt.crop_size)))))
|
||||
if "patch" in self.opt.netD and self.opt.D_patch_size is not None:
|
||||
size = 2 ** int(np.log2(self.opt.D_patch_size))
|
||||
|
||||
blur_kernel = [1, 3, 3, 1]
|
||||
channel_multiplier = ndf / 64
|
||||
channels = {
|
||||
4: min(384, int(4096 * channel_multiplier)),
|
||||
8: min(384, int(2048 * channel_multiplier)),
|
||||
16: min(384, int(1024 * channel_multiplier)),
|
||||
32: min(384, int(512 * channel_multiplier)),
|
||||
64: int(256 * channel_multiplier),
|
||||
128: int(128 * channel_multiplier),
|
||||
256: int(64 * channel_multiplier),
|
||||
512: int(32 * channel_multiplier),
|
||||
1024: int(16 * channel_multiplier),
|
||||
}
|
||||
|
||||
convs = [ConvLayer(3, channels[size], 1)]
|
||||
|
||||
log_size = int(math.log(size, 2))
|
||||
|
||||
in_channel = channels[size]
|
||||
|
||||
if "smallpatch" in self.opt.netD:
|
||||
final_res_log2 = 4
|
||||
elif "patch" in self.opt.netD:
|
||||
final_res_log2 = 3
|
||||
else:
|
||||
final_res_log2 = 2
|
||||
|
||||
for i in range(log_size, final_res_log2, -1):
|
||||
out_channel = channels[2 ** (i - 1)]
|
||||
|
||||
convs.append(ResBlock(in_channel, out_channel, blur_kernel))
|
||||
|
||||
in_channel = out_channel
|
||||
|
||||
self.convs = nn.Sequential(*convs)
|
||||
|
||||
if False and "tile" in self.opt.netD:
|
||||
in_channel += 1
|
||||
self.final_conv = ConvLayer(in_channel, channels[4], 3)
|
||||
if "patch" in self.opt.netD:
|
||||
self.final_linear = ConvLayer(channels[4], 1, 3, bias=False, activate=False)
|
||||
else:
|
||||
self.final_linear = nn.Sequential(
|
||||
EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'),
|
||||
EqualLinear(channels[4], 1),
|
||||
)
|
||||
|
||||
def forward(self, input, get_minibatch_features=False):
|
||||
if "patch" in self.opt.netD and self.opt.D_patch_size is not None:
|
||||
h, w = input.size(2), input.size(3)
|
||||
y = torch.randint(h - self.opt.D_patch_size, ())
|
||||
x = torch.randint(w - self.opt.D_patch_size, ())
|
||||
input = input[:, :, y:y + self.opt.D_patch_size, x:x + self.opt.D_patch_size]
|
||||
out = input
|
||||
for i, conv in enumerate(self.convs):
|
||||
out = conv(out)
|
||||
# print(i, out.abs().mean())
|
||||
# out = self.convs(input)
|
||||
|
||||
batch, channel, height, width = out.shape
|
||||
|
||||
if False and "tile" in self.opt.netD:
|
||||
group = min(batch, self.stddev_group)
|
||||
stddev = out.view(
|
||||
group, -1, 1, channel // 1, height, width
|
||||
)
|
||||
stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
|
||||
stddev = stddev.mean([2, 3, 4], keepdim=True).squeeze(2)
|
||||
stddev = stddev.repeat(group, 1, height, width)
|
||||
out = torch.cat([out, stddev], 1)
|
||||
|
||||
out = self.final_conv(out)
|
||||
# print(out.abs().mean())
|
||||
|
||||
if "patch" not in self.opt.netD:
|
||||
out = out.view(batch, -1)
|
||||
out = self.final_linear(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class TileStyleGAN2Discriminator(StyleGAN2Discriminator):
|
||||
def forward(self, input):
|
||||
B, C, H, W = input.size(0), input.size(1), input.size(2), input.size(3)
|
||||
size = self.opt.D_patch_size
|
||||
Y = H // size
|
||||
X = W // size
|
||||
input = input.view(B, C, Y, size, X, size)
|
||||
input = input.permute(0, 2, 4, 1, 3, 5).contiguous().view(B * Y * X, C, size, size)
|
||||
return super().forward(input)
|
||||
|
||||
|
||||
class StyleGAN2Encoder(nn.Module):
|
||||
def __init__(self, input_nc, output_nc, ngf=64, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False, opt=None):
|
||||
super().__init__()
|
||||
assert opt is not None
|
||||
self.opt = opt
|
||||
channel_multiplier = ngf / 32
|
||||
channels = {
|
||||
4: min(512, int(round(4096 * channel_multiplier))),
|
||||
8: min(512, int(round(2048 * channel_multiplier))),
|
||||
16: min(512, int(round(1024 * channel_multiplier))),
|
||||
32: min(512, int(round(512 * channel_multiplier))),
|
||||
64: int(round(256 * channel_multiplier)),
|
||||
128: int(round(128 * channel_multiplier)),
|
||||
256: int(round(64 * channel_multiplier)),
|
||||
512: int(round(32 * channel_multiplier)),
|
||||
1024: int(round(16 * channel_multiplier)),
|
||||
}
|
||||
|
||||
blur_kernel = [1, 3, 3, 1]
|
||||
|
||||
cur_res = 2 ** int((np.rint(np.log2(min(opt.load_size, opt.crop_size)))))
|
||||
convs = [nn.Identity(),
|
||||
ConvLayer(3, channels[cur_res], 1)]
|
||||
|
||||
num_downsampling = self.opt.stylegan2_G_num_downsampling
|
||||
for i in range(num_downsampling):
|
||||
in_channel = channels[cur_res]
|
||||
out_channel = channels[cur_res // 2]
|
||||
convs.append(ResBlock(in_channel, out_channel, blur_kernel, downsample=True))
|
||||
cur_res = cur_res // 2
|
||||
|
||||
for i in range(n_blocks // 2):
|
||||
n_channel = channels[cur_res]
|
||||
convs.append(ResBlock(n_channel, n_channel, downsample=False))
|
||||
|
||||
self.convs = nn.Sequential(*convs)
|
||||
|
||||
def forward(self, input, layers=[], get_features=False):
|
||||
feat = input
|
||||
feats = []
|
||||
if -1 in layers:
|
||||
layers.append(len(self.convs) - 1)
|
||||
for layer_id, layer in enumerate(self.convs):
|
||||
feat = layer(feat)
|
||||
# print(layer_id, " features ", feat.abs().mean())
|
||||
if layer_id in layers:
|
||||
feats.append(feat)
|
||||
|
||||
if get_features:
|
||||
return feat, feats
|
||||
else:
|
||||
return feat
|
||||
|
||||
|
||||
class StyleGAN2Decoder(nn.Module):
|
||||
def __init__(self, input_nc, output_nc, ngf=64, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False, opt=None):
|
||||
super().__init__()
|
||||
assert opt is not None
|
||||
self.opt = opt
|
||||
|
||||
blur_kernel = [1, 3, 3, 1]
|
||||
|
||||
channel_multiplier = ngf / 32
|
||||
channels = {
|
||||
4: min(512, int(round(4096 * channel_multiplier))),
|
||||
8: min(512, int(round(2048 * channel_multiplier))),
|
||||
16: min(512, int(round(1024 * channel_multiplier))),
|
||||
32: min(512, int(round(512 * channel_multiplier))),
|
||||
64: int(round(256 * channel_multiplier)),
|
||||
128: int(round(128 * channel_multiplier)),
|
||||
256: int(round(64 * channel_multiplier)),
|
||||
512: int(round(32 * channel_multiplier)),
|
||||
1024: int(round(16 * channel_multiplier)),
|
||||
}
|
||||
|
||||
num_downsampling = self.opt.stylegan2_G_num_downsampling
|
||||
cur_res = 2 ** int((np.rint(np.log2(min(opt.load_size, opt.crop_size))))) // (2 ** num_downsampling)
|
||||
convs = []
|
||||
|
||||
for i in range(n_blocks // 2):
|
||||
n_channel = channels[cur_res]
|
||||
convs.append(ResBlock(n_channel, n_channel, downsample=False))
|
||||
|
||||
for i in range(num_downsampling):
|
||||
in_channel = channels[cur_res]
|
||||
out_channel = channels[cur_res * 2]
|
||||
inject_noise = "small" not in self.opt.netG
|
||||
convs.append(
|
||||
StyledConv(in_channel, out_channel, 3, upsample=True, blur_kernel=blur_kernel, inject_noise=inject_noise)
|
||||
)
|
||||
cur_res = cur_res * 2
|
||||
|
||||
convs.append(ConvLayer(channels[cur_res], 3, 1))
|
||||
|
||||
self.convs = nn.Sequential(*convs)
|
||||
|
||||
def forward(self, input):
|
||||
return self.convs(input)
|
||||
|
||||
|
||||
class StyleGAN2Generator(nn.Module):
|
||||
def __init__(self, input_nc, output_nc, ngf=64, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False, opt=None):
|
||||
super().__init__()
|
||||
self.opt = opt
|
||||
self.encoder = StyleGAN2Encoder(input_nc, output_nc, ngf, use_dropout, n_blocks, padding_type, no_antialias, opt)
|
||||
self.decoder = StyleGAN2Decoder(input_nc, output_nc, ngf, use_dropout, n_blocks, padding_type, no_antialias, opt)
|
||||
|
||||
def forward(self, input, layers=[], encode_only=False):
|
||||
feat, feats = self.encoder(input, layers, True)
|
||||
if encode_only:
|
||||
return feats
|
||||
else:
|
||||
fake = self.decoder(feat)
|
||||
|
||||
if len(layers) > 0:
|
||||
return fake, feats
|
||||
else:
|
||||
return fake
|
||||
99
models/template_model.py
Normal file
99
models/template_model.py
Normal file
@ -0,0 +1,99 @@
|
||||
"""Model class template
|
||||
|
||||
This module provides a template for users to implement custom models.
|
||||
You can specify '--model template' to use this model.
|
||||
The class name should be consistent with both the filename and its model option.
|
||||
The filename should be <model>_dataset.py
|
||||
The class name should be <Model>Dataset.py
|
||||
It implements a simple image-to-image translation baseline based on regression loss.
|
||||
Given input-output pairs (data_A, data_B), it learns a network netG that can minimize the following L1 loss:
|
||||
min_<netG> ||netG(data_A) - data_B||_1
|
||||
You need to implement the following functions:
|
||||
<modify_commandline_options>: Add model-specific options and rewrite default values for existing options.
|
||||
<__init__>: Initialize this model class.
|
||||
<set_input>: Unpack input data and perform data pre-processing.
|
||||
<forward>: Run forward pass. This will be called by both <optimize_parameters> and <test>.
|
||||
<optimize_parameters>: Update network weights; it will be called in every training iteration.
|
||||
"""
|
||||
import torch
|
||||
from .base_model import BaseModel
|
||||
from . import networks
|
||||
|
||||
|
||||
class TemplateModel(BaseModel):
|
||||
@staticmethod
|
||||
def modify_commandline_options(parser, is_train=True):
|
||||
"""Add new model-specific options and rewrite default values for existing options.
|
||||
|
||||
Parameters:
|
||||
parser -- the option parser
|
||||
is_train -- if it is training phase or test phase. You can use this flag to add training-specific or test-specific options.
|
||||
|
||||
Returns:
|
||||
the modified parser.
|
||||
"""
|
||||
parser.set_defaults(dataset_mode='aligned') # You can rewrite default values for this model. For example, this model usually uses aligned dataset as its dataset.
|
||||
if is_train:
|
||||
parser.add_argument('--lambda_regression', type=float, default=1.0, help='weight for the regression loss') # You can define new arguments for this model.
|
||||
|
||||
return parser
|
||||
|
||||
def __init__(self, opt):
|
||||
"""Initialize this model class.
|
||||
|
||||
Parameters:
|
||||
opt -- training/test options
|
||||
|
||||
A few things can be done here.
|
||||
- (required) call the initialization function of BaseModel
|
||||
- define loss function, visualization images, model names, and optimizers
|
||||
"""
|
||||
BaseModel.__init__(self, opt) # call the initialization method of BaseModel
|
||||
# specify the training losses you want to print out. The program will call base_model.get_current_losses to plot the losses to the console and save them to the disk.
|
||||
self.loss_names = ['loss_G']
|
||||
# specify the images you want to save and display. The program will call base_model.get_current_visuals to save and display these images.
|
||||
self.visual_names = ['data_A', 'data_B', 'output']
|
||||
# specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks to save and load networks.
|
||||
# you can use opt.isTrain to specify different behaviors for training and test. For example, some networks will not be used during test, and you don't need to load them.
|
||||
self.model_names = ['G']
|
||||
# define networks; you can use opt.isTrain to specify different behaviors for training and test.
|
||||
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, gpu_ids=self.gpu_ids)
|
||||
if self.isTrain: # only defined during training time
|
||||
# define your loss functions. You can use losses provided by torch.nn such as torch.nn.L1Loss.
|
||||
# We also provide a GANLoss class "networks.GANLoss". self.criterionGAN = networks.GANLoss().to(self.device)
|
||||
self.criterionLoss = torch.nn.L1Loss()
|
||||
# define and initialize optimizers. You can define one optimizer for each network.
|
||||
# If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
|
||||
self.optimizer = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
|
||||
self.optimizers = [self.optimizer]
|
||||
|
||||
# Our program will automatically call <model.setup> to define schedulers, load networks, and print networks
|
||||
|
||||
def set_input(self, input):
|
||||
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
|
||||
|
||||
Parameters:
|
||||
input: a dictionary that contains the data itself and its metadata information.
|
||||
"""
|
||||
AtoB = self.opt.direction == 'AtoB' # use <direction> to swap data_A and data_B
|
||||
self.data_A = input['A' if AtoB else 'B'].to(self.device) # get image data A
|
||||
self.data_B = input['B' if AtoB else 'A'].to(self.device) # get image data B
|
||||
self.image_paths = input['A_paths' if AtoB else 'B_paths'] # get image paths
|
||||
|
||||
def forward(self):
|
||||
"""Run forward pass. This will be called by both functions <optimize_parameters> and <test>."""
|
||||
self.output = self.netG(self.data_A) # generate output image given the input data_A
|
||||
|
||||
def backward(self):
|
||||
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
|
||||
# caculate the intermediate results if necessary; here self.output has been computed during function <forward>
|
||||
# calculate loss given the input and intermediate results
|
||||
self.loss_G = self.criterionLoss(self.output, self.data_B) * self.opt.lambda_regression
|
||||
self.loss_G.backward() # calculate gradients of network G w.r.t. loss_G
|
||||
|
||||
def optimize_parameters(self):
|
||||
"""Update network weights; it will be called in every training iteration."""
|
||||
self.forward() # first call forward to calculate intermediate results
|
||||
self.optimizer.zero_grad() # clear network G's existing gradients
|
||||
self.backward() # calculate gradients for network G
|
||||
self.optimizer.step() # update gradients for network G
|
||||
BIN
models/util/__pycache__/pos_embed.cpython-36.pyc
Normal file
BIN
models/util/__pycache__/pos_embed.cpython-36.pyc
Normal file
Binary file not shown.
42
models/util/crop.py
Normal file
42
models/util/crop.py
Normal file
@ -0,0 +1,42 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms import functional as F
|
||||
|
||||
|
||||
class RandomResizedCrop(transforms.RandomResizedCrop):
|
||||
"""
|
||||
RandomResizedCrop for matching TF/TPU implementation: no for-loop is used.
|
||||
This may lead to results different with torchvision's version.
|
||||
Following BYOL's TF code:
|
||||
https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206
|
||||
"""
|
||||
@staticmethod
|
||||
def get_params(img, scale, ratio):
|
||||
width, height = F._get_image_size(img)
|
||||
area = height * width
|
||||
|
||||
target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
|
||||
log_ratio = torch.log(torch.tensor(ratio))
|
||||
aspect_ratio = torch.exp(
|
||||
torch.empty(1).uniform_(log_ratio[0], log_ratio[1])
|
||||
).item()
|
||||
|
||||
w = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||
h = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||
|
||||
w = min(w, width)
|
||||
h = min(h, height)
|
||||
|
||||
i = torch.randint(0, height - h + 1, size=(1,)).item()
|
||||
j = torch.randint(0, width - w + 1, size=(1,)).item()
|
||||
|
||||
return i, j, h, w
|
||||
65
models/util/datasets.py
Normal file
65
models/util/datasets.py
Normal file
@ -0,0 +1,65 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# --------------------------------------------------------
|
||||
# References:
|
||||
# DeiT: https://github.com/facebookresearch/deit
|
||||
# --------------------------------------------------------
|
||||
|
||||
import os
|
||||
import PIL
|
||||
|
||||
from torchvision import datasets, transforms
|
||||
|
||||
from timm.data import create_transform
|
||||
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
|
||||
|
||||
def build_dataset(is_train, args):
|
||||
transform = build_transform(is_train, args)
|
||||
|
||||
root = os.path.join(args.data_path, 'train' if is_train else 'val')
|
||||
dataset = datasets.ImageFolder(root, transform=transform)
|
||||
|
||||
print(dataset)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def build_transform(is_train, args):
|
||||
mean = IMAGENET_DEFAULT_MEAN
|
||||
std = IMAGENET_DEFAULT_STD
|
||||
# train transform
|
||||
if is_train:
|
||||
# this should always dispatch to transforms_imagenet_train
|
||||
transform = create_transform(
|
||||
input_size=args.input_size,
|
||||
is_training=True,
|
||||
color_jitter=args.color_jitter,
|
||||
auto_augment=args.aa,
|
||||
interpolation='bicubic',
|
||||
re_prob=args.reprob,
|
||||
re_mode=args.remode,
|
||||
re_count=args.recount,
|
||||
mean=mean,
|
||||
std=std,
|
||||
)
|
||||
return transform
|
||||
|
||||
# eval transform
|
||||
t = []
|
||||
if args.input_size <= 224:
|
||||
crop_pct = 224 / 256
|
||||
else:
|
||||
crop_pct = 1.0
|
||||
size = int(args.input_size / crop_pct)
|
||||
t.append(
|
||||
transforms.Resize(size, interpolation=PIL.Image.BICUBIC), # to maintain same ratio w.r.t. 224 images
|
||||
)
|
||||
t.append(transforms.CenterCrop(args.input_size))
|
||||
|
||||
t.append(transforms.ToTensor())
|
||||
t.append(transforms.Normalize(mean, std))
|
||||
return transforms.Compose(t)
|
||||
47
models/util/lars.py
Normal file
47
models/util/lars.py
Normal file
@ -0,0 +1,47 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# --------------------------------------------------------
|
||||
# LARS optimizer, implementation from MoCo v3:
|
||||
# https://github.com/facebookresearch/moco-v3
|
||||
# --------------------------------------------------------
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class LARS(torch.optim.Optimizer):
|
||||
"""
|
||||
LARS optimizer, no rate scaling or weight decay for parameters <= 1D.
|
||||
"""
|
||||
def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001):
|
||||
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self):
|
||||
for g in self.param_groups:
|
||||
for p in g['params']:
|
||||
dp = p.grad
|
||||
|
||||
if dp is None:
|
||||
continue
|
||||
|
||||
if p.ndim > 1: # if not normalization gamma/beta or bias
|
||||
dp = dp.add(p, alpha=g['weight_decay'])
|
||||
param_norm = torch.norm(p)
|
||||
update_norm = torch.norm(dp)
|
||||
one = torch.ones_like(param_norm)
|
||||
q = torch.where(param_norm > 0.,
|
||||
torch.where(update_norm > 0,
|
||||
(g['trust_coefficient'] * param_norm / update_norm), one),
|
||||
one)
|
||||
dp = dp.mul(q)
|
||||
|
||||
param_state = self.state[p]
|
||||
if 'mu' not in param_state:
|
||||
param_state['mu'] = torch.zeros_like(p)
|
||||
mu = param_state['mu']
|
||||
mu.mul_(g['momentum']).add_(dp)
|
||||
p.add_(mu, alpha=-g['lr'])
|
||||
76
models/util/lr_decay.py
Normal file
76
models/util/lr_decay.py
Normal file
@ -0,0 +1,76 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# --------------------------------------------------------
|
||||
# References:
|
||||
# ELECTRA https://github.com/google-research/electra
|
||||
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
|
||||
# --------------------------------------------------------
|
||||
|
||||
import json
|
||||
|
||||
|
||||
def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75):
|
||||
"""
|
||||
Parameter groups for layer-wise lr decay
|
||||
Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
|
||||
"""
|
||||
param_group_names = {}
|
||||
param_groups = {}
|
||||
|
||||
num_layers = len(model.blocks) + 1
|
||||
|
||||
layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1))
|
||||
|
||||
for n, p in model.named_parameters():
|
||||
if not p.requires_grad:
|
||||
continue
|
||||
|
||||
# no decay: all 1D parameters and model specific ones
|
||||
if p.ndim == 1 or n in no_weight_decay_list:
|
||||
g_decay = "no_decay"
|
||||
this_decay = 0.
|
||||
else:
|
||||
g_decay = "decay"
|
||||
this_decay = weight_decay
|
||||
|
||||
layer_id = get_layer_id_for_vit(n, num_layers)
|
||||
group_name = "layer_%d_%s" % (layer_id, g_decay)
|
||||
|
||||
if group_name not in param_group_names:
|
||||
this_scale = layer_scales[layer_id]
|
||||
|
||||
param_group_names[group_name] = {
|
||||
"lr_scale": this_scale,
|
||||
"weight_decay": this_decay,
|
||||
"params": [],
|
||||
}
|
||||
param_groups[group_name] = {
|
||||
"lr_scale": this_scale,
|
||||
"weight_decay": this_decay,
|
||||
"params": [],
|
||||
}
|
||||
|
||||
param_group_names[group_name]["params"].append(n)
|
||||
param_groups[group_name]["params"].append(p)
|
||||
|
||||
# print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
|
||||
|
||||
return list(param_groups.values())
|
||||
|
||||
|
||||
def get_layer_id_for_vit(name, num_layers):
|
||||
"""
|
||||
Assign a parameter with its layer id
|
||||
Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
|
||||
"""
|
||||
if name in ['cls_token', 'pos_embed']:
|
||||
return 0
|
||||
elif name.startswith('patch_embed'):
|
||||
return 0
|
||||
elif name.startswith('blocks'):
|
||||
return int(name.split('.')[1]) + 1
|
||||
else:
|
||||
return num_layers
|
||||
21
models/util/lr_sched.py
Normal file
21
models/util/lr_sched.py
Normal file
@ -0,0 +1,21 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
|
||||
def adjust_learning_rate(optimizer, epoch, args):
|
||||
"""Decay the learning rate with half-cycle cosine after warmup"""
|
||||
if epoch < args.warmup_epochs:
|
||||
lr = args.lr * epoch / args.warmup_epochs
|
||||
else:
|
||||
lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \
|
||||
(1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
|
||||
for param_group in optimizer.param_groups:
|
||||
if "lr_scale" in param_group:
|
||||
param_group["lr"] = lr * param_group["lr_scale"]
|
||||
else:
|
||||
param_group["lr"] = lr
|
||||
return lr
|
||||
340
models/util/misc.py
Normal file
340
models/util/misc.py
Normal file
@ -0,0 +1,340 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# --------------------------------------------------------
|
||||
# References:
|
||||
# DeiT: https://github.com/facebookresearch/deit
|
||||
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
|
||||
# --------------------------------------------------------
|
||||
|
||||
import builtins
|
||||
import datetime
|
||||
import os
|
||||
import time
|
||||
from collections import defaultdict, deque
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch._six import inf
|
||||
|
||||
|
||||
class SmoothedValue(object):
|
||||
"""Track a series of values and provide access to smoothed values over a
|
||||
window or the global series average.
|
||||
"""
|
||||
|
||||
def __init__(self, window_size=20, fmt=None):
|
||||
if fmt is None:
|
||||
fmt = "{median:.4f} ({global_avg:.4f})"
|
||||
self.deque = deque(maxlen=window_size)
|
||||
self.total = 0.0
|
||||
self.count = 0
|
||||
self.fmt = fmt
|
||||
|
||||
def update(self, value, n=1):
|
||||
self.deque.append(value)
|
||||
self.count += n
|
||||
self.total += value * n
|
||||
|
||||
def synchronize_between_processes(self):
|
||||
"""
|
||||
Warning: does not synchronize the deque!
|
||||
"""
|
||||
if not is_dist_avail_and_initialized():
|
||||
return
|
||||
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
||||
dist.barrier()
|
||||
dist.all_reduce(t)
|
||||
t = t.tolist()
|
||||
self.count = int(t[0])
|
||||
self.total = t[1]
|
||||
|
||||
@property
|
||||
def median(self):
|
||||
d = torch.tensor(list(self.deque))
|
||||
return d.median().item()
|
||||
|
||||
@property
|
||||
def avg(self):
|
||||
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
||||
return d.mean().item()
|
||||
|
||||
@property
|
||||
def global_avg(self):
|
||||
return self.total / self.count
|
||||
|
||||
@property
|
||||
def max(self):
|
||||
return max(self.deque)
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return self.deque[-1]
|
||||
|
||||
def __str__(self):
|
||||
return self.fmt.format(
|
||||
median=self.median,
|
||||
avg=self.avg,
|
||||
global_avg=self.global_avg,
|
||||
max=self.max,
|
||||
value=self.value)
|
||||
|
||||
|
||||
class MetricLogger(object):
|
||||
def __init__(self, delimiter="\t"):
|
||||
self.meters = defaultdict(SmoothedValue)
|
||||
self.delimiter = delimiter
|
||||
|
||||
def update(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
if v is None:
|
||||
continue
|
||||
if isinstance(v, torch.Tensor):
|
||||
v = v.item()
|
||||
assert isinstance(v, (float, int))
|
||||
self.meters[k].update(v)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
if attr in self.meters:
|
||||
return self.meters[attr]
|
||||
if attr in self.__dict__:
|
||||
return self.__dict__[attr]
|
||||
raise AttributeError("'{}' object has no attribute '{}'".format(
|
||||
type(self).__name__, attr))
|
||||
|
||||
def __str__(self):
|
||||
loss_str = []
|
||||
for name, meter in self.meters.items():
|
||||
loss_str.append(
|
||||
"{}: {}".format(name, str(meter))
|
||||
)
|
||||
return self.delimiter.join(loss_str)
|
||||
|
||||
def synchronize_between_processes(self):
|
||||
for meter in self.meters.values():
|
||||
meter.synchronize_between_processes()
|
||||
|
||||
def add_meter(self, name, meter):
|
||||
self.meters[name] = meter
|
||||
|
||||
def log_every(self, iterable, print_freq, header=None):
|
||||
i = 0
|
||||
if not header:
|
||||
header = ''
|
||||
start_time = time.time()
|
||||
end = time.time()
|
||||
iter_time = SmoothedValue(fmt='{avg:.4f}')
|
||||
data_time = SmoothedValue(fmt='{avg:.4f}')
|
||||
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
||||
log_msg = [
|
||||
header,
|
||||
'[{0' + space_fmt + '}/{1}]',
|
||||
'eta: {eta}',
|
||||
'{meters}',
|
||||
'time: {time}',
|
||||
'data: {data}'
|
||||
]
|
||||
if torch.cuda.is_available():
|
||||
log_msg.append('max mem: {memory:.0f}')
|
||||
log_msg = self.delimiter.join(log_msg)
|
||||
MB = 1024.0 * 1024.0
|
||||
for obj in iterable:
|
||||
data_time.update(time.time() - end)
|
||||
yield obj
|
||||
iter_time.update(time.time() - end)
|
||||
if i % print_freq == 0 or i == len(iterable) - 1:
|
||||
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
||||
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||
if torch.cuda.is_available():
|
||||
print(log_msg.format(
|
||||
i, len(iterable), eta=eta_string,
|
||||
meters=str(self),
|
||||
time=str(iter_time), data=str(data_time),
|
||||
memory=torch.cuda.max_memory_allocated() / MB))
|
||||
else:
|
||||
print(log_msg.format(
|
||||
i, len(iterable), eta=eta_string,
|
||||
meters=str(self),
|
||||
time=str(iter_time), data=str(data_time)))
|
||||
i += 1
|
||||
end = time.time()
|
||||
total_time = time.time() - start_time
|
||||
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||
print('{} Total time: {} ({:.4f} s / it)'.format(
|
||||
header, total_time_str, total_time / len(iterable)))
|
||||
|
||||
|
||||
def setup_for_distributed(is_master):
|
||||
"""
|
||||
This function disables printing when not in master process
|
||||
"""
|
||||
builtin_print = builtins.print
|
||||
|
||||
def print(*args, **kwargs):
|
||||
force = kwargs.pop('force', False)
|
||||
force = force or (get_world_size() > 8)
|
||||
if is_master or force:
|
||||
now = datetime.datetime.now().time()
|
||||
builtin_print('[{}] '.format(now), end='') # print with time stamp
|
||||
builtin_print(*args, **kwargs)
|
||||
|
||||
builtins.print = print
|
||||
|
||||
|
||||
def is_dist_avail_and_initialized():
|
||||
if not dist.is_available():
|
||||
return False
|
||||
if not dist.is_initialized():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_world_size():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 1
|
||||
return dist.get_world_size()
|
||||
|
||||
|
||||
def get_rank():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 0
|
||||
return dist.get_rank()
|
||||
|
||||
|
||||
def is_main_process():
|
||||
return get_rank() == 0
|
||||
|
||||
|
||||
def save_on_master(*args, **kwargs):
|
||||
if is_main_process():
|
||||
torch.save(*args, **kwargs)
|
||||
|
||||
|
||||
def init_distributed_mode(args):
|
||||
if args.dist_on_itp:
|
||||
args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
|
||||
args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
|
||||
args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
|
||||
args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
|
||||
os.environ['LOCAL_RANK'] = str(args.gpu)
|
||||
os.environ['RANK'] = str(args.rank)
|
||||
os.environ['WORLD_SIZE'] = str(args.world_size)
|
||||
# ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
|
||||
elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
||||
args.rank = int(os.environ["RANK"])
|
||||
args.world_size = int(os.environ['WORLD_SIZE'])
|
||||
args.gpu = int(os.environ['LOCAL_RANK'])
|
||||
elif 'SLURM_PROCID' in os.environ:
|
||||
args.rank = int(os.environ['SLURM_PROCID'])
|
||||
args.gpu = args.rank % torch.cuda.device_count()
|
||||
else:
|
||||
print('Not using distributed mode')
|
||||
setup_for_distributed(is_master=True) # hack
|
||||
args.distributed = False
|
||||
return
|
||||
|
||||
args.distributed = True
|
||||
|
||||
torch.cuda.set_device(args.gpu)
|
||||
args.dist_backend = 'nccl'
|
||||
print('| distributed init (rank {}): {}, gpu {}'.format(
|
||||
args.rank, args.dist_url, args.gpu), flush=True)
|
||||
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
||||
world_size=args.world_size, rank=args.rank)
|
||||
torch.distributed.barrier()
|
||||
setup_for_distributed(args.rank == 0)
|
||||
|
||||
|
||||
class NativeScalerWithGradNormCount:
|
||||
state_dict_key = "amp_scaler"
|
||||
|
||||
def __init__(self):
|
||||
self._scaler = torch.cuda.amp.GradScaler()
|
||||
|
||||
def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
|
||||
self._scaler.scale(loss).backward(create_graph=create_graph)
|
||||
if update_grad:
|
||||
if clip_grad is not None:
|
||||
assert parameters is not None
|
||||
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
|
||||
norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
|
||||
else:
|
||||
self._scaler.unscale_(optimizer)
|
||||
norm = get_grad_norm_(parameters)
|
||||
self._scaler.step(optimizer)
|
||||
self._scaler.update()
|
||||
else:
|
||||
norm = None
|
||||
return norm
|
||||
|
||||
def state_dict(self):
|
||||
return self._scaler.state_dict()
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self._scaler.load_state_dict(state_dict)
|
||||
|
||||
|
||||
def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
|
||||
if isinstance(parameters, torch.Tensor):
|
||||
parameters = [parameters]
|
||||
parameters = [p for p in parameters if p.grad is not None]
|
||||
norm_type = float(norm_type)
|
||||
if len(parameters) == 0:
|
||||
return torch.tensor(0.)
|
||||
device = parameters[0].grad.device
|
||||
if norm_type == inf:
|
||||
total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
|
||||
else:
|
||||
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
|
||||
return total_norm
|
||||
|
||||
|
||||
def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler):
|
||||
output_dir = Path(args.output_dir)
|
||||
epoch_name = str(epoch)
|
||||
if loss_scaler is not None:
|
||||
checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]
|
||||
for checkpoint_path in checkpoint_paths:
|
||||
to_save = {
|
||||
'model': model_without_ddp.state_dict(),
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'epoch': epoch,
|
||||
'scaler': loss_scaler.state_dict(),
|
||||
'args': args,
|
||||
}
|
||||
|
||||
save_on_master(to_save, checkpoint_path)
|
||||
else:
|
||||
client_state = {'epoch': epoch}
|
||||
model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state)
|
||||
|
||||
|
||||
def load_model(args, model_without_ddp, optimizer, loss_scaler):
|
||||
if args.resume:
|
||||
if args.resume.startswith('https'):
|
||||
checkpoint = torch.hub.load_state_dict_from_url(
|
||||
args.resume, map_location='cpu', check_hash=True)
|
||||
else:
|
||||
checkpoint = torch.load(args.resume, map_location='cpu')
|
||||
model_without_ddp.load_state_dict(checkpoint['model'])
|
||||
print("Resume checkpoint %s" % args.resume)
|
||||
if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval):
|
||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
args.start_epoch = checkpoint['epoch'] + 1
|
||||
if 'scaler' in checkpoint:
|
||||
loss_scaler.load_state_dict(checkpoint['scaler'])
|
||||
print("With optim & sched!")
|
||||
|
||||
|
||||
def all_reduce_mean(x):
|
||||
world_size = get_world_size()
|
||||
if world_size > 1:
|
||||
x_reduce = torch.tensor(x).cuda()
|
||||
dist.all_reduce(x_reduce)
|
||||
x_reduce /= world_size
|
||||
return x_reduce.item()
|
||||
else:
|
||||
return x
|
||||
96
models/util/pos_embed.py
Normal file
96
models/util/pos_embed.py
Normal file
@ -0,0 +1,96 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# --------------------------------------------------------
|
||||
# Position embedding utils
|
||||
# --------------------------------------------------------
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
|
||||
# --------------------------------------------------------
|
||||
# 2D sine-cosine position embedding
|
||||
# References:
|
||||
# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
|
||||
# MoCo v3: https://github.com/facebookresearch/moco-v3
|
||||
# --------------------------------------------------------
|
||||
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
|
||||
"""
|
||||
grid_size: int of the grid height and width
|
||||
return:
|
||||
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
||||
"""
|
||||
grid_h = np.arange(grid_size, dtype=np.float32)
|
||||
grid_w = np.arange(grid_size, dtype=np.float32)
|
||||
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
||||
grid = np.stack(grid, axis=0)
|
||||
|
||||
grid = grid.reshape([2, 1, grid_size, grid_size])
|
||||
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
||||
if cls_token:
|
||||
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
||||
return pos_embed
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
||||
assert embed_dim % 2 == 0
|
||||
|
||||
# use half of dimensions to encode grid_h
|
||||
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
||||
|
||||
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
||||
return emb
|
||||
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||
"""
|
||||
embed_dim: output dimension for each position
|
||||
pos: a list of positions to be encoded: size (M,)
|
||||
out: (M, D)
|
||||
"""
|
||||
assert embed_dim % 2 == 0
|
||||
omega = np.arange(embed_dim // 2, dtype=np.float)
|
||||
omega /= embed_dim / 2.
|
||||
omega = 1. / 10000**omega # (D/2,)
|
||||
|
||||
pos = pos.reshape(-1) # (M,)
|
||||
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
||||
|
||||
emb_sin = np.sin(out) # (M, D/2)
|
||||
emb_cos = np.cos(out) # (M, D/2)
|
||||
|
||||
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
||||
return emb
|
||||
|
||||
|
||||
# --------------------------------------------------------
|
||||
# Interpolate position embeddings for high-resolution
|
||||
# References:
|
||||
# DeiT: https://github.com/facebookresearch/deit
|
||||
# --------------------------------------------------------
|
||||
def interpolate_pos_embed(model, checkpoint_model):
|
||||
if 'pos_embed' in checkpoint_model:
|
||||
pos_embed_checkpoint = checkpoint_model['pos_embed']
|
||||
embedding_size = pos_embed_checkpoint.shape[-1]
|
||||
num_patches = model.patch_embed.num_patches
|
||||
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
|
||||
# height (== width) for the checkpoint position embedding
|
||||
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
||||
# height (== width) for the new position embedding
|
||||
new_size = int(num_patches ** 0.5)
|
||||
# class_token and dist_token are kept unchanged
|
||||
if orig_size != new_size:
|
||||
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
|
||||
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
||||
# only the position tokens are interpolated
|
||||
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
||||
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
||||
pos_tokens = torch.nn.functional.interpolate(
|
||||
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
||||
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
||||
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
||||
checkpoint_model['pos_embed'] = new_pos_embed
|
||||
1
options/__init__.py
Normal file
1
options/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""This package options includes option modules: training options, test options, and basic options (used in both training and test)."""
|
||||
BIN
options/__pycache__/__init__.cpython-36.pyc
Normal file
BIN
options/__pycache__/__init__.cpython-36.pyc
Normal file
Binary file not shown.
BIN
options/__pycache__/base_options.cpython-36.pyc
Normal file
BIN
options/__pycache__/base_options.cpython-36.pyc
Normal file
Binary file not shown.
BIN
options/__pycache__/test_options.cpython-36.pyc
Normal file
BIN
options/__pycache__/test_options.cpython-36.pyc
Normal file
Binary file not shown.
BIN
options/__pycache__/train_options.cpython-36.pyc
Normal file
BIN
options/__pycache__/train_options.cpython-36.pyc
Normal file
Binary file not shown.
167
options/base_options.py
Normal file
167
options/base_options.py
Normal file
@ -0,0 +1,167 @@
|
||||
import argparse
|
||||
import os
|
||||
from util import util
|
||||
import torch
|
||||
import models
|
||||
import data
|
||||
|
||||
|
||||
class BaseOptions():
|
||||
"""This class defines options used during both training and test time.
|
||||
|
||||
It also implements several helper functions such as parsing, printing, and saving the options.
|
||||
It also gathers additional options defined in <modify_commandline_options> functions in both dataset class and model class.
|
||||
"""
|
||||
|
||||
def __init__(self, cmd_line=None):
|
||||
"""Reset the class; indicates the class hasn't been initailized"""
|
||||
self.initialized = False
|
||||
self.cmd_line = None
|
||||
if cmd_line is not None:
|
||||
self.cmd_line = cmd_line.split()
|
||||
|
||||
def initialize(self, parser):
|
||||
"""Define the common options that are used in both training and test."""
|
||||
# basic parameters
|
||||
parser.add_argument('--dataroot', default='placeholder', help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')
|
||||
parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')
|
||||
parser.add_argument('--easy_label', type=str, default='experiment_name', help='Interpretable name')
|
||||
parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
|
||||
parser.add_argument('--use_idt', action='store_true', help='use_idt')
|
||||
parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
|
||||
# model parameters
|
||||
parser.add_argument('--model', type=str, default='cut', help='chooses which model to use.')
|
||||
parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')
|
||||
parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')
|
||||
parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')
|
||||
parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')
|
||||
parser.add_argument('--netD', type=str, default='basic', choices=['basic', 'n_layers', 'pixel', 'patch', 'tilestylegan2', 'stylegan2'], help='specify discriminator architecture. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')
|
||||
parser.add_argument('--netG', type=str, default='resnet_9blocks', choices=['resnet_9blocks','resnet_9blocks_mask', 'resnet_6blocks', 'unet_256', 'unet_128', 'stylegan2', 'smallstylegan2', 'resnet_cat'], help='specify generator architecture')
|
||||
parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')
|
||||
parser.add_argument('--normG', type=str, default='instance', choices=['instance', 'batch', 'none'], help='instance normalization or batch normalization for G')
|
||||
parser.add_argument('--normD', type=str, default='instance', choices=['instance', 'batch', 'none'], help='instance normalization or batch normalization for D')
|
||||
parser.add_argument('--init_type', type=str, default='xavier', choices=['normal', 'xavier', 'kaiming', 'orthogonal'], help='network initialization')
|
||||
parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')
|
||||
parser.add_argument('--no_dropout', type=util.str2bool, nargs='?', const=True, default=True,
|
||||
help='no dropout for the generator')
|
||||
parser.add_argument('--no_antialias', action='store_true', help='if specified, use stride=2 convs instead of antialiased-downsampling (sad)')
|
||||
parser.add_argument('--no_antialias_up', action='store_true', help='if specified, use [upconv(learned filter)] instead of [upconv(hard-coded [1,3,3,1] filter), conv]')
|
||||
# dataset parameters
|
||||
parser.add_argument('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')
|
||||
parser.add_argument('--direction', type=str, default='AtoB', help='AtoB or BtoA')
|
||||
parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
|
||||
parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data')
|
||||
parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
|
||||
parser.add_argument('--load_size', type=int, default=286, help='scale images to this size')
|
||||
parser.add_argument('--crop_size', type=int, default=256, help='then crop to this size')
|
||||
parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
|
||||
parser.add_argument('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')
|
||||
parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')
|
||||
parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')
|
||||
parser.add_argument('--random_scale_max', type=float, default=3.0,
|
||||
help='(used for single image translation) Randomly scale the image by the specified factor as data augmentation.')
|
||||
# additional parameters
|
||||
parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
|
||||
parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
|
||||
parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')
|
||||
|
||||
# parameters related to StyleGAN2-based networks
|
||||
parser.add_argument('--stylegan2_G_num_downsampling',
|
||||
default=1, type=int,
|
||||
help='Number of downsampling layers used by StyleGAN2Generator')
|
||||
|
||||
self.initialized = True
|
||||
return parser
|
||||
|
||||
def gather_options(self):
|
||||
"""Initialize our parser with basic options(only once).
|
||||
Add additional model-specific and dataset-specific options.
|
||||
These options are defined in the <modify_commandline_options> function
|
||||
in model and dataset classes.
|
||||
"""
|
||||
if not self.initialized: # check if it has been initialized
|
||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser = self.initialize(parser)
|
||||
|
||||
# get the basic options
|
||||
if self.cmd_line is None:
|
||||
opt, _ = parser.parse_known_args()
|
||||
else:
|
||||
opt, _ = parser.parse_known_args(self.cmd_line)
|
||||
|
||||
# modify model-related parser options
|
||||
model_name = opt.model
|
||||
model_option_setter = models.get_option_setter(model_name)
|
||||
|
||||
parser = model_option_setter(parser, self.isTrain)
|
||||
if self.cmd_line is None:
|
||||
print(parser)
|
||||
opt, _ = parser.parse_known_args() # parse again with new defaults
|
||||
else:
|
||||
opt, _ = parser.parse_known_args(self.cmd_line) # parse again with new defaults
|
||||
|
||||
# modify dataset-related parser options
|
||||
dataset_name = opt.dataset_mode
|
||||
dataset_option_setter = data.get_option_setter(dataset_name)
|
||||
parser = dataset_option_setter(parser, self.isTrain)
|
||||
|
||||
# save and return the parser
|
||||
self.parser = parser
|
||||
if self.cmd_line is None:
|
||||
return parser.parse_args()
|
||||
else:
|
||||
return parser.parse_args(self.cmd_line)
|
||||
|
||||
def print_options(self, opt):
|
||||
"""Print and save options
|
||||
|
||||
It will print both current options and default values(if different).
|
||||
It will save options into a text file / [checkpoints_dir] / opt.txt
|
||||
"""
|
||||
message = ''
|
||||
message += '----------------- Options ---------------\n'
|
||||
for k, v in sorted(vars(opt).items()):
|
||||
comment = ''
|
||||
default = self.parser.get_default(k)
|
||||
if v != default:
|
||||
comment = '\t[default: %s]' % str(default)
|
||||
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
|
||||
message += '----------------- End -------------------'
|
||||
print(message)
|
||||
|
||||
# save to the disk
|
||||
expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
|
||||
util.mkdirs(expr_dir)
|
||||
file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase))
|
||||
try:
|
||||
with open(file_name, 'wt') as opt_file:
|
||||
opt_file.write(message)
|
||||
opt_file.write('\n')
|
||||
except PermissionError as error:
|
||||
print("permission error {}".format(error))
|
||||
pass
|
||||
|
||||
def parse(self):
|
||||
"""Parse our options, create checkpoints directory suffix, and set up gpu device."""
|
||||
opt = self.gather_options()
|
||||
opt.isTrain = self.isTrain # train or test
|
||||
|
||||
# process opt.suffix
|
||||
if opt.suffix:
|
||||
suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
|
||||
opt.name = opt.name + suffix
|
||||
|
||||
self.print_options(opt)
|
||||
|
||||
# set gpu ids
|
||||
str_ids = opt.gpu_ids.split(',')
|
||||
opt.gpu_ids = []
|
||||
for str_id in str_ids:
|
||||
id = int(str_id)
|
||||
if id >= 0:
|
||||
opt.gpu_ids.append(id)
|
||||
if len(opt.gpu_ids) > 0:
|
||||
torch.cuda.set_device(opt.gpu_ids[0])
|
||||
|
||||
self.opt = opt
|
||||
return self.opt
|
||||
21
options/test_options.py
Normal file
21
options/test_options.py
Normal file
@ -0,0 +1,21 @@
|
||||
from .base_options import BaseOptions
|
||||
|
||||
|
||||
class TestOptions(BaseOptions):
|
||||
"""This class includes test options.
|
||||
|
||||
It also includes shared options defined in BaseOptions.
|
||||
"""
|
||||
|
||||
def initialize(self, parser):
|
||||
parser = BaseOptions.initialize(self, parser) # define shared options
|
||||
parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
|
||||
parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
|
||||
# Dropout and Batchnorm has different behavioir during training and test.
|
||||
parser.add_argument('--eval', action='store_true', help='use eval mode during test time.')
|
||||
parser.add_argument('--num_test', type=int, default=50, help='how many test images to run')
|
||||
|
||||
# To avoid cropping, the load_size should be the same as crop_size
|
||||
parser.set_defaults(load_size=parser.get_default('crop_size'))
|
||||
self.isTrain = False
|
||||
return parser
|
||||
47
options/train_options.py
Normal file
47
options/train_options.py
Normal file
@ -0,0 +1,47 @@
|
||||
from .base_options import BaseOptions
|
||||
|
||||
|
||||
class TrainOptions(BaseOptions):
|
||||
"""This class includes training options.
|
||||
|
||||
It also includes shared options defined in BaseOptions.
|
||||
"""
|
||||
|
||||
def initialize(self, parser):
|
||||
parser = BaseOptions.initialize(self, parser)
|
||||
# visdom and HTML visualization parameters
|
||||
parser.add_argument('--display_freq', type=int, default=50, help='frequency of showing training results on screen')
|
||||
parser.add_argument('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.')
|
||||
parser.add_argument('--display_id', type=int, default=None, help='window id of the web display. Default is random window id')
|
||||
parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display')
|
||||
parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")')
|
||||
parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')
|
||||
parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html')
|
||||
parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
|
||||
parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
|
||||
# network saving and loading parameters
|
||||
parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')
|
||||
parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')
|
||||
parser.add_argument('--evaluation_freq', type=int, default=5000, help='evaluation freq')
|
||||
parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration')
|
||||
parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
|
||||
|
||||
# parser.add_argument('--use_mlp', action='store_true', help='use_mlp')
|
||||
# parser.add_argument('--use_tgt_style_src', action='store_true', help='use_tgt_style_src')
|
||||
parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>, ...')
|
||||
parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
|
||||
parser.add_argument('--pretrained_name', type=str, default=None, help='resume training from another checkpoint')
|
||||
|
||||
# training parameters
|
||||
parser.add_argument('--n_epochs', type=int, default=100, help='number of epochs with the initial learning rate')
|
||||
parser.add_argument('--n_epochs_decay', type=int, default=100, help='number of epochs to linearly decay learning rate to zero')
|
||||
parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
|
||||
parser.add_argument('--beta2', type=float, default=0.999, help='momentum term of adam')
|
||||
parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
|
||||
parser.add_argument('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.')
|
||||
parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images')
|
||||
parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]')
|
||||
parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations')
|
||||
|
||||
self.isTrain = True
|
||||
return parser
|
||||
1
scripts/test.sh
Normal file
1
scripts/test.sh
Normal file
@ -0,0 +1 @@
|
||||
CUDA_VISIBLE_DEVICES=0 python test.py --dataroot /path/of/test_dataset --checkpoints_dir ./checkpoints --name train1 --model roma_single --num_test 10000 --epoch latest
|
||||
5
scripts/train.sh
Normal file
5
scripts/train.sh
Normal file
@ -0,0 +1,5 @@
|
||||
# Train for video mode
|
||||
CUDA_VISIBLE_DEVICES=0 python train.py --dataroot /path --name ROMA_name --dataset_mode unaligned_double --no_flip --local_nums 64 --display_env ROMA_env --model roma --side_length 7 --lambda_spatial 5.0 --lambda_global 5.0 --lambda_motion 1.0 --atten_layers 1,3,5 --lr 0.00001
|
||||
|
||||
# Train for image mode
|
||||
CUDA_VISIBLE_DEVICES=0 python train.py --dataroot /path --name ROMA_name --dataset_mode unaligned --local_nums 64 --display_env ROMA_env --model roma --side_length 7 --lambda_spatial 5.0 --lambda_global 5.0 --atten_layers 1,3,5 --lr 0.00001
|
||||
70
test.py
Normal file
70
test.py
Normal file
@ -0,0 +1,70 @@
|
||||
"""General-purpose test script for image-to-image translation.
|
||||
|
||||
Once you have trained your model with train.py, you can use this script to test the model.
|
||||
It will load a saved model from --checkpoints_dir and save the results to --results_dir.
|
||||
|
||||
It first creates model and dataset given the option. It will hard-code some parameters.
|
||||
It then runs inference for --num_test images and save results to an HTML file.
|
||||
|
||||
Example (You need to train models first or download pre-trained models from our website):
|
||||
Test a CycleGAN model (both sides):
|
||||
python test.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan
|
||||
|
||||
Test a CycleGAN model (one side only):
|
||||
python test.py --dataroot datasets/horse2zebra/testA --name horse2zebra_pretrained --model test --no_dropout
|
||||
|
||||
The option '--model test' is used for generating CycleGAN results only for one side.
|
||||
This option will automatically set '--dataset_mode single', which only loads the images from one set.
|
||||
On the contrary, using '--model cycle_gan' requires loading and generating results in both directions,
|
||||
which is sometimes unnecessary. The results will be saved at ./results/.
|
||||
Use '--results_dir <directory_path_to_save_result>' to specify the results directory.
|
||||
|
||||
Test a pix2pix model:
|
||||
python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA
|
||||
|
||||
See options/base_options.py and options/test_options.py for more test options.
|
||||
See training and test tips at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/tips.md
|
||||
See frequently asked questions at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/qa.md
|
||||
"""
|
||||
import os
|
||||
from options.test_options import TestOptions
|
||||
from data import create_dataset
|
||||
from models import create_model
|
||||
from util.visualizer import save_images
|
||||
from util import html
|
||||
import util.util as util
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
opt = TestOptions().parse() # get test options
|
||||
# hard-code some parameters for test
|
||||
opt.num_threads = 0 # test code only supports num_threads = 1
|
||||
opt.batch_size = 1 # test code only supports batch_size = 1
|
||||
opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed.
|
||||
opt.no_flip = True # no flip; comment this line if results on flipped images are needed.
|
||||
opt.display_id = -1 # no visdom display; the test code saves the results to a HTML file.
|
||||
dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options
|
||||
# train_dataset = create_dataset(util.copyconf(opt, phase="train"))
|
||||
model = create_model(opt) # create a model given opt.model and other options
|
||||
# create a webpage for viewing the results
|
||||
web_dir = os.path.join(opt.results_dir, opt.name, '{}_{}'.format(opt.phase, opt.epoch)) # define the website directory
|
||||
print('creating web directory', web_dir)
|
||||
webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.epoch))
|
||||
|
||||
for i, data in enumerate(dataset):
|
||||
if i == 0:
|
||||
model.data_dependent_initialize(data)
|
||||
model.setup(opt) # regular setup: load and print networks; create schedulers
|
||||
model.parallelize()
|
||||
if opt.eval:
|
||||
model.eval()
|
||||
if i >= opt.num_test: # only apply our model to opt.num_test images.
|
||||
break
|
||||
model.set_input(data) # unpack data from data loader
|
||||
model.test() # run inference
|
||||
visuals = model.get_current_visuals() # get image results
|
||||
img_path = model.get_image_paths() # get image paths
|
||||
if i % 5 == 0: # save images to an HTML file
|
||||
print('processing (%04d)-th image... %s' % (i, img_path))
|
||||
save_images(webpage, visuals, img_path, width=opt.display_winsize)
|
||||
webpage.save() # save the HTML
|
||||
4
timm/__init__.py
Normal file
4
timm/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
from .version import __version__
|
||||
from .models import create_model, list_models, is_model, list_modules, model_entrypoint, \
|
||||
is_scriptable, is_exportable, set_scriptable, set_exportable, has_model_default_key, is_model_default_key, \
|
||||
get_model_default_value, is_model_pretrained
|
||||
BIN
timm/__pycache__/__init__.cpython-36.pyc
Normal file
BIN
timm/__pycache__/__init__.cpython-36.pyc
Normal file
Binary file not shown.
BIN
timm/__pycache__/version.cpython-36.pyc
Normal file
BIN
timm/__pycache__/version.cpython-36.pyc
Normal file
Binary file not shown.
12
timm/data/__init__.py
Normal file
12
timm/data/__init__.py
Normal file
@ -0,0 +1,12 @@
|
||||
from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\
|
||||
rand_augment_transform, auto_augment_transform
|
||||
from .config import resolve_data_config
|
||||
from .constants import *
|
||||
from .dataset import ImageDataset, IterableImageDataset, AugMixDataset
|
||||
from .dataset_factory import create_dataset
|
||||
from .loader import create_loader
|
||||
from .mixup import Mixup, FastCollateMixup
|
||||
from .parsers import create_parser
|
||||
from .real_labels import RealLabelsImagenet
|
||||
from .transforms import *
|
||||
from .transforms_factory import create_transform
|
||||
BIN
timm/data/__pycache__/__init__.cpython-36.pyc
Normal file
BIN
timm/data/__pycache__/__init__.cpython-36.pyc
Normal file
Binary file not shown.
BIN
timm/data/__pycache__/auto_augment.cpython-36.pyc
Normal file
BIN
timm/data/__pycache__/auto_augment.cpython-36.pyc
Normal file
Binary file not shown.
BIN
timm/data/__pycache__/config.cpython-36.pyc
Normal file
BIN
timm/data/__pycache__/config.cpython-36.pyc
Normal file
Binary file not shown.
BIN
timm/data/__pycache__/constants.cpython-36.pyc
Normal file
BIN
timm/data/__pycache__/constants.cpython-36.pyc
Normal file
Binary file not shown.
BIN
timm/data/__pycache__/dataset.cpython-36.pyc
Normal file
BIN
timm/data/__pycache__/dataset.cpython-36.pyc
Normal file
Binary file not shown.
BIN
timm/data/__pycache__/dataset_factory.cpython-36.pyc
Normal file
BIN
timm/data/__pycache__/dataset_factory.cpython-36.pyc
Normal file
Binary file not shown.
BIN
timm/data/__pycache__/distributed_sampler.cpython-36.pyc
Normal file
BIN
timm/data/__pycache__/distributed_sampler.cpython-36.pyc
Normal file
Binary file not shown.
BIN
timm/data/__pycache__/loader.cpython-36.pyc
Normal file
BIN
timm/data/__pycache__/loader.cpython-36.pyc
Normal file
Binary file not shown.
BIN
timm/data/__pycache__/mixup.cpython-36.pyc
Normal file
BIN
timm/data/__pycache__/mixup.cpython-36.pyc
Normal file
Binary file not shown.
BIN
timm/data/__pycache__/random_erasing.cpython-36.pyc
Normal file
BIN
timm/data/__pycache__/random_erasing.cpython-36.pyc
Normal file
Binary file not shown.
BIN
timm/data/__pycache__/real_labels.cpython-36.pyc
Normal file
BIN
timm/data/__pycache__/real_labels.cpython-36.pyc
Normal file
Binary file not shown.
BIN
timm/data/__pycache__/transforms.cpython-36.pyc
Normal file
BIN
timm/data/__pycache__/transforms.cpython-36.pyc
Normal file
Binary file not shown.
BIN
timm/data/__pycache__/transforms_factory.cpython-36.pyc
Normal file
BIN
timm/data/__pycache__/transforms_factory.cpython-36.pyc
Normal file
Binary file not shown.
865
timm/data/auto_augment.py
Normal file
865
timm/data/auto_augment.py
Normal file
@ -0,0 +1,865 @@
|
||||
""" AutoAugment, RandAugment, and AugMix for PyTorch
|
||||
|
||||
This code implements the searched ImageNet policies with various tweaks and improvements and
|
||||
does not include any of the search code.
|
||||
|
||||
AA and RA Implementation adapted from:
|
||||
https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
|
||||
|
||||
AugMix adapted from:
|
||||
https://github.com/google-research/augmix
|
||||
|
||||
Papers:
|
||||
AutoAugment: Learning Augmentation Policies from Data - https://arxiv.org/abs/1805.09501
|
||||
Learning Data Augmentation Strategies for Object Detection - https://arxiv.org/abs/1906.11172
|
||||
RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719
|
||||
AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - https://arxiv.org/abs/1912.02781
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import random
|
||||
import math
|
||||
import re
|
||||
from PIL import Image, ImageOps, ImageEnhance, ImageChops
|
||||
import PIL
|
||||
import numpy as np
|
||||
|
||||
|
||||
_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]])
|
||||
|
||||
_FILL = (128, 128, 128)
|
||||
|
||||
_LEVEL_DENOM = 10. # denominator for conversion from 'Mx' magnitude scale to fractional aug level for op arguments
|
||||
|
||||
_HPARAMS_DEFAULT = dict(
|
||||
translate_const=250,
|
||||
img_mean=_FILL,
|
||||
)
|
||||
|
||||
_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
|
||||
|
||||
|
||||
def _interpolation(kwargs):
|
||||
interpolation = kwargs.pop('resample', Image.BILINEAR)
|
||||
if isinstance(interpolation, (list, tuple)):
|
||||
return random.choice(interpolation)
|
||||
else:
|
||||
return interpolation
|
||||
|
||||
|
||||
def _check_args_tf(kwargs):
|
||||
if 'fillcolor' in kwargs and _PIL_VER < (5, 0):
|
||||
kwargs.pop('fillcolor')
|
||||
kwargs['resample'] = _interpolation(kwargs)
|
||||
|
||||
|
||||
def shear_x(img, factor, **kwargs):
|
||||
_check_args_tf(kwargs)
|
||||
return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs)
|
||||
|
||||
|
||||
def shear_y(img, factor, **kwargs):
|
||||
_check_args_tf(kwargs)
|
||||
return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs)
|
||||
|
||||
|
||||
def translate_x_rel(img, pct, **kwargs):
|
||||
pixels = pct * img.size[0]
|
||||
_check_args_tf(kwargs)
|
||||
return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
|
||||
|
||||
|
||||
def translate_y_rel(img, pct, **kwargs):
|
||||
pixels = pct * img.size[1]
|
||||
_check_args_tf(kwargs)
|
||||
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
|
||||
|
||||
|
||||
def translate_x_abs(img, pixels, **kwargs):
|
||||
_check_args_tf(kwargs)
|
||||
return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
|
||||
|
||||
|
||||
def translate_y_abs(img, pixels, **kwargs):
|
||||
_check_args_tf(kwargs)
|
||||
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
|
||||
|
||||
|
||||
def rotate(img, degrees, **kwargs):
|
||||
_check_args_tf(kwargs)
|
||||
if _PIL_VER >= (5, 2):
|
||||
return img.rotate(degrees, **kwargs)
|
||||
elif _PIL_VER >= (5, 0):
|
||||
w, h = img.size
|
||||
post_trans = (0, 0)
|
||||
rotn_center = (w / 2.0, h / 2.0)
|
||||
angle = -math.radians(degrees)
|
||||
matrix = [
|
||||
round(math.cos(angle), 15),
|
||||
round(math.sin(angle), 15),
|
||||
0.0,
|
||||
round(-math.sin(angle), 15),
|
||||
round(math.cos(angle), 15),
|
||||
0.0,
|
||||
]
|
||||
|
||||
def transform(x, y, matrix):
|
||||
(a, b, c, d, e, f) = matrix
|
||||
return a * x + b * y + c, d * x + e * y + f
|
||||
|
||||
matrix[2], matrix[5] = transform(
|
||||
-rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix
|
||||
)
|
||||
matrix[2] += rotn_center[0]
|
||||
matrix[5] += rotn_center[1]
|
||||
return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
|
||||
else:
|
||||
return img.rotate(degrees, resample=kwargs['resample'])
|
||||
|
||||
|
||||
def auto_contrast(img, **__):
|
||||
return ImageOps.autocontrast(img)
|
||||
|
||||
|
||||
def invert(img, **__):
|
||||
return ImageOps.invert(img)
|
||||
|
||||
|
||||
def equalize(img, **__):
|
||||
return ImageOps.equalize(img)
|
||||
|
||||
|
||||
def solarize(img, thresh, **__):
|
||||
return ImageOps.solarize(img, thresh)
|
||||
|
||||
|
||||
def solarize_add(img, add, thresh=128, **__):
|
||||
lut = []
|
||||
for i in range(256):
|
||||
if i < thresh:
|
||||
lut.append(min(255, i + add))
|
||||
else:
|
||||
lut.append(i)
|
||||
if img.mode in ("L", "RGB"):
|
||||
if img.mode == "RGB" and len(lut) == 256:
|
||||
lut = lut + lut + lut
|
||||
return img.point(lut)
|
||||
else:
|
||||
return img
|
||||
|
||||
|
||||
def posterize(img, bits_to_keep, **__):
|
||||
if bits_to_keep >= 8:
|
||||
return img
|
||||
return ImageOps.posterize(img, bits_to_keep)
|
||||
|
||||
|
||||
def contrast(img, factor, **__):
|
||||
return ImageEnhance.Contrast(img).enhance(factor)
|
||||
|
||||
|
||||
def color(img, factor, **__):
|
||||
return ImageEnhance.Color(img).enhance(factor)
|
||||
|
||||
|
||||
def brightness(img, factor, **__):
|
||||
return ImageEnhance.Brightness(img).enhance(factor)
|
||||
|
||||
|
||||
def sharpness(img, factor, **__):
|
||||
return ImageEnhance.Sharpness(img).enhance(factor)
|
||||
|
||||
|
||||
def _randomly_negate(v):
|
||||
"""With 50% prob, negate the value"""
|
||||
return -v if random.random() > 0.5 else v
|
||||
|
||||
|
||||
def _rotate_level_to_arg(level, _hparams):
|
||||
# range [-30, 30]
|
||||
level = (level / _LEVEL_DENOM) * 30.
|
||||
level = _randomly_negate(level)
|
||||
return level,
|
||||
|
||||
|
||||
def _enhance_level_to_arg(level, _hparams):
|
||||
# range [0.1, 1.9]
|
||||
return (level / _LEVEL_DENOM) * 1.8 + 0.1,
|
||||
|
||||
|
||||
def _enhance_increasing_level_to_arg(level, _hparams):
|
||||
# the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend
|
||||
# range [0.1, 1.9] if level <= _LEVEL_DENOM
|
||||
level = (level / _LEVEL_DENOM) * .9
|
||||
level = max(0.1, 1.0 + _randomly_negate(level)) # keep it >= 0.1
|
||||
return level,
|
||||
|
||||
|
||||
def _shear_level_to_arg(level, _hparams):
|
||||
# range [-0.3, 0.3]
|
||||
level = (level / _LEVEL_DENOM) * 0.3
|
||||
level = _randomly_negate(level)
|
||||
return level,
|
||||
|
||||
|
||||
def _translate_abs_level_to_arg(level, hparams):
|
||||
translate_const = hparams['translate_const']
|
||||
level = (level / _LEVEL_DENOM) * float(translate_const)
|
||||
level = _randomly_negate(level)
|
||||
return level,
|
||||
|
||||
|
||||
def _translate_rel_level_to_arg(level, hparams):
|
||||
# default range [-0.45, 0.45]
|
||||
translate_pct = hparams.get('translate_pct', 0.45)
|
||||
level = (level / _LEVEL_DENOM) * translate_pct
|
||||
level = _randomly_negate(level)
|
||||
return level,
|
||||
|
||||
|
||||
def _posterize_level_to_arg(level, _hparams):
|
||||
# As per Tensorflow TPU EfficientNet impl
|
||||
# range [0, 4], 'keep 0 up to 4 MSB of original image'
|
||||
# intensity/severity of augmentation decreases with level
|
||||
return int((level / _LEVEL_DENOM) * 4),
|
||||
|
||||
|
||||
def _posterize_increasing_level_to_arg(level, hparams):
|
||||
# As per Tensorflow models research and UDA impl
|
||||
# range [4, 0], 'keep 4 down to 0 MSB of original image',
|
||||
# intensity/severity of augmentation increases with level
|
||||
return 4 - _posterize_level_to_arg(level, hparams)[0],
|
||||
|
||||
|
||||
def _posterize_original_level_to_arg(level, _hparams):
|
||||
# As per original AutoAugment paper description
|
||||
# range [4, 8], 'keep 4 up to 8 MSB of image'
|
||||
# intensity/severity of augmentation decreases with level
|
||||
return int((level / _LEVEL_DENOM) * 4) + 4,
|
||||
|
||||
|
||||
def _solarize_level_to_arg(level, _hparams):
|
||||
# range [0, 256]
|
||||
# intensity/severity of augmentation decreases with level
|
||||
return int((level / _LEVEL_DENOM) * 256),
|
||||
|
||||
|
||||
def _solarize_increasing_level_to_arg(level, _hparams):
|
||||
# range [0, 256]
|
||||
# intensity/severity of augmentation increases with level
|
||||
return 256 - _solarize_level_to_arg(level, _hparams)[0],
|
||||
|
||||
|
||||
def _solarize_add_level_to_arg(level, _hparams):
|
||||
# range [0, 110]
|
||||
return int((level / _LEVEL_DENOM) * 110),
|
||||
|
||||
|
||||
LEVEL_TO_ARG = {
|
||||
'AutoContrast': None,
|
||||
'Equalize': None,
|
||||
'Invert': None,
|
||||
'Rotate': _rotate_level_to_arg,
|
||||
# There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers
|
||||
'Posterize': _posterize_level_to_arg,
|
||||
'PosterizeIncreasing': _posterize_increasing_level_to_arg,
|
||||
'PosterizeOriginal': _posterize_original_level_to_arg,
|
||||
'Solarize': _solarize_level_to_arg,
|
||||
'SolarizeIncreasing': _solarize_increasing_level_to_arg,
|
||||
'SolarizeAdd': _solarize_add_level_to_arg,
|
||||
'Color': _enhance_level_to_arg,
|
||||
'ColorIncreasing': _enhance_increasing_level_to_arg,
|
||||
'Contrast': _enhance_level_to_arg,
|
||||
'ContrastIncreasing': _enhance_increasing_level_to_arg,
|
||||
'Brightness': _enhance_level_to_arg,
|
||||
'BrightnessIncreasing': _enhance_increasing_level_to_arg,
|
||||
'Sharpness': _enhance_level_to_arg,
|
||||
'SharpnessIncreasing': _enhance_increasing_level_to_arg,
|
||||
'ShearX': _shear_level_to_arg,
|
||||
'ShearY': _shear_level_to_arg,
|
||||
'TranslateX': _translate_abs_level_to_arg,
|
||||
'TranslateY': _translate_abs_level_to_arg,
|
||||
'TranslateXRel': _translate_rel_level_to_arg,
|
||||
'TranslateYRel': _translate_rel_level_to_arg,
|
||||
}
|
||||
|
||||
|
||||
NAME_TO_OP = {
|
||||
'AutoContrast': auto_contrast,
|
||||
'Equalize': equalize,
|
||||
'Invert': invert,
|
||||
'Rotate': rotate,
|
||||
'Posterize': posterize,
|
||||
'PosterizeIncreasing': posterize,
|
||||
'PosterizeOriginal': posterize,
|
||||
'Solarize': solarize,
|
||||
'SolarizeIncreasing': solarize,
|
||||
'SolarizeAdd': solarize_add,
|
||||
'Color': color,
|
||||
'ColorIncreasing': color,
|
||||
'Contrast': contrast,
|
||||
'ContrastIncreasing': contrast,
|
||||
'Brightness': brightness,
|
||||
'BrightnessIncreasing': brightness,
|
||||
'Sharpness': sharpness,
|
||||
'SharpnessIncreasing': sharpness,
|
||||
'ShearX': shear_x,
|
||||
'ShearY': shear_y,
|
||||
'TranslateX': translate_x_abs,
|
||||
'TranslateY': translate_y_abs,
|
||||
'TranslateXRel': translate_x_rel,
|
||||
'TranslateYRel': translate_y_rel,
|
||||
}
|
||||
|
||||
|
||||
class AugmentOp:
|
||||
|
||||
def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
|
||||
hparams = hparams or _HPARAMS_DEFAULT
|
||||
self.name = name
|
||||
self.aug_fn = NAME_TO_OP[name]
|
||||
self.level_fn = LEVEL_TO_ARG[name]
|
||||
self.prob = prob
|
||||
self.magnitude = magnitude
|
||||
self.hparams = hparams.copy()
|
||||
self.kwargs = dict(
|
||||
fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL,
|
||||
resample=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION,
|
||||
)
|
||||
|
||||
# If magnitude_std is > 0, we introduce some randomness
|
||||
# in the usually fixed policy and sample magnitude from a normal distribution
|
||||
# with mean `magnitude` and std-dev of `magnitude_std`.
|
||||
# NOTE This is my own hack, being tested, not in papers or reference impls.
|
||||
# If magnitude_std is inf, we sample magnitude from a uniform distribution
|
||||
self.magnitude_std = self.hparams.get('magnitude_std', 0)
|
||||
self.magnitude_max = self.hparams.get('magnitude_max', None)
|
||||
|
||||
def __call__(self, img):
|
||||
if self.prob < 1.0 and random.random() > self.prob:
|
||||
return img
|
||||
magnitude = self.magnitude
|
||||
if self.magnitude_std > 0:
|
||||
# magnitude randomization enabled
|
||||
if self.magnitude_std == float('inf'):
|
||||
magnitude = random.uniform(0, magnitude)
|
||||
elif self.magnitude_std > 0:
|
||||
magnitude = random.gauss(magnitude, self.magnitude_std)
|
||||
# default upper_bound for the timm RA impl is _LEVEL_DENOM (10)
|
||||
# setting magnitude_max overrides this to allow M > 10 (behaviour closer to Google TF RA impl)
|
||||
upper_bound = self.magnitude_max or _LEVEL_DENOM
|
||||
magnitude = max(0., min(magnitude, upper_bound))
|
||||
level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple()
|
||||
return self.aug_fn(img, *level_args, **self.kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
fs = self.__class__.__name__ + f'(name={self.name}, p={self.prob}'
|
||||
fs += f', m={self.magnitude}, mstd={self.magnitude_std}'
|
||||
if self.magnitude_max is not None:
|
||||
fs += f', mmax={self.magnitude_max}'
|
||||
fs += ')'
|
||||
return fs
|
||||
|
||||
|
||||
def auto_augment_policy_v0(hparams):
|
||||
# ImageNet v0 policy from TPU EfficientNet impl, cannot find a paper reference.
|
||||
policy = [
|
||||
[('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
|
||||
[('Color', 0.4, 9), ('Equalize', 0.6, 3)],
|
||||
[('Color', 0.4, 1), ('Rotate', 0.6, 8)],
|
||||
[('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
|
||||
[('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
|
||||
[('Color', 0.2, 0), ('Equalize', 0.8, 8)],
|
||||
[('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
|
||||
[('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
|
||||
[('Color', 0.6, 1), ('Equalize', 1.0, 2)],
|
||||
[('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
|
||||
[('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
|
||||
[('Color', 0.4, 7), ('Equalize', 0.6, 0)],
|
||||
[('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)],
|
||||
[('Solarize', 0.6, 8), ('Color', 0.6, 9)],
|
||||
[('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
|
||||
[('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
|
||||
[('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
|
||||
[('ShearY', 0.8, 0), ('Color', 0.6, 4)],
|
||||
[('Color', 1.0, 0), ('Rotate', 0.6, 2)],
|
||||
[('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
|
||||
[('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
|
||||
[('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
|
||||
[('Posterize', 0.8, 2), ('Solarize', 0.6, 10)], # This results in black image with Tpu posterize
|
||||
[('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
|
||||
[('Color', 0.8, 6), ('Rotate', 0.4, 5)],
|
||||
]
|
||||
pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
|
||||
return pc
|
||||
|
||||
|
||||
def auto_augment_policy_v0r(hparams):
|
||||
# ImageNet v0 policy from TPU EfficientNet impl, with variation of Posterize used
|
||||
# in Google research implementation (number of bits discarded increases with magnitude)
|
||||
policy = [
|
||||
[('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
|
||||
[('Color', 0.4, 9), ('Equalize', 0.6, 3)],
|
||||
[('Color', 0.4, 1), ('Rotate', 0.6, 8)],
|
||||
[('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
|
||||
[('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
|
||||
[('Color', 0.2, 0), ('Equalize', 0.8, 8)],
|
||||
[('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
|
||||
[('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
|
||||
[('Color', 0.6, 1), ('Equalize', 1.0, 2)],
|
||||
[('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
|
||||
[('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
|
||||
[('Color', 0.4, 7), ('Equalize', 0.6, 0)],
|
||||
[('PosterizeIncreasing', 0.4, 6), ('AutoContrast', 0.4, 7)],
|
||||
[('Solarize', 0.6, 8), ('Color', 0.6, 9)],
|
||||
[('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
|
||||
[('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
|
||||
[('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
|
||||
[('ShearY', 0.8, 0), ('Color', 0.6, 4)],
|
||||
[('Color', 1.0, 0), ('Rotate', 0.6, 2)],
|
||||
[('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
|
||||
[('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
|
||||
[('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
|
||||
[('PosterizeIncreasing', 0.8, 2), ('Solarize', 0.6, 10)],
|
||||
[('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
|
||||
[('Color', 0.8, 6), ('Rotate', 0.4, 5)],
|
||||
]
|
||||
pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
|
||||
return pc
|
||||
|
||||
|
||||
def auto_augment_policy_original(hparams):
|
||||
# ImageNet policy from https://arxiv.org/abs/1805.09501
|
||||
policy = [
|
||||
[('PosterizeOriginal', 0.4, 8), ('Rotate', 0.6, 9)],
|
||||
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
|
||||
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
|
||||
[('PosterizeOriginal', 0.6, 7), ('PosterizeOriginal', 0.6, 6)],
|
||||
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
|
||||
[('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
|
||||
[('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
|
||||
[('PosterizeOriginal', 0.8, 5), ('Equalize', 1.0, 2)],
|
||||
[('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
|
||||
[('Equalize', 0.6, 8), ('PosterizeOriginal', 0.4, 6)],
|
||||
[('Rotate', 0.8, 8), ('Color', 0.4, 0)],
|
||||
[('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
|
||||
[('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
|
||||
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
|
||||
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
|
||||
[('Rotate', 0.8, 8), ('Color', 1.0, 2)],
|
||||
[('Color', 0.8, 8), ('Solarize', 0.8, 7)],
|
||||
[('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
|
||||
[('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
|
||||
[('Color', 0.4, 0), ('Equalize', 0.6, 3)],
|
||||
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
|
||||
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
|
||||
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
|
||||
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
|
||||
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
|
||||
]
|
||||
pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
|
||||
return pc
|
||||
|
||||
|
||||
def auto_augment_policy_originalr(hparams):
|
||||
# ImageNet policy from https://arxiv.org/abs/1805.09501 with research posterize variation
|
||||
policy = [
|
||||
[('PosterizeIncreasing', 0.4, 8), ('Rotate', 0.6, 9)],
|
||||
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
|
||||
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
|
||||
[('PosterizeIncreasing', 0.6, 7), ('PosterizeIncreasing', 0.6, 6)],
|
||||
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
|
||||
[('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
|
||||
[('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
|
||||
[('PosterizeIncreasing', 0.8, 5), ('Equalize', 1.0, 2)],
|
||||
[('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
|
||||
[('Equalize', 0.6, 8), ('PosterizeIncreasing', 0.4, 6)],
|
||||
[('Rotate', 0.8, 8), ('Color', 0.4, 0)],
|
||||
[('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
|
||||
[('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
|
||||
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
|
||||
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
|
||||
[('Rotate', 0.8, 8), ('Color', 1.0, 2)],
|
||||
[('Color', 0.8, 8), ('Solarize', 0.8, 7)],
|
||||
[('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
|
||||
[('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
|
||||
[('Color', 0.4, 0), ('Equalize', 0.6, 3)],
|
||||
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
|
||||
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
|
||||
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
|
||||
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
|
||||
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
|
||||
]
|
||||
pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
|
||||
return pc
|
||||
|
||||
|
||||
def auto_augment_policy(name='v0', hparams=None):
|
||||
hparams = hparams or _HPARAMS_DEFAULT
|
||||
if name == 'original':
|
||||
return auto_augment_policy_original(hparams)
|
||||
elif name == 'originalr':
|
||||
return auto_augment_policy_originalr(hparams)
|
||||
elif name == 'v0':
|
||||
return auto_augment_policy_v0(hparams)
|
||||
elif name == 'v0r':
|
||||
return auto_augment_policy_v0r(hparams)
|
||||
else:
|
||||
assert False, 'Unknown AA policy (%s)' % name
|
||||
|
||||
|
||||
class AutoAugment:
|
||||
|
||||
def __init__(self, policy):
|
||||
self.policy = policy
|
||||
|
||||
def __call__(self, img):
|
||||
sub_policy = random.choice(self.policy)
|
||||
for op in sub_policy:
|
||||
img = op(img)
|
||||
return img
|
||||
|
||||
def __repr__(self):
|
||||
fs = self.__class__.__name__ + f'(policy='
|
||||
for p in self.policy:
|
||||
fs += '\n\t['
|
||||
fs += ', '.join([str(op) for op in p])
|
||||
fs += ']'
|
||||
fs += ')'
|
||||
return fs
|
||||
|
||||
|
||||
def auto_augment_transform(config_str, hparams):
|
||||
"""
|
||||
Create a AutoAugment transform
|
||||
|
||||
:param config_str: String defining configuration of auto augmentation. Consists of multiple sections separated by
|
||||
dashes ('-'). The first section defines the AutoAugment policy (one of 'v0', 'v0r', 'original', 'originalr').
|
||||
The remaining sections, not order sepecific determine
|
||||
'mstd' - float std deviation of magnitude noise applied
|
||||
Ex 'original-mstd0.5' results in AutoAugment with original policy, magnitude_std 0.5
|
||||
|
||||
:param hparams: Other hparams (kwargs) for the AutoAugmentation scheme
|
||||
|
||||
:return: A PyTorch compatible Transform
|
||||
"""
|
||||
config = config_str.split('-')
|
||||
policy_name = config[0]
|
||||
config = config[1:]
|
||||
for c in config:
|
||||
cs = re.split(r'(\d.*)', c)
|
||||
if len(cs) < 2:
|
||||
continue
|
||||
key, val = cs[:2]
|
||||
if key == 'mstd':
|
||||
# noise param injected via hparams for now
|
||||
hparams.setdefault('magnitude_std', float(val))
|
||||
else:
|
||||
assert False, 'Unknown AutoAugment config section'
|
||||
aa_policy = auto_augment_policy(policy_name, hparams=hparams)
|
||||
return AutoAugment(aa_policy)
|
||||
|
||||
|
||||
_RAND_TRANSFORMS = [
|
||||
'AutoContrast',
|
||||
'Equalize',
|
||||
'Invert',
|
||||
'Rotate',
|
||||
'Posterize',
|
||||
'Solarize',
|
||||
'SolarizeAdd',
|
||||
'Color',
|
||||
'Contrast',
|
||||
'Brightness',
|
||||
'Sharpness',
|
||||
'ShearX',
|
||||
'ShearY',
|
||||
'TranslateXRel',
|
||||
'TranslateYRel',
|
||||
#'Cutout' # NOTE I've implement this as random erasing separately
|
||||
]
|
||||
|
||||
|
||||
_RAND_INCREASING_TRANSFORMS = [
|
||||
'AutoContrast',
|
||||
'Equalize',
|
||||
'Invert',
|
||||
'Rotate',
|
||||
'PosterizeIncreasing',
|
||||
'SolarizeIncreasing',
|
||||
'SolarizeAdd',
|
||||
'ColorIncreasing',
|
||||
'ContrastIncreasing',
|
||||
'BrightnessIncreasing',
|
||||
'SharpnessIncreasing',
|
||||
'ShearX',
|
||||
'ShearY',
|
||||
'TranslateXRel',
|
||||
'TranslateYRel',
|
||||
#'Cutout' # NOTE I've implement this as random erasing separately
|
||||
]
|
||||
|
||||
|
||||
|
||||
# These experimental weights are based loosely on the relative improvements mentioned in paper.
|
||||
# They may not result in increased performance, but could likely be tuned to so.
|
||||
_RAND_CHOICE_WEIGHTS_0 = {
|
||||
'Rotate': 0.3,
|
||||
'ShearX': 0.2,
|
||||
'ShearY': 0.2,
|
||||
'TranslateXRel': 0.1,
|
||||
'TranslateYRel': 0.1,
|
||||
'Color': .025,
|
||||
'Sharpness': 0.025,
|
||||
'AutoContrast': 0.025,
|
||||
'Solarize': .005,
|
||||
'SolarizeAdd': .005,
|
||||
'Contrast': .005,
|
||||
'Brightness': .005,
|
||||
'Equalize': .005,
|
||||
'Posterize': 0,
|
||||
'Invert': 0,
|
||||
}
|
||||
|
||||
|
||||
def _select_rand_weights(weight_idx=0, transforms=None):
|
||||
transforms = transforms or _RAND_TRANSFORMS
|
||||
assert weight_idx == 0 # only one set of weights currently
|
||||
rand_weights = _RAND_CHOICE_WEIGHTS_0
|
||||
probs = [rand_weights[k] for k in transforms]
|
||||
probs /= np.sum(probs)
|
||||
return probs
|
||||
|
||||
|
||||
def rand_augment_ops(magnitude=10, hparams=None, transforms=None):
|
||||
hparams = hparams or _HPARAMS_DEFAULT
|
||||
transforms = transforms or _RAND_TRANSFORMS
|
||||
return [AugmentOp(
|
||||
name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms]
|
||||
|
||||
|
||||
class RandAugment:
|
||||
def __init__(self, ops, num_layers=2, choice_weights=None):
|
||||
self.ops = ops
|
||||
self.num_layers = num_layers
|
||||
self.choice_weights = choice_weights
|
||||
|
||||
def __call__(self, img):
|
||||
# no replacement when using weighted choice
|
||||
ops = np.random.choice(
|
||||
self.ops, self.num_layers, replace=self.choice_weights is None, p=self.choice_weights)
|
||||
for op in ops:
|
||||
img = op(img)
|
||||
return img
|
||||
|
||||
def __repr__(self):
|
||||
fs = self.__class__.__name__ + f'(n={self.num_layers}, ops='
|
||||
for op in self.ops:
|
||||
fs += f'\n\t{op}'
|
||||
fs += ')'
|
||||
return fs
|
||||
|
||||
|
||||
def rand_augment_transform(config_str, hparams):
|
||||
"""
|
||||
Create a RandAugment transform
|
||||
|
||||
:param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
|
||||
dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
|
||||
sections, not order sepecific determine
|
||||
'm' - integer magnitude of rand augment
|
||||
'n' - integer num layers (number of transform ops selected per image)
|
||||
'w' - integer probabiliy weight index (index of a set of weights to influence choice of op)
|
||||
'mstd' - float std deviation of magnitude noise applied, or uniform sampling if infinity (or > 100)
|
||||
'mmax' - set upper bound for magnitude to something other than default of _LEVEL_DENOM (10)
|
||||
'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0)
|
||||
Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
|
||||
'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2
|
||||
|
||||
:param hparams: Other hparams (kwargs) for the RandAugmentation scheme
|
||||
|
||||
:return: A PyTorch compatible Transform
|
||||
"""
|
||||
magnitude = _LEVEL_DENOM # default to _LEVEL_DENOM for magnitude (currently 10)
|
||||
num_layers = 2 # default to 2 ops per image
|
||||
weight_idx = None # default to no probability weights for op choice
|
||||
transforms = _RAND_TRANSFORMS
|
||||
config = config_str.split('-')
|
||||
assert config[0] == 'rand'
|
||||
config = config[1:]
|
||||
for c in config:
|
||||
cs = re.split(r'(\d.*)', c)
|
||||
if len(cs) < 2:
|
||||
continue
|
||||
key, val = cs[:2]
|
||||
if key == 'mstd':
|
||||
# noise param / randomization of magnitude values
|
||||
mstd = float(val)
|
||||
if mstd > 100:
|
||||
# use uniform sampling in 0 to magnitude if mstd is > 100
|
||||
mstd = float('inf')
|
||||
hparams.setdefault('magnitude_std', mstd)
|
||||
elif key == 'mmax':
|
||||
# clip magnitude between [0, mmax] instead of default [0, _LEVEL_DENOM]
|
||||
hparams.setdefault('magnitude_max', int(val))
|
||||
elif key == 'inc':
|
||||
if bool(val):
|
||||
transforms = _RAND_INCREASING_TRANSFORMS
|
||||
elif key == 'm':
|
||||
magnitude = int(val)
|
||||
elif key == 'n':
|
||||
num_layers = int(val)
|
||||
elif key == 'w':
|
||||
weight_idx = int(val)
|
||||
else:
|
||||
assert False, 'Unknown RandAugment config section'
|
||||
ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams, transforms=transforms)
|
||||
choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx)
|
||||
return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)
|
||||
|
||||
|
||||
_AUGMIX_TRANSFORMS = [
|
||||
'AutoContrast',
|
||||
'ColorIncreasing', # not in paper
|
||||
'ContrastIncreasing', # not in paper
|
||||
'BrightnessIncreasing', # not in paper
|
||||
'SharpnessIncreasing', # not in paper
|
||||
'Equalize',
|
||||
'Rotate',
|
||||
'PosterizeIncreasing',
|
||||
'SolarizeIncreasing',
|
||||
'ShearX',
|
||||
'ShearY',
|
||||
'TranslateXRel',
|
||||
'TranslateYRel',
|
||||
]
|
||||
|
||||
|
||||
def augmix_ops(magnitude=10, hparams=None, transforms=None):
|
||||
hparams = hparams or _HPARAMS_DEFAULT
|
||||
transforms = transforms or _AUGMIX_TRANSFORMS
|
||||
return [AugmentOp(
|
||||
name, prob=1.0, magnitude=magnitude, hparams=hparams) for name in transforms]
|
||||
|
||||
|
||||
class AugMixAugment:
|
||||
""" AugMix Transform
|
||||
Adapted and improved from impl here: https://github.com/google-research/augmix/blob/master/imagenet.py
|
||||
From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty -
|
||||
https://arxiv.org/abs/1912.02781
|
||||
"""
|
||||
def __init__(self, ops, alpha=1., width=3, depth=-1, blended=False):
|
||||
self.ops = ops
|
||||
self.alpha = alpha
|
||||
self.width = width
|
||||
self.depth = depth
|
||||
self.blended = blended # blended mode is faster but not well tested
|
||||
|
||||
def _calc_blended_weights(self, ws, m):
|
||||
ws = ws * m
|
||||
cump = 1.
|
||||
rws = []
|
||||
for w in ws[::-1]:
|
||||
alpha = w / cump
|
||||
cump *= (1 - alpha)
|
||||
rws.append(alpha)
|
||||
return np.array(rws[::-1], dtype=np.float32)
|
||||
|
||||
def _apply_blended(self, img, mixing_weights, m):
|
||||
# This is my first crack and implementing a slightly faster mixed augmentation. Instead
|
||||
# of accumulating the mix for each chain in a Numpy array and then blending with original,
|
||||
# it recomputes the blending coefficients and applies one PIL image blend per chain.
|
||||
# TODO the results appear in the right ballpark but they differ by more than rounding.
|
||||
img_orig = img.copy()
|
||||
ws = self._calc_blended_weights(mixing_weights, m)
|
||||
for w in ws:
|
||||
depth = self.depth if self.depth > 0 else np.random.randint(1, 4)
|
||||
ops = np.random.choice(self.ops, depth, replace=True)
|
||||
img_aug = img_orig # no ops are in-place, deep copy not necessary
|
||||
for op in ops:
|
||||
img_aug = op(img_aug)
|
||||
img = Image.blend(img, img_aug, w)
|
||||
return img
|
||||
|
||||
def _apply_basic(self, img, mixing_weights, m):
|
||||
# This is a literal adaptation of the paper/official implementation without normalizations and
|
||||
# PIL <-> Numpy conversions between every op. It is still quite CPU compute heavy compared to the
|
||||
# typical augmentation transforms, could use a GPU / Kornia implementation.
|
||||
img_shape = img.size[0], img.size[1], len(img.getbands())
|
||||
mixed = np.zeros(img_shape, dtype=np.float32)
|
||||
for mw in mixing_weights:
|
||||
depth = self.depth if self.depth > 0 else np.random.randint(1, 4)
|
||||
ops = np.random.choice(self.ops, depth, replace=True)
|
||||
img_aug = img # no ops are in-place, deep copy not necessary
|
||||
for op in ops:
|
||||
img_aug = op(img_aug)
|
||||
mixed += mw * np.asarray(img_aug, dtype=np.float32)
|
||||
np.clip(mixed, 0, 255., out=mixed)
|
||||
mixed = Image.fromarray(mixed.astype(np.uint8))
|
||||
return Image.blend(img, mixed, m)
|
||||
|
||||
def __call__(self, img):
|
||||
mixing_weights = np.float32(np.random.dirichlet([self.alpha] * self.width))
|
||||
m = np.float32(np.random.beta(self.alpha, self.alpha))
|
||||
if self.blended:
|
||||
mixed = self._apply_blended(img, mixing_weights, m)
|
||||
else:
|
||||
mixed = self._apply_basic(img, mixing_weights, m)
|
||||
return mixed
|
||||
|
||||
def __repr__(self):
|
||||
fs = self.__class__.__name__ + f'(alpha={self.alpha}, width={self.width}, depth={self.depth}, ops='
|
||||
for op in self.ops:
|
||||
fs += f'\n\t{op}'
|
||||
fs += ')'
|
||||
return fs
|
||||
|
||||
|
||||
def augment_and_mix_transform(config_str, hparams):
|
||||
""" Create AugMix PyTorch transform
|
||||
|
||||
:param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
|
||||
dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
|
||||
sections, not order sepecific determine
|
||||
'm' - integer magnitude (severity) of augmentation mix (default: 3)
|
||||
'w' - integer width of augmentation chain (default: 3)
|
||||
'd' - integer depth of augmentation chain (-1 is random [1, 3], default: -1)
|
||||
'b' - integer (bool), blend each branch of chain into end result without a final blend, less CPU (default: 0)
|
||||
'mstd' - float std deviation of magnitude noise applied (default: 0)
|
||||
Ex 'augmix-m5-w4-d2' results in AugMix with severity 5, chain width 4, chain depth 2
|
||||
|
||||
:param hparams: Other hparams (kwargs) for the Augmentation transforms
|
||||
|
||||
:return: A PyTorch compatible Transform
|
||||
"""
|
||||
magnitude = 3
|
||||
width = 3
|
||||
depth = -1
|
||||
alpha = 1.
|
||||
blended = False
|
||||
config = config_str.split('-')
|
||||
assert config[0] == 'augmix'
|
||||
config = config[1:]
|
||||
for c in config:
|
||||
cs = re.split(r'(\d.*)', c)
|
||||
if len(cs) < 2:
|
||||
continue
|
||||
key, val = cs[:2]
|
||||
if key == 'mstd':
|
||||
# noise param injected via hparams for now
|
||||
hparams.setdefault('magnitude_std', float(val))
|
||||
elif key == 'm':
|
||||
magnitude = int(val)
|
||||
elif key == 'w':
|
||||
width = int(val)
|
||||
elif key == 'd':
|
||||
depth = int(val)
|
||||
elif key == 'a':
|
||||
alpha = float(val)
|
||||
elif key == 'b':
|
||||
blended = bool(val)
|
||||
else:
|
||||
assert False, 'Unknown AugMix config section'
|
||||
hparams.setdefault('magnitude_std', float('inf')) # default to uniform sampling (if not set via mstd arg)
|
||||
ops = augmix_ops(magnitude=magnitude, hparams=hparams)
|
||||
return AugMixAugment(ops, alpha=alpha, width=width, depth=depth, blended=blended)
|
||||
78
timm/data/config.py
Normal file
78
timm/data/config.py
Normal file
@ -0,0 +1,78 @@
|
||||
import logging
|
||||
from .constants import *
|
||||
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, verbose=False):
|
||||
new_config = {}
|
||||
default_cfg = default_cfg
|
||||
if not default_cfg and model is not None and hasattr(model, 'default_cfg'):
|
||||
default_cfg = model.default_cfg
|
||||
|
||||
# Resolve input/image size
|
||||
in_chans = 3
|
||||
if 'chans' in args and args['chans'] is not None:
|
||||
in_chans = args['chans']
|
||||
|
||||
input_size = (in_chans, 224, 224)
|
||||
if 'input_size' in args and args['input_size'] is not None:
|
||||
assert isinstance(args['input_size'], (tuple, list))
|
||||
assert len(args['input_size']) == 3
|
||||
input_size = tuple(args['input_size'])
|
||||
in_chans = input_size[0] # input_size overrides in_chans
|
||||
elif 'img_size' in args and args['img_size'] is not None:
|
||||
assert isinstance(args['img_size'], int)
|
||||
input_size = (in_chans, args['img_size'], args['img_size'])
|
||||
else:
|
||||
if use_test_size and 'test_input_size' in default_cfg:
|
||||
input_size = default_cfg['test_input_size']
|
||||
elif 'input_size' in default_cfg:
|
||||
input_size = default_cfg['input_size']
|
||||
new_config['input_size'] = input_size
|
||||
|
||||
# resolve interpolation method
|
||||
new_config['interpolation'] = 'bicubic'
|
||||
if 'interpolation' in args and args['interpolation']:
|
||||
new_config['interpolation'] = args['interpolation']
|
||||
elif 'interpolation' in default_cfg:
|
||||
new_config['interpolation'] = default_cfg['interpolation']
|
||||
|
||||
# resolve dataset + model mean for normalization
|
||||
new_config['mean'] = IMAGENET_DEFAULT_MEAN
|
||||
if 'mean' in args and args['mean'] is not None:
|
||||
mean = tuple(args['mean'])
|
||||
if len(mean) == 1:
|
||||
mean = tuple(list(mean) * in_chans)
|
||||
else:
|
||||
assert len(mean) == in_chans
|
||||
new_config['mean'] = mean
|
||||
elif 'mean' in default_cfg:
|
||||
new_config['mean'] = default_cfg['mean']
|
||||
|
||||
# resolve dataset + model std deviation for normalization
|
||||
new_config['std'] = IMAGENET_DEFAULT_STD
|
||||
if 'std' in args and args['std'] is not None:
|
||||
std = tuple(args['std'])
|
||||
if len(std) == 1:
|
||||
std = tuple(list(std) * in_chans)
|
||||
else:
|
||||
assert len(std) == in_chans
|
||||
new_config['std'] = std
|
||||
elif 'std' in default_cfg:
|
||||
new_config['std'] = default_cfg['std']
|
||||
|
||||
# resolve default crop percentage
|
||||
new_config['crop_pct'] = DEFAULT_CROP_PCT
|
||||
if 'crop_pct' in args and args['crop_pct'] is not None:
|
||||
new_config['crop_pct'] = args['crop_pct']
|
||||
elif 'crop_pct' in default_cfg:
|
||||
new_config['crop_pct'] = default_cfg['crop_pct']
|
||||
|
||||
if verbose:
|
||||
_logger.info('Data processing configuration for current model + dataset:')
|
||||
for n, v in new_config.items():
|
||||
_logger.info('\t%s: %s' % (n, str(v)))
|
||||
|
||||
return new_config
|
||||
7
timm/data/constants.py
Normal file
7
timm/data/constants.py
Normal file
@ -0,0 +1,7 @@
|
||||
DEFAULT_CROP_PCT = 0.875
|
||||
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
||||
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
||||
IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
|
||||
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
|
||||
IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255)
|
||||
IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3)
|
||||
152
timm/data/dataset.py
Normal file
152
timm/data/dataset.py
Normal file
@ -0,0 +1,152 @@
|
||||
""" Quick n Simple Image Folder, Tarfile based DataSet
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import torch.utils.data as data
|
||||
import os
|
||||
import torch
|
||||
import logging
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from .parsers import create_parser
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_ERROR_RETRY = 50
|
||||
|
||||
|
||||
class ImageDataset(data.Dataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root,
|
||||
parser=None,
|
||||
class_map=None,
|
||||
load_bytes=False,
|
||||
transform=None,
|
||||
target_transform=None,
|
||||
):
|
||||
if parser is None or isinstance(parser, str):
|
||||
parser = create_parser(parser or '', root=root, class_map=class_map)
|
||||
self.parser = parser
|
||||
self.load_bytes = load_bytes
|
||||
self.transform = transform
|
||||
self.target_transform = target_transform
|
||||
self._consecutive_errors = 0
|
||||
|
||||
def __getitem__(self, index):
|
||||
img, target = self.parser[index]
|
||||
try:
|
||||
img = img.read() if self.load_bytes else Image.open(img).convert('RGB')
|
||||
except Exception as e:
|
||||
_logger.warning(f'Skipped sample (index {index}, file {self.parser.filename(index)}). {str(e)}')
|
||||
self._consecutive_errors += 1
|
||||
if self._consecutive_errors < _ERROR_RETRY:
|
||||
return self.__getitem__((index + 1) % len(self.parser))
|
||||
else:
|
||||
raise e
|
||||
self._consecutive_errors = 0
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
if target is None:
|
||||
target = -1
|
||||
elif self.target_transform is not None:
|
||||
target = self.target_transform(target)
|
||||
return img, target
|
||||
|
||||
def __len__(self):
|
||||
return len(self.parser)
|
||||
|
||||
def filename(self, index, basename=False, absolute=False):
|
||||
return self.parser.filename(index, basename, absolute)
|
||||
|
||||
def filenames(self, basename=False, absolute=False):
|
||||
return self.parser.filenames(basename, absolute)
|
||||
|
||||
|
||||
class IterableImageDataset(data.IterableDataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root,
|
||||
parser=None,
|
||||
split='train',
|
||||
is_training=False,
|
||||
batch_size=None,
|
||||
repeats=0,
|
||||
download=False,
|
||||
transform=None,
|
||||
target_transform=None,
|
||||
):
|
||||
assert parser is not None
|
||||
if isinstance(parser, str):
|
||||
self.parser = create_parser(
|
||||
parser, root=root, split=split, is_training=is_training,
|
||||
batch_size=batch_size, repeats=repeats, download=download)
|
||||
else:
|
||||
self.parser = parser
|
||||
self.transform = transform
|
||||
self.target_transform = target_transform
|
||||
self._consecutive_errors = 0
|
||||
|
||||
def __iter__(self):
|
||||
for img, target in self.parser:
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
if self.target_transform is not None:
|
||||
target = self.target_transform(target)
|
||||
yield img, target
|
||||
|
||||
def __len__(self):
|
||||
if hasattr(self.parser, '__len__'):
|
||||
return len(self.parser)
|
||||
else:
|
||||
return 0
|
||||
|
||||
def filename(self, index, basename=False, absolute=False):
|
||||
assert False, 'Filename lookup by index not supported, use filenames().'
|
||||
|
||||
def filenames(self, basename=False, absolute=False):
|
||||
return self.parser.filenames(basename, absolute)
|
||||
|
||||
|
||||
class AugMixDataset(torch.utils.data.Dataset):
|
||||
"""Dataset wrapper to perform AugMix or other clean/augmentation mixes"""
|
||||
|
||||
def __init__(self, dataset, num_splits=2):
|
||||
self.augmentation = None
|
||||
self.normalize = None
|
||||
self.dataset = dataset
|
||||
if self.dataset.transform is not None:
|
||||
self._set_transforms(self.dataset.transform)
|
||||
self.num_splits = num_splits
|
||||
|
||||
def _set_transforms(self, x):
|
||||
assert isinstance(x, (list, tuple)) and len(x) == 3, 'Expecting a tuple/list of 3 transforms'
|
||||
self.dataset.transform = x[0]
|
||||
self.augmentation = x[1]
|
||||
self.normalize = x[2]
|
||||
|
||||
@property
|
||||
def transform(self):
|
||||
return self.dataset.transform
|
||||
|
||||
@transform.setter
|
||||
def transform(self, x):
|
||||
self._set_transforms(x)
|
||||
|
||||
def _normalize(self, x):
|
||||
return x if self.normalize is None else self.normalize(x)
|
||||
|
||||
def __getitem__(self, i):
|
||||
x, y = self.dataset[i] # all splits share the same dataset base transform
|
||||
x_list = [self._normalize(x)] # first split only normalizes (this is the 'clean' split)
|
||||
# run the full augmentation on the remaining splits
|
||||
for _ in range(self.num_splits - 1):
|
||||
x_list.append(self._normalize(self.augmentation(x)))
|
||||
return tuple(x_list), y
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user