Spaces:
Running
Running
Add files
Browse files- .gitattributes +1 -0
- .gitmodules +6 -0
- .style.yapf +5 -0
- HairCLIP +1 -0
- app.py +151 -0
- encoder4editing +1 -0
- images/95UF6LXe-Lo.jpg +3 -0
- images/ILip77SbmOE.jpg +3 -0
- images/README.md +7 -0
- images/et_78QkMMQs.jpg +3 -0
- images/rDEOVtE7vOs.jpg +3 -0
- model.py +151 -0
- packages.txt +2 -0
- patch.e4e +131 -0
- patch.hairclip +61 -0
- requirements.txt +8 -0
.gitattributes
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
2 |
*.7z filter=lfs diff=lfs merge=lfs -text
|
3 |
*.arrow filter=lfs diff=lfs merge=lfs -text
|
4 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
.gitmodules
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[submodule "HairCLIP"]
|
2 |
+
path = HairCLIP
|
3 |
+
url = https://github.com/wty-ustc/HairCLIP
|
4 |
+
[submodule "encoder4editing"]
|
5 |
+
path = encoder4editing
|
6 |
+
url = https://github.com/omertov/encoder4editing
|
.style.yapf
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[style]
|
2 |
+
based_on_style = pep8
|
3 |
+
blank_line_before_nested_class_or_def = false
|
4 |
+
spaces_before_comment = 2
|
5 |
+
split_before_logical_operator = true
|
HairCLIP
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Subproject commit 29290cf5bdca0f21ff27e0ec2e93bdd1ebbe3605
|
app.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
from __future__ import annotations
|
4 |
+
|
5 |
+
import argparse
|
6 |
+
import os
|
7 |
+
import pathlib
|
8 |
+
import subprocess
|
9 |
+
|
10 |
+
import gradio as gr
|
11 |
+
|
12 |
+
if os.getenv('SYSTEM') == 'spaces':
|
13 |
+
subprocess.call('git apply ../patch.e4e'.split(), cwd='encoder4editing')
|
14 |
+
subprocess.call('git apply ../patch.hairclip'.split(), cwd='HairCLIP')
|
15 |
+
|
16 |
+
from model import Model
|
17 |
+
|
18 |
+
|
19 |
+
def parse_args() -> argparse.Namespace:
|
20 |
+
parser = argparse.ArgumentParser()
|
21 |
+
parser.add_argument('--device', type=str, default='cpu')
|
22 |
+
parser.add_argument('--theme', type=str)
|
23 |
+
parser.add_argument('--share', action='store_true')
|
24 |
+
parser.add_argument('--port', type=int)
|
25 |
+
parser.add_argument('--disable-queue',
|
26 |
+
dest='enable_queue',
|
27 |
+
action='store_false')
|
28 |
+
return parser.parse_args()
|
29 |
+
|
30 |
+
|
31 |
+
def load_hairstyle_list() -> list[str]:
|
32 |
+
with open('HairCLIP/mapper/hairstyle_list.txt') as f:
|
33 |
+
lines = [line.strip() for line in f.readlines()]
|
34 |
+
lines = [line[:-10] for line in lines]
|
35 |
+
return lines
|
36 |
+
|
37 |
+
|
38 |
+
def set_example_image(example: list) -> dict:
|
39 |
+
return gr.Image.update(value=example[0])
|
40 |
+
|
41 |
+
|
42 |
+
def update_step2_components(choice: str) -> tuple[dict, dict]:
|
43 |
+
return (
|
44 |
+
gr.Dropdown.update(visible=choice in ['hairstyle', 'both']),
|
45 |
+
gr.Textbox.update(visible=choice in ['color', 'both']),
|
46 |
+
)
|
47 |
+
|
48 |
+
|
49 |
+
def main():
|
50 |
+
args = parse_args()
|
51 |
+
model = Model(device=args.device)
|
52 |
+
|
53 |
+
css = '''
|
54 |
+
h1#title {
|
55 |
+
text-align: center;
|
56 |
+
}
|
57 |
+
img#teaser {
|
58 |
+
max-width: 1000px;
|
59 |
+
max-height: 600px;
|
60 |
+
}
|
61 |
+
'''
|
62 |
+
|
63 |
+
with gr.Blocks(theme=args.theme, css=css) as demo:
|
64 |
+
gr.Markdown('''<h1 id="title">HairCLIP</h1>
|
65 |
+
|
66 |
+
This is an unofficial demo for <a href="https://github.com/wty-ustc/HairCLIP">https://github.com/wty-ustc/HairCLIP</a>.
|
67 |
+
|
68 |
+
<center><img id="teaser" src="https://raw.githubusercontent.com/wty-ustc/HairCLIP/main/assets/teaser.png" alt="teaser"></center>
|
69 |
+
''')
|
70 |
+
with gr.Box():
|
71 |
+
gr.Markdown('## Step 1')
|
72 |
+
with gr.Row():
|
73 |
+
with gr.Column():
|
74 |
+
with gr.Row():
|
75 |
+
input_image = gr.Image(label='Input Image',
|
76 |
+
type='file')
|
77 |
+
with gr.Row():
|
78 |
+
preprocess_button = gr.Button('Preprocess')
|
79 |
+
with gr.Column():
|
80 |
+
aligned_face = gr.Image(label='Aligned Face',
|
81 |
+
type='pil',
|
82 |
+
interactive=False)
|
83 |
+
with gr.Column():
|
84 |
+
reconstructed_face = gr.Image(label='Reconstructed Face',
|
85 |
+
type='numpy')
|
86 |
+
latent = gr.Variable()
|
87 |
+
|
88 |
+
with gr.Row():
|
89 |
+
paths = sorted(pathlib.Path('images').glob('*.jpg'))
|
90 |
+
example_images = gr.Dataset(components=[input_image],
|
91 |
+
samples=[[path.as_posix()]
|
92 |
+
for path in paths])
|
93 |
+
|
94 |
+
with gr.Box():
|
95 |
+
gr.Markdown('## Step 2')
|
96 |
+
with gr.Row():
|
97 |
+
with gr.Column():
|
98 |
+
with gr.Row():
|
99 |
+
editing_type = gr.Radio(['hairstyle', 'color', 'both'],
|
100 |
+
value='both',
|
101 |
+
type='value',
|
102 |
+
label='Editing Type')
|
103 |
+
with gr.Row():
|
104 |
+
hairstyles = load_hairstyle_list()
|
105 |
+
hairstyle_index = gr.Dropdown(hairstyles,
|
106 |
+
value='afro',
|
107 |
+
type='index',
|
108 |
+
label='Hairstyle')
|
109 |
+
with gr.Row():
|
110 |
+
color_description = gr.Textbox(value='red',
|
111 |
+
label='Color')
|
112 |
+
with gr.Row():
|
113 |
+
run_button = gr.Button('Run')
|
114 |
+
|
115 |
+
with gr.Column():
|
116 |
+
result = gr.Image(label='Result')
|
117 |
+
|
118 |
+
gr.Markdown(
|
119 |
+
'<center><img src="https://visitor-badge.glitch.me/badge?page_id=gradio-blocks.hairclip" alt="visitor badge"/></center>'
|
120 |
+
)
|
121 |
+
|
122 |
+
preprocess_button.click(fn=model.detect_and_align_face,
|
123 |
+
inputs=[input_image],
|
124 |
+
outputs=[aligned_face])
|
125 |
+
aligned_face.change(fn=model.reconstruct_face,
|
126 |
+
inputs=[aligned_face],
|
127 |
+
outputs=[reconstructed_face, latent])
|
128 |
+
editing_type.change(fn=update_step2_components,
|
129 |
+
inputs=[editing_type],
|
130 |
+
outputs=[hairstyle_index, color_description])
|
131 |
+
run_button.click(fn=model.generate,
|
132 |
+
inputs=[
|
133 |
+
editing_type,
|
134 |
+
hairstyle_index,
|
135 |
+
color_description,
|
136 |
+
latent,
|
137 |
+
],
|
138 |
+
outputs=[result])
|
139 |
+
example_images.click(fn=set_example_image,
|
140 |
+
inputs=example_images,
|
141 |
+
outputs=example_images.components)
|
142 |
+
|
143 |
+
demo.launch(
|
144 |
+
enable_queue=args.enable_queue,
|
145 |
+
server_port=args.port,
|
146 |
+
share=args.share,
|
147 |
+
)
|
148 |
+
|
149 |
+
|
150 |
+
if __name__ == '__main__':
|
151 |
+
main()
|
encoder4editing
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Subproject commit 99ea50578695d2e8a1cf7259d8ee89b23eea942b
|
images/95UF6LXe-Lo.jpg
ADDED
Git LFS Details
|
images/ILip77SbmOE.jpg
ADDED
Git LFS Details
|
images/README.md
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
These images are freely-usable ones from [Unsplash](https://unsplash.com/).
|
2 |
+
|
3 |
+
- https://unsplash.com/photos/rDEOVtE7vOs
|
4 |
+
- https://unsplash.com/photos/et_78QkMMQs
|
5 |
+
- https://unsplash.com/photos/ILip77SbmOE
|
6 |
+
- https://unsplash.com/photos/95UF6LXe-Lo
|
7 |
+
|
images/et_78QkMMQs.jpg
ADDED
Git LFS Details
|
images/rDEOVtE7vOs.jpg
ADDED
Git LFS Details
|
model.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
from typing import Callable, Union
|
7 |
+
|
8 |
+
import dlib
|
9 |
+
import huggingface_hub
|
10 |
+
import numpy as np
|
11 |
+
import PIL.Image
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torchvision.transforms as T
|
15 |
+
|
16 |
+
sys.path.insert(0, 'encoder4editing')
|
17 |
+
|
18 |
+
from models.psp import pSp
|
19 |
+
from utils.alignment import align_face
|
20 |
+
|
21 |
+
sys.path.insert(0, 'HairCLIP/')
|
22 |
+
sys.path.insert(0, 'HairCLIP/mapper/')
|
23 |
+
|
24 |
+
from mapper.datasets.latents_dataset_inference import LatentsDatasetInference
|
25 |
+
from mapper.hairclip_mapper import HairCLIPMapper
|
26 |
+
|
27 |
+
TOKEN = os.environ['TOKEN']
|
28 |
+
|
29 |
+
|
30 |
+
class Model:
|
31 |
+
def __init__(self, device: Union[torch.device, str]):
|
32 |
+
self.device = torch.device(device)
|
33 |
+
self.landmark_model = self._create_dlib_landmark_model()
|
34 |
+
self.e4e = self._load_e4e()
|
35 |
+
self.hairclip = self._load_hairclip()
|
36 |
+
self.transform = self._create_transform()
|
37 |
+
|
38 |
+
@staticmethod
|
39 |
+
def _create_dlib_landmark_model():
|
40 |
+
path = huggingface_hub.hf_hub_download(
|
41 |
+
'hysts/dlib_face_landmark_model',
|
42 |
+
'shape_predictor_68_face_landmarks.dat',
|
43 |
+
use_auth_token=TOKEN)
|
44 |
+
return dlib.shape_predictor(path)
|
45 |
+
|
46 |
+
def _load_e4e(self) -> nn.Module:
|
47 |
+
ckpt_path = huggingface_hub.hf_hub_download('hysts/e4e',
|
48 |
+
'e4e_ffhq_encode.pt',
|
49 |
+
use_auth_token=TOKEN)
|
50 |
+
ckpt = torch.load(ckpt_path, map_location='cpu')
|
51 |
+
opts = ckpt['opts']
|
52 |
+
opts['device'] = self.device.type
|
53 |
+
opts['checkpoint_path'] = ckpt_path
|
54 |
+
opts = argparse.Namespace(**opts)
|
55 |
+
model = pSp(opts)
|
56 |
+
model.to(self.device)
|
57 |
+
model.eval()
|
58 |
+
return model
|
59 |
+
|
60 |
+
def _load_hairclip(self) -> nn.Module:
|
61 |
+
ckpt_path = huggingface_hub.hf_hub_download('hysts/HairCLIP',
|
62 |
+
'hairclip.pt',
|
63 |
+
use_auth_token=TOKEN)
|
64 |
+
ckpt = torch.load(ckpt_path, map_location='cpu')
|
65 |
+
opts = ckpt['opts']
|
66 |
+
opts['device'] = self.device.type
|
67 |
+
opts['checkpoint_path'] = ckpt_path
|
68 |
+
opts['editing_type'] = 'both'
|
69 |
+
opts['input_type'] = 'text'
|
70 |
+
opts['hairstyle_description'] = 'HairCLIP/mapper/hairstyle_list.txt'
|
71 |
+
opts['color_description'] = 'red'
|
72 |
+
opts = argparse.Namespace(**opts)
|
73 |
+
model = HairCLIPMapper(opts)
|
74 |
+
model.to(self.device)
|
75 |
+
model.eval()
|
76 |
+
return model
|
77 |
+
|
78 |
+
@staticmethod
|
79 |
+
def _create_transform() -> Callable:
|
80 |
+
transform = T.Compose([
|
81 |
+
T.Resize(256),
|
82 |
+
T.CenterCrop(256),
|
83 |
+
T.ToTensor(),
|
84 |
+
T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
|
85 |
+
])
|
86 |
+
return transform
|
87 |
+
|
88 |
+
def detect_and_align_face(self, image) -> PIL.Image.Image:
|
89 |
+
image = align_face(filepath=image.name, predictor=self.landmark_model)
|
90 |
+
return image
|
91 |
+
|
92 |
+
@staticmethod
|
93 |
+
def denormalize(tensor: torch.Tensor) -> torch.Tensor:
|
94 |
+
return torch.clamp((tensor + 1) / 2 * 255, 0, 255).to(torch.uint8)
|
95 |
+
|
96 |
+
def postprocess(self, tensor: torch.Tensor) -> np.ndarray:
|
97 |
+
tensor = self.denormalize(tensor)
|
98 |
+
return tensor.cpu().numpy().transpose(1, 2, 0)
|
99 |
+
|
100 |
+
@torch.inference_mode()
|
101 |
+
def reconstruct_face(
|
102 |
+
self, image: PIL.Image.Image) -> tuple[np.ndarray, torch.Tensor]:
|
103 |
+
input_data = self.transform(image).unsqueeze(0).to(self.device)
|
104 |
+
reconstructed_images, latents = self.e4e(input_data,
|
105 |
+
randomize_noise=False,
|
106 |
+
return_latents=True)
|
107 |
+
reconstructed = torch.clamp(reconstructed_images[0].detach(), -1, 1)
|
108 |
+
reconstructed = self.postprocess(reconstructed)
|
109 |
+
return reconstructed, latents[0]
|
110 |
+
|
111 |
+
@torch.inference_mode()
|
112 |
+
def generate(self, editing_type: str, hairstyle_index: int,
|
113 |
+
color_description: str, latent: torch.Tensor) -> np.ndarray:
|
114 |
+
opts = self.hairclip.opts
|
115 |
+
opts.editing_type = editing_type
|
116 |
+
opts.color_description = color_description
|
117 |
+
|
118 |
+
if editing_type == 'color':
|
119 |
+
hairstyle_index = 0
|
120 |
+
|
121 |
+
device = torch.device(opts.device)
|
122 |
+
|
123 |
+
dataset = LatentsDatasetInference(latents=latent.unsqueeze(0).cpu(),
|
124 |
+
opts=opts)
|
125 |
+
w, hairstyle_text_inputs_list, color_text_inputs_list = dataset[0][:3]
|
126 |
+
|
127 |
+
w = w.unsqueeze(0).to(device)
|
128 |
+
hairstyle_text_inputs = hairstyle_text_inputs_list[
|
129 |
+
hairstyle_index].unsqueeze(0).to(device)
|
130 |
+
color_text_inputs = color_text_inputs_list[0].unsqueeze(0).to(device)
|
131 |
+
|
132 |
+
hairstyle_tensor_hairmasked = torch.Tensor([0]).unsqueeze(0).to(device)
|
133 |
+
color_tensor_hairmasked = torch.Tensor([0]).unsqueeze(0).to(device)
|
134 |
+
|
135 |
+
w_hat = w + 0.1 * self.hairclip.mapper(
|
136 |
+
w,
|
137 |
+
hairstyle_text_inputs,
|
138 |
+
color_text_inputs,
|
139 |
+
hairstyle_tensor_hairmasked,
|
140 |
+
color_tensor_hairmasked,
|
141 |
+
)
|
142 |
+
x_hat, _ = self.hairclip.decoder(
|
143 |
+
[w_hat],
|
144 |
+
input_is_latent=True,
|
145 |
+
return_latents=True,
|
146 |
+
randomize_noise=False,
|
147 |
+
truncation=1,
|
148 |
+
)
|
149 |
+
res = torch.clamp(x_hat[0].detach(), -1, 1)
|
150 |
+
res = self.postprocess(res)
|
151 |
+
return res
|
packages.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
cmake
|
2 |
+
ninja-build
|
patch.e4e
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
diff --git a/models/stylegan2/op/fused_act.py b/models/stylegan2/op/fused_act.py
|
2 |
+
index 973a84f..6854b97 100644
|
3 |
+
--- a/models/stylegan2/op/fused_act.py
|
4 |
+
+++ b/models/stylegan2/op/fused_act.py
|
5 |
+
@@ -2,17 +2,18 @@ import os
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
+from torch.nn import functional as F
|
10 |
+
from torch.autograd import Function
|
11 |
+
from torch.utils.cpp_extension import load
|
12 |
+
|
13 |
+
-module_path = os.path.dirname(__file__)
|
14 |
+
-fused = load(
|
15 |
+
- 'fused',
|
16 |
+
- sources=[
|
17 |
+
- os.path.join(module_path, 'fused_bias_act.cpp'),
|
18 |
+
- os.path.join(module_path, 'fused_bias_act_kernel.cu'),
|
19 |
+
- ],
|
20 |
+
-)
|
21 |
+
+#module_path = os.path.dirname(__file__)
|
22 |
+
+#fused = load(
|
23 |
+
+# 'fused',
|
24 |
+
+# sources=[
|
25 |
+
+# os.path.join(module_path, 'fused_bias_act.cpp'),
|
26 |
+
+# os.path.join(module_path, 'fused_bias_act_kernel.cu'),
|
27 |
+
+# ],
|
28 |
+
+#)
|
29 |
+
|
30 |
+
|
31 |
+
class FusedLeakyReLUFunctionBackward(Function):
|
32 |
+
@@ -82,4 +83,18 @@ class FusedLeakyReLU(nn.Module):
|
33 |
+
|
34 |
+
|
35 |
+
def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
|
36 |
+
- return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
|
37 |
+
+ if input.device.type == "cpu":
|
38 |
+
+ if bias is not None:
|
39 |
+
+ rest_dim = [1] * (input.ndim - bias.ndim - 1)
|
40 |
+
+ return (
|
41 |
+
+ F.leaky_relu(
|
42 |
+
+ input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
|
43 |
+
+ )
|
44 |
+
+ * scale
|
45 |
+
+ )
|
46 |
+
+
|
47 |
+
+ else:
|
48 |
+
+ return F.leaky_relu(input, negative_slope=0.2) * scale
|
49 |
+
+
|
50 |
+
+ else:
|
51 |
+
+ return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
|
52 |
+
diff --git a/models/stylegan2/op/upfirdn2d.py b/models/stylegan2/op/upfirdn2d.py
|
53 |
+
index 7bc5a1e..5465d1a 100644
|
54 |
+
--- a/models/stylegan2/op/upfirdn2d.py
|
55 |
+
+++ b/models/stylegan2/op/upfirdn2d.py
|
56 |
+
@@ -1,17 +1,18 @@
|
57 |
+
import os
|
58 |
+
|
59 |
+
import torch
|
60 |
+
+from torch.nn import functional as F
|
61 |
+
from torch.autograd import Function
|
62 |
+
from torch.utils.cpp_extension import load
|
63 |
+
|
64 |
+
-module_path = os.path.dirname(__file__)
|
65 |
+
-upfirdn2d_op = load(
|
66 |
+
- 'upfirdn2d',
|
67 |
+
- sources=[
|
68 |
+
- os.path.join(module_path, 'upfirdn2d.cpp'),
|
69 |
+
- os.path.join(module_path, 'upfirdn2d_kernel.cu'),
|
70 |
+
- ],
|
71 |
+
-)
|
72 |
+
+#module_path = os.path.dirname(__file__)
|
73 |
+
+#upfirdn2d_op = load(
|
74 |
+
+# 'upfirdn2d',
|
75 |
+
+# sources=[
|
76 |
+
+# os.path.join(module_path, 'upfirdn2d.cpp'),
|
77 |
+
+# os.path.join(module_path, 'upfirdn2d_kernel.cu'),
|
78 |
+
+# ],
|
79 |
+
+#)
|
80 |
+
|
81 |
+
|
82 |
+
class UpFirDn2dBackward(Function):
|
83 |
+
@@ -97,8 +98,8 @@ class UpFirDn2d(Function):
|
84 |
+
|
85 |
+
ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
|
86 |
+
|
87 |
+
- out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
|
88 |
+
- out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
|
89 |
+
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
|
90 |
+
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
|
91 |
+
ctx.out_size = (out_h, out_w)
|
92 |
+
|
93 |
+
ctx.up = (up_x, up_y)
|
94 |
+
@@ -140,9 +141,13 @@ class UpFirDn2d(Function):
|
95 |
+
|
96 |
+
|
97 |
+
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
|
98 |
+
- out = UpFirDn2d.apply(
|
99 |
+
- input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
|
100 |
+
- )
|
101 |
+
+ if input.device.type == "cpu":
|
102 |
+
+ out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
|
103 |
+
+
|
104 |
+
+ else:
|
105 |
+
+ out = UpFirDn2d.apply(
|
106 |
+
+ input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
|
107 |
+
+ )
|
108 |
+
|
109 |
+
return out
|
110 |
+
|
111 |
+
@@ -150,6 +155,9 @@ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
|
112 |
+
def upfirdn2d_native(
|
113 |
+
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
|
114 |
+
):
|
115 |
+
+ _, channel, in_h, in_w = input.shape
|
116 |
+
+ input = input.reshape(-1, in_h, in_w, 1)
|
117 |
+
+
|
118 |
+
_, in_h, in_w, minor = input.shape
|
119 |
+
kernel_h, kernel_w = kernel.shape
|
120 |
+
|
121 |
+
@@ -180,5 +188,9 @@ def upfirdn2d_native(
|
122 |
+
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
|
123 |
+
)
|
124 |
+
out = out.permute(0, 2, 3, 1)
|
125 |
+
+ out = out[:, ::down_y, ::down_x, :]
|
126 |
+
+
|
127 |
+
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
|
128 |
+
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
|
129 |
+
|
130 |
+
- return out[:, ::down_y, ::down_x, :]
|
131 |
+
+ return out.view(-1, channel, out_h, out_w)
|
patch.hairclip
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
diff --git a/mapper/latent_mappers.py b/mapper/latent_mappers.py
|
2 |
+
index 56b9c55..f0dd005 100644
|
3 |
+
--- a/mapper/latent_mappers.py
|
4 |
+
+++ b/mapper/latent_mappers.py
|
5 |
+
@@ -19,7 +19,7 @@ class ModulationModule(Module):
|
6 |
+
|
7 |
+
def forward(self, x, embedding, cut_flag):
|
8 |
+
x = self.fc(x)
|
9 |
+
- x = self.norm(x)
|
10 |
+
+ x = self.norm(x)
|
11 |
+
if cut_flag == 1:
|
12 |
+
return x
|
13 |
+
gamma = self.gamma_function(embedding.float())
|
14 |
+
@@ -39,20 +39,20 @@ class SubHairMapper(Module):
|
15 |
+
def forward(self, x, embedding, cut_flag=0):
|
16 |
+
x = self.pixelnorm(x)
|
17 |
+
for modulation_module in self.modulation_module_list:
|
18 |
+
- x = modulation_module(x, embedding, cut_flag)
|
19 |
+
+ x = modulation_module(x, embedding, cut_flag)
|
20 |
+
return x
|
21 |
+
|
22 |
+
-class HairMapper(Module):
|
23 |
+
+class HairMapper(Module):
|
24 |
+
def __init__(self, opts):
|
25 |
+
super(HairMapper, self).__init__()
|
26 |
+
self.opts = opts
|
27 |
+
- self.clip_model, self.preprocess = clip.load("ViT-B/32", device="cuda")
|
28 |
+
+ self.clip_model, self.preprocess = clip.load("ViT-B/32", device=opts.device)
|
29 |
+
self.transform = transforms.Compose([transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))])
|
30 |
+
self.face_pool = torch.nn.AdaptiveAvgPool2d((224, 224))
|
31 |
+
self.hairstyle_cut_flag = 0
|
32 |
+
self.color_cut_flag = 0
|
33 |
+
|
34 |
+
- if not opts.no_coarse_mapper:
|
35 |
+
+ if not opts.no_coarse_mapper:
|
36 |
+
self.course_mapping = SubHairMapper(opts, 4)
|
37 |
+
if not opts.no_medium_mapper:
|
38 |
+
self.medium_mapping = SubHairMapper(opts, 4)
|
39 |
+
@@ -70,13 +70,13 @@ class HairMapper(Module):
|
40 |
+
elif hairstyle_tensor.shape[1] != 1:
|
41 |
+
hairstyle_embedding = self.gen_image_embedding(hairstyle_tensor, self.clip_model, self.preprocess).unsqueeze(1).repeat(1, 18, 1).detach()
|
42 |
+
else:
|
43 |
+
- hairstyle_embedding = torch.ones(x.shape[0], 18, 512).cuda()
|
44 |
+
+ hairstyle_embedding = torch.ones(x.shape[0], 18, 512).to(self.opts.device)
|
45 |
+
if color_text_inputs.shape[1] != 1:
|
46 |
+
color_embedding = self.clip_model.encode_text(color_text_inputs).unsqueeze(1).repeat(1, 18, 1).detach()
|
47 |
+
elif color_tensor.shape[1] != 1:
|
48 |
+
color_embedding = self.gen_image_embedding(color_tensor, self.clip_model, self.preprocess).unsqueeze(1).repeat(1, 18, 1).detach()
|
49 |
+
else:
|
50 |
+
- color_embedding = torch.ones(x.shape[0], 18, 512).cuda()
|
51 |
+
+ color_embedding = torch.ones(x.shape[0], 18, 512).to(self.opts.device)
|
52 |
+
|
53 |
+
|
54 |
+
if (hairstyle_text_inputs.shape[1] == 1) and (hairstyle_tensor.shape[1] == 1):
|
55 |
+
@@ -106,4 +106,4 @@ class HairMapper(Module):
|
56 |
+
x_fine = torch.zeros_like(x_fine)
|
57 |
+
|
58 |
+
out = torch.cat([x_coarse, x_medium, x_fine], dim=1)
|
59 |
+
- return out
|
60 |
+
|
61 |
+
+ return out
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dlib==19.23.0
|
2 |
+
numpy==1.22.3
|
3 |
+
opencv-python-headless==4.5.5.64
|
4 |
+
Pillow==9.1.0
|
5 |
+
scipy==1.8.0
|
6 |
+
torch==1.11.0
|
7 |
+
torchvision==0.12.0
|
8 |
+
git+https://github.com/openai/CLIP.git
|