Spaces:
Runtime error
Runtime error
Added e4e code
Browse files- e4e/LICENSE +21 -0
- e4e/README.md +142 -0
- e4e/configs/__init__.py +0 -0
- e4e/configs/data_configs.py +41 -0
- e4e/configs/paths_config.py +28 -0
- e4e/configs/transforms_config.py +62 -0
- e4e/models/__init__.py +0 -0
- e4e/models/discriminator.py +20 -0
- e4e/models/encoders/__init__.py +0 -0
- e4e/models/encoders/helpers.py +140 -0
- e4e/models/encoders/model_irse.py +84 -0
- e4e/models/encoders/psp_encoders.py +200 -0
- e4e/models/latent_codes_pool.py +55 -0
- e4e/models/psp.py +99 -0
- e4e/models/stylegan2/__init__.py +0 -0
- e4e/models/stylegan2/model.py +678 -0
- e4e/models/stylegan2/op/__init__.py +0 -0
- e4e/models/stylegan2/op/fused_act.py +85 -0
- e4e/models/stylegan2/op/fused_bias_act.cpp +21 -0
- e4e/models/stylegan2/op/fused_bias_act_kernel.cu +99 -0
- e4e/models/stylegan2/op/upfirdn2d.cpp +23 -0
- e4e/models/stylegan2/op/upfirdn2d.py +184 -0
- e4e/models/stylegan2/op/upfirdn2d_kernel.cu +272 -0
- e4e/options/__init__.py +0 -0
- e4e/options/train_options.py +84 -0
- e4e/scripts/calc_losses_on_images.py +87 -0
- e4e/scripts/inference.py +133 -0
- e4e/scripts/train.py +88 -0
- e4e/utils/__init__.py +0 -0
- e4e/utils/alignment.py +115 -0
- e4e/utils/common.py +55 -0
- e4e/utils/data_utils.py +25 -0
- e4e/utils/model_utils.py +35 -0
- e4e/utils/train_utils.py +13 -0
e4e/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2021 omertov
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
e4e/README.md
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Designing an Encoder for StyleGAN Image Manipulation
|
2 |
+
<a href="https://arxiv.org/abs/2102.02766"><img src="https://img.shields.io/badge/arXiv-2008.00951-b31b1b.svg"></a>
|
3 |
+
<a href="https://opensource.org/licenses/MIT"><img src="https://img.shields.io/badge/License-MIT-yellow.svg"></a>
|
4 |
+
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](http://colab.research.google.com/github/omertov/encoder4editing/blob/main/notebooks/inference_playground.ipynb)
|
5 |
+
|
6 |
+
> Recently, there has been a surge of diverse methods for performing image editing by employing pre-trained unconditional generators. Applying these methods on real images, however, remains a challenge, as it necessarily requires the inversion of the images into their latent space. To successfully invert a real image, one needs to find a latent code that reconstructs the input image accurately, and more importantly, allows for its meaningful manipulation. In this paper, we carefully study the latent space of StyleGAN, the state-of-the-art unconditional generator. We identify and analyze the existence of a distortion-editability tradeoff and a distortion-perception tradeoff within the StyleGAN latent space. We then suggest two principles for designing encoders in a manner that allows one to control the proximity of the inversions to regions that StyleGAN was originally trained on. We present an encoder based on our two principles that is specifically designed for facilitating editing on real images by balancing these tradeoffs. By evaluating its performance qualitatively and quantitatively on numerous challenging domains, including cars and horses, we show that our inversion method, followed by common editing techniques, achieves superior real-image editing quality, with only a small reconstruction accuracy drop.
|
7 |
+
|
8 |
+
<p align="center">
|
9 |
+
<img src="docs/teaser.jpg" width="800px"/>
|
10 |
+
</p>
|
11 |
+
|
12 |
+
## Description
|
13 |
+
Official Implementation of "<a href="https://arxiv.org/abs/2102.02766">Designing an Encoder for StyleGAN Image Manipulation</a>" paper for both training and evaluation.
|
14 |
+
The e4e encoder is specifically designed to complement existing image manipulation techniques performed over StyleGAN's latent space.
|
15 |
+
|
16 |
+
## Recent Updates
|
17 |
+
`2021.03.25`: Add pose editing direction.
|
18 |
+
|
19 |
+
## Getting Started
|
20 |
+
### Prerequisites
|
21 |
+
- Linux or macOS
|
22 |
+
- NVIDIA GPU + CUDA CuDNN (CPU may be possible with some modifications, but is not inherently supported)
|
23 |
+
- Python 3
|
24 |
+
|
25 |
+
### Installation
|
26 |
+
- Clone the repository:
|
27 |
+
```
|
28 |
+
git clone https://github.com/omertov/encoder4editing.git
|
29 |
+
cd encoder4editing
|
30 |
+
```
|
31 |
+
- Dependencies:
|
32 |
+
We recommend running this repository using [Anaconda](https://docs.anaconda.com/anaconda/install/).
|
33 |
+
All dependencies for defining the environment are provided in `environment/e4e_env.yaml`.
|
34 |
+
|
35 |
+
### Inference Notebook
|
36 |
+
We provide a Jupyter notebook found in `notebooks/inference_playground.ipynb` that allows one to encode and perform several editings on real images using StyleGAN.
|
37 |
+
|
38 |
+
### Pretrained Models
|
39 |
+
Please download the pre-trained models from the following links. Each e4e model contains the entire pSp framework architecture, including the encoder and decoder weights.
|
40 |
+
| Path | Description
|
41 |
+
| :--- | :----------
|
42 |
+
|[FFHQ Inversion](https://drive.google.com/file/d/1cUv_reLE6k3604or78EranS7XzuVMWeO/view?usp=sharing) | FFHQ e4e encoder.
|
43 |
+
|[Cars Inversion](https://drive.google.com/file/d/17faPqBce2m1AQeLCLHUVXaDfxMRU2QcV/view?usp=sharing) | Cars e4e encoder.
|
44 |
+
|[Horse Inversion](https://drive.google.com/file/d/1TkLLnuX86B_BMo2ocYD0kX9kWh53rUVX/view?usp=sharing) | Horse e4e encoder.
|
45 |
+
|[Church Inversion](https://drive.google.com/file/d/1-L0ZdnQLwtdy6-A_Ccgq5uNJGTqE7qBa/view?usp=sharing) | Church e4e encoder.
|
46 |
+
|
47 |
+
If you wish to use one of the pretrained models for training or inference, you may do so using the flag `--checkpoint_path`.
|
48 |
+
|
49 |
+
In addition, we provide various auxiliary models needed for training your own e4e model from scratch.
|
50 |
+
| Path | Description
|
51 |
+
| :--- | :----------
|
52 |
+
|[FFHQ StyleGAN](https://drive.google.com/file/d/1EM87UquaoQmk17Q8d5kYIAHqu0dkYqdT/view?usp=sharing) | StyleGAN model pretrained on FFHQ taken from [rosinality](https://github.com/rosinality/stylegan2-pytorch) with 1024x1024 output resolution.
|
53 |
+
|[IR-SE50 Model](https://drive.google.com/file/d/1KW7bjndL3QG3sxBbZxreGHigcCCpsDgn/view?usp=sharing) | Pretrained IR-SE50 model taken from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) for use in our ID loss during training.
|
54 |
+
|[MOCOv2 Model](https://drive.google.com/file/d/18rLcNGdteX5LwT7sv_F7HWr12HpVEzVe/view?usp=sharing) | Pretrained ResNet-50 model trained using MOCOv2 for use in our simmilarity loss for domains other then human faces during training.
|
55 |
+
|
56 |
+
By default, we assume that all auxiliary models are downloaded and saved to the directory `pretrained_models`. However, you may use your own paths by changing the necessary values in `configs/path_configs.py`.
|
57 |
+
|
58 |
+
## Training
|
59 |
+
To train the e4e encoder, make sure the paths to the required models, as well as training and testing data is configured in `configs/path_configs.py` and `configs/data_configs.py`.
|
60 |
+
#### **Training the e4e Encoder**
|
61 |
+
```
|
62 |
+
python scripts/train.py \
|
63 |
+
--dataset_type cars_encode \
|
64 |
+
--exp_dir new/experiment/directory \
|
65 |
+
--start_from_latent_avg \
|
66 |
+
--use_w_pool \
|
67 |
+
--w_discriminator_lambda 0.1 \
|
68 |
+
--progressive_start 20000 \
|
69 |
+
--id_lambda 0.5 \
|
70 |
+
--val_interval 10000 \
|
71 |
+
--max_steps 200000 \
|
72 |
+
--stylegan_size 512 \
|
73 |
+
--stylegan_weights path/to/pretrained/stylegan.pt \
|
74 |
+
--workers 8 \
|
75 |
+
--batch_size 8 \
|
76 |
+
--test_batch_size 4 \
|
77 |
+
--test_workers 4
|
78 |
+
```
|
79 |
+
|
80 |
+
#### Training on your own dataset
|
81 |
+
In order to train the e4e encoder on a custom dataset, perform the following adjustments:
|
82 |
+
1. Insert the paths to your train and test data into the `dataset_paths` variable defined in `configs/paths_config.py`:
|
83 |
+
```
|
84 |
+
dataset_paths = {
|
85 |
+
'my_train_data': '/path/to/train/images/directory',
|
86 |
+
'my_test_data': '/path/to/test/images/directory'
|
87 |
+
}
|
88 |
+
```
|
89 |
+
2. Configure a new dataset under the DATASETS variable defined in `configs/data_configs.py`:
|
90 |
+
```
|
91 |
+
DATASETS = {
|
92 |
+
'my_data_encode': {
|
93 |
+
'transforms': transforms_config.EncodeTransforms,
|
94 |
+
'train_source_root': dataset_paths['my_train_data'],
|
95 |
+
'train_target_root': dataset_paths['my_train_data'],
|
96 |
+
'test_source_root': dataset_paths['my_test_data'],
|
97 |
+
'test_target_root': dataset_paths['my_test_data']
|
98 |
+
}
|
99 |
+
}
|
100 |
+
```
|
101 |
+
Refer to `configs/transforms_config.py` for the transformations applied to the train and test images during training.
|
102 |
+
|
103 |
+
3. Finally, run a training session with `--dataset_type my_data_encode`.
|
104 |
+
|
105 |
+
## Inference
|
106 |
+
Having trained your model, you can use `scripts/inference.py` to apply the model on a set of images.
|
107 |
+
For example,
|
108 |
+
```
|
109 |
+
python scripts/inference.py \
|
110 |
+
--images_dir=/path/to/images/directory \
|
111 |
+
--save_dir=/path/to/saving/directory \
|
112 |
+
path/to/checkpoint.pt
|
113 |
+
```
|
114 |
+
|
115 |
+
## Latent Editing Consistency (LEC)
|
116 |
+
As described in the paper, we suggest a new metric, Latent Editing Consistency (LEC), for evaluating the encoder's
|
117 |
+
performance.
|
118 |
+
We provide an example for calculating the metric over the FFHQ StyleGAN using the aging editing direction in
|
119 |
+
`metrics/LEC.py`.
|
120 |
+
|
121 |
+
To run the example:
|
122 |
+
```
|
123 |
+
cd metrics
|
124 |
+
python LEC.py \
|
125 |
+
--images_dir=/path/to/images/directory \
|
126 |
+
path/to/checkpoint.pt
|
127 |
+
```
|
128 |
+
|
129 |
+
## Acknowledgments
|
130 |
+
This code borrows heavily from [pixel2style2pixel](https://github.com/eladrich/pixel2style2pixel)
|
131 |
+
|
132 |
+
## Citation
|
133 |
+
If you use this code for your research, please cite our paper <a href="https://arxiv.org/abs/2102.02766">Designing an Encoder for StyleGAN Image Manipulation</a>:
|
134 |
+
|
135 |
+
```
|
136 |
+
@article{tov2021designing,
|
137 |
+
title={Designing an Encoder for StyleGAN Image Manipulation},
|
138 |
+
author={Tov, Omer and Alaluf, Yuval and Nitzan, Yotam and Patashnik, Or and Cohen-Or, Daniel},
|
139 |
+
journal={arXiv preprint arXiv:2102.02766},
|
140 |
+
year={2021}
|
141 |
+
}
|
142 |
+
```
|
e4e/configs/__init__.py
ADDED
File without changes
|
e4e/configs/data_configs.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from configs import transforms_config
|
2 |
+
from configs.paths_config import dataset_paths
|
3 |
+
|
4 |
+
|
5 |
+
DATASETS = {
|
6 |
+
'ffhq_encode': {
|
7 |
+
'transforms': transforms_config.EncodeTransforms,
|
8 |
+
'train_source_root': dataset_paths['ffhq'],
|
9 |
+
'train_target_root': dataset_paths['ffhq'],
|
10 |
+
'test_source_root': dataset_paths['celeba_test'],
|
11 |
+
'test_target_root': dataset_paths['celeba_test'],
|
12 |
+
},
|
13 |
+
'cars_encode': {
|
14 |
+
'transforms': transforms_config.CarsEncodeTransforms,
|
15 |
+
'train_source_root': dataset_paths['cars_train'],
|
16 |
+
'train_target_root': dataset_paths['cars_train'],
|
17 |
+
'test_source_root': dataset_paths['cars_test'],
|
18 |
+
'test_target_root': dataset_paths['cars_test'],
|
19 |
+
},
|
20 |
+
'horse_encode': {
|
21 |
+
'transforms': transforms_config.EncodeTransforms,
|
22 |
+
'train_source_root': dataset_paths['horse_train'],
|
23 |
+
'train_target_root': dataset_paths['horse_train'],
|
24 |
+
'test_source_root': dataset_paths['horse_test'],
|
25 |
+
'test_target_root': dataset_paths['horse_test'],
|
26 |
+
},
|
27 |
+
'church_encode': {
|
28 |
+
'transforms': transforms_config.EncodeTransforms,
|
29 |
+
'train_source_root': dataset_paths['church_train'],
|
30 |
+
'train_target_root': dataset_paths['church_train'],
|
31 |
+
'test_source_root': dataset_paths['church_test'],
|
32 |
+
'test_target_root': dataset_paths['church_test'],
|
33 |
+
},
|
34 |
+
'cats_encode': {
|
35 |
+
'transforms': transforms_config.EncodeTransforms,
|
36 |
+
'train_source_root': dataset_paths['cats_train'],
|
37 |
+
'train_target_root': dataset_paths['cats_train'],
|
38 |
+
'test_source_root': dataset_paths['cats_test'],
|
39 |
+
'test_target_root': dataset_paths['cats_test'],
|
40 |
+
}
|
41 |
+
}
|
e4e/configs/paths_config.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset_paths = {
|
2 |
+
# Face Datasets (In the paper: FFHQ - train, CelebAHQ - test)
|
3 |
+
'ffhq': '',
|
4 |
+
'celeba_test': '',
|
5 |
+
|
6 |
+
# Cars Dataset (In the paper: Stanford cars)
|
7 |
+
'cars_train': '',
|
8 |
+
'cars_test': '',
|
9 |
+
|
10 |
+
# Horse Dataset (In the paper: LSUN Horse)
|
11 |
+
'horse_train': '',
|
12 |
+
'horse_test': '',
|
13 |
+
|
14 |
+
# Church Dataset (In the paper: LSUN Church)
|
15 |
+
'church_train': '',
|
16 |
+
'church_test': '',
|
17 |
+
|
18 |
+
# Cats Dataset (In the paper: LSUN Cat)
|
19 |
+
'cats_train': '',
|
20 |
+
'cats_test': ''
|
21 |
+
}
|
22 |
+
|
23 |
+
model_paths = {
|
24 |
+
'stylegan_ffhq': 'pretrained_models/stylegan2-ffhq-config-f.pt',
|
25 |
+
'ir_se50': 'pretrained_models/model_ir_se50.pth',
|
26 |
+
'shape_predictor': 'pretrained_models/shape_predictor_68_face_landmarks.dat',
|
27 |
+
'moco': 'pretrained_models/moco_v2_800ep_pretrain.pth'
|
28 |
+
}
|
e4e/configs/transforms_config.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import abstractmethod
|
2 |
+
import torchvision.transforms as transforms
|
3 |
+
|
4 |
+
|
5 |
+
class TransformsConfig(object):
|
6 |
+
|
7 |
+
def __init__(self, opts):
|
8 |
+
self.opts = opts
|
9 |
+
|
10 |
+
@abstractmethod
|
11 |
+
def get_transforms(self):
|
12 |
+
pass
|
13 |
+
|
14 |
+
|
15 |
+
class EncodeTransforms(TransformsConfig):
|
16 |
+
|
17 |
+
def __init__(self, opts):
|
18 |
+
super(EncodeTransforms, self).__init__(opts)
|
19 |
+
|
20 |
+
def get_transforms(self):
|
21 |
+
transforms_dict = {
|
22 |
+
'transform_gt_train': transforms.Compose([
|
23 |
+
transforms.Resize((256, 256)),
|
24 |
+
transforms.RandomHorizontalFlip(0.5),
|
25 |
+
transforms.ToTensor(),
|
26 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
|
27 |
+
'transform_source': None,
|
28 |
+
'transform_test': transforms.Compose([
|
29 |
+
transforms.Resize((256, 256)),
|
30 |
+
transforms.ToTensor(),
|
31 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
|
32 |
+
'transform_inference': transforms.Compose([
|
33 |
+
transforms.Resize((256, 256)),
|
34 |
+
transforms.ToTensor(),
|
35 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
|
36 |
+
}
|
37 |
+
return transforms_dict
|
38 |
+
|
39 |
+
|
40 |
+
class CarsEncodeTransforms(TransformsConfig):
|
41 |
+
|
42 |
+
def __init__(self, opts):
|
43 |
+
super(CarsEncodeTransforms, self).__init__(opts)
|
44 |
+
|
45 |
+
def get_transforms(self):
|
46 |
+
transforms_dict = {
|
47 |
+
'transform_gt_train': transforms.Compose([
|
48 |
+
transforms.Resize((192, 256)),
|
49 |
+
transforms.RandomHorizontalFlip(0.5),
|
50 |
+
transforms.ToTensor(),
|
51 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
|
52 |
+
'transform_source': None,
|
53 |
+
'transform_test': transforms.Compose([
|
54 |
+
transforms.Resize((192, 256)),
|
55 |
+
transforms.ToTensor(),
|
56 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
|
57 |
+
'transform_inference': transforms.Compose([
|
58 |
+
transforms.Resize((192, 256)),
|
59 |
+
transforms.ToTensor(),
|
60 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
|
61 |
+
}
|
62 |
+
return transforms_dict
|
e4e/models/__init__.py
ADDED
File without changes
|
e4e/models/discriminator.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
|
3 |
+
|
4 |
+
class LatentCodesDiscriminator(nn.Module):
|
5 |
+
def __init__(self, style_dim, n_mlp):
|
6 |
+
super().__init__()
|
7 |
+
|
8 |
+
self.style_dim = style_dim
|
9 |
+
|
10 |
+
layers = []
|
11 |
+
for i in range(n_mlp-1):
|
12 |
+
layers.append(
|
13 |
+
nn.Linear(style_dim, style_dim)
|
14 |
+
)
|
15 |
+
layers.append(nn.LeakyReLU(0.2))
|
16 |
+
layers.append(nn.Linear(512, 1))
|
17 |
+
self.mlp = nn.Sequential(*layers)
|
18 |
+
|
19 |
+
def forward(self, w):
|
20 |
+
return self.mlp(w)
|
e4e/models/encoders/__init__.py
ADDED
File without changes
|
e4e/models/encoders/helpers.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import namedtuple
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module
|
5 |
+
|
6 |
+
"""
|
7 |
+
ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
|
8 |
+
"""
|
9 |
+
|
10 |
+
|
11 |
+
class Flatten(Module):
|
12 |
+
def forward(self, input):
|
13 |
+
return input.view(input.size(0), -1)
|
14 |
+
|
15 |
+
|
16 |
+
def l2_norm(input, axis=1):
|
17 |
+
norm = torch.norm(input, 2, axis, True)
|
18 |
+
output = torch.div(input, norm)
|
19 |
+
return output
|
20 |
+
|
21 |
+
|
22 |
+
class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
|
23 |
+
""" A named tuple describing a ResNet block. """
|
24 |
+
|
25 |
+
|
26 |
+
def get_block(in_channel, depth, num_units, stride=2):
|
27 |
+
return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
|
28 |
+
|
29 |
+
|
30 |
+
def get_blocks(num_layers):
|
31 |
+
if num_layers == 50:
|
32 |
+
blocks = [
|
33 |
+
get_block(in_channel=64, depth=64, num_units=3),
|
34 |
+
get_block(in_channel=64, depth=128, num_units=4),
|
35 |
+
get_block(in_channel=128, depth=256, num_units=14),
|
36 |
+
get_block(in_channel=256, depth=512, num_units=3)
|
37 |
+
]
|
38 |
+
elif num_layers == 100:
|
39 |
+
blocks = [
|
40 |
+
get_block(in_channel=64, depth=64, num_units=3),
|
41 |
+
get_block(in_channel=64, depth=128, num_units=13),
|
42 |
+
get_block(in_channel=128, depth=256, num_units=30),
|
43 |
+
get_block(in_channel=256, depth=512, num_units=3)
|
44 |
+
]
|
45 |
+
elif num_layers == 152:
|
46 |
+
blocks = [
|
47 |
+
get_block(in_channel=64, depth=64, num_units=3),
|
48 |
+
get_block(in_channel=64, depth=128, num_units=8),
|
49 |
+
get_block(in_channel=128, depth=256, num_units=36),
|
50 |
+
get_block(in_channel=256, depth=512, num_units=3)
|
51 |
+
]
|
52 |
+
else:
|
53 |
+
raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers))
|
54 |
+
return blocks
|
55 |
+
|
56 |
+
|
57 |
+
class SEModule(Module):
|
58 |
+
def __init__(self, channels, reduction):
|
59 |
+
super(SEModule, self).__init__()
|
60 |
+
self.avg_pool = AdaptiveAvgPool2d(1)
|
61 |
+
self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
|
62 |
+
self.relu = ReLU(inplace=True)
|
63 |
+
self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
|
64 |
+
self.sigmoid = Sigmoid()
|
65 |
+
|
66 |
+
def forward(self, x):
|
67 |
+
module_input = x
|
68 |
+
x = self.avg_pool(x)
|
69 |
+
x = self.fc1(x)
|
70 |
+
x = self.relu(x)
|
71 |
+
x = self.fc2(x)
|
72 |
+
x = self.sigmoid(x)
|
73 |
+
return module_input * x
|
74 |
+
|
75 |
+
|
76 |
+
class bottleneck_IR(Module):
|
77 |
+
def __init__(self, in_channel, depth, stride):
|
78 |
+
super(bottleneck_IR, self).__init__()
|
79 |
+
if in_channel == depth:
|
80 |
+
self.shortcut_layer = MaxPool2d(1, stride)
|
81 |
+
else:
|
82 |
+
self.shortcut_layer = Sequential(
|
83 |
+
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
|
84 |
+
BatchNorm2d(depth)
|
85 |
+
)
|
86 |
+
self.res_layer = Sequential(
|
87 |
+
BatchNorm2d(in_channel),
|
88 |
+
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),
|
89 |
+
Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)
|
90 |
+
)
|
91 |
+
|
92 |
+
def forward(self, x):
|
93 |
+
shortcut = self.shortcut_layer(x)
|
94 |
+
res = self.res_layer(x)
|
95 |
+
return res + shortcut
|
96 |
+
|
97 |
+
|
98 |
+
class bottleneck_IR_SE(Module):
|
99 |
+
def __init__(self, in_channel, depth, stride):
|
100 |
+
super(bottleneck_IR_SE, self).__init__()
|
101 |
+
if in_channel == depth:
|
102 |
+
self.shortcut_layer = MaxPool2d(1, stride)
|
103 |
+
else:
|
104 |
+
self.shortcut_layer = Sequential(
|
105 |
+
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
|
106 |
+
BatchNorm2d(depth)
|
107 |
+
)
|
108 |
+
self.res_layer = Sequential(
|
109 |
+
BatchNorm2d(in_channel),
|
110 |
+
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
|
111 |
+
PReLU(depth),
|
112 |
+
Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
|
113 |
+
BatchNorm2d(depth),
|
114 |
+
SEModule(depth, 16)
|
115 |
+
)
|
116 |
+
|
117 |
+
def forward(self, x):
|
118 |
+
shortcut = self.shortcut_layer(x)
|
119 |
+
res = self.res_layer(x)
|
120 |
+
return res + shortcut
|
121 |
+
|
122 |
+
|
123 |
+
def _upsample_add(x, y):
|
124 |
+
"""Upsample and add two feature maps.
|
125 |
+
Args:
|
126 |
+
x: (Variable) top feature map to be upsampled.
|
127 |
+
y: (Variable) lateral feature map.
|
128 |
+
Returns:
|
129 |
+
(Variable) added feature map.
|
130 |
+
Note in PyTorch, when input size is odd, the upsampled feature map
|
131 |
+
with `F.upsample(..., scale_factor=2, mode='nearest')`
|
132 |
+
maybe not equal to the lateral feature map size.
|
133 |
+
e.g.
|
134 |
+
original input size: [N,_,15,15] ->
|
135 |
+
conv2d feature map size: [N,_,8,8] ->
|
136 |
+
upsampled feature map size: [N,_,16,16]
|
137 |
+
So we choose bilinear upsample which supports arbitrary output sizes.
|
138 |
+
"""
|
139 |
+
_, _, H, W = y.size()
|
140 |
+
return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) + y
|
e4e/models/encoders/model_irse.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module
|
2 |
+
from e4e.models.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm
|
3 |
+
|
4 |
+
"""
|
5 |
+
Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
|
6 |
+
"""
|
7 |
+
|
8 |
+
|
9 |
+
class Backbone(Module):
|
10 |
+
def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True):
|
11 |
+
super(Backbone, self).__init__()
|
12 |
+
assert input_size in [112, 224], "input_size should be 112 or 224"
|
13 |
+
assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
|
14 |
+
assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se"
|
15 |
+
blocks = get_blocks(num_layers)
|
16 |
+
if mode == 'ir':
|
17 |
+
unit_module = bottleneck_IR
|
18 |
+
elif mode == 'ir_se':
|
19 |
+
unit_module = bottleneck_IR_SE
|
20 |
+
self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
|
21 |
+
BatchNorm2d(64),
|
22 |
+
PReLU(64))
|
23 |
+
if input_size == 112:
|
24 |
+
self.output_layer = Sequential(BatchNorm2d(512),
|
25 |
+
Dropout(drop_ratio),
|
26 |
+
Flatten(),
|
27 |
+
Linear(512 * 7 * 7, 512),
|
28 |
+
BatchNorm1d(512, affine=affine))
|
29 |
+
else:
|
30 |
+
self.output_layer = Sequential(BatchNorm2d(512),
|
31 |
+
Dropout(drop_ratio),
|
32 |
+
Flatten(),
|
33 |
+
Linear(512 * 14 * 14, 512),
|
34 |
+
BatchNorm1d(512, affine=affine))
|
35 |
+
|
36 |
+
modules = []
|
37 |
+
for block in blocks:
|
38 |
+
for bottleneck in block:
|
39 |
+
modules.append(unit_module(bottleneck.in_channel,
|
40 |
+
bottleneck.depth,
|
41 |
+
bottleneck.stride))
|
42 |
+
self.body = Sequential(*modules)
|
43 |
+
|
44 |
+
def forward(self, x):
|
45 |
+
x = self.input_layer(x)
|
46 |
+
x = self.body(x)
|
47 |
+
x = self.output_layer(x)
|
48 |
+
return l2_norm(x)
|
49 |
+
|
50 |
+
|
51 |
+
def IR_50(input_size):
|
52 |
+
"""Constructs a ir-50 model."""
|
53 |
+
model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False)
|
54 |
+
return model
|
55 |
+
|
56 |
+
|
57 |
+
def IR_101(input_size):
|
58 |
+
"""Constructs a ir-101 model."""
|
59 |
+
model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False)
|
60 |
+
return model
|
61 |
+
|
62 |
+
|
63 |
+
def IR_152(input_size):
|
64 |
+
"""Constructs a ir-152 model."""
|
65 |
+
model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False)
|
66 |
+
return model
|
67 |
+
|
68 |
+
|
69 |
+
def IR_SE_50(input_size):
|
70 |
+
"""Constructs a ir_se-50 model."""
|
71 |
+
model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False)
|
72 |
+
return model
|
73 |
+
|
74 |
+
|
75 |
+
def IR_SE_101(input_size):
|
76 |
+
"""Constructs a ir_se-101 model."""
|
77 |
+
model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False)
|
78 |
+
return model
|
79 |
+
|
80 |
+
|
81 |
+
def IR_SE_152(input_size):
|
82 |
+
"""Constructs a ir_se-152 model."""
|
83 |
+
model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False)
|
84 |
+
return model
|
e4e/models/encoders/psp_encoders.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum
|
2 |
+
import math
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
from torch.nn import Conv2d, BatchNorm2d, PReLU, Sequential, Module
|
7 |
+
|
8 |
+
from e4e.models.encoders.helpers import get_blocks, bottleneck_IR, bottleneck_IR_SE, _upsample_add
|
9 |
+
from e4e.models.stylegan2.model import EqualLinear
|
10 |
+
|
11 |
+
|
12 |
+
class ProgressiveStage(Enum):
|
13 |
+
WTraining = 0
|
14 |
+
Delta1Training = 1
|
15 |
+
Delta2Training = 2
|
16 |
+
Delta3Training = 3
|
17 |
+
Delta4Training = 4
|
18 |
+
Delta5Training = 5
|
19 |
+
Delta6Training = 6
|
20 |
+
Delta7Training = 7
|
21 |
+
Delta8Training = 8
|
22 |
+
Delta9Training = 9
|
23 |
+
Delta10Training = 10
|
24 |
+
Delta11Training = 11
|
25 |
+
Delta12Training = 12
|
26 |
+
Delta13Training = 13
|
27 |
+
Delta14Training = 14
|
28 |
+
Delta15Training = 15
|
29 |
+
Delta16Training = 16
|
30 |
+
Delta17Training = 17
|
31 |
+
Inference = 18
|
32 |
+
|
33 |
+
|
34 |
+
class GradualStyleBlock(Module):
|
35 |
+
def __init__(self, in_c, out_c, spatial):
|
36 |
+
super(GradualStyleBlock, self).__init__()
|
37 |
+
self.out_c = out_c
|
38 |
+
self.spatial = spatial
|
39 |
+
num_pools = int(np.log2(spatial))
|
40 |
+
modules = []
|
41 |
+
modules += [Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1),
|
42 |
+
nn.LeakyReLU()]
|
43 |
+
for i in range(num_pools - 1):
|
44 |
+
modules += [
|
45 |
+
Conv2d(out_c, out_c, kernel_size=3, stride=2, padding=1),
|
46 |
+
nn.LeakyReLU()
|
47 |
+
]
|
48 |
+
self.convs = nn.Sequential(*modules)
|
49 |
+
self.linear = EqualLinear(out_c, out_c, lr_mul=1)
|
50 |
+
|
51 |
+
def forward(self, x):
|
52 |
+
x = self.convs(x)
|
53 |
+
x = x.view(-1, self.out_c)
|
54 |
+
x = self.linear(x)
|
55 |
+
return x
|
56 |
+
|
57 |
+
|
58 |
+
class GradualStyleEncoder(Module):
|
59 |
+
def __init__(self, num_layers, mode='ir', opts=None):
|
60 |
+
super(GradualStyleEncoder, self).__init__()
|
61 |
+
assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
|
62 |
+
assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
|
63 |
+
blocks = get_blocks(num_layers)
|
64 |
+
if mode == 'ir':
|
65 |
+
unit_module = bottleneck_IR
|
66 |
+
elif mode == 'ir_se':
|
67 |
+
unit_module = bottleneck_IR_SE
|
68 |
+
self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
|
69 |
+
BatchNorm2d(64),
|
70 |
+
PReLU(64))
|
71 |
+
modules = []
|
72 |
+
for block in blocks:
|
73 |
+
for bottleneck in block:
|
74 |
+
modules.append(unit_module(bottleneck.in_channel,
|
75 |
+
bottleneck.depth,
|
76 |
+
bottleneck.stride))
|
77 |
+
self.body = Sequential(*modules)
|
78 |
+
|
79 |
+
self.styles = nn.ModuleList()
|
80 |
+
log_size = int(math.log(opts.stylegan_size, 2))
|
81 |
+
self.style_count = 2 * log_size - 2
|
82 |
+
self.coarse_ind = 3
|
83 |
+
self.middle_ind = 7
|
84 |
+
for i in range(self.style_count):
|
85 |
+
if i < self.coarse_ind:
|
86 |
+
style = GradualStyleBlock(512, 512, 16)
|
87 |
+
elif i < self.middle_ind:
|
88 |
+
style = GradualStyleBlock(512, 512, 32)
|
89 |
+
else:
|
90 |
+
style = GradualStyleBlock(512, 512, 64)
|
91 |
+
self.styles.append(style)
|
92 |
+
self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0)
|
93 |
+
self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0)
|
94 |
+
|
95 |
+
def forward(self, x):
|
96 |
+
x = self.input_layer(x)
|
97 |
+
|
98 |
+
latents = []
|
99 |
+
modulelist = list(self.body._modules.values())
|
100 |
+
for i, l in enumerate(modulelist):
|
101 |
+
x = l(x)
|
102 |
+
if i == 6:
|
103 |
+
c1 = x
|
104 |
+
elif i == 20:
|
105 |
+
c2 = x
|
106 |
+
elif i == 23:
|
107 |
+
c3 = x
|
108 |
+
|
109 |
+
for j in range(self.coarse_ind):
|
110 |
+
latents.append(self.styles[j](c3))
|
111 |
+
|
112 |
+
p2 = _upsample_add(c3, self.latlayer1(c2))
|
113 |
+
for j in range(self.coarse_ind, self.middle_ind):
|
114 |
+
latents.append(self.styles[j](p2))
|
115 |
+
|
116 |
+
p1 = _upsample_add(p2, self.latlayer2(c1))
|
117 |
+
for j in range(self.middle_ind, self.style_count):
|
118 |
+
latents.append(self.styles[j](p1))
|
119 |
+
|
120 |
+
out = torch.stack(latents, dim=1)
|
121 |
+
return out
|
122 |
+
|
123 |
+
|
124 |
+
class Encoder4Editing(Module):
|
125 |
+
def __init__(self, num_layers, mode='ir', opts=None):
|
126 |
+
super(Encoder4Editing, self).__init__()
|
127 |
+
assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
|
128 |
+
assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
|
129 |
+
blocks = get_blocks(num_layers)
|
130 |
+
if mode == 'ir':
|
131 |
+
unit_module = bottleneck_IR
|
132 |
+
elif mode == 'ir_se':
|
133 |
+
unit_module = bottleneck_IR_SE
|
134 |
+
self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
|
135 |
+
BatchNorm2d(64),
|
136 |
+
PReLU(64))
|
137 |
+
modules = []
|
138 |
+
for block in blocks:
|
139 |
+
for bottleneck in block:
|
140 |
+
modules.append(unit_module(bottleneck.in_channel,
|
141 |
+
bottleneck.depth,
|
142 |
+
bottleneck.stride))
|
143 |
+
self.body = Sequential(*modules)
|
144 |
+
|
145 |
+
self.styles = nn.ModuleList()
|
146 |
+
log_size = int(math.log(opts.stylegan_size, 2))
|
147 |
+
self.style_count = 2 * log_size - 2
|
148 |
+
self.coarse_ind = 3
|
149 |
+
self.middle_ind = 7
|
150 |
+
|
151 |
+
for i in range(self.style_count):
|
152 |
+
if i < self.coarse_ind:
|
153 |
+
style = GradualStyleBlock(512, 512, 16)
|
154 |
+
elif i < self.middle_ind:
|
155 |
+
style = GradualStyleBlock(512, 512, 32)
|
156 |
+
else:
|
157 |
+
style = GradualStyleBlock(512, 512, 64)
|
158 |
+
self.styles.append(style)
|
159 |
+
|
160 |
+
self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0)
|
161 |
+
self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0)
|
162 |
+
|
163 |
+
self.progressive_stage = ProgressiveStage.Inference
|
164 |
+
|
165 |
+
def get_deltas_starting_dimensions(self):
|
166 |
+
''' Get a list of the initial dimension of every delta from which it is applied '''
|
167 |
+
return list(range(self.style_count)) # Each dimension has a delta applied to it
|
168 |
+
|
169 |
+
def set_progressive_stage(self, new_stage: ProgressiveStage):
|
170 |
+
self.progressive_stage = new_stage
|
171 |
+
print('Changed progressive stage to: ', new_stage)
|
172 |
+
|
173 |
+
def forward(self, x):
|
174 |
+
x = self.input_layer(x)
|
175 |
+
|
176 |
+
modulelist = list(self.body._modules.values())
|
177 |
+
for i, l in enumerate(modulelist):
|
178 |
+
x = l(x)
|
179 |
+
if i == 6:
|
180 |
+
c1 = x
|
181 |
+
elif i == 20:
|
182 |
+
c2 = x
|
183 |
+
elif i == 23:
|
184 |
+
c3 = x
|
185 |
+
|
186 |
+
# Infer main W and duplicate it
|
187 |
+
w0 = self.styles[0](c3)
|
188 |
+
w = w0.repeat(self.style_count, 1, 1).permute(1, 0, 2)
|
189 |
+
stage = self.progressive_stage.value
|
190 |
+
features = c3
|
191 |
+
for i in range(1, min(stage + 1, self.style_count)): # Infer additional deltas
|
192 |
+
if i == self.coarse_ind:
|
193 |
+
p2 = _upsample_add(c3, self.latlayer1(c2)) # FPN's middle features
|
194 |
+
features = p2
|
195 |
+
elif i == self.middle_ind:
|
196 |
+
p1 = _upsample_add(p2, self.latlayer2(c1)) # FPN's fine features
|
197 |
+
features = p1
|
198 |
+
delta_i = self.styles[i](features)
|
199 |
+
w[:, i] += delta_i
|
200 |
+
return w
|
e4e/models/latent_codes_pool.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
class LatentCodesPool:
|
6 |
+
"""This class implements latent codes buffer that stores previously generated w latent codes.
|
7 |
+
This buffer enables us to update discriminators using a history of generated w's
|
8 |
+
rather than the ones produced by the latest encoder.
|
9 |
+
"""
|
10 |
+
|
11 |
+
def __init__(self, pool_size):
|
12 |
+
"""Initialize the ImagePool class
|
13 |
+
Parameters:
|
14 |
+
pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created
|
15 |
+
"""
|
16 |
+
self.pool_size = pool_size
|
17 |
+
if self.pool_size > 0: # create an empty pool
|
18 |
+
self.num_ws = 0
|
19 |
+
self.ws = []
|
20 |
+
|
21 |
+
def query(self, ws):
|
22 |
+
"""Return w's from the pool.
|
23 |
+
Parameters:
|
24 |
+
ws: the latest generated w's from the generator
|
25 |
+
Returns w's from the buffer.
|
26 |
+
By 50/100, the buffer will return input w's.
|
27 |
+
By 50/100, the buffer will return w's previously stored in the buffer,
|
28 |
+
and insert the current w's to the buffer.
|
29 |
+
"""
|
30 |
+
if self.pool_size == 0: # if the buffer size is 0, do nothing
|
31 |
+
return ws
|
32 |
+
return_ws = []
|
33 |
+
for w in ws: # ws.shape: (batch, 512) or (batch, n_latent, 512)
|
34 |
+
# w = torch.unsqueeze(image.data, 0)
|
35 |
+
if w.ndim == 2:
|
36 |
+
i = random.randint(0, len(w) - 1) # apply a random latent index as a candidate
|
37 |
+
w = w[i]
|
38 |
+
self.handle_w(w, return_ws)
|
39 |
+
return_ws = torch.stack(return_ws, 0) # collect all the images and return
|
40 |
+
return return_ws
|
41 |
+
|
42 |
+
def handle_w(self, w, return_ws):
|
43 |
+
if self.num_ws < self.pool_size: # if the buffer is not full; keep inserting current codes to the buffer
|
44 |
+
self.num_ws = self.num_ws + 1
|
45 |
+
self.ws.append(w)
|
46 |
+
return_ws.append(w)
|
47 |
+
else:
|
48 |
+
p = random.uniform(0, 1)
|
49 |
+
if p > 0.5: # by 50% chance, the buffer will return a previously stored latent code, and insert the current code into the buffer
|
50 |
+
random_id = random.randint(0, self.pool_size - 1) # randint is inclusive
|
51 |
+
tmp = self.ws[random_id].clone()
|
52 |
+
self.ws[random_id] = w
|
53 |
+
return_ws.append(tmp)
|
54 |
+
else: # by another 50% chance, the buffer will return the current image
|
55 |
+
return_ws.append(w)
|
e4e/models/psp.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib
|
2 |
+
|
3 |
+
matplotlib.use('Agg')
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
from e4e.models.encoders import psp_encoders
|
7 |
+
from e4e.models.stylegan2.model import Generator
|
8 |
+
from e4e.configs.paths_config import model_paths
|
9 |
+
|
10 |
+
|
11 |
+
def get_keys(d, name):
|
12 |
+
if 'state_dict' in d:
|
13 |
+
d = d['state_dict']
|
14 |
+
d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name}
|
15 |
+
return d_filt
|
16 |
+
|
17 |
+
|
18 |
+
class pSp(nn.Module):
|
19 |
+
|
20 |
+
def __init__(self, opts, device):
|
21 |
+
super(pSp, self).__init__()
|
22 |
+
self.opts = opts
|
23 |
+
self.device = device
|
24 |
+
# Define architecture
|
25 |
+
self.encoder = self.set_encoder()
|
26 |
+
self.decoder = Generator(opts.stylegan_size, 512, 8, channel_multiplier=2)
|
27 |
+
self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
|
28 |
+
# Load weights if needed
|
29 |
+
self.load_weights()
|
30 |
+
|
31 |
+
def set_encoder(self):
|
32 |
+
if self.opts.encoder_type == 'GradualStyleEncoder':
|
33 |
+
encoder = psp_encoders.GradualStyleEncoder(50, 'ir_se', self.opts)
|
34 |
+
elif self.opts.encoder_type == 'Encoder4Editing':
|
35 |
+
encoder = psp_encoders.Encoder4Editing(50, 'ir_se', self.opts)
|
36 |
+
else:
|
37 |
+
raise Exception('{} is not a valid encoders'.format(self.opts.encoder_type))
|
38 |
+
return encoder
|
39 |
+
|
40 |
+
def load_weights(self):
|
41 |
+
if self.opts.checkpoint_path is not None:
|
42 |
+
print('Loading e4e over the pSp framework from checkpoint: {}'.format(self.opts.checkpoint_path))
|
43 |
+
ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu')
|
44 |
+
self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=True)
|
45 |
+
self.decoder.load_state_dict(get_keys(ckpt, 'decoder'), strict=True)
|
46 |
+
self.__load_latent_avg(ckpt)
|
47 |
+
else:
|
48 |
+
print('Loading encoders weights from irse50!')
|
49 |
+
encoder_ckpt = torch.load(model_paths['ir_se50'])
|
50 |
+
self.encoder.load_state_dict(encoder_ckpt, strict=False)
|
51 |
+
print('Loading decoder weights from pretrained!')
|
52 |
+
ckpt = torch.load(self.opts.stylegan_weights)
|
53 |
+
self.decoder.load_state_dict(ckpt['g_ema'], strict=False)
|
54 |
+
self.__load_latent_avg(ckpt, repeat=self.encoder.style_count)
|
55 |
+
|
56 |
+
def forward(self, x, resize=True, latent_mask=None, input_code=False, randomize_noise=True,
|
57 |
+
inject_latent=None, return_latents=False, alpha=None):
|
58 |
+
if input_code:
|
59 |
+
codes = x
|
60 |
+
else:
|
61 |
+
codes = self.encoder(x)
|
62 |
+
# normalize with respect to the center of an average face
|
63 |
+
if self.opts.start_from_latent_avg:
|
64 |
+
if codes.ndim == 2:
|
65 |
+
codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)[:, 0, :]
|
66 |
+
else:
|
67 |
+
codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)
|
68 |
+
|
69 |
+
if latent_mask is not None:
|
70 |
+
for i in latent_mask:
|
71 |
+
if inject_latent is not None:
|
72 |
+
if alpha is not None:
|
73 |
+
codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i]
|
74 |
+
else:
|
75 |
+
codes[:, i] = inject_latent[:, i]
|
76 |
+
else:
|
77 |
+
codes[:, i] = 0
|
78 |
+
|
79 |
+
input_is_latent = not input_code
|
80 |
+
images, result_latent = self.decoder([codes],
|
81 |
+
input_is_latent=input_is_latent,
|
82 |
+
randomize_noise=randomize_noise,
|
83 |
+
return_latents=return_latents)
|
84 |
+
|
85 |
+
if resize:
|
86 |
+
images = self.face_pool(images)
|
87 |
+
|
88 |
+
if return_latents:
|
89 |
+
return images, result_latent
|
90 |
+
else:
|
91 |
+
return images
|
92 |
+
|
93 |
+
def __load_latent_avg(self, ckpt, repeat=None):
|
94 |
+
if 'latent_avg' in ckpt:
|
95 |
+
self.latent_avg = ckpt['latent_avg'].to(self.device)
|
96 |
+
if repeat is not None:
|
97 |
+
self.latent_avg = self.latent_avg.repeat(repeat, 1)
|
98 |
+
else:
|
99 |
+
self.latent_avg = None
|
e4e/models/stylegan2/__init__.py
ADDED
File without changes
|
e4e/models/stylegan2/model.py
ADDED
@@ -0,0 +1,678 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
if torch.cuda.is_available():
|
8 |
+
from op.fused_act import FusedLeakyReLU, fused_leaky_relu
|
9 |
+
from op.upfirdn2d import upfirdn2d
|
10 |
+
else:
|
11 |
+
from op.fused_act_cpu import FusedLeakyReLU, fused_leaky_relu
|
12 |
+
from op.upfirdn2d_cpu import upfirdn2d
|
13 |
+
|
14 |
+
|
15 |
+
class PixelNorm(nn.Module):
|
16 |
+
def __init__(self):
|
17 |
+
super().__init__()
|
18 |
+
|
19 |
+
def forward(self, input):
|
20 |
+
return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
|
21 |
+
|
22 |
+
|
23 |
+
def make_kernel(k):
|
24 |
+
k = torch.tensor(k, dtype=torch.float32)
|
25 |
+
|
26 |
+
if k.ndim == 1:
|
27 |
+
k = k[None, :] * k[:, None]
|
28 |
+
|
29 |
+
k /= k.sum()
|
30 |
+
|
31 |
+
return k
|
32 |
+
|
33 |
+
|
34 |
+
class Upsample(nn.Module):
|
35 |
+
def __init__(self, kernel, factor=2):
|
36 |
+
super().__init__()
|
37 |
+
|
38 |
+
self.factor = factor
|
39 |
+
kernel = make_kernel(kernel) * (factor ** 2)
|
40 |
+
self.register_buffer('kernel', kernel)
|
41 |
+
|
42 |
+
p = kernel.shape[0] - factor
|
43 |
+
|
44 |
+
pad0 = (p + 1) // 2 + factor - 1
|
45 |
+
pad1 = p // 2
|
46 |
+
|
47 |
+
self.pad = (pad0, pad1)
|
48 |
+
|
49 |
+
def forward(self, input):
|
50 |
+
out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
|
51 |
+
|
52 |
+
return out
|
53 |
+
|
54 |
+
|
55 |
+
class Downsample(nn.Module):
|
56 |
+
def __init__(self, kernel, factor=2):
|
57 |
+
super().__init__()
|
58 |
+
|
59 |
+
self.factor = factor
|
60 |
+
kernel = make_kernel(kernel)
|
61 |
+
self.register_buffer('kernel', kernel)
|
62 |
+
|
63 |
+
p = kernel.shape[0] - factor
|
64 |
+
|
65 |
+
pad0 = (p + 1) // 2
|
66 |
+
pad1 = p // 2
|
67 |
+
|
68 |
+
self.pad = (pad0, pad1)
|
69 |
+
|
70 |
+
def forward(self, input):
|
71 |
+
out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
|
72 |
+
|
73 |
+
return out
|
74 |
+
|
75 |
+
|
76 |
+
class Blur(nn.Module):
|
77 |
+
def __init__(self, kernel, pad, upsample_factor=1):
|
78 |
+
super().__init__()
|
79 |
+
|
80 |
+
kernel = make_kernel(kernel)
|
81 |
+
|
82 |
+
if upsample_factor > 1:
|
83 |
+
kernel = kernel * (upsample_factor ** 2)
|
84 |
+
|
85 |
+
self.register_buffer('kernel', kernel)
|
86 |
+
|
87 |
+
self.pad = pad
|
88 |
+
|
89 |
+
def forward(self, input):
|
90 |
+
out = upfirdn2d(input, self.kernel, pad=self.pad)
|
91 |
+
|
92 |
+
return out
|
93 |
+
|
94 |
+
|
95 |
+
class EqualConv2d(nn.Module):
|
96 |
+
def __init__(
|
97 |
+
self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
|
98 |
+
):
|
99 |
+
super().__init__()
|
100 |
+
|
101 |
+
self.weight = nn.Parameter(
|
102 |
+
torch.randn(out_channel, in_channel, kernel_size, kernel_size)
|
103 |
+
)
|
104 |
+
self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
|
105 |
+
|
106 |
+
self.stride = stride
|
107 |
+
self.padding = padding
|
108 |
+
|
109 |
+
if bias:
|
110 |
+
self.bias = nn.Parameter(torch.zeros(out_channel))
|
111 |
+
|
112 |
+
else:
|
113 |
+
self.bias = None
|
114 |
+
|
115 |
+
def forward(self, input):
|
116 |
+
out = F.conv2d(
|
117 |
+
input,
|
118 |
+
self.weight * self.scale,
|
119 |
+
bias=self.bias,
|
120 |
+
stride=self.stride,
|
121 |
+
padding=self.padding,
|
122 |
+
)
|
123 |
+
|
124 |
+
return out
|
125 |
+
|
126 |
+
def __repr__(self):
|
127 |
+
return (
|
128 |
+
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
|
129 |
+
f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
|
130 |
+
)
|
131 |
+
|
132 |
+
|
133 |
+
class EqualLinear(nn.Module):
|
134 |
+
def __init__(
|
135 |
+
self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
|
136 |
+
):
|
137 |
+
super().__init__()
|
138 |
+
|
139 |
+
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
|
140 |
+
|
141 |
+
if bias:
|
142 |
+
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
|
143 |
+
|
144 |
+
else:
|
145 |
+
self.bias = None
|
146 |
+
|
147 |
+
self.activation = activation
|
148 |
+
|
149 |
+
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
|
150 |
+
self.lr_mul = lr_mul
|
151 |
+
|
152 |
+
def forward(self, input):
|
153 |
+
if self.activation:
|
154 |
+
out = F.linear(input, self.weight * self.scale)
|
155 |
+
out = fused_leaky_relu(out, self.bias * self.lr_mul)
|
156 |
+
|
157 |
+
else:
|
158 |
+
out = F.linear(
|
159 |
+
input, self.weight * self.scale, bias=self.bias * self.lr_mul
|
160 |
+
)
|
161 |
+
|
162 |
+
return out
|
163 |
+
|
164 |
+
def __repr__(self):
|
165 |
+
return (
|
166 |
+
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
|
167 |
+
)
|
168 |
+
|
169 |
+
|
170 |
+
class ScaledLeakyReLU(nn.Module):
|
171 |
+
def __init__(self, negative_slope=0.2):
|
172 |
+
super().__init__()
|
173 |
+
|
174 |
+
self.negative_slope = negative_slope
|
175 |
+
|
176 |
+
def forward(self, input):
|
177 |
+
out = F.leaky_relu(input, negative_slope=self.negative_slope)
|
178 |
+
|
179 |
+
return out * math.sqrt(2)
|
180 |
+
|
181 |
+
|
182 |
+
class ModulatedConv2d(nn.Module):
|
183 |
+
def __init__(
|
184 |
+
self,
|
185 |
+
in_channel,
|
186 |
+
out_channel,
|
187 |
+
kernel_size,
|
188 |
+
style_dim,
|
189 |
+
demodulate=True,
|
190 |
+
upsample=False,
|
191 |
+
downsample=False,
|
192 |
+
blur_kernel=[1, 3, 3, 1],
|
193 |
+
):
|
194 |
+
super().__init__()
|
195 |
+
|
196 |
+
self.eps = 1e-8
|
197 |
+
self.kernel_size = kernel_size
|
198 |
+
self.in_channel = in_channel
|
199 |
+
self.out_channel = out_channel
|
200 |
+
self.upsample = upsample
|
201 |
+
self.downsample = downsample
|
202 |
+
|
203 |
+
if upsample:
|
204 |
+
factor = 2
|
205 |
+
p = (len(blur_kernel) - factor) - (kernel_size - 1)
|
206 |
+
pad0 = (p + 1) // 2 + factor - 1
|
207 |
+
pad1 = p // 2 + 1
|
208 |
+
|
209 |
+
self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
|
210 |
+
|
211 |
+
if downsample:
|
212 |
+
factor = 2
|
213 |
+
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
214 |
+
pad0 = (p + 1) // 2
|
215 |
+
pad1 = p // 2
|
216 |
+
|
217 |
+
self.blur = Blur(blur_kernel, pad=(pad0, pad1))
|
218 |
+
|
219 |
+
fan_in = in_channel * kernel_size ** 2
|
220 |
+
self.scale = 1 / math.sqrt(fan_in)
|
221 |
+
self.padding = kernel_size // 2
|
222 |
+
|
223 |
+
self.weight = nn.Parameter(
|
224 |
+
torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
|
225 |
+
)
|
226 |
+
|
227 |
+
self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
|
228 |
+
|
229 |
+
self.demodulate = demodulate
|
230 |
+
|
231 |
+
def __repr__(self):
|
232 |
+
return (
|
233 |
+
f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '
|
234 |
+
f'upsample={self.upsample}, downsample={self.downsample})'
|
235 |
+
)
|
236 |
+
|
237 |
+
def forward(self, input, style):
|
238 |
+
batch, in_channel, height, width = input.shape
|
239 |
+
|
240 |
+
style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
|
241 |
+
weight = self.scale * self.weight * style
|
242 |
+
|
243 |
+
if self.demodulate:
|
244 |
+
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
|
245 |
+
weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
|
246 |
+
|
247 |
+
weight = weight.view(
|
248 |
+
batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
|
249 |
+
)
|
250 |
+
|
251 |
+
if self.upsample:
|
252 |
+
input = input.view(1, batch * in_channel, height, width)
|
253 |
+
weight = weight.view(
|
254 |
+
batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
|
255 |
+
)
|
256 |
+
weight = weight.transpose(1, 2).reshape(
|
257 |
+
batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
|
258 |
+
)
|
259 |
+
out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
|
260 |
+
_, _, height, width = out.shape
|
261 |
+
out = out.view(batch, self.out_channel, height, width)
|
262 |
+
out = self.blur(out)
|
263 |
+
|
264 |
+
elif self.downsample:
|
265 |
+
input = self.blur(input)
|
266 |
+
_, _, height, width = input.shape
|
267 |
+
input = input.view(1, batch * in_channel, height, width)
|
268 |
+
out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
|
269 |
+
_, _, height, width = out.shape
|
270 |
+
out = out.view(batch, self.out_channel, height, width)
|
271 |
+
|
272 |
+
else:
|
273 |
+
input = input.view(1, batch * in_channel, height, width)
|
274 |
+
out = F.conv2d(input, weight, padding=self.padding, groups=batch)
|
275 |
+
_, _, height, width = out.shape
|
276 |
+
out = out.view(batch, self.out_channel, height, width)
|
277 |
+
|
278 |
+
return out
|
279 |
+
|
280 |
+
|
281 |
+
class NoiseInjection(nn.Module):
|
282 |
+
def __init__(self):
|
283 |
+
super().__init__()
|
284 |
+
|
285 |
+
self.weight = nn.Parameter(torch.zeros(1))
|
286 |
+
|
287 |
+
def forward(self, image, noise=None):
|
288 |
+
if noise is None:
|
289 |
+
batch, _, height, width = image.shape
|
290 |
+
noise = image.new_empty(batch, 1, height, width).normal_()
|
291 |
+
|
292 |
+
return image + self.weight * noise
|
293 |
+
|
294 |
+
|
295 |
+
class ConstantInput(nn.Module):
|
296 |
+
def __init__(self, channel, size=4):
|
297 |
+
super().__init__()
|
298 |
+
|
299 |
+
self.input = nn.Parameter(torch.randn(1, channel, size, size))
|
300 |
+
|
301 |
+
def forward(self, input):
|
302 |
+
batch = input.shape[0]
|
303 |
+
out = self.input.repeat(batch, 1, 1, 1)
|
304 |
+
|
305 |
+
return out
|
306 |
+
|
307 |
+
|
308 |
+
class StyledConv(nn.Module):
|
309 |
+
def __init__(
|
310 |
+
self,
|
311 |
+
in_channel,
|
312 |
+
out_channel,
|
313 |
+
kernel_size,
|
314 |
+
style_dim,
|
315 |
+
upsample=False,
|
316 |
+
blur_kernel=[1, 3, 3, 1],
|
317 |
+
demodulate=True,
|
318 |
+
):
|
319 |
+
super().__init__()
|
320 |
+
|
321 |
+
self.conv = ModulatedConv2d(
|
322 |
+
in_channel,
|
323 |
+
out_channel,
|
324 |
+
kernel_size,
|
325 |
+
style_dim,
|
326 |
+
upsample=upsample,
|
327 |
+
blur_kernel=blur_kernel,
|
328 |
+
demodulate=demodulate,
|
329 |
+
)
|
330 |
+
|
331 |
+
self.noise = NoiseInjection()
|
332 |
+
# self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
|
333 |
+
# self.activate = ScaledLeakyReLU(0.2)
|
334 |
+
self.activate = FusedLeakyReLU(out_channel)
|
335 |
+
|
336 |
+
def forward(self, input, style, noise=None):
|
337 |
+
out = self.conv(input, style)
|
338 |
+
out = self.noise(out, noise=noise)
|
339 |
+
# out = out + self.bias
|
340 |
+
out = self.activate(out)
|
341 |
+
|
342 |
+
return out
|
343 |
+
|
344 |
+
|
345 |
+
class ToRGB(nn.Module):
|
346 |
+
def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
|
347 |
+
super().__init__()
|
348 |
+
|
349 |
+
if upsample:
|
350 |
+
self.upsample = Upsample(blur_kernel)
|
351 |
+
|
352 |
+
self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
|
353 |
+
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
|
354 |
+
|
355 |
+
def forward(self, input, style, skip=None):
|
356 |
+
out = self.conv(input, style)
|
357 |
+
out = out + self.bias
|
358 |
+
|
359 |
+
if skip is not None:
|
360 |
+
skip = self.upsample(skip)
|
361 |
+
|
362 |
+
out = out + skip
|
363 |
+
|
364 |
+
return out
|
365 |
+
|
366 |
+
|
367 |
+
class Generator(nn.Module):
|
368 |
+
def __init__(
|
369 |
+
self,
|
370 |
+
size,
|
371 |
+
style_dim,
|
372 |
+
n_mlp,
|
373 |
+
channel_multiplier=2,
|
374 |
+
blur_kernel=[1, 3, 3, 1],
|
375 |
+
lr_mlp=0.01,
|
376 |
+
):
|
377 |
+
super().__init__()
|
378 |
+
|
379 |
+
self.size = size
|
380 |
+
|
381 |
+
self.style_dim = style_dim
|
382 |
+
|
383 |
+
layers = [PixelNorm()]
|
384 |
+
|
385 |
+
for i in range(n_mlp):
|
386 |
+
layers.append(
|
387 |
+
EqualLinear(
|
388 |
+
style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu'
|
389 |
+
)
|
390 |
+
)
|
391 |
+
|
392 |
+
self.style = nn.Sequential(*layers)
|
393 |
+
|
394 |
+
self.channels = {
|
395 |
+
4: 512,
|
396 |
+
8: 512,
|
397 |
+
16: 512,
|
398 |
+
32: 512,
|
399 |
+
64: 256 * channel_multiplier,
|
400 |
+
128: 128 * channel_multiplier,
|
401 |
+
256: 64 * channel_multiplier,
|
402 |
+
512: 32 * channel_multiplier,
|
403 |
+
1024: 16 * channel_multiplier,
|
404 |
+
}
|
405 |
+
|
406 |
+
self.input = ConstantInput(self.channels[4])
|
407 |
+
self.conv1 = StyledConv(
|
408 |
+
self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
|
409 |
+
)
|
410 |
+
self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
|
411 |
+
|
412 |
+
self.log_size = int(math.log(size, 2))
|
413 |
+
self.num_layers = (self.log_size - 2) * 2 + 1
|
414 |
+
|
415 |
+
self.convs = nn.ModuleList()
|
416 |
+
self.upsamples = nn.ModuleList()
|
417 |
+
self.to_rgbs = nn.ModuleList()
|
418 |
+
self.noises = nn.Module()
|
419 |
+
|
420 |
+
in_channel = self.channels[4]
|
421 |
+
|
422 |
+
for layer_idx in range(self.num_layers):
|
423 |
+
res = (layer_idx + 5) // 2
|
424 |
+
shape = [1, 1, 2 ** res, 2 ** res]
|
425 |
+
self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape))
|
426 |
+
|
427 |
+
for i in range(3, self.log_size + 1):
|
428 |
+
out_channel = self.channels[2 ** i]
|
429 |
+
|
430 |
+
self.convs.append(
|
431 |
+
StyledConv(
|
432 |
+
in_channel,
|
433 |
+
out_channel,
|
434 |
+
3,
|
435 |
+
style_dim,
|
436 |
+
upsample=True,
|
437 |
+
blur_kernel=blur_kernel,
|
438 |
+
)
|
439 |
+
)
|
440 |
+
|
441 |
+
self.convs.append(
|
442 |
+
StyledConv(
|
443 |
+
out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
|
444 |
+
)
|
445 |
+
)
|
446 |
+
|
447 |
+
self.to_rgbs.append(ToRGB(out_channel, style_dim))
|
448 |
+
|
449 |
+
in_channel = out_channel
|
450 |
+
|
451 |
+
self.n_latent = self.log_size * 2 - 2
|
452 |
+
|
453 |
+
def make_noise(self):
|
454 |
+
device = self.input.input.device
|
455 |
+
|
456 |
+
noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
|
457 |
+
|
458 |
+
for i in range(3, self.log_size + 1):
|
459 |
+
for _ in range(2):
|
460 |
+
noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
|
461 |
+
|
462 |
+
return noises
|
463 |
+
|
464 |
+
def mean_latent(self, n_latent):
|
465 |
+
latent_in = torch.randn(
|
466 |
+
n_latent, self.style_dim, device=self.input.input.device
|
467 |
+
)
|
468 |
+
latent = self.style(latent_in).mean(0, keepdim=True)
|
469 |
+
|
470 |
+
return latent
|
471 |
+
|
472 |
+
def get_latent(self, input):
|
473 |
+
return self.style(input)
|
474 |
+
|
475 |
+
def forward(
|
476 |
+
self,
|
477 |
+
styles,
|
478 |
+
return_latents=False,
|
479 |
+
return_features=False,
|
480 |
+
inject_index=None,
|
481 |
+
truncation=1,
|
482 |
+
truncation_latent=None,
|
483 |
+
input_is_latent=False,
|
484 |
+
noise=None,
|
485 |
+
randomize_noise=True,
|
486 |
+
):
|
487 |
+
if not input_is_latent:
|
488 |
+
styles = [self.style(s) for s in styles]
|
489 |
+
|
490 |
+
if noise is None:
|
491 |
+
if randomize_noise:
|
492 |
+
noise = [None] * self.num_layers
|
493 |
+
else:
|
494 |
+
noise = [
|
495 |
+
getattr(self.noises, f'noise_{i}') for i in range(self.num_layers)
|
496 |
+
]
|
497 |
+
|
498 |
+
if truncation < 1:
|
499 |
+
style_t = []
|
500 |
+
|
501 |
+
for style in styles:
|
502 |
+
style_t.append(
|
503 |
+
truncation_latent + truncation * (style - truncation_latent)
|
504 |
+
)
|
505 |
+
|
506 |
+
styles = style_t
|
507 |
+
|
508 |
+
if len(styles) < 2:
|
509 |
+
inject_index = self.n_latent
|
510 |
+
|
511 |
+
if styles[0].ndim < 3:
|
512 |
+
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
513 |
+
else:
|
514 |
+
latent = styles[0]
|
515 |
+
|
516 |
+
else:
|
517 |
+
if inject_index is None:
|
518 |
+
inject_index = random.randint(1, self.n_latent - 1)
|
519 |
+
|
520 |
+
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
521 |
+
latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
|
522 |
+
|
523 |
+
latent = torch.cat([latent, latent2], 1)
|
524 |
+
|
525 |
+
out = self.input(latent)
|
526 |
+
out = self.conv1(out, latent[:, 0], noise=noise[0])
|
527 |
+
|
528 |
+
skip = self.to_rgb1(out, latent[:, 1])
|
529 |
+
|
530 |
+
i = 1
|
531 |
+
for conv1, conv2, noise1, noise2, to_rgb in zip(
|
532 |
+
self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
|
533 |
+
):
|
534 |
+
out = conv1(out, latent[:, i], noise=noise1)
|
535 |
+
out = conv2(out, latent[:, i + 1], noise=noise2)
|
536 |
+
skip = to_rgb(out, latent[:, i + 2], skip)
|
537 |
+
|
538 |
+
i += 2
|
539 |
+
|
540 |
+
image = skip
|
541 |
+
|
542 |
+
if return_latents:
|
543 |
+
return image, latent
|
544 |
+
elif return_features:
|
545 |
+
return image, out
|
546 |
+
else:
|
547 |
+
return image, None
|
548 |
+
|
549 |
+
|
550 |
+
class ConvLayer(nn.Sequential):
|
551 |
+
def __init__(
|
552 |
+
self,
|
553 |
+
in_channel,
|
554 |
+
out_channel,
|
555 |
+
kernel_size,
|
556 |
+
downsample=False,
|
557 |
+
blur_kernel=[1, 3, 3, 1],
|
558 |
+
bias=True,
|
559 |
+
activate=True,
|
560 |
+
):
|
561 |
+
layers = []
|
562 |
+
|
563 |
+
if downsample:
|
564 |
+
factor = 2
|
565 |
+
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
566 |
+
pad0 = (p + 1) // 2
|
567 |
+
pad1 = p // 2
|
568 |
+
|
569 |
+
layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
|
570 |
+
|
571 |
+
stride = 2
|
572 |
+
self.padding = 0
|
573 |
+
|
574 |
+
else:
|
575 |
+
stride = 1
|
576 |
+
self.padding = kernel_size // 2
|
577 |
+
|
578 |
+
layers.append(
|
579 |
+
EqualConv2d(
|
580 |
+
in_channel,
|
581 |
+
out_channel,
|
582 |
+
kernel_size,
|
583 |
+
padding=self.padding,
|
584 |
+
stride=stride,
|
585 |
+
bias=bias and not activate,
|
586 |
+
)
|
587 |
+
)
|
588 |
+
|
589 |
+
if activate:
|
590 |
+
if bias:
|
591 |
+
layers.append(FusedLeakyReLU(out_channel))
|
592 |
+
|
593 |
+
else:
|
594 |
+
layers.append(ScaledLeakyReLU(0.2))
|
595 |
+
|
596 |
+
super().__init__(*layers)
|
597 |
+
|
598 |
+
|
599 |
+
class ResBlock(nn.Module):
|
600 |
+
def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
|
601 |
+
super().__init__()
|
602 |
+
|
603 |
+
self.conv1 = ConvLayer(in_channel, in_channel, 3)
|
604 |
+
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
|
605 |
+
|
606 |
+
self.skip = ConvLayer(
|
607 |
+
in_channel, out_channel, 1, downsample=True, activate=False, bias=False
|
608 |
+
)
|
609 |
+
|
610 |
+
def forward(self, input):
|
611 |
+
out = self.conv1(input)
|
612 |
+
out = self.conv2(out)
|
613 |
+
|
614 |
+
skip = self.skip(input)
|
615 |
+
out = (out + skip) / math.sqrt(2)
|
616 |
+
|
617 |
+
return out
|
618 |
+
|
619 |
+
|
620 |
+
class Discriminator(nn.Module):
|
621 |
+
def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
|
622 |
+
super().__init__()
|
623 |
+
|
624 |
+
channels = {
|
625 |
+
4: 512,
|
626 |
+
8: 512,
|
627 |
+
16: 512,
|
628 |
+
32: 512,
|
629 |
+
64: 256 * channel_multiplier,
|
630 |
+
128: 128 * channel_multiplier,
|
631 |
+
256: 64 * channel_multiplier,
|
632 |
+
512: 32 * channel_multiplier,
|
633 |
+
1024: 16 * channel_multiplier,
|
634 |
+
}
|
635 |
+
|
636 |
+
convs = [ConvLayer(3, channels[size], 1)]
|
637 |
+
|
638 |
+
log_size = int(math.log(size, 2))
|
639 |
+
|
640 |
+
in_channel = channels[size]
|
641 |
+
|
642 |
+
for i in range(log_size, 2, -1):
|
643 |
+
out_channel = channels[2 ** (i - 1)]
|
644 |
+
|
645 |
+
convs.append(ResBlock(in_channel, out_channel, blur_kernel))
|
646 |
+
|
647 |
+
in_channel = out_channel
|
648 |
+
|
649 |
+
self.convs = nn.Sequential(*convs)
|
650 |
+
|
651 |
+
self.stddev_group = 4
|
652 |
+
self.stddev_feat = 1
|
653 |
+
|
654 |
+
self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
|
655 |
+
self.final_linear = nn.Sequential(
|
656 |
+
EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'),
|
657 |
+
EqualLinear(channels[4], 1),
|
658 |
+
)
|
659 |
+
|
660 |
+
def forward(self, input):
|
661 |
+
out = self.convs(input)
|
662 |
+
|
663 |
+
batch, channel, height, width = out.shape
|
664 |
+
group = min(batch, self.stddev_group)
|
665 |
+
stddev = out.view(
|
666 |
+
group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
|
667 |
+
)
|
668 |
+
stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
|
669 |
+
stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
|
670 |
+
stddev = stddev.repeat(group, 1, height, width)
|
671 |
+
out = torch.cat([out, stddev], 1)
|
672 |
+
|
673 |
+
out = self.final_conv(out)
|
674 |
+
|
675 |
+
out = out.view(batch, -1)
|
676 |
+
out = self.final_linear(out)
|
677 |
+
|
678 |
+
return out
|
e4e/models/stylegan2/op/__init__.py
ADDED
File without changes
|
e4e/models/stylegan2/op/fused_act.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.autograd import Function
|
6 |
+
from torch.utils.cpp_extension import load
|
7 |
+
|
8 |
+
module_path = os.path.dirname(__file__)
|
9 |
+
fused = load(
|
10 |
+
'fused',
|
11 |
+
sources=[
|
12 |
+
os.path.join(module_path, 'fused_bias_act.cpp'),
|
13 |
+
os.path.join(module_path, 'fused_bias_act_kernel.cu'),
|
14 |
+
],
|
15 |
+
)
|
16 |
+
|
17 |
+
|
18 |
+
class FusedLeakyReLUFunctionBackward(Function):
|
19 |
+
@staticmethod
|
20 |
+
def forward(ctx, grad_output, out, negative_slope, scale):
|
21 |
+
ctx.save_for_backward(out)
|
22 |
+
ctx.negative_slope = negative_slope
|
23 |
+
ctx.scale = scale
|
24 |
+
|
25 |
+
empty = grad_output.new_empty(0)
|
26 |
+
|
27 |
+
grad_input = fused.fused_bias_act(
|
28 |
+
grad_output, empty, out, 3, 1, negative_slope, scale
|
29 |
+
)
|
30 |
+
|
31 |
+
dim = [0]
|
32 |
+
|
33 |
+
if grad_input.ndim > 2:
|
34 |
+
dim += list(range(2, grad_input.ndim))
|
35 |
+
|
36 |
+
grad_bias = grad_input.sum(dim).detach()
|
37 |
+
|
38 |
+
return grad_input, grad_bias
|
39 |
+
|
40 |
+
@staticmethod
|
41 |
+
def backward(ctx, gradgrad_input, gradgrad_bias):
|
42 |
+
out, = ctx.saved_tensors
|
43 |
+
gradgrad_out = fused.fused_bias_act(
|
44 |
+
gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
|
45 |
+
)
|
46 |
+
|
47 |
+
return gradgrad_out, None, None, None
|
48 |
+
|
49 |
+
|
50 |
+
class FusedLeakyReLUFunction(Function):
|
51 |
+
@staticmethod
|
52 |
+
def forward(ctx, input, bias, negative_slope, scale):
|
53 |
+
empty = input.new_empty(0)
|
54 |
+
out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
|
55 |
+
ctx.save_for_backward(out)
|
56 |
+
ctx.negative_slope = negative_slope
|
57 |
+
ctx.scale = scale
|
58 |
+
|
59 |
+
return out
|
60 |
+
|
61 |
+
@staticmethod
|
62 |
+
def backward(ctx, grad_output):
|
63 |
+
out, = ctx.saved_tensors
|
64 |
+
|
65 |
+
grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
|
66 |
+
grad_output, out, ctx.negative_slope, ctx.scale
|
67 |
+
)
|
68 |
+
|
69 |
+
return grad_input, grad_bias, None, None
|
70 |
+
|
71 |
+
|
72 |
+
class FusedLeakyReLU(nn.Module):
|
73 |
+
def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
|
74 |
+
super().__init__()
|
75 |
+
|
76 |
+
self.bias = nn.Parameter(torch.zeros(channel))
|
77 |
+
self.negative_slope = negative_slope
|
78 |
+
self.scale = scale
|
79 |
+
|
80 |
+
def forward(self, input):
|
81 |
+
return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
|
82 |
+
|
83 |
+
|
84 |
+
def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
|
85 |
+
return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
|
e4e/models/stylegan2/op/fused_bias_act.cpp
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
|
3 |
+
|
4 |
+
torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
|
5 |
+
int act, int grad, float alpha, float scale);
|
6 |
+
|
7 |
+
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
8 |
+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
9 |
+
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
10 |
+
|
11 |
+
torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
|
12 |
+
int act, int grad, float alpha, float scale) {
|
13 |
+
CHECK_CUDA(input);
|
14 |
+
CHECK_CUDA(bias);
|
15 |
+
|
16 |
+
return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
|
17 |
+
}
|
18 |
+
|
19 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
20 |
+
m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
|
21 |
+
}
|
e4e/models/stylegan2/op/fused_bias_act_kernel.cu
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
2 |
+
//
|
3 |
+
// This work is made available under the Nvidia Source Code License-NC.
|
4 |
+
// To view a copy of this license, visit
|
5 |
+
// https://nvlabs.github.io/stylegan2/license.html
|
6 |
+
|
7 |
+
#include <torch/types.h>
|
8 |
+
|
9 |
+
#include <ATen/ATen.h>
|
10 |
+
#include <ATen/AccumulateType.h>
|
11 |
+
#include <ATen/cuda/CUDAContext.h>
|
12 |
+
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
13 |
+
|
14 |
+
#include <cuda.h>
|
15 |
+
#include <cuda_runtime.h>
|
16 |
+
|
17 |
+
|
18 |
+
template <typename scalar_t>
|
19 |
+
static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
|
20 |
+
int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
|
21 |
+
int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
|
22 |
+
|
23 |
+
scalar_t zero = 0.0;
|
24 |
+
|
25 |
+
for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
|
26 |
+
scalar_t x = p_x[xi];
|
27 |
+
|
28 |
+
if (use_bias) {
|
29 |
+
x += p_b[(xi / step_b) % size_b];
|
30 |
+
}
|
31 |
+
|
32 |
+
scalar_t ref = use_ref ? p_ref[xi] : zero;
|
33 |
+
|
34 |
+
scalar_t y;
|
35 |
+
|
36 |
+
switch (act * 10 + grad) {
|
37 |
+
default:
|
38 |
+
case 10: y = x; break;
|
39 |
+
case 11: y = x; break;
|
40 |
+
case 12: y = 0.0; break;
|
41 |
+
|
42 |
+
case 30: y = (x > 0.0) ? x : x * alpha; break;
|
43 |
+
case 31: y = (ref > 0.0) ? x : x * alpha; break;
|
44 |
+
case 32: y = 0.0; break;
|
45 |
+
}
|
46 |
+
|
47 |
+
out[xi] = y * scale;
|
48 |
+
}
|
49 |
+
}
|
50 |
+
|
51 |
+
|
52 |
+
torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
|
53 |
+
int act, int grad, float alpha, float scale) {
|
54 |
+
int curDevice = -1;
|
55 |
+
cudaGetDevice(&curDevice);
|
56 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
|
57 |
+
|
58 |
+
auto x = input.contiguous();
|
59 |
+
auto b = bias.contiguous();
|
60 |
+
auto ref = refer.contiguous();
|
61 |
+
|
62 |
+
int use_bias = b.numel() ? 1 : 0;
|
63 |
+
int use_ref = ref.numel() ? 1 : 0;
|
64 |
+
|
65 |
+
int size_x = x.numel();
|
66 |
+
int size_b = b.numel();
|
67 |
+
int step_b = 1;
|
68 |
+
|
69 |
+
for (int i = 1 + 1; i < x.dim(); i++) {
|
70 |
+
step_b *= x.size(i);
|
71 |
+
}
|
72 |
+
|
73 |
+
int loop_x = 4;
|
74 |
+
int block_size = 4 * 32;
|
75 |
+
int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
|
76 |
+
|
77 |
+
auto y = torch::empty_like(x);
|
78 |
+
|
79 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
|
80 |
+
fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
|
81 |
+
y.data_ptr<scalar_t>(),
|
82 |
+
x.data_ptr<scalar_t>(),
|
83 |
+
b.data_ptr<scalar_t>(),
|
84 |
+
ref.data_ptr<scalar_t>(),
|
85 |
+
act,
|
86 |
+
grad,
|
87 |
+
alpha,
|
88 |
+
scale,
|
89 |
+
loop_x,
|
90 |
+
size_x,
|
91 |
+
step_b,
|
92 |
+
size_b,
|
93 |
+
use_bias,
|
94 |
+
use_ref
|
95 |
+
);
|
96 |
+
});
|
97 |
+
|
98 |
+
return y;
|
99 |
+
}
|
e4e/models/stylegan2/op/upfirdn2d.cpp
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
|
3 |
+
|
4 |
+
torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
|
5 |
+
int up_x, int up_y, int down_x, int down_y,
|
6 |
+
int pad_x0, int pad_x1, int pad_y0, int pad_y1);
|
7 |
+
|
8 |
+
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
9 |
+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
10 |
+
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
11 |
+
|
12 |
+
torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
|
13 |
+
int up_x, int up_y, int down_x, int down_y,
|
14 |
+
int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
|
15 |
+
CHECK_CUDA(input);
|
16 |
+
CHECK_CUDA(kernel);
|
17 |
+
|
18 |
+
return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
|
19 |
+
}
|
20 |
+
|
21 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
22 |
+
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
|
23 |
+
}
|
e4e/models/stylegan2/op/upfirdn2d.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch.autograd import Function
|
5 |
+
from torch.utils.cpp_extension import load
|
6 |
+
|
7 |
+
module_path = os.path.dirname(__file__)
|
8 |
+
upfirdn2d_op = load(
|
9 |
+
'upfirdn2d',
|
10 |
+
sources=[
|
11 |
+
os.path.join(module_path, 'upfirdn2d.cpp'),
|
12 |
+
os.path.join(module_path, 'upfirdn2d_kernel.cu'),
|
13 |
+
],
|
14 |
+
)
|
15 |
+
|
16 |
+
|
17 |
+
class UpFirDn2dBackward(Function):
|
18 |
+
@staticmethod
|
19 |
+
def forward(
|
20 |
+
ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
|
21 |
+
):
|
22 |
+
up_x, up_y = up
|
23 |
+
down_x, down_y = down
|
24 |
+
g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
|
25 |
+
|
26 |
+
grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
|
27 |
+
|
28 |
+
grad_input = upfirdn2d_op.upfirdn2d(
|
29 |
+
grad_output,
|
30 |
+
grad_kernel,
|
31 |
+
down_x,
|
32 |
+
down_y,
|
33 |
+
up_x,
|
34 |
+
up_y,
|
35 |
+
g_pad_x0,
|
36 |
+
g_pad_x1,
|
37 |
+
g_pad_y0,
|
38 |
+
g_pad_y1,
|
39 |
+
)
|
40 |
+
grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
|
41 |
+
|
42 |
+
ctx.save_for_backward(kernel)
|
43 |
+
|
44 |
+
pad_x0, pad_x1, pad_y0, pad_y1 = pad
|
45 |
+
|
46 |
+
ctx.up_x = up_x
|
47 |
+
ctx.up_y = up_y
|
48 |
+
ctx.down_x = down_x
|
49 |
+
ctx.down_y = down_y
|
50 |
+
ctx.pad_x0 = pad_x0
|
51 |
+
ctx.pad_x1 = pad_x1
|
52 |
+
ctx.pad_y0 = pad_y0
|
53 |
+
ctx.pad_y1 = pad_y1
|
54 |
+
ctx.in_size = in_size
|
55 |
+
ctx.out_size = out_size
|
56 |
+
|
57 |
+
return grad_input
|
58 |
+
|
59 |
+
@staticmethod
|
60 |
+
def backward(ctx, gradgrad_input):
|
61 |
+
kernel, = ctx.saved_tensors
|
62 |
+
|
63 |
+
gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
|
64 |
+
|
65 |
+
gradgrad_out = upfirdn2d_op.upfirdn2d(
|
66 |
+
gradgrad_input,
|
67 |
+
kernel,
|
68 |
+
ctx.up_x,
|
69 |
+
ctx.up_y,
|
70 |
+
ctx.down_x,
|
71 |
+
ctx.down_y,
|
72 |
+
ctx.pad_x0,
|
73 |
+
ctx.pad_x1,
|
74 |
+
ctx.pad_y0,
|
75 |
+
ctx.pad_y1,
|
76 |
+
)
|
77 |
+
# gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
|
78 |
+
gradgrad_out = gradgrad_out.view(
|
79 |
+
ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
|
80 |
+
)
|
81 |
+
|
82 |
+
return gradgrad_out, None, None, None, None, None, None, None, None
|
83 |
+
|
84 |
+
|
85 |
+
class UpFirDn2d(Function):
|
86 |
+
@staticmethod
|
87 |
+
def forward(ctx, input, kernel, up, down, pad):
|
88 |
+
up_x, up_y = up
|
89 |
+
down_x, down_y = down
|
90 |
+
pad_x0, pad_x1, pad_y0, pad_y1 = pad
|
91 |
+
|
92 |
+
kernel_h, kernel_w = kernel.shape
|
93 |
+
batch, channel, in_h, in_w = input.shape
|
94 |
+
ctx.in_size = input.shape
|
95 |
+
|
96 |
+
input = input.reshape(-1, in_h, in_w, 1)
|
97 |
+
|
98 |
+
ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
|
99 |
+
|
100 |
+
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
|
101 |
+
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
|
102 |
+
ctx.out_size = (out_h, out_w)
|
103 |
+
|
104 |
+
ctx.up = (up_x, up_y)
|
105 |
+
ctx.down = (down_x, down_y)
|
106 |
+
ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
|
107 |
+
|
108 |
+
g_pad_x0 = kernel_w - pad_x0 - 1
|
109 |
+
g_pad_y0 = kernel_h - pad_y0 - 1
|
110 |
+
g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
|
111 |
+
g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
|
112 |
+
|
113 |
+
ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
|
114 |
+
|
115 |
+
out = upfirdn2d_op.upfirdn2d(
|
116 |
+
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
|
117 |
+
)
|
118 |
+
# out = out.view(major, out_h, out_w, minor)
|
119 |
+
out = out.view(-1, channel, out_h, out_w)
|
120 |
+
|
121 |
+
return out
|
122 |
+
|
123 |
+
@staticmethod
|
124 |
+
def backward(ctx, grad_output):
|
125 |
+
kernel, grad_kernel = ctx.saved_tensors
|
126 |
+
|
127 |
+
grad_input = UpFirDn2dBackward.apply(
|
128 |
+
grad_output,
|
129 |
+
kernel,
|
130 |
+
grad_kernel,
|
131 |
+
ctx.up,
|
132 |
+
ctx.down,
|
133 |
+
ctx.pad,
|
134 |
+
ctx.g_pad,
|
135 |
+
ctx.in_size,
|
136 |
+
ctx.out_size,
|
137 |
+
)
|
138 |
+
|
139 |
+
return grad_input, None, None, None, None
|
140 |
+
|
141 |
+
|
142 |
+
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
|
143 |
+
out = UpFirDn2d.apply(
|
144 |
+
input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
|
145 |
+
)
|
146 |
+
|
147 |
+
return out
|
148 |
+
|
149 |
+
|
150 |
+
def upfirdn2d_native(
|
151 |
+
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
|
152 |
+
):
|
153 |
+
_, in_h, in_w, minor = input.shape
|
154 |
+
kernel_h, kernel_w = kernel.shape
|
155 |
+
|
156 |
+
out = input.view(-1, in_h, 1, in_w, 1, minor)
|
157 |
+
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
|
158 |
+
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
|
159 |
+
|
160 |
+
out = F.pad(
|
161 |
+
out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
|
162 |
+
)
|
163 |
+
out = out[
|
164 |
+
:,
|
165 |
+
max(-pad_y0, 0): out.shape[1] - max(-pad_y1, 0),
|
166 |
+
max(-pad_x0, 0): out.shape[2] - max(-pad_x1, 0),
|
167 |
+
:,
|
168 |
+
]
|
169 |
+
|
170 |
+
out = out.permute(0, 3, 1, 2)
|
171 |
+
out = out.reshape(
|
172 |
+
[-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
|
173 |
+
)
|
174 |
+
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
|
175 |
+
out = F.conv2d(out, w)
|
176 |
+
out = out.reshape(
|
177 |
+
-1,
|
178 |
+
minor,
|
179 |
+
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
|
180 |
+
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
|
181 |
+
)
|
182 |
+
out = out.permute(0, 2, 3, 1)
|
183 |
+
|
184 |
+
return out[:, ::down_y, ::down_x, :]
|
e4e/models/stylegan2/op/upfirdn2d_kernel.cu
ADDED
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
2 |
+
//
|
3 |
+
// This work is made available under the Nvidia Source Code License-NC.
|
4 |
+
// To view a copy of this license, visit
|
5 |
+
// https://nvlabs.github.io/stylegan2/license.html
|
6 |
+
|
7 |
+
#include <torch/types.h>
|
8 |
+
|
9 |
+
#include <ATen/ATen.h>
|
10 |
+
#include <ATen/AccumulateType.h>
|
11 |
+
#include <ATen/cuda/CUDAContext.h>
|
12 |
+
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
13 |
+
|
14 |
+
#include <cuda.h>
|
15 |
+
#include <cuda_runtime.h>
|
16 |
+
|
17 |
+
|
18 |
+
static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
|
19 |
+
int c = a / b;
|
20 |
+
|
21 |
+
if (c * b > a) {
|
22 |
+
c--;
|
23 |
+
}
|
24 |
+
|
25 |
+
return c;
|
26 |
+
}
|
27 |
+
|
28 |
+
|
29 |
+
struct UpFirDn2DKernelParams {
|
30 |
+
int up_x;
|
31 |
+
int up_y;
|
32 |
+
int down_x;
|
33 |
+
int down_y;
|
34 |
+
int pad_x0;
|
35 |
+
int pad_x1;
|
36 |
+
int pad_y0;
|
37 |
+
int pad_y1;
|
38 |
+
|
39 |
+
int major_dim;
|
40 |
+
int in_h;
|
41 |
+
int in_w;
|
42 |
+
int minor_dim;
|
43 |
+
int kernel_h;
|
44 |
+
int kernel_w;
|
45 |
+
int out_h;
|
46 |
+
int out_w;
|
47 |
+
int loop_major;
|
48 |
+
int loop_x;
|
49 |
+
};
|
50 |
+
|
51 |
+
|
52 |
+
template <typename scalar_t, int up_x, int up_y, int down_x, int down_y, int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
|
53 |
+
__global__ void upfirdn2d_kernel(scalar_t* out, const scalar_t* input, const scalar_t* kernel, const UpFirDn2DKernelParams p) {
|
54 |
+
const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
|
55 |
+
const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
|
56 |
+
|
57 |
+
__shared__ volatile float sk[kernel_h][kernel_w];
|
58 |
+
__shared__ volatile float sx[tile_in_h][tile_in_w];
|
59 |
+
|
60 |
+
int minor_idx = blockIdx.x;
|
61 |
+
int tile_out_y = minor_idx / p.minor_dim;
|
62 |
+
minor_idx -= tile_out_y * p.minor_dim;
|
63 |
+
tile_out_y *= tile_out_h;
|
64 |
+
int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
|
65 |
+
int major_idx_base = blockIdx.z * p.loop_major;
|
66 |
+
|
67 |
+
if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | major_idx_base >= p.major_dim) {
|
68 |
+
return;
|
69 |
+
}
|
70 |
+
|
71 |
+
for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; tap_idx += blockDim.x) {
|
72 |
+
int ky = tap_idx / kernel_w;
|
73 |
+
int kx = tap_idx - ky * kernel_w;
|
74 |
+
scalar_t v = 0.0;
|
75 |
+
|
76 |
+
if (kx < p.kernel_w & ky < p.kernel_h) {
|
77 |
+
v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
|
78 |
+
}
|
79 |
+
|
80 |
+
sk[ky][kx] = v;
|
81 |
+
}
|
82 |
+
|
83 |
+
for (int loop_major = 0, major_idx = major_idx_base; loop_major < p.loop_major & major_idx < p.major_dim; loop_major++, major_idx++) {
|
84 |
+
for (int loop_x = 0, tile_out_x = tile_out_x_base; loop_x < p.loop_x & tile_out_x < p.out_w; loop_x++, tile_out_x += tile_out_w) {
|
85 |
+
int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
|
86 |
+
int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
|
87 |
+
int tile_in_x = floor_div(tile_mid_x, up_x);
|
88 |
+
int tile_in_y = floor_div(tile_mid_y, up_y);
|
89 |
+
|
90 |
+
__syncthreads();
|
91 |
+
|
92 |
+
for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; in_idx += blockDim.x) {
|
93 |
+
int rel_in_y = in_idx / tile_in_w;
|
94 |
+
int rel_in_x = in_idx - rel_in_y * tile_in_w;
|
95 |
+
int in_x = rel_in_x + tile_in_x;
|
96 |
+
int in_y = rel_in_y + tile_in_y;
|
97 |
+
|
98 |
+
scalar_t v = 0.0;
|
99 |
+
|
100 |
+
if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
|
101 |
+
v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + minor_idx];
|
102 |
+
}
|
103 |
+
|
104 |
+
sx[rel_in_y][rel_in_x] = v;
|
105 |
+
}
|
106 |
+
|
107 |
+
__syncthreads();
|
108 |
+
for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; out_idx += blockDim.x) {
|
109 |
+
int rel_out_y = out_idx / tile_out_w;
|
110 |
+
int rel_out_x = out_idx - rel_out_y * tile_out_w;
|
111 |
+
int out_x = rel_out_x + tile_out_x;
|
112 |
+
int out_y = rel_out_y + tile_out_y;
|
113 |
+
|
114 |
+
int mid_x = tile_mid_x + rel_out_x * down_x;
|
115 |
+
int mid_y = tile_mid_y + rel_out_y * down_y;
|
116 |
+
int in_x = floor_div(mid_x, up_x);
|
117 |
+
int in_y = floor_div(mid_y, up_y);
|
118 |
+
int rel_in_x = in_x - tile_in_x;
|
119 |
+
int rel_in_y = in_y - tile_in_y;
|
120 |
+
int kernel_x = (in_x + 1) * up_x - mid_x - 1;
|
121 |
+
int kernel_y = (in_y + 1) * up_y - mid_y - 1;
|
122 |
+
|
123 |
+
scalar_t v = 0.0;
|
124 |
+
|
125 |
+
#pragma unroll
|
126 |
+
for (int y = 0; y < kernel_h / up_y; y++)
|
127 |
+
#pragma unroll
|
128 |
+
for (int x = 0; x < kernel_w / up_x; x++)
|
129 |
+
v += sx[rel_in_y + y][rel_in_x + x] * sk[kernel_y + y * up_y][kernel_x + x * up_x];
|
130 |
+
|
131 |
+
if (out_x < p.out_w & out_y < p.out_h) {
|
132 |
+
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + minor_idx] = v;
|
133 |
+
}
|
134 |
+
}
|
135 |
+
}
|
136 |
+
}
|
137 |
+
}
|
138 |
+
|
139 |
+
|
140 |
+
torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
|
141 |
+
int up_x, int up_y, int down_x, int down_y,
|
142 |
+
int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
|
143 |
+
int curDevice = -1;
|
144 |
+
cudaGetDevice(&curDevice);
|
145 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
|
146 |
+
|
147 |
+
UpFirDn2DKernelParams p;
|
148 |
+
|
149 |
+
auto x = input.contiguous();
|
150 |
+
auto k = kernel.contiguous();
|
151 |
+
|
152 |
+
p.major_dim = x.size(0);
|
153 |
+
p.in_h = x.size(1);
|
154 |
+
p.in_w = x.size(2);
|
155 |
+
p.minor_dim = x.size(3);
|
156 |
+
p.kernel_h = k.size(0);
|
157 |
+
p.kernel_w = k.size(1);
|
158 |
+
p.up_x = up_x;
|
159 |
+
p.up_y = up_y;
|
160 |
+
p.down_x = down_x;
|
161 |
+
p.down_y = down_y;
|
162 |
+
p.pad_x0 = pad_x0;
|
163 |
+
p.pad_x1 = pad_x1;
|
164 |
+
p.pad_y0 = pad_y0;
|
165 |
+
p.pad_y1 = pad_y1;
|
166 |
+
|
167 |
+
p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / p.down_y;
|
168 |
+
p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / p.down_x;
|
169 |
+
|
170 |
+
auto out = at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
|
171 |
+
|
172 |
+
int mode = -1;
|
173 |
+
|
174 |
+
int tile_out_h;
|
175 |
+
int tile_out_w;
|
176 |
+
|
177 |
+
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) {
|
178 |
+
mode = 1;
|
179 |
+
tile_out_h = 16;
|
180 |
+
tile_out_w = 64;
|
181 |
+
}
|
182 |
+
|
183 |
+
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 3 && p.kernel_w <= 3) {
|
184 |
+
mode = 2;
|
185 |
+
tile_out_h = 16;
|
186 |
+
tile_out_w = 64;
|
187 |
+
}
|
188 |
+
|
189 |
+
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) {
|
190 |
+
mode = 3;
|
191 |
+
tile_out_h = 16;
|
192 |
+
tile_out_w = 64;
|
193 |
+
}
|
194 |
+
|
195 |
+
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 2 && p.kernel_w <= 2) {
|
196 |
+
mode = 4;
|
197 |
+
tile_out_h = 16;
|
198 |
+
tile_out_w = 64;
|
199 |
+
}
|
200 |
+
|
201 |
+
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 4 && p.kernel_w <= 4) {
|
202 |
+
mode = 5;
|
203 |
+
tile_out_h = 8;
|
204 |
+
tile_out_w = 32;
|
205 |
+
}
|
206 |
+
|
207 |
+
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 2 && p.kernel_w <= 2) {
|
208 |
+
mode = 6;
|
209 |
+
tile_out_h = 8;
|
210 |
+
tile_out_w = 32;
|
211 |
+
}
|
212 |
+
|
213 |
+
dim3 block_size;
|
214 |
+
dim3 grid_size;
|
215 |
+
|
216 |
+
if (tile_out_h > 0 && tile_out_w) {
|
217 |
+
p.loop_major = (p.major_dim - 1) / 16384 + 1;
|
218 |
+
p.loop_x = 1;
|
219 |
+
block_size = dim3(32 * 8, 1, 1);
|
220 |
+
grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
|
221 |
+
(p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
|
222 |
+
(p.major_dim - 1) / p.loop_major + 1);
|
223 |
+
}
|
224 |
+
|
225 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
|
226 |
+
switch (mode) {
|
227 |
+
case 1:
|
228 |
+
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64><<<grid_size, block_size, 0, stream>>>(
|
229 |
+
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
|
230 |
+
);
|
231 |
+
|
232 |
+
break;
|
233 |
+
|
234 |
+
case 2:
|
235 |
+
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64><<<grid_size, block_size, 0, stream>>>(
|
236 |
+
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
|
237 |
+
);
|
238 |
+
|
239 |
+
break;
|
240 |
+
|
241 |
+
case 3:
|
242 |
+
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64><<<grid_size, block_size, 0, stream>>>(
|
243 |
+
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
|
244 |
+
);
|
245 |
+
|
246 |
+
break;
|
247 |
+
|
248 |
+
case 4:
|
249 |
+
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64><<<grid_size, block_size, 0, stream>>>(
|
250 |
+
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
|
251 |
+
);
|
252 |
+
|
253 |
+
break;
|
254 |
+
|
255 |
+
case 5:
|
256 |
+
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32><<<grid_size, block_size, 0, stream>>>(
|
257 |
+
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
|
258 |
+
);
|
259 |
+
|
260 |
+
break;
|
261 |
+
|
262 |
+
case 6:
|
263 |
+
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32><<<grid_size, block_size, 0, stream>>>(
|
264 |
+
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
|
265 |
+
);
|
266 |
+
|
267 |
+
break;
|
268 |
+
}
|
269 |
+
});
|
270 |
+
|
271 |
+
return out;
|
272 |
+
}
|
e4e/options/__init__.py
ADDED
File without changes
|
e4e/options/train_options.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from argparse import ArgumentParser
|
2 |
+
from configs.paths_config import model_paths
|
3 |
+
|
4 |
+
|
5 |
+
class TrainOptions:
|
6 |
+
|
7 |
+
def __init__(self):
|
8 |
+
self.parser = ArgumentParser()
|
9 |
+
self.initialize()
|
10 |
+
|
11 |
+
def initialize(self):
|
12 |
+
self.parser.add_argument('--exp_dir', type=str, help='Path to experiment output directory')
|
13 |
+
self.parser.add_argument('--dataset_type', default='ffhq_encode', type=str,
|
14 |
+
help='Type of dataset/experiment to run')
|
15 |
+
self.parser.add_argument('--encoder_type', default='Encoder4Editing', type=str, help='Which encoder to use')
|
16 |
+
|
17 |
+
self.parser.add_argument('--batch_size', default=4, type=int, help='Batch size for training')
|
18 |
+
self.parser.add_argument('--test_batch_size', default=2, type=int, help='Batch size for testing and inference')
|
19 |
+
self.parser.add_argument('--workers', default=4, type=int, help='Number of train dataloader workers')
|
20 |
+
self.parser.add_argument('--test_workers', default=2, type=int,
|
21 |
+
help='Number of test/inference dataloader workers')
|
22 |
+
|
23 |
+
self.parser.add_argument('--learning_rate', default=0.0001, type=float, help='Optimizer learning rate')
|
24 |
+
self.parser.add_argument('--optim_name', default='ranger', type=str, help='Which optimizer to use')
|
25 |
+
self.parser.add_argument('--train_decoder', default=False, type=bool, help='Whether to train the decoder model')
|
26 |
+
self.parser.add_argument('--start_from_latent_avg', action='store_true',
|
27 |
+
help='Whether to add average latent vector to generate codes from encoder.')
|
28 |
+
self.parser.add_argument('--lpips_type', default='alex', type=str, help='LPIPS backbone')
|
29 |
+
|
30 |
+
self.parser.add_argument('--lpips_lambda', default=0.8, type=float, help='LPIPS loss multiplier factor')
|
31 |
+
self.parser.add_argument('--id_lambda', default=0.1, type=float, help='ID loss multiplier factor')
|
32 |
+
self.parser.add_argument('--l2_lambda', default=1.0, type=float, help='L2 loss multiplier factor')
|
33 |
+
|
34 |
+
self.parser.add_argument('--stylegan_weights', default=model_paths['stylegan_ffhq'], type=str,
|
35 |
+
help='Path to StyleGAN model weights')
|
36 |
+
self.parser.add_argument('--stylegan_size', default=1024, type=int,
|
37 |
+
help='size of pretrained StyleGAN Generator')
|
38 |
+
self.parser.add_argument('--checkpoint_path', default=None, type=str, help='Path to pSp model checkpoint')
|
39 |
+
|
40 |
+
self.parser.add_argument('--max_steps', default=500000, type=int, help='Maximum number of training steps')
|
41 |
+
self.parser.add_argument('--image_interval', default=100, type=int,
|
42 |
+
help='Interval for logging train images during training')
|
43 |
+
self.parser.add_argument('--board_interval', default=50, type=int,
|
44 |
+
help='Interval for logging metrics to tensorboard')
|
45 |
+
self.parser.add_argument('--val_interval', default=1000, type=int, help='Validation interval')
|
46 |
+
self.parser.add_argument('--save_interval', default=None, type=int, help='Model checkpoint interval')
|
47 |
+
|
48 |
+
# Discriminator flags
|
49 |
+
self.parser.add_argument('--w_discriminator_lambda', default=0, type=float, help='Dw loss multiplier')
|
50 |
+
self.parser.add_argument('--w_discriminator_lr', default=2e-5, type=float, help='Dw learning rate')
|
51 |
+
self.parser.add_argument("--r1", type=float, default=10, help="weight of the r1 regularization")
|
52 |
+
self.parser.add_argument("--d_reg_every", type=int, default=16,
|
53 |
+
help="interval for applying r1 regularization")
|
54 |
+
self.parser.add_argument('--use_w_pool', action='store_true',
|
55 |
+
help='Whether to store a latnet codes pool for the discriminator\'s training')
|
56 |
+
self.parser.add_argument("--w_pool_size", type=int, default=50,
|
57 |
+
help="W\'s pool size, depends on --use_w_pool")
|
58 |
+
|
59 |
+
# e4e specific
|
60 |
+
self.parser.add_argument('--delta_norm', type=int, default=2, help="norm type of the deltas")
|
61 |
+
self.parser.add_argument('--delta_norm_lambda', type=float, default=2e-4, help="lambda for delta norm loss")
|
62 |
+
|
63 |
+
# Progressive training
|
64 |
+
self.parser.add_argument('--progressive_steps', nargs='+', type=int, default=None,
|
65 |
+
help="The training steps of training new deltas. steps[i] starts the delta_i training")
|
66 |
+
self.parser.add_argument('--progressive_start', type=int, default=None,
|
67 |
+
help="The training step to start training the deltas, overrides progressive_steps")
|
68 |
+
self.parser.add_argument('--progressive_step_every', type=int, default=2_000,
|
69 |
+
help="Amount of training steps for each progressive step")
|
70 |
+
|
71 |
+
# Save additional training info to enable future training continuation from produced checkpoints
|
72 |
+
self.parser.add_argument('--save_training_data', action='store_true',
|
73 |
+
help='Save intermediate training data to resume training from the checkpoint')
|
74 |
+
self.parser.add_argument('--sub_exp_dir', default=None, type=str, help='Name of sub experiment directory')
|
75 |
+
self.parser.add_argument('--keep_optimizer', action='store_true',
|
76 |
+
help='Whether to continue from the checkpoint\'s optimizer')
|
77 |
+
self.parser.add_argument('--resume_training_from_ckpt', default=None, type=str,
|
78 |
+
help='Path to training checkpoint, works when --save_training_data was set to True')
|
79 |
+
self.parser.add_argument('--update_param_list', nargs='+', type=str, default=None,
|
80 |
+
help="Name of training parameters to update the loaded training checkpoint")
|
81 |
+
|
82 |
+
def parse(self):
|
83 |
+
opts = self.parser.parse_args()
|
84 |
+
return opts
|
e4e/scripts/calc_losses_on_images.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from argparse import ArgumentParser
|
2 |
+
import os
|
3 |
+
import json
|
4 |
+
import sys
|
5 |
+
from tqdm import tqdm
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from torch.utils.data import DataLoader
|
9 |
+
import torchvision.transforms as transforms
|
10 |
+
|
11 |
+
sys.path.append(".")
|
12 |
+
sys.path.append("..")
|
13 |
+
|
14 |
+
from criteria.lpips.lpips import LPIPS
|
15 |
+
from datasets.gt_res_dataset import GTResDataset
|
16 |
+
|
17 |
+
|
18 |
+
def parse_args():
|
19 |
+
parser = ArgumentParser(add_help=False)
|
20 |
+
parser.add_argument('--mode', type=str, default='lpips', choices=['lpips', 'l2'])
|
21 |
+
parser.add_argument('--data_path', type=str, default='results')
|
22 |
+
parser.add_argument('--gt_path', type=str, default='gt_images')
|
23 |
+
parser.add_argument('--workers', type=int, default=4)
|
24 |
+
parser.add_argument('--batch_size', type=int, default=4)
|
25 |
+
parser.add_argument('--is_cars', action='store_true')
|
26 |
+
args = parser.parse_args()
|
27 |
+
return args
|
28 |
+
|
29 |
+
|
30 |
+
def run(args):
|
31 |
+
resize_dims = (256, 256)
|
32 |
+
if args.is_cars:
|
33 |
+
resize_dims = (192, 256)
|
34 |
+
transform = transforms.Compose([transforms.Resize(resize_dims),
|
35 |
+
transforms.ToTensor(),
|
36 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
|
37 |
+
|
38 |
+
print('Loading dataset')
|
39 |
+
dataset = GTResDataset(root_path=args.data_path,
|
40 |
+
gt_dir=args.gt_path,
|
41 |
+
transform=transform)
|
42 |
+
|
43 |
+
dataloader = DataLoader(dataset,
|
44 |
+
batch_size=args.batch_size,
|
45 |
+
shuffle=False,
|
46 |
+
num_workers=int(args.workers),
|
47 |
+
drop_last=True)
|
48 |
+
|
49 |
+
if args.mode == 'lpips':
|
50 |
+
loss_func = LPIPS(net_type='alex')
|
51 |
+
elif args.mode == 'l2':
|
52 |
+
loss_func = torch.nn.MSELoss()
|
53 |
+
else:
|
54 |
+
raise Exception('Not a valid mode!')
|
55 |
+
loss_func.cuda()
|
56 |
+
|
57 |
+
global_i = 0
|
58 |
+
scores_dict = {}
|
59 |
+
all_scores = []
|
60 |
+
for result_batch, gt_batch in tqdm(dataloader):
|
61 |
+
for i in range(args.batch_size):
|
62 |
+
loss = float(loss_func(result_batch[i:i + 1].cuda(), gt_batch[i:i + 1].cuda()))
|
63 |
+
all_scores.append(loss)
|
64 |
+
im_path = dataset.pairs[global_i][0]
|
65 |
+
scores_dict[os.path.basename(im_path)] = loss
|
66 |
+
global_i += 1
|
67 |
+
|
68 |
+
all_scores = list(scores_dict.values())
|
69 |
+
mean = np.mean(all_scores)
|
70 |
+
std = np.std(all_scores)
|
71 |
+
result_str = 'Average loss is {:.2f}+-{:.2f}'.format(mean, std)
|
72 |
+
print('Finished with ', args.data_path)
|
73 |
+
print(result_str)
|
74 |
+
|
75 |
+
out_path = os.path.join(os.path.dirname(args.data_path), 'inference_metrics')
|
76 |
+
if not os.path.exists(out_path):
|
77 |
+
os.makedirs(out_path)
|
78 |
+
|
79 |
+
with open(os.path.join(out_path, 'stat_{}.txt'.format(args.mode)), 'w') as f:
|
80 |
+
f.write(result_str)
|
81 |
+
with open(os.path.join(out_path, 'scores_{}.json'.format(args.mode)), 'w') as f:
|
82 |
+
json.dump(scores_dict, f)
|
83 |
+
|
84 |
+
|
85 |
+
if __name__ == '__main__':
|
86 |
+
args = parse_args()
|
87 |
+
run(args)
|
e4e/scripts/inference.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import sys
|
6 |
+
import os
|
7 |
+
import dlib
|
8 |
+
|
9 |
+
sys.path.append(".")
|
10 |
+
sys.path.append("..")
|
11 |
+
|
12 |
+
from configs import data_configs, paths_config
|
13 |
+
from datasets.inference_dataset import InferenceDataset
|
14 |
+
from torch.utils.data import DataLoader
|
15 |
+
from utils.model_utils import setup_model
|
16 |
+
from utils.common import tensor2im
|
17 |
+
from utils.alignment import align_face
|
18 |
+
from PIL import Image
|
19 |
+
|
20 |
+
|
21 |
+
def main(args):
|
22 |
+
net, opts = setup_model(args.ckpt, device)
|
23 |
+
is_cars = 'cars_' in opts.dataset_type
|
24 |
+
generator = net.decoder
|
25 |
+
generator.eval()
|
26 |
+
args, data_loader = setup_data_loader(args, opts)
|
27 |
+
|
28 |
+
# Check if latents exist
|
29 |
+
latents_file_path = os.path.join(args.save_dir, 'latents.pt')
|
30 |
+
if os.path.exists(latents_file_path):
|
31 |
+
latent_codes = torch.load(latents_file_path).to(device)
|
32 |
+
else:
|
33 |
+
latent_codes = get_all_latents(net, data_loader, args.n_sample, is_cars=is_cars)
|
34 |
+
torch.save(latent_codes, latents_file_path)
|
35 |
+
|
36 |
+
if not args.latents_only:
|
37 |
+
generate_inversions(args, generator, latent_codes, is_cars=is_cars)
|
38 |
+
|
39 |
+
|
40 |
+
def setup_data_loader(args, opts):
|
41 |
+
dataset_args = data_configs.DATASETS[opts.dataset_type]
|
42 |
+
transforms_dict = dataset_args['transforms'](opts).get_transforms()
|
43 |
+
images_path = args.images_dir if args.images_dir is not None else dataset_args['test_source_root']
|
44 |
+
print(f"images path: {images_path}")
|
45 |
+
align_function = None
|
46 |
+
if args.align:
|
47 |
+
align_function = run_alignment
|
48 |
+
test_dataset = InferenceDataset(root=images_path,
|
49 |
+
transform=transforms_dict['transform_test'],
|
50 |
+
preprocess=align_function,
|
51 |
+
opts=opts)
|
52 |
+
|
53 |
+
data_loader = DataLoader(test_dataset,
|
54 |
+
batch_size=args.batch,
|
55 |
+
shuffle=False,
|
56 |
+
num_workers=2,
|
57 |
+
drop_last=True)
|
58 |
+
|
59 |
+
print(f'dataset length: {len(test_dataset)}')
|
60 |
+
|
61 |
+
if args.n_sample is None:
|
62 |
+
args.n_sample = len(test_dataset)
|
63 |
+
return args, data_loader
|
64 |
+
|
65 |
+
|
66 |
+
def get_latents(net, x, is_cars=False):
|
67 |
+
codes = net.encoder(x)
|
68 |
+
if net.opts.start_from_latent_avg:
|
69 |
+
if codes.ndim == 2:
|
70 |
+
codes = codes + net.latent_avg.repeat(codes.shape[0], 1, 1)[:, 0, :]
|
71 |
+
else:
|
72 |
+
codes = codes + net.latent_avg.repeat(codes.shape[0], 1, 1)
|
73 |
+
if codes.shape[1] == 18 and is_cars:
|
74 |
+
codes = codes[:, :16, :]
|
75 |
+
return codes
|
76 |
+
|
77 |
+
|
78 |
+
def get_all_latents(net, data_loader, n_images=None, is_cars=False):
|
79 |
+
all_latents = []
|
80 |
+
i = 0
|
81 |
+
with torch.no_grad():
|
82 |
+
for batch in data_loader:
|
83 |
+
if n_images is not None and i > n_images:
|
84 |
+
break
|
85 |
+
x = batch
|
86 |
+
inputs = x.to(device).float()
|
87 |
+
latents = get_latents(net, inputs, is_cars)
|
88 |
+
all_latents.append(latents)
|
89 |
+
i += len(latents)
|
90 |
+
return torch.cat(all_latents)
|
91 |
+
|
92 |
+
|
93 |
+
def save_image(img, save_dir, idx):
|
94 |
+
result = tensor2im(img)
|
95 |
+
im_save_path = os.path.join(save_dir, f"{idx:05d}.jpg")
|
96 |
+
Image.fromarray(np.array(result)).save(im_save_path)
|
97 |
+
|
98 |
+
|
99 |
+
@torch.no_grad()
|
100 |
+
def generate_inversions(args, g, latent_codes, is_cars):
|
101 |
+
print('Saving inversion images')
|
102 |
+
inversions_directory_path = os.path.join(args.save_dir, 'inversions')
|
103 |
+
os.makedirs(inversions_directory_path, exist_ok=True)
|
104 |
+
for i in range(args.n_sample):
|
105 |
+
imgs, _ = g([latent_codes[i].unsqueeze(0)], input_is_latent=True, randomize_noise=False, return_latents=True)
|
106 |
+
if is_cars:
|
107 |
+
imgs = imgs[:, :, 64:448, :]
|
108 |
+
save_image(imgs[0], inversions_directory_path, i + 1)
|
109 |
+
|
110 |
+
|
111 |
+
def run_alignment(image_path):
|
112 |
+
predictor = dlib.shape_predictor(paths_config.model_paths['shape_predictor'])
|
113 |
+
aligned_image = align_face(filepath=image_path, predictor=predictor)
|
114 |
+
print("Aligned image has shape: {}".format(aligned_image.size))
|
115 |
+
return aligned_image
|
116 |
+
|
117 |
+
|
118 |
+
if __name__ == "__main__":
|
119 |
+
device = "cuda"
|
120 |
+
|
121 |
+
parser = argparse.ArgumentParser(description="Inference")
|
122 |
+
parser.add_argument("--images_dir", type=str, default=None,
|
123 |
+
help="The directory of the images to be inverted")
|
124 |
+
parser.add_argument("--save_dir", type=str, default=None,
|
125 |
+
help="The directory to save the latent codes and inversion images. (default: images_dir")
|
126 |
+
parser.add_argument("--batch", type=int, default=1, help="batch size for the generator")
|
127 |
+
parser.add_argument("--n_sample", type=int, default=None, help="number of the samples to infer.")
|
128 |
+
parser.add_argument("--latents_only", action="store_true", help="infer only the latent codes of the directory")
|
129 |
+
parser.add_argument("--align", action="store_true", help="align face images before inference")
|
130 |
+
parser.add_argument("ckpt", metavar="CHECKPOINT", help="path to generator checkpoint")
|
131 |
+
|
132 |
+
args = parser.parse_args()
|
133 |
+
main(args)
|
e4e/scripts/train.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file runs the main training/val loop
|
3 |
+
"""
|
4 |
+
import os
|
5 |
+
import json
|
6 |
+
import math
|
7 |
+
import sys
|
8 |
+
import pprint
|
9 |
+
import torch
|
10 |
+
from argparse import Namespace
|
11 |
+
|
12 |
+
sys.path.append(".")
|
13 |
+
sys.path.append("..")
|
14 |
+
|
15 |
+
from options.train_options import TrainOptions
|
16 |
+
from training.coach import Coach
|
17 |
+
|
18 |
+
|
19 |
+
def main():
|
20 |
+
opts = TrainOptions().parse()
|
21 |
+
previous_train_ckpt = None
|
22 |
+
if opts.resume_training_from_ckpt:
|
23 |
+
opts, previous_train_ckpt = load_train_checkpoint(opts)
|
24 |
+
else:
|
25 |
+
setup_progressive_steps(opts)
|
26 |
+
create_initial_experiment_dir(opts)
|
27 |
+
|
28 |
+
coach = Coach(opts, previous_train_ckpt)
|
29 |
+
coach.train()
|
30 |
+
|
31 |
+
|
32 |
+
def load_train_checkpoint(opts):
|
33 |
+
train_ckpt_path = opts.resume_training_from_ckpt
|
34 |
+
previous_train_ckpt = torch.load(opts.resume_training_from_ckpt, map_location='cpu')
|
35 |
+
new_opts_dict = vars(opts)
|
36 |
+
opts = previous_train_ckpt['opts']
|
37 |
+
opts['resume_training_from_ckpt'] = train_ckpt_path
|
38 |
+
update_new_configs(opts, new_opts_dict)
|
39 |
+
pprint.pprint(opts)
|
40 |
+
opts = Namespace(**opts)
|
41 |
+
if opts.sub_exp_dir is not None:
|
42 |
+
sub_exp_dir = opts.sub_exp_dir
|
43 |
+
opts.exp_dir = os.path.join(opts.exp_dir, sub_exp_dir)
|
44 |
+
create_initial_experiment_dir(opts)
|
45 |
+
return opts, previous_train_ckpt
|
46 |
+
|
47 |
+
|
48 |
+
def setup_progressive_steps(opts):
|
49 |
+
log_size = int(math.log(opts.stylegan_size, 2))
|
50 |
+
num_style_layers = 2*log_size - 2
|
51 |
+
num_deltas = num_style_layers - 1
|
52 |
+
if opts.progressive_start is not None: # If progressive delta training
|
53 |
+
opts.progressive_steps = [0]
|
54 |
+
next_progressive_step = opts.progressive_start
|
55 |
+
for i in range(num_deltas):
|
56 |
+
opts.progressive_steps.append(next_progressive_step)
|
57 |
+
next_progressive_step += opts.progressive_step_every
|
58 |
+
|
59 |
+
assert opts.progressive_steps is None or is_valid_progressive_steps(opts, num_style_layers), \
|
60 |
+
"Invalid progressive training input"
|
61 |
+
|
62 |
+
|
63 |
+
def is_valid_progressive_steps(opts, num_style_layers):
|
64 |
+
return len(opts.progressive_steps) == num_style_layers and opts.progressive_steps[0] == 0
|
65 |
+
|
66 |
+
|
67 |
+
def create_initial_experiment_dir(opts):
|
68 |
+
if os.path.exists(opts.exp_dir):
|
69 |
+
raise Exception('Oops... {} already exists'.format(opts.exp_dir))
|
70 |
+
os.makedirs(opts.exp_dir)
|
71 |
+
|
72 |
+
opts_dict = vars(opts)
|
73 |
+
pprint.pprint(opts_dict)
|
74 |
+
with open(os.path.join(opts.exp_dir, 'opt.json'), 'w') as f:
|
75 |
+
json.dump(opts_dict, f, indent=4, sort_keys=True)
|
76 |
+
|
77 |
+
|
78 |
+
def update_new_configs(ckpt_opts, new_opts):
|
79 |
+
for k, v in new_opts.items():
|
80 |
+
if k not in ckpt_opts:
|
81 |
+
ckpt_opts[k] = v
|
82 |
+
if new_opts['update_param_list']:
|
83 |
+
for param in new_opts['update_param_list']:
|
84 |
+
ckpt_opts[param] = new_opts[param]
|
85 |
+
|
86 |
+
|
87 |
+
if __name__ == '__main__':
|
88 |
+
main()
|
e4e/utils/__init__.py
ADDED
File without changes
|
e4e/utils/alignment.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import PIL
|
3 |
+
import PIL.Image
|
4 |
+
import scipy
|
5 |
+
import scipy.ndimage
|
6 |
+
import dlib
|
7 |
+
|
8 |
+
|
9 |
+
def get_landmark(filepath, predictor):
|
10 |
+
"""get landmark with dlib
|
11 |
+
:return: np.array shape=(68, 2)
|
12 |
+
"""
|
13 |
+
detector = dlib.get_frontal_face_detector()
|
14 |
+
|
15 |
+
img = dlib.load_rgb_image(filepath)
|
16 |
+
dets = detector(img, 1)
|
17 |
+
|
18 |
+
for k, d in enumerate(dets):
|
19 |
+
shape = predictor(img, d)
|
20 |
+
|
21 |
+
t = list(shape.parts())
|
22 |
+
a = []
|
23 |
+
for tt in t:
|
24 |
+
a.append([tt.x, tt.y])
|
25 |
+
lm = np.array(a)
|
26 |
+
return lm
|
27 |
+
|
28 |
+
|
29 |
+
def align_face(filepath, predictor):
|
30 |
+
"""
|
31 |
+
:param filepath: str
|
32 |
+
:return: PIL Image
|
33 |
+
"""
|
34 |
+
|
35 |
+
lm = get_landmark(filepath, predictor)
|
36 |
+
|
37 |
+
lm_chin = lm[0: 17] # left-right
|
38 |
+
lm_eyebrow_left = lm[17: 22] # left-right
|
39 |
+
lm_eyebrow_right = lm[22: 27] # left-right
|
40 |
+
lm_nose = lm[27: 31] # top-down
|
41 |
+
lm_nostrils = lm[31: 36] # top-down
|
42 |
+
lm_eye_left = lm[36: 42] # left-clockwise
|
43 |
+
lm_eye_right = lm[42: 48] # left-clockwise
|
44 |
+
lm_mouth_outer = lm[48: 60] # left-clockwise
|
45 |
+
lm_mouth_inner = lm[60: 68] # left-clockwise
|
46 |
+
|
47 |
+
# Calculate auxiliary vectors.
|
48 |
+
eye_left = np.mean(lm_eye_left, axis=0)
|
49 |
+
eye_right = np.mean(lm_eye_right, axis=0)
|
50 |
+
eye_avg = (eye_left + eye_right) * 0.5
|
51 |
+
eye_to_eye = eye_right - eye_left
|
52 |
+
mouth_left = lm_mouth_outer[0]
|
53 |
+
mouth_right = lm_mouth_outer[6]
|
54 |
+
mouth_avg = (mouth_left + mouth_right) * 0.5
|
55 |
+
eye_to_mouth = mouth_avg - eye_avg
|
56 |
+
|
57 |
+
# Choose oriented crop rectangle.
|
58 |
+
x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
|
59 |
+
x /= np.hypot(*x)
|
60 |
+
x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
|
61 |
+
y = np.flipud(x) * [-1, 1]
|
62 |
+
c = eye_avg + eye_to_mouth * 0.1
|
63 |
+
quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
|
64 |
+
qsize = np.hypot(*x) * 2
|
65 |
+
|
66 |
+
# read image
|
67 |
+
img = PIL.Image.open(filepath)
|
68 |
+
|
69 |
+
output_size = 256
|
70 |
+
transform_size = 256
|
71 |
+
enable_padding = True
|
72 |
+
|
73 |
+
# Shrink.
|
74 |
+
shrink = int(np.floor(qsize / output_size * 0.5))
|
75 |
+
if shrink > 1:
|
76 |
+
rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
|
77 |
+
img = img.resize(rsize, PIL.Image.ANTIALIAS)
|
78 |
+
quad /= shrink
|
79 |
+
qsize /= shrink
|
80 |
+
|
81 |
+
# Crop.
|
82 |
+
border = max(int(np.rint(qsize * 0.1)), 3)
|
83 |
+
crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
|
84 |
+
int(np.ceil(max(quad[:, 1]))))
|
85 |
+
crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]),
|
86 |
+
min(crop[3] + border, img.size[1]))
|
87 |
+
if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
|
88 |
+
img = img.crop(crop)
|
89 |
+
quad -= crop[0:2]
|
90 |
+
|
91 |
+
# Pad.
|
92 |
+
pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
|
93 |
+
int(np.ceil(max(quad[:, 1]))))
|
94 |
+
pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0),
|
95 |
+
max(pad[3] - img.size[1] + border, 0))
|
96 |
+
if enable_padding and max(pad) > border - 4:
|
97 |
+
pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
|
98 |
+
img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
|
99 |
+
h, w, _ = img.shape
|
100 |
+
y, x, _ = np.ogrid[:h, :w, :1]
|
101 |
+
mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]),
|
102 |
+
1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3]))
|
103 |
+
blur = qsize * 0.02
|
104 |
+
img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
|
105 |
+
img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
|
106 |
+
img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
|
107 |
+
quad += pad[:2]
|
108 |
+
|
109 |
+
# Transform.
|
110 |
+
img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR)
|
111 |
+
if output_size < transform_size:
|
112 |
+
img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS)
|
113 |
+
|
114 |
+
# Return aligned image.
|
115 |
+
return img
|
e4e/utils/common.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
|
4 |
+
|
5 |
+
# Log images
|
6 |
+
def log_input_image(x, opts):
|
7 |
+
return tensor2im(x)
|
8 |
+
|
9 |
+
|
10 |
+
def tensor2im(var):
|
11 |
+
# var shape: (3, H, W)
|
12 |
+
var = var.cpu().detach().transpose(0, 2).transpose(0, 1).numpy()
|
13 |
+
var = ((var + 1) / 2)
|
14 |
+
var[var < 0] = 0
|
15 |
+
var[var > 1] = 1
|
16 |
+
var = var * 255
|
17 |
+
return Image.fromarray(var.astype('uint8'))
|
18 |
+
|
19 |
+
|
20 |
+
def vis_faces(log_hooks):
|
21 |
+
display_count = len(log_hooks)
|
22 |
+
fig = plt.figure(figsize=(8, 4 * display_count))
|
23 |
+
gs = fig.add_gridspec(display_count, 3)
|
24 |
+
for i in range(display_count):
|
25 |
+
hooks_dict = log_hooks[i]
|
26 |
+
fig.add_subplot(gs[i, 0])
|
27 |
+
if 'diff_input' in hooks_dict:
|
28 |
+
vis_faces_with_id(hooks_dict, fig, gs, i)
|
29 |
+
else:
|
30 |
+
vis_faces_no_id(hooks_dict, fig, gs, i)
|
31 |
+
plt.tight_layout()
|
32 |
+
return fig
|
33 |
+
|
34 |
+
|
35 |
+
def vis_faces_with_id(hooks_dict, fig, gs, i):
|
36 |
+
plt.imshow(hooks_dict['input_face'])
|
37 |
+
plt.title('Input\nOut Sim={:.2f}'.format(float(hooks_dict['diff_input'])))
|
38 |
+
fig.add_subplot(gs[i, 1])
|
39 |
+
plt.imshow(hooks_dict['target_face'])
|
40 |
+
plt.title('Target\nIn={:.2f}, Out={:.2f}'.format(float(hooks_dict['diff_views']),
|
41 |
+
float(hooks_dict['diff_target'])))
|
42 |
+
fig.add_subplot(gs[i, 2])
|
43 |
+
plt.imshow(hooks_dict['output_face'])
|
44 |
+
plt.title('Output\n Target Sim={:.2f}'.format(float(hooks_dict['diff_target'])))
|
45 |
+
|
46 |
+
|
47 |
+
def vis_faces_no_id(hooks_dict, fig, gs, i):
|
48 |
+
plt.imshow(hooks_dict['input_face'], cmap="gray")
|
49 |
+
plt.title('Input')
|
50 |
+
fig.add_subplot(gs[i, 1])
|
51 |
+
plt.imshow(hooks_dict['target_face'])
|
52 |
+
plt.title('Target')
|
53 |
+
fig.add_subplot(gs[i, 2])
|
54 |
+
plt.imshow(hooks_dict['output_face'])
|
55 |
+
plt.title('Output')
|
e4e/utils/data_utils.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Code adopted from pix2pixHD:
|
3 |
+
https://github.com/NVIDIA/pix2pixHD/blob/master/data/image_folder.py
|
4 |
+
"""
|
5 |
+
import os
|
6 |
+
|
7 |
+
IMG_EXTENSIONS = [
|
8 |
+
'.jpg', '.JPG', '.jpeg', '.JPEG',
|
9 |
+
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff'
|
10 |
+
]
|
11 |
+
|
12 |
+
|
13 |
+
def is_image_file(filename):
|
14 |
+
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
15 |
+
|
16 |
+
|
17 |
+
def make_dataset(dir):
|
18 |
+
images = []
|
19 |
+
assert os.path.isdir(dir), '%s is not a valid directory' % dir
|
20 |
+
for root, _, fnames in sorted(os.walk(dir)):
|
21 |
+
for fname in fnames:
|
22 |
+
if is_image_file(fname):
|
23 |
+
path = os.path.join(root, fname)
|
24 |
+
images.append(path)
|
25 |
+
return images
|
e4e/utils/model_utils.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import argparse
|
3 |
+
from models.psp import pSp
|
4 |
+
from models.encoders.psp_encoders import Encoder4Editing
|
5 |
+
|
6 |
+
|
7 |
+
def setup_model(checkpoint_path, device='cuda'):
|
8 |
+
ckpt = torch.load(checkpoint_path, map_location='cpu')
|
9 |
+
opts = ckpt['opts']
|
10 |
+
|
11 |
+
opts['checkpoint_path'] = checkpoint_path
|
12 |
+
opts['device'] = device
|
13 |
+
opts = argparse.Namespace(**opts)
|
14 |
+
|
15 |
+
net = pSp(opts)
|
16 |
+
net.eval()
|
17 |
+
net = net.to(device)
|
18 |
+
return net, opts
|
19 |
+
|
20 |
+
|
21 |
+
def load_e4e_standalone(checkpoint_path, device='cuda'):
|
22 |
+
ckpt = torch.load(checkpoint_path, map_location='cpu')
|
23 |
+
opts = argparse.Namespace(**ckpt['opts'])
|
24 |
+
e4e = Encoder4Editing(50, 'ir_se', opts)
|
25 |
+
e4e_dict = {k.replace('encoder.', ''): v for k, v in ckpt['state_dict'].items() if k.startswith('encoder.')}
|
26 |
+
e4e.load_state_dict(e4e_dict)
|
27 |
+
e4e.eval()
|
28 |
+
e4e = e4e.to(device)
|
29 |
+
latent_avg = ckpt['latent_avg'].to(device)
|
30 |
+
|
31 |
+
def add_latent_avg(model, inputs, outputs):
|
32 |
+
return outputs + latent_avg.repeat(outputs.shape[0], 1, 1)
|
33 |
+
|
34 |
+
e4e.register_forward_hook(add_latent_avg)
|
35 |
+
return e4e
|
e4e/utils/train_utils.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
def aggregate_loss_dict(agg_loss_dict):
|
3 |
+
mean_vals = {}
|
4 |
+
for output in agg_loss_dict:
|
5 |
+
for key in output:
|
6 |
+
mean_vals[key] = mean_vals.setdefault(key, []) + [output[key]]
|
7 |
+
for key in mean_vals:
|
8 |
+
if len(mean_vals[key]) > 0:
|
9 |
+
mean_vals[key] = sum(mean_vals[key]) / len(mean_vals[key])
|
10 |
+
else:
|
11 |
+
print('{} has no value'.format(key))
|
12 |
+
mean_vals[key] = 0
|
13 |
+
return mean_vals
|