Spaces:
Running
on
Zero
Running
on
Zero
♥️
Browse files- README.md +6 -6
- app.py +102 -0
- image0.jpeg +0 -0
- image1.jpeg +0 -0
- lama_cleaner/__init__.py +0 -0
- lama_cleaner/__pycache__/__init__.cpython-310.pyc +0 -0
- lama_cleaner/__pycache__/helper.cpython-310.pyc +0 -0
- lama_cleaner/__pycache__/model_manager.cpython-310.pyc +0 -0
- lama_cleaner/__pycache__/schema.cpython-310.pyc +0 -0
- lama_cleaner/__pycache__/settings.cpython-310.pyc +0 -0
- lama_cleaner/__pycache__/urls.cpython-310.pyc +0 -0
- lama_cleaner/__pycache__/wsgi.cpython-310.pyc +0 -0
- lama_cleaner/asgi.py +16 -0
- lama_cleaner/helper.py +182 -0
- lama_cleaner/model/__init__.py +0 -0
- lama_cleaner/model/__pycache__/__init__.cpython-310.pyc +0 -0
- lama_cleaner/model/__pycache__/base.cpython-310.pyc +0 -0
- lama_cleaner/model/__pycache__/ddim_sampler.cpython-310.pyc +0 -0
- lama_cleaner/model/__pycache__/fcf.cpython-310.pyc +0 -0
- lama_cleaner/model/__pycache__/lama.cpython-310.pyc +0 -0
- lama_cleaner/model/__pycache__/ldm.cpython-310.pyc +0 -0
- lama_cleaner/model/__pycache__/mat.cpython-310.pyc +0 -0
- lama_cleaner/model/__pycache__/opencv2.cpython-310.pyc +0 -0
- lama_cleaner/model/__pycache__/plms_sampler.cpython-310.pyc +0 -0
- lama_cleaner/model/__pycache__/sd.cpython-310.pyc +0 -0
- lama_cleaner/model/__pycache__/utils.cpython-310.pyc +0 -0
- lama_cleaner/model/__pycache__/zits.cpython-310.pyc +0 -0
- lama_cleaner/model/base.py +183 -0
- lama_cleaner/model/ddim_sampler.py +193 -0
- lama_cleaner/model/fcf.py +1214 -0
- lama_cleaner/model/lama.py +61 -0
- lama_cleaner/model/ldm.py +312 -0
- lama_cleaner/model/mat.py +1444 -0
- lama_cleaner/model/opencv2.py +24 -0
- lama_cleaner/model/plms_sampler.py +225 -0
- lama_cleaner/model/sd.py +215 -0
- lama_cleaner/model/sd_pipeline.py +310 -0
- lama_cleaner/model/utils.py +709 -0
- lama_cleaner/model/zits.py +427 -0
- lama_cleaner/model_manager.py +43 -0
- lama_cleaner/schema.py +50 -0
- lama_cleaner/settings.py +124 -0
- lama_cleaner/urls.py +22 -0
- lama_cleaner/wsgi.py +16 -0
- requirements.txt +12 -0
README.md
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
---
|
2 |
-
title: Remove
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.38.1
|
8 |
app_file: app.py
|
9 |
pinned: false
|
|
|
10 |
---
|
11 |
-
|
12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: Remove-WM
|
3 |
+
emoji: 🔍🗑️
|
4 |
+
colorFrom: pink
|
5 |
+
colorTo: purple
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.38.1
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
+
license: apache-2.0
|
11 |
---
|
12 |
+
https://github.com/sponsors/Damarcreative
|
|
app.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from lama_cleaner.model_manager import ModelManager
|
4 |
+
from lama_cleaner.schema import Config, HDStrategy, LDMSampler
|
5 |
+
from transformers import AutoProcessor, AutoModelForCausalLM
|
6 |
+
import cv2
|
7 |
+
import numpy as np
|
8 |
+
from PIL import Image, ImageDraw
|
9 |
+
import spaces
|
10 |
+
import subprocess
|
11 |
+
|
12 |
+
# Install necessary packages
|
13 |
+
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
14 |
+
|
15 |
+
# Initialize Florence model
|
16 |
+
model_id = 'microsoft/Florence-2-large'
|
17 |
+
florence_model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True).to("cuda").eval()
|
18 |
+
florence_processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
|
19 |
+
|
20 |
+
# Initialize Llama Cleaner model
|
21 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
22 |
+
|
23 |
+
@spaces.GPU()
|
24 |
+
def process_image(image, mask, strategy, sampler, fx=1, fy=1):
|
25 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
26 |
+
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
|
27 |
+
|
28 |
+
if fx != 1 or fy != 1:
|
29 |
+
image = cv2.resize(image, None, fx=fx, fy=fy, interpolation=cv2.INTER_AREA)
|
30 |
+
mask = cv2.resize(mask, None, fx=fx, fy=fy, interpolation=cv2.INTER_NEAREST)
|
31 |
+
|
32 |
+
config = Config(
|
33 |
+
ldm_steps=1,
|
34 |
+
ldm_sampler=sampler,
|
35 |
+
hd_strategy=strategy,
|
36 |
+
hd_strategy_crop_margin=32,
|
37 |
+
hd_strategy_crop_trigger_size=200,
|
38 |
+
hd_strategy_resize_limit=200,
|
39 |
+
)
|
40 |
+
|
41 |
+
model = ModelManager(name="lama", device=device)
|
42 |
+
result = model(image, mask, config)
|
43 |
+
return result
|
44 |
+
|
45 |
+
def create_mask(image, prediction):
|
46 |
+
mask = Image.new("RGBA", image.size, (0, 0, 0, 255)) # Black background
|
47 |
+
draw = ImageDraw.Draw(mask)
|
48 |
+
scale = 1
|
49 |
+
for polygons in prediction['polygons']:
|
50 |
+
for _polygon in polygons:
|
51 |
+
_polygon = np.array(_polygon).reshape(-1, 2)
|
52 |
+
if len(_polygon) < 3:
|
53 |
+
continue
|
54 |
+
_polygon = (_polygon * scale).reshape(-1).tolist()
|
55 |
+
draw.polygon(_polygon, fill=(255, 255, 255, 255)) # Make selected area white
|
56 |
+
return mask
|
57 |
+
|
58 |
+
@spaces.GPU()
|
59 |
+
def process_images_florence_lama(image):
|
60 |
+
# Convert image to OpenCV format
|
61 |
+
image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
62 |
+
|
63 |
+
# Run Florence to get mask
|
64 |
+
text_input = 'watermark' # Teks untuk Florence agar mengenali watermark
|
65 |
+
task_prompt = '<REGION_TO_SEGMENTATION>'
|
66 |
+
image_pil = Image.fromarray(image_cv) # Convert array to PIL Image
|
67 |
+
inputs = florence_processor(text=task_prompt + text_input, images=image_pil, return_tensors="pt").to("cuda")
|
68 |
+
generated_ids = florence_model.generate(
|
69 |
+
input_ids=inputs["input_ids"],
|
70 |
+
pixel_values=inputs["pixel_values"],
|
71 |
+
max_new_tokens=1024,
|
72 |
+
early_stopping=False,
|
73 |
+
do_sample=False,
|
74 |
+
num_beams=3,
|
75 |
+
)
|
76 |
+
generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
|
77 |
+
parsed_answer = florence_processor.post_process_generation(
|
78 |
+
generated_text,
|
79 |
+
task=task_prompt,
|
80 |
+
image_size=(image_pil.width, image_pil.height)
|
81 |
+
)
|
82 |
+
|
83 |
+
# Create mask and process image with Llama Cleaner
|
84 |
+
mask_image = create_mask(image_pil, parsed_answer['<REGION_TO_SEGMENTATION>'])
|
85 |
+
result_image = process_image(image_cv, np.array(mask_image), HDStrategy.RESIZE, LDMSampler.ddim)
|
86 |
+
|
87 |
+
# Convert result back to PIL Image
|
88 |
+
result_image_pil = Image.fromarray(cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB))
|
89 |
+
|
90 |
+
return result_image_pil
|
91 |
+
|
92 |
+
# Define Gradio interface
|
93 |
+
demo = gr.Interface(
|
94 |
+
fn=process_images_florence_lama,
|
95 |
+
inputs=gr.Image(type="pil", label="Input Image"),
|
96 |
+
outputs=gr.Image(type="pil", label="Output Image"),
|
97 |
+
title="Watermark Remover.",
|
98 |
+
description="Upload images and remove selected watermarks using Florence and Llama Cleaner."
|
99 |
+
)
|
100 |
+
# Launch Gradio interface with example images
|
101 |
+
if __name__ == "__main__":
|
102 |
+
demo.launch()
|
image0.jpeg
ADDED
image1.jpeg
ADDED
lama_cleaner/__init__.py
ADDED
File without changes
|
lama_cleaner/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (145 Bytes). View file
|
|
lama_cleaner/__pycache__/helper.cpython-310.pyc
ADDED
Binary file (4.86 kB). View file
|
|
lama_cleaner/__pycache__/model_manager.cpython-310.pyc
ADDED
Binary file (1.92 kB). View file
|
|
lama_cleaner/__pycache__/schema.cpython-310.pyc
ADDED
Binary file (1.68 kB). View file
|
|
lama_cleaner/__pycache__/settings.cpython-310.pyc
ADDED
Binary file (2.3 kB). View file
|
|
lama_cleaner/__pycache__/urls.cpython-310.pyc
ADDED
Binary file (989 Bytes). View file
|
|
lama_cleaner/__pycache__/wsgi.cpython-310.pyc
ADDED
Binary file (558 Bytes). View file
|
|
lama_cleaner/asgi.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
ASGI config for lama_cleaner project.
|
3 |
+
|
4 |
+
It exposes the ASGI callable as a module-level variable named ``application``.
|
5 |
+
|
6 |
+
For more information on this file, see
|
7 |
+
https://docs.djangoproject.com/en/4.1/howto/deployment/asgi/
|
8 |
+
"""
|
9 |
+
|
10 |
+
import os
|
11 |
+
|
12 |
+
from django.core.asgi import get_asgi_application
|
13 |
+
|
14 |
+
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'lama_cleaner.settings')
|
15 |
+
|
16 |
+
application = get_asgi_application()
|
lama_cleaner/helper.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
from typing import List, Optional
|
4 |
+
|
5 |
+
from urllib.parse import urlparse
|
6 |
+
import cv2
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from loguru import logger
|
10 |
+
from torch.hub import download_url_to_file, get_dir
|
11 |
+
|
12 |
+
|
13 |
+
def get_cache_path_by_url(url):
|
14 |
+
parts = urlparse(url)
|
15 |
+
hub_dir = get_dir()
|
16 |
+
model_dir = os.path.join(hub_dir, "checkpoints")
|
17 |
+
if not os.path.isdir(model_dir):
|
18 |
+
os.makedirs(os.path.join(model_dir, "hub", "checkpoints"))
|
19 |
+
filename = os.path.basename(parts.path)
|
20 |
+
cached_file = os.path.join(model_dir, filename)
|
21 |
+
return cached_file
|
22 |
+
|
23 |
+
|
24 |
+
def download_model(url):
|
25 |
+
cached_file = get_cache_path_by_url(url)
|
26 |
+
if not os.path.exists(cached_file):
|
27 |
+
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
|
28 |
+
hash_prefix = None
|
29 |
+
download_url_to_file(url, cached_file, hash_prefix, progress=True)
|
30 |
+
return cached_file
|
31 |
+
|
32 |
+
|
33 |
+
def ceil_modulo(x, mod):
|
34 |
+
if x % mod == 0:
|
35 |
+
return x
|
36 |
+
return (x // mod + 1) * mod
|
37 |
+
|
38 |
+
|
39 |
+
def load_jit_model(url_or_path, device):
|
40 |
+
if os.path.exists(url_or_path):
|
41 |
+
model_path = url_or_path
|
42 |
+
else:
|
43 |
+
model_path = download_model(url_or_path)
|
44 |
+
logger.info(f"Load model from: {model_path}")
|
45 |
+
try:
|
46 |
+
model = torch.jit.load(model_path).to(device)
|
47 |
+
except:
|
48 |
+
logger.error(
|
49 |
+
f"Failed to load {model_path}, delete model and restart lama-cleaner"
|
50 |
+
)
|
51 |
+
exit(-1)
|
52 |
+
model.eval()
|
53 |
+
return model
|
54 |
+
|
55 |
+
|
56 |
+
def load_model(model: torch.nn.Module, url_or_path, device):
|
57 |
+
if os.path.exists(url_or_path):
|
58 |
+
model_path = url_or_path
|
59 |
+
else:
|
60 |
+
model_path = download_model(url_or_path)
|
61 |
+
|
62 |
+
try:
|
63 |
+
state_dict = torch.load(model_path, map_location='cpu')
|
64 |
+
model.load_state_dict(state_dict, strict=True)
|
65 |
+
model.to(device)
|
66 |
+
logger.info(f"Load model from: {model_path}")
|
67 |
+
except:
|
68 |
+
logger.error(
|
69 |
+
f"Failed to load {model_path}, delete model and restart lama-cleaner"
|
70 |
+
)
|
71 |
+
exit(-1)
|
72 |
+
model.eval()
|
73 |
+
return model
|
74 |
+
|
75 |
+
|
76 |
+
def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes:
|
77 |
+
data = cv2.imencode(
|
78 |
+
f".{ext}",
|
79 |
+
image_numpy,
|
80 |
+
[int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0],
|
81 |
+
)[1]
|
82 |
+
image_bytes = data.tobytes()
|
83 |
+
return image_bytes
|
84 |
+
|
85 |
+
|
86 |
+
def load_img(img_bytes, gray: bool = False):
|
87 |
+
alpha_channel = None
|
88 |
+
nparr = np.frombuffer(img_bytes, np.uint8)
|
89 |
+
if gray:
|
90 |
+
np_img = cv2.imdecode(nparr, cv2.IMREAD_GRAYSCALE)
|
91 |
+
else:
|
92 |
+
np_img = cv2.imdecode(nparr, cv2.IMREAD_UNCHANGED)
|
93 |
+
if len(np_img.shape) == 3 and np_img.shape[2] == 4:
|
94 |
+
alpha_channel = np_img[:, :, -1]
|
95 |
+
np_img = cv2.cvtColor(np_img, cv2.COLOR_BGRA2RGB)
|
96 |
+
else:
|
97 |
+
np_img = cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB)
|
98 |
+
|
99 |
+
return np_img, alpha_channel
|
100 |
+
|
101 |
+
|
102 |
+
def norm_img(np_img):
|
103 |
+
if len(np_img.shape) == 2:
|
104 |
+
np_img = np_img[:, :, np.newaxis]
|
105 |
+
np_img = np.transpose(np_img, (2, 0, 1))
|
106 |
+
np_img = np_img.astype("float32") / 255
|
107 |
+
return np_img
|
108 |
+
|
109 |
+
|
110 |
+
def resize_max_size(
|
111 |
+
np_img, size_limit: int, interpolation=cv2.INTER_CUBIC
|
112 |
+
) -> np.ndarray:
|
113 |
+
# Resize image's longer size to size_limit if longer size larger than size_limit
|
114 |
+
h, w = np_img.shape[:2]
|
115 |
+
if max(h, w) > size_limit:
|
116 |
+
ratio = size_limit / max(h, w)
|
117 |
+
new_w = int(w * ratio + 0.5)
|
118 |
+
new_h = int(h * ratio + 0.5)
|
119 |
+
return cv2.resize(np_img, dsize=(new_w, new_h), interpolation=interpolation)
|
120 |
+
else:
|
121 |
+
return np_img
|
122 |
+
|
123 |
+
|
124 |
+
def pad_img_to_modulo(
|
125 |
+
img: np.ndarray, mod: int, square: bool = False, min_size: Optional[int] = None
|
126 |
+
):
|
127 |
+
"""
|
128 |
+
|
129 |
+
Args:
|
130 |
+
img: [H, W, C]
|
131 |
+
mod:
|
132 |
+
square: 是否为正方形
|
133 |
+
min_size:
|
134 |
+
|
135 |
+
Returns:
|
136 |
+
|
137 |
+
"""
|
138 |
+
if len(img.shape) == 2:
|
139 |
+
img = img[:, :, np.newaxis]
|
140 |
+
height, width = img.shape[:2]
|
141 |
+
out_height = ceil_modulo(height, mod)
|
142 |
+
out_width = ceil_modulo(width, mod)
|
143 |
+
|
144 |
+
if min_size is not None:
|
145 |
+
assert min_size % mod == 0
|
146 |
+
out_width = max(min_size, out_width)
|
147 |
+
out_height = max(min_size, out_height)
|
148 |
+
|
149 |
+
if square:
|
150 |
+
max_size = max(out_height, out_width)
|
151 |
+
out_height = max_size
|
152 |
+
out_width = max_size
|
153 |
+
|
154 |
+
return np.pad(
|
155 |
+
img,
|
156 |
+
((0, out_height - height), (0, out_width - width), (0, 0)),
|
157 |
+
mode="symmetric",
|
158 |
+
)
|
159 |
+
|
160 |
+
|
161 |
+
def boxes_from_mask(mask: np.ndarray) -> List[np.ndarray]:
|
162 |
+
"""
|
163 |
+
Args:
|
164 |
+
mask: (h, w, 1) 0~255
|
165 |
+
|
166 |
+
Returns:
|
167 |
+
|
168 |
+
"""
|
169 |
+
height, width = mask.shape[:2]
|
170 |
+
_, thresh = cv2.threshold(mask, 127, 255, 0)
|
171 |
+
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
172 |
+
|
173 |
+
boxes = []
|
174 |
+
for cnt in contours:
|
175 |
+
x, y, w, h = cv2.boundingRect(cnt)
|
176 |
+
box = np.array([x, y, x + w, y + h]).astype(int)
|
177 |
+
|
178 |
+
box[::2] = np.clip(box[::2], 0, width)
|
179 |
+
box[1::2] = np.clip(box[1::2], 0, height)
|
180 |
+
boxes.append(box)
|
181 |
+
|
182 |
+
return boxes
|
lama_cleaner/model/__init__.py
ADDED
File without changes
|
lama_cleaner/model/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (184 Bytes). View file
|
|
lama_cleaner/model/__pycache__/base.cpython-310.pyc
ADDED
Binary file (4.76 kB). View file
|
|
lama_cleaner/model/__pycache__/ddim_sampler.cpython-310.pyc
ADDED
Binary file (4.74 kB). View file
|
|
lama_cleaner/model/__pycache__/fcf.cpython-310.pyc
ADDED
Binary file (33.4 kB). View file
|
|
lama_cleaner/model/__pycache__/lama.cpython-310.pyc
ADDED
Binary file (2.16 kB). View file
|
|
lama_cleaner/model/__pycache__/ldm.cpython-310.pyc
ADDED
Binary file (7.86 kB). View file
|
|
lama_cleaner/model/__pycache__/mat.cpython-310.pyc
ADDED
Binary file (38 kB). View file
|
|
lama_cleaner/model/__pycache__/opencv2.cpython-310.pyc
ADDED
Binary file (1.14 kB). View file
|
|
lama_cleaner/model/__pycache__/plms_sampler.cpython-310.pyc
ADDED
Binary file (7.08 kB). View file
|
|
lama_cleaner/model/__pycache__/sd.cpython-310.pyc
ADDED
Binary file (5.83 kB). View file
|
|
lama_cleaner/model/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (26 kB). View file
|
|
lama_cleaner/model/__pycache__/zits.cpython-310.pyc
ADDED
Binary file (10.5 kB). View file
|
|
lama_cleaner/model/base.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import abc
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import torch
|
6 |
+
from loguru import logger
|
7 |
+
|
8 |
+
from lama_cleaner.helper import boxes_from_mask, resize_max_size, pad_img_to_modulo
|
9 |
+
from lama_cleaner.schema import Config, HDStrategy
|
10 |
+
|
11 |
+
|
12 |
+
class InpaintModel:
|
13 |
+
min_size: Optional[int] = None
|
14 |
+
pad_mod = 8
|
15 |
+
pad_to_square = False
|
16 |
+
|
17 |
+
def __init__(self, device, **kwargs):
|
18 |
+
"""
|
19 |
+
|
20 |
+
Args:
|
21 |
+
device:
|
22 |
+
"""
|
23 |
+
self.device = device
|
24 |
+
self.init_model(device, **kwargs)
|
25 |
+
|
26 |
+
@abc.abstractmethod
|
27 |
+
def init_model(self, device, **kwargs):
|
28 |
+
...
|
29 |
+
|
30 |
+
@staticmethod
|
31 |
+
@abc.abstractmethod
|
32 |
+
def is_downloaded() -> bool:
|
33 |
+
...
|
34 |
+
|
35 |
+
@abc.abstractmethod
|
36 |
+
def forward(self, image, mask, config: Config):
|
37 |
+
"""Input images and output images have same size
|
38 |
+
images: [H, W, C] RGB
|
39 |
+
masks: [H, W, 1] 255 为 masks 区域
|
40 |
+
return: BGR IMAGE
|
41 |
+
"""
|
42 |
+
...
|
43 |
+
|
44 |
+
def _pad_forward(self, image, mask, config: Config):
|
45 |
+
origin_height, origin_width = image.shape[:2]
|
46 |
+
pad_image = pad_img_to_modulo(
|
47 |
+
image, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size
|
48 |
+
)
|
49 |
+
pad_mask = pad_img_to_modulo(
|
50 |
+
mask, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size
|
51 |
+
)
|
52 |
+
|
53 |
+
logger.info(f"final forward pad size: {pad_image.shape}")
|
54 |
+
|
55 |
+
result = self.forward(pad_image, pad_mask, config)
|
56 |
+
result = result[0:origin_height, 0:origin_width, :]
|
57 |
+
|
58 |
+
original_pixel_indices = mask < 127
|
59 |
+
result[original_pixel_indices] = image[:, :, ::-1][original_pixel_indices]
|
60 |
+
return result
|
61 |
+
|
62 |
+
@torch.no_grad()
|
63 |
+
def __call__(self, image, mask, config: Config):
|
64 |
+
"""
|
65 |
+
images: [H, W, C] RGB, not normalized
|
66 |
+
masks: [H, W]
|
67 |
+
return: BGR IMAGE
|
68 |
+
"""
|
69 |
+
inpaint_result = None
|
70 |
+
logger.info(f"hd_strategy: {config.hd_strategy}")
|
71 |
+
if config.hd_strategy == HDStrategy.CROP:
|
72 |
+
if max(image.shape) > config.hd_strategy_crop_trigger_size:
|
73 |
+
logger.info(f"Run crop strategy")
|
74 |
+
boxes = boxes_from_mask(mask)
|
75 |
+
crop_result = []
|
76 |
+
for box in boxes:
|
77 |
+
crop_image, crop_box = self._run_box(image, mask, box, config)
|
78 |
+
crop_result.append((crop_image, crop_box))
|
79 |
+
|
80 |
+
inpaint_result = image[:, :, ::-1]
|
81 |
+
for crop_image, crop_box in crop_result:
|
82 |
+
x1, y1, x2, y2 = crop_box
|
83 |
+
inpaint_result[y1:y2, x1:x2, :] = crop_image
|
84 |
+
|
85 |
+
elif config.hd_strategy == HDStrategy.RESIZE:
|
86 |
+
if max(image.shape) > config.hd_strategy_resize_limit:
|
87 |
+
origin_size = image.shape[:2]
|
88 |
+
downsize_image = resize_max_size(
|
89 |
+
image, size_limit=config.hd_strategy_resize_limit
|
90 |
+
)
|
91 |
+
downsize_mask = resize_max_size(
|
92 |
+
mask, size_limit=config.hd_strategy_resize_limit
|
93 |
+
)
|
94 |
+
|
95 |
+
logger.info(
|
96 |
+
f"Run resize strategy, origin size: {image.shape} forward size: {downsize_image.shape}"
|
97 |
+
)
|
98 |
+
inpaint_result = self._pad_forward(
|
99 |
+
downsize_image, downsize_mask, config
|
100 |
+
)
|
101 |
+
|
102 |
+
# only paste masked area result
|
103 |
+
inpaint_result = cv2.resize(
|
104 |
+
inpaint_result,
|
105 |
+
(origin_size[1], origin_size[0]),
|
106 |
+
interpolation=cv2.INTER_CUBIC,
|
107 |
+
)
|
108 |
+
original_pixel_indices = mask < 127
|
109 |
+
inpaint_result[original_pixel_indices] = image[:, :, ::-1][
|
110 |
+
original_pixel_indices
|
111 |
+
]
|
112 |
+
|
113 |
+
if inpaint_result is None:
|
114 |
+
inpaint_result = self._pad_forward(image, mask, config)
|
115 |
+
|
116 |
+
return inpaint_result
|
117 |
+
|
118 |
+
def _crop_box(self, image, mask, box, config: Config):
|
119 |
+
"""
|
120 |
+
|
121 |
+
Args:
|
122 |
+
image: [H, W, C] RGB
|
123 |
+
mask: [H, W, 1]
|
124 |
+
box: [left,top,right,bottom]
|
125 |
+
|
126 |
+
Returns:
|
127 |
+
BGR IMAGE, (l, r, r, b)
|
128 |
+
"""
|
129 |
+
box_h = box[3] - box[1]
|
130 |
+
box_w = box[2] - box[0]
|
131 |
+
cx = (box[0] + box[2]) // 2
|
132 |
+
cy = (box[1] + box[3]) // 2
|
133 |
+
img_h, img_w = image.shape[:2]
|
134 |
+
|
135 |
+
w = box_w + config.hd_strategy_crop_margin * 2
|
136 |
+
h = box_h + config.hd_strategy_crop_margin * 2
|
137 |
+
|
138 |
+
_l = cx - w // 2
|
139 |
+
_r = cx + w // 2
|
140 |
+
_t = cy - h // 2
|
141 |
+
_b = cy + h // 2
|
142 |
+
|
143 |
+
l = max(_l, 0)
|
144 |
+
r = min(_r, img_w)
|
145 |
+
t = max(_t, 0)
|
146 |
+
b = min(_b, img_h)
|
147 |
+
|
148 |
+
# try to get more context when crop around image edge
|
149 |
+
if _l < 0:
|
150 |
+
r += abs(_l)
|
151 |
+
if _r > img_w:
|
152 |
+
l -= _r - img_w
|
153 |
+
if _t < 0:
|
154 |
+
b += abs(_t)
|
155 |
+
if _b > img_h:
|
156 |
+
t -= _b - img_h
|
157 |
+
|
158 |
+
l = max(l, 0)
|
159 |
+
r = min(r, img_w)
|
160 |
+
t = max(t, 0)
|
161 |
+
b = min(b, img_h)
|
162 |
+
|
163 |
+
crop_img = image[t:b, l:r, :]
|
164 |
+
crop_mask = mask[t:b, l:r]
|
165 |
+
|
166 |
+
logger.info(f"box size: ({box_h},{box_w}) crop size: {crop_img.shape}")
|
167 |
+
|
168 |
+
return crop_img, crop_mask, [l, t, r, b]
|
169 |
+
|
170 |
+
def _run_box(self, image, mask, box, config: Config):
|
171 |
+
"""
|
172 |
+
|
173 |
+
Args:
|
174 |
+
image: [H, W, C] RGB
|
175 |
+
mask: [H, W, 1]
|
176 |
+
box: [left,top,right,bottom]
|
177 |
+
|
178 |
+
Returns:
|
179 |
+
BGR IMAGE
|
180 |
+
"""
|
181 |
+
crop_img, crop_mask, [l, t, r, b] = self._crop_box(image, mask, box, config)
|
182 |
+
|
183 |
+
return self._pad_forward(crop_img, crop_mask, config), [l, t, r, b]
|
lama_cleaner/model/ddim_sampler.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from tqdm import tqdm
|
4 |
+
|
5 |
+
from lama_cleaner.model.utils import make_ddim_timesteps, make_ddim_sampling_parameters, noise_like
|
6 |
+
|
7 |
+
from loguru import logger
|
8 |
+
|
9 |
+
|
10 |
+
class DDIMSampler(object):
|
11 |
+
def __init__(self, model, schedule="linear"):
|
12 |
+
super().__init__()
|
13 |
+
self.model = model
|
14 |
+
self.ddpm_num_timesteps = model.num_timesteps
|
15 |
+
self.schedule = schedule
|
16 |
+
|
17 |
+
def register_buffer(self, name, attr):
|
18 |
+
setattr(self, name, attr)
|
19 |
+
|
20 |
+
def make_schedule(
|
21 |
+
self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
|
22 |
+
):
|
23 |
+
self.ddim_timesteps = make_ddim_timesteps(
|
24 |
+
ddim_discr_method=ddim_discretize,
|
25 |
+
num_ddim_timesteps=ddim_num_steps,
|
26 |
+
# array([1])
|
27 |
+
num_ddpm_timesteps=self.ddpm_num_timesteps,
|
28 |
+
verbose=verbose,
|
29 |
+
)
|
30 |
+
alphas_cumprod = self.model.alphas_cumprod # torch.Size([1000])
|
31 |
+
assert (
|
32 |
+
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
|
33 |
+
), "alphas have to be defined for each timestep"
|
34 |
+
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
35 |
+
|
36 |
+
self.register_buffer("betas", to_torch(self.model.betas))
|
37 |
+
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
|
38 |
+
self.register_buffer(
|
39 |
+
"alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
|
40 |
+
)
|
41 |
+
|
42 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
43 |
+
self.register_buffer(
|
44 |
+
"sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
|
45 |
+
)
|
46 |
+
self.register_buffer(
|
47 |
+
"sqrt_one_minus_alphas_cumprod",
|
48 |
+
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
|
49 |
+
)
|
50 |
+
self.register_buffer(
|
51 |
+
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
|
52 |
+
)
|
53 |
+
self.register_buffer(
|
54 |
+
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
|
55 |
+
)
|
56 |
+
self.register_buffer(
|
57 |
+
"sqrt_recipm1_alphas_cumprod",
|
58 |
+
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
|
59 |
+
)
|
60 |
+
|
61 |
+
# ddim sampling parameters
|
62 |
+
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
|
63 |
+
alphacums=alphas_cumprod.cpu(),
|
64 |
+
ddim_timesteps=self.ddim_timesteps,
|
65 |
+
eta=ddim_eta,
|
66 |
+
verbose=verbose,
|
67 |
+
)
|
68 |
+
self.register_buffer("ddim_sigmas", ddim_sigmas)
|
69 |
+
self.register_buffer("ddim_alphas", ddim_alphas)
|
70 |
+
self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
|
71 |
+
self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
|
72 |
+
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
73 |
+
(1 - self.alphas_cumprod_prev)
|
74 |
+
/ (1 - self.alphas_cumprod)
|
75 |
+
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
|
76 |
+
)
|
77 |
+
self.register_buffer(
|
78 |
+
"ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
|
79 |
+
)
|
80 |
+
|
81 |
+
@torch.no_grad()
|
82 |
+
def sample(self, steps, conditioning, batch_size, shape):
|
83 |
+
self.make_schedule(ddim_num_steps=steps, ddim_eta=0, verbose=False)
|
84 |
+
# sampling
|
85 |
+
C, H, W = shape
|
86 |
+
size = (batch_size, C, H, W)
|
87 |
+
|
88 |
+
# samples: 1,3,128,128
|
89 |
+
return self.ddim_sampling(
|
90 |
+
conditioning,
|
91 |
+
size,
|
92 |
+
quantize_denoised=False,
|
93 |
+
ddim_use_original_steps=False,
|
94 |
+
noise_dropout=0,
|
95 |
+
temperature=1.0,
|
96 |
+
)
|
97 |
+
|
98 |
+
@torch.no_grad()
|
99 |
+
def ddim_sampling(
|
100 |
+
self,
|
101 |
+
cond,
|
102 |
+
shape,
|
103 |
+
ddim_use_original_steps=False,
|
104 |
+
quantize_denoised=False,
|
105 |
+
temperature=1.0,
|
106 |
+
noise_dropout=0.0,
|
107 |
+
):
|
108 |
+
device = self.model.betas.device
|
109 |
+
b = shape[0]
|
110 |
+
img = torch.randn(shape, device=device, dtype=cond.dtype)
|
111 |
+
timesteps = (
|
112 |
+
self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
113 |
+
)
|
114 |
+
|
115 |
+
time_range = (
|
116 |
+
reversed(range(0, timesteps))
|
117 |
+
if ddim_use_original_steps
|
118 |
+
else np.flip(timesteps)
|
119 |
+
)
|
120 |
+
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
121 |
+
logger.info(f"Running DDIM Sampling with {total_steps} timesteps")
|
122 |
+
|
123 |
+
iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
|
124 |
+
|
125 |
+
for i, step in enumerate(iterator):
|
126 |
+
index = total_steps - i - 1
|
127 |
+
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
128 |
+
|
129 |
+
outs = self.p_sample_ddim(
|
130 |
+
img,
|
131 |
+
cond,
|
132 |
+
ts,
|
133 |
+
index=index,
|
134 |
+
use_original_steps=ddim_use_original_steps,
|
135 |
+
quantize_denoised=quantize_denoised,
|
136 |
+
temperature=temperature,
|
137 |
+
noise_dropout=noise_dropout,
|
138 |
+
)
|
139 |
+
img, _ = outs
|
140 |
+
|
141 |
+
return img
|
142 |
+
|
143 |
+
@torch.no_grad()
|
144 |
+
def p_sample_ddim(
|
145 |
+
self,
|
146 |
+
x,
|
147 |
+
c,
|
148 |
+
t,
|
149 |
+
index,
|
150 |
+
repeat_noise=False,
|
151 |
+
use_original_steps=False,
|
152 |
+
quantize_denoised=False,
|
153 |
+
temperature=1.0,
|
154 |
+
noise_dropout=0.0,
|
155 |
+
):
|
156 |
+
b, *_, device = *x.shape, x.device
|
157 |
+
e_t = self.model.apply_model(x, t, c)
|
158 |
+
|
159 |
+
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
160 |
+
alphas_prev = (
|
161 |
+
self.model.alphas_cumprod_prev
|
162 |
+
if use_original_steps
|
163 |
+
else self.ddim_alphas_prev
|
164 |
+
)
|
165 |
+
sqrt_one_minus_alphas = (
|
166 |
+
self.model.sqrt_one_minus_alphas_cumprod
|
167 |
+
if use_original_steps
|
168 |
+
else self.ddim_sqrt_one_minus_alphas
|
169 |
+
)
|
170 |
+
sigmas = (
|
171 |
+
self.model.ddim_sigmas_for_original_num_steps
|
172 |
+
if use_original_steps
|
173 |
+
else self.ddim_sigmas
|
174 |
+
)
|
175 |
+
# select parameters corresponding to the currently considered timestep
|
176 |
+
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
177 |
+
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
178 |
+
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
179 |
+
sqrt_one_minus_at = torch.full(
|
180 |
+
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
|
181 |
+
)
|
182 |
+
|
183 |
+
# current prediction for x_0
|
184 |
+
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
185 |
+
if quantize_denoised: # 没用
|
186 |
+
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
187 |
+
# direction pointing to x_t
|
188 |
+
dir_xt = (1.0 - a_prev - sigma_t ** 2).sqrt() * e_t
|
189 |
+
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
190 |
+
if noise_dropout > 0.0: # 没用
|
191 |
+
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
192 |
+
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
193 |
+
return x_prev, pred_x0
|
lama_cleaner/model/fcf.py
ADDED
@@ -0,0 +1,1214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
import torch.fft as fft
|
8 |
+
|
9 |
+
from lama_cleaner.schema import Config
|
10 |
+
|
11 |
+
from lama_cleaner.helper import load_model, get_cache_path_by_url, norm_img, boxes_from_mask, resize_max_size
|
12 |
+
from lama_cleaner.model.base import InpaintModel
|
13 |
+
from torch import conv2d, nn
|
14 |
+
import torch.nn.functional as F
|
15 |
+
|
16 |
+
from lama_cleaner.model.utils import setup_filter, _parse_scaling, _parse_padding, Conv2dLayer, FullyConnectedLayer, \
|
17 |
+
MinibatchStdLayer, activation_funcs, conv2d_resample, bias_act, upsample2d, normalize_2nd_moment, downsample2d
|
18 |
+
|
19 |
+
|
20 |
+
def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
21 |
+
assert isinstance(x, torch.Tensor)
|
22 |
+
return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain)
|
23 |
+
|
24 |
+
|
25 |
+
def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
|
26 |
+
"""Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.
|
27 |
+
"""
|
28 |
+
# Validate arguments.
|
29 |
+
assert isinstance(x, torch.Tensor) and x.ndim == 4
|
30 |
+
if f is None:
|
31 |
+
f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
|
32 |
+
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
|
33 |
+
assert f.dtype == torch.float32 and not f.requires_grad
|
34 |
+
batch_size, num_channels, in_height, in_width = x.shape
|
35 |
+
upx, upy = _parse_scaling(up)
|
36 |
+
downx, downy = _parse_scaling(down)
|
37 |
+
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
38 |
+
|
39 |
+
# Upsample by inserting zeros.
|
40 |
+
x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
|
41 |
+
x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
|
42 |
+
x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
|
43 |
+
|
44 |
+
# Pad or crop.
|
45 |
+
x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)])
|
46 |
+
x = x[:, :, max(-pady0, 0): x.shape[2] - max(-pady1, 0), max(-padx0, 0): x.shape[3] - max(-padx1, 0)]
|
47 |
+
|
48 |
+
# Setup filter.
|
49 |
+
f = f * (gain ** (f.ndim / 2))
|
50 |
+
f = f.to(x.dtype)
|
51 |
+
if not flip_filter:
|
52 |
+
f = f.flip(list(range(f.ndim)))
|
53 |
+
|
54 |
+
# Convolve with the filter.
|
55 |
+
f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
|
56 |
+
if f.ndim == 4:
|
57 |
+
x = conv2d(input=x, weight=f, groups=num_channels)
|
58 |
+
else:
|
59 |
+
x = conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
|
60 |
+
x = conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
|
61 |
+
|
62 |
+
# Downsample by throwing away pixels.
|
63 |
+
x = x[:, :, ::downy, ::downx]
|
64 |
+
return x
|
65 |
+
|
66 |
+
|
67 |
+
class EncoderEpilogue(torch.nn.Module):
|
68 |
+
def __init__(self,
|
69 |
+
in_channels, # Number of input channels.
|
70 |
+
cmap_dim, # Dimensionality of mapped conditioning label, 0 = no label.
|
71 |
+
z_dim, # Output Latent (Z) dimensionality.
|
72 |
+
resolution, # Resolution of this block.
|
73 |
+
img_channels, # Number of input color channels.
|
74 |
+
architecture='resnet', # Architecture: 'orig', 'skip', 'resnet'.
|
75 |
+
mbstd_group_size=4, # Group size for the minibatch standard deviation layer, None = entire minibatch.
|
76 |
+
mbstd_num_channels=1, # Number of features for the minibatch standard deviation layer, 0 = disable.
|
77 |
+
activation='lrelu', # Activation function: 'relu', 'lrelu', etc.
|
78 |
+
conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
|
79 |
+
):
|
80 |
+
assert architecture in ['orig', 'skip', 'resnet']
|
81 |
+
super().__init__()
|
82 |
+
self.in_channels = in_channels
|
83 |
+
self.cmap_dim = cmap_dim
|
84 |
+
self.resolution = resolution
|
85 |
+
self.img_channels = img_channels
|
86 |
+
self.architecture = architecture
|
87 |
+
|
88 |
+
if architecture == 'skip':
|
89 |
+
self.fromrgb = Conv2dLayer(self.img_channels, in_channels, kernel_size=1, activation=activation)
|
90 |
+
self.mbstd = MinibatchStdLayer(group_size=mbstd_group_size,
|
91 |
+
num_channels=mbstd_num_channels) if mbstd_num_channels > 0 else None
|
92 |
+
self.conv = Conv2dLayer(in_channels + mbstd_num_channels, in_channels, kernel_size=3, activation=activation,
|
93 |
+
conv_clamp=conv_clamp)
|
94 |
+
self.fc = FullyConnectedLayer(in_channels * (resolution ** 2), z_dim, activation=activation)
|
95 |
+
self.dropout = torch.nn.Dropout(p=0.5)
|
96 |
+
|
97 |
+
def forward(self, x, cmap, force_fp32=False):
|
98 |
+
_ = force_fp32 # unused
|
99 |
+
dtype = torch.float32
|
100 |
+
memory_format = torch.contiguous_format
|
101 |
+
|
102 |
+
# FromRGB.
|
103 |
+
x = x.to(dtype=dtype, memory_format=memory_format)
|
104 |
+
|
105 |
+
# Main layers.
|
106 |
+
if self.mbstd is not None:
|
107 |
+
x = self.mbstd(x)
|
108 |
+
const_e = self.conv(x)
|
109 |
+
x = self.fc(const_e.flatten(1))
|
110 |
+
x = self.dropout(x)
|
111 |
+
|
112 |
+
# Conditioning.
|
113 |
+
if self.cmap_dim > 0:
|
114 |
+
x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
|
115 |
+
|
116 |
+
assert x.dtype == dtype
|
117 |
+
return x, const_e
|
118 |
+
|
119 |
+
|
120 |
+
class EncoderBlock(torch.nn.Module):
|
121 |
+
def __init__(self,
|
122 |
+
in_channels, # Number of input channels, 0 = first block.
|
123 |
+
tmp_channels, # Number of intermediate channels.
|
124 |
+
out_channels, # Number of output channels.
|
125 |
+
resolution, # Resolution of this block.
|
126 |
+
img_channels, # Number of input color channels.
|
127 |
+
first_layer_idx, # Index of the first layer.
|
128 |
+
architecture='skip', # Architecture: 'orig', 'skip', 'resnet'.
|
129 |
+
activation='lrelu', # Activation function: 'relu', 'lrelu', etc.
|
130 |
+
resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations.
|
131 |
+
conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
|
132 |
+
use_fp16=False, # Use FP16 for this block?
|
133 |
+
fp16_channels_last=False, # Use channels-last memory format with FP16?
|
134 |
+
freeze_layers=0, # Freeze-D: Number of layers to freeze.
|
135 |
+
):
|
136 |
+
assert in_channels in [0, tmp_channels]
|
137 |
+
assert architecture in ['orig', 'skip', 'resnet']
|
138 |
+
super().__init__()
|
139 |
+
self.in_channels = in_channels
|
140 |
+
self.resolution = resolution
|
141 |
+
self.img_channels = img_channels + 1
|
142 |
+
self.first_layer_idx = first_layer_idx
|
143 |
+
self.architecture = architecture
|
144 |
+
self.use_fp16 = use_fp16
|
145 |
+
self.channels_last = (use_fp16 and fp16_channels_last)
|
146 |
+
self.register_buffer('resample_filter', setup_filter(resample_filter))
|
147 |
+
|
148 |
+
self.num_layers = 0
|
149 |
+
|
150 |
+
def trainable_gen():
|
151 |
+
while True:
|
152 |
+
layer_idx = self.first_layer_idx + self.num_layers
|
153 |
+
trainable = (layer_idx >= freeze_layers)
|
154 |
+
self.num_layers += 1
|
155 |
+
yield trainable
|
156 |
+
|
157 |
+
trainable_iter = trainable_gen()
|
158 |
+
|
159 |
+
if in_channels == 0:
|
160 |
+
self.fromrgb = Conv2dLayer(self.img_channels, tmp_channels, kernel_size=1, activation=activation,
|
161 |
+
trainable=next(trainable_iter), conv_clamp=conv_clamp,
|
162 |
+
channels_last=self.channels_last)
|
163 |
+
|
164 |
+
self.conv0 = Conv2dLayer(tmp_channels, tmp_channels, kernel_size=3, activation=activation,
|
165 |
+
trainable=next(trainable_iter), conv_clamp=conv_clamp,
|
166 |
+
channels_last=self.channels_last)
|
167 |
+
|
168 |
+
self.conv1 = Conv2dLayer(tmp_channels, out_channels, kernel_size=3, activation=activation, down=2,
|
169 |
+
trainable=next(trainable_iter), resample_filter=resample_filter, conv_clamp=conv_clamp,
|
170 |
+
channels_last=self.channels_last)
|
171 |
+
|
172 |
+
if architecture == 'resnet':
|
173 |
+
self.skip = Conv2dLayer(tmp_channels, out_channels, kernel_size=1, bias=False, down=2,
|
174 |
+
trainable=next(trainable_iter), resample_filter=resample_filter,
|
175 |
+
channels_last=self.channels_last)
|
176 |
+
|
177 |
+
def forward(self, x, img, force_fp32=False):
|
178 |
+
# dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
|
179 |
+
dtype = torch.float32
|
180 |
+
memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
|
181 |
+
|
182 |
+
# Input.
|
183 |
+
if x is not None:
|
184 |
+
x = x.to(dtype=dtype, memory_format=memory_format)
|
185 |
+
|
186 |
+
# FromRGB.
|
187 |
+
if self.in_channels == 0:
|
188 |
+
img = img.to(dtype=dtype, memory_format=memory_format)
|
189 |
+
y = self.fromrgb(img)
|
190 |
+
x = x + y if x is not None else y
|
191 |
+
img = downsample2d(img, self.resample_filter) if self.architecture == 'skip' else None
|
192 |
+
|
193 |
+
# Main layers.
|
194 |
+
if self.architecture == 'resnet':
|
195 |
+
y = self.skip(x, gain=np.sqrt(0.5))
|
196 |
+
x = self.conv0(x)
|
197 |
+
feat = x.clone()
|
198 |
+
x = self.conv1(x, gain=np.sqrt(0.5))
|
199 |
+
x = y.add_(x)
|
200 |
+
else:
|
201 |
+
x = self.conv0(x)
|
202 |
+
feat = x.clone()
|
203 |
+
x = self.conv1(x)
|
204 |
+
|
205 |
+
assert x.dtype == dtype
|
206 |
+
return x, img, feat
|
207 |
+
|
208 |
+
|
209 |
+
class EncoderNetwork(torch.nn.Module):
|
210 |
+
def __init__(self,
|
211 |
+
c_dim, # Conditioning label (C) dimensionality.
|
212 |
+
z_dim, # Input latent (Z) dimensionality.
|
213 |
+
img_resolution, # Input resolution.
|
214 |
+
img_channels, # Number of input color channels.
|
215 |
+
architecture='orig', # Architecture: 'orig', 'skip', 'resnet'.
|
216 |
+
channel_base=16384, # Overall multiplier for the number of channels.
|
217 |
+
channel_max=512, # Maximum number of channels in any layer.
|
218 |
+
num_fp16_res=0, # Use FP16 for the N highest resolutions.
|
219 |
+
conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
|
220 |
+
cmap_dim=None, # Dimensionality of mapped conditioning label, None = default.
|
221 |
+
block_kwargs={}, # Arguments for DiscriminatorBlock.
|
222 |
+
mapping_kwargs={}, # Arguments for MappingNetwork.
|
223 |
+
epilogue_kwargs={}, # Arguments for EncoderEpilogue.
|
224 |
+
):
|
225 |
+
super().__init__()
|
226 |
+
self.c_dim = c_dim
|
227 |
+
self.z_dim = z_dim
|
228 |
+
self.img_resolution = img_resolution
|
229 |
+
self.img_resolution_log2 = int(np.log2(img_resolution))
|
230 |
+
self.img_channels = img_channels
|
231 |
+
self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)]
|
232 |
+
channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]}
|
233 |
+
fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
|
234 |
+
|
235 |
+
if cmap_dim is None:
|
236 |
+
cmap_dim = channels_dict[4]
|
237 |
+
if c_dim == 0:
|
238 |
+
cmap_dim = 0
|
239 |
+
|
240 |
+
common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp)
|
241 |
+
cur_layer_idx = 0
|
242 |
+
for res in self.block_resolutions:
|
243 |
+
in_channels = channels_dict[res] if res < img_resolution else 0
|
244 |
+
tmp_channels = channels_dict[res]
|
245 |
+
out_channels = channels_dict[res // 2]
|
246 |
+
use_fp16 = (res >= fp16_resolution)
|
247 |
+
use_fp16 = False
|
248 |
+
block = EncoderBlock(in_channels, tmp_channels, out_channels, resolution=res,
|
249 |
+
first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs)
|
250 |
+
setattr(self, f'b{res}', block)
|
251 |
+
cur_layer_idx += block.num_layers
|
252 |
+
if c_dim > 0:
|
253 |
+
self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None,
|
254 |
+
**mapping_kwargs)
|
255 |
+
self.b4 = EncoderEpilogue(channels_dict[4], cmap_dim=cmap_dim, z_dim=z_dim * 2, resolution=4, **epilogue_kwargs,
|
256 |
+
**common_kwargs)
|
257 |
+
|
258 |
+
def forward(self, img, c, **block_kwargs):
|
259 |
+
x = None
|
260 |
+
feats = {}
|
261 |
+
for res in self.block_resolutions:
|
262 |
+
block = getattr(self, f'b{res}')
|
263 |
+
x, img, feat = block(x, img, **block_kwargs)
|
264 |
+
feats[res] = feat
|
265 |
+
|
266 |
+
cmap = None
|
267 |
+
if self.c_dim > 0:
|
268 |
+
cmap = self.mapping(None, c)
|
269 |
+
x, const_e = self.b4(x, cmap)
|
270 |
+
feats[4] = const_e
|
271 |
+
|
272 |
+
B, _ = x.shape
|
273 |
+
z = torch.zeros((B, self.z_dim), requires_grad=False, dtype=x.dtype,
|
274 |
+
device=x.device) ## Noise for Co-Modulation
|
275 |
+
return x, z, feats
|
276 |
+
|
277 |
+
|
278 |
+
def fma(a, b, c): # => a * b + c
|
279 |
+
return _FusedMultiplyAdd.apply(a, b, c)
|
280 |
+
|
281 |
+
|
282 |
+
class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
|
283 |
+
@staticmethod
|
284 |
+
def forward(ctx, a, b, c): # pylint: disable=arguments-differ
|
285 |
+
out = torch.addcmul(c, a, b)
|
286 |
+
ctx.save_for_backward(a, b)
|
287 |
+
ctx.c_shape = c.shape
|
288 |
+
return out
|
289 |
+
|
290 |
+
@staticmethod
|
291 |
+
def backward(ctx, dout): # pylint: disable=arguments-differ
|
292 |
+
a, b = ctx.saved_tensors
|
293 |
+
c_shape = ctx.c_shape
|
294 |
+
da = None
|
295 |
+
db = None
|
296 |
+
dc = None
|
297 |
+
|
298 |
+
if ctx.needs_input_grad[0]:
|
299 |
+
da = _unbroadcast(dout * b, a.shape)
|
300 |
+
|
301 |
+
if ctx.needs_input_grad[1]:
|
302 |
+
db = _unbroadcast(dout * a, b.shape)
|
303 |
+
|
304 |
+
if ctx.needs_input_grad[2]:
|
305 |
+
dc = _unbroadcast(dout, c_shape)
|
306 |
+
|
307 |
+
return da, db, dc
|
308 |
+
|
309 |
+
|
310 |
+
def _unbroadcast(x, shape):
|
311 |
+
extra_dims = x.ndim - len(shape)
|
312 |
+
assert extra_dims >= 0
|
313 |
+
dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
|
314 |
+
if len(dim):
|
315 |
+
x = x.sum(dim=dim, keepdim=True)
|
316 |
+
if extra_dims:
|
317 |
+
x = x.reshape(-1, *x.shape[extra_dims + 1:])
|
318 |
+
assert x.shape == shape
|
319 |
+
return x
|
320 |
+
|
321 |
+
|
322 |
+
def modulated_conv2d(
|
323 |
+
x, # Input tensor of shape [batch_size, in_channels, in_height, in_width].
|
324 |
+
weight, # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width].
|
325 |
+
styles, # Modulation coefficients of shape [batch_size, in_channels].
|
326 |
+
noise=None, # Optional noise tensor to add to the output activations.
|
327 |
+
up=1, # Integer upsampling factor.
|
328 |
+
down=1, # Integer downsampling factor.
|
329 |
+
padding=0, # Padding with respect to the upsampled image.
|
330 |
+
resample_filter=None,
|
331 |
+
# Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter().
|
332 |
+
demodulate=True, # Apply weight demodulation?
|
333 |
+
flip_weight=True, # False = convolution, True = correlation (matches torch.nn.functional.conv2d).
|
334 |
+
fused_modconv=True, # Perform modulation, convolution, and demodulation as a single fused operation?
|
335 |
+
):
|
336 |
+
batch_size = x.shape[0]
|
337 |
+
out_channels, in_channels, kh, kw = weight.shape
|
338 |
+
|
339 |
+
# Pre-normalize inputs to avoid FP16 overflow.
|
340 |
+
if x.dtype == torch.float16 and demodulate:
|
341 |
+
weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm(float('inf'), dim=[1, 2, 3],
|
342 |
+
keepdim=True)) # max_Ikk
|
343 |
+
styles = styles / styles.norm(float('inf'), dim=1, keepdim=True) # max_I
|
344 |
+
|
345 |
+
# Calculate per-sample weights and demodulation coefficients.
|
346 |
+
w = None
|
347 |
+
dcoefs = None
|
348 |
+
if demodulate or fused_modconv:
|
349 |
+
w = weight.unsqueeze(0) # [NOIkk]
|
350 |
+
w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk]
|
351 |
+
if demodulate:
|
352 |
+
dcoefs = (w.square().sum(dim=[2, 3, 4]) + 1e-8).rsqrt() # [NO]
|
353 |
+
if demodulate and fused_modconv:
|
354 |
+
w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk]
|
355 |
+
# Execute by scaling the activations before and after the convolution.
|
356 |
+
if not fused_modconv:
|
357 |
+
x = x * styles.to(x.dtype).reshape(batch_size, -1, 1, 1)
|
358 |
+
x = conv2d_resample.conv2d_resample(x=x, w=weight.to(x.dtype), f=resample_filter, up=up, down=down,
|
359 |
+
padding=padding, flip_weight=flip_weight)
|
360 |
+
if demodulate and noise is not None:
|
361 |
+
x = fma(x, dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1), noise.to(x.dtype))
|
362 |
+
elif demodulate:
|
363 |
+
x = x * dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1)
|
364 |
+
elif noise is not None:
|
365 |
+
x = x.add_(noise.to(x.dtype))
|
366 |
+
return x
|
367 |
+
|
368 |
+
# Execute as one fused op using grouped convolution.
|
369 |
+
batch_size = int(batch_size)
|
370 |
+
x = x.reshape(1, -1, *x.shape[2:])
|
371 |
+
w = w.reshape(-1, in_channels, kh, kw)
|
372 |
+
x = conv2d_resample(x=x, w=w.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding,
|
373 |
+
groups=batch_size, flip_weight=flip_weight)
|
374 |
+
x = x.reshape(batch_size, -1, *x.shape[2:])
|
375 |
+
if noise is not None:
|
376 |
+
x = x.add_(noise)
|
377 |
+
return x
|
378 |
+
|
379 |
+
|
380 |
+
class SynthesisLayer(torch.nn.Module):
|
381 |
+
def __init__(self,
|
382 |
+
in_channels, # Number of input channels.
|
383 |
+
out_channels, # Number of output channels.
|
384 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
385 |
+
resolution, # Resolution of this layer.
|
386 |
+
kernel_size=3, # Convolution kernel size.
|
387 |
+
up=1, # Integer upsampling factor.
|
388 |
+
use_noise=True, # Enable noise input?
|
389 |
+
activation='lrelu', # Activation function: 'relu', 'lrelu', etc.
|
390 |
+
resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations.
|
391 |
+
conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
|
392 |
+
channels_last=False, # Use channels_last format for the weights?
|
393 |
+
):
|
394 |
+
super().__init__()
|
395 |
+
self.resolution = resolution
|
396 |
+
self.up = up
|
397 |
+
self.use_noise = use_noise
|
398 |
+
self.activation = activation
|
399 |
+
self.conv_clamp = conv_clamp
|
400 |
+
self.register_buffer('resample_filter', setup_filter(resample_filter))
|
401 |
+
self.padding = kernel_size // 2
|
402 |
+
self.act_gain = activation_funcs[activation].def_gain
|
403 |
+
|
404 |
+
self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
|
405 |
+
memory_format = torch.channels_last if channels_last else torch.contiguous_format
|
406 |
+
self.weight = torch.nn.Parameter(
|
407 |
+
torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format))
|
408 |
+
if use_noise:
|
409 |
+
self.register_buffer('noise_const', torch.randn([resolution, resolution]))
|
410 |
+
self.noise_strength = torch.nn.Parameter(torch.zeros([]))
|
411 |
+
self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
|
412 |
+
|
413 |
+
def forward(self, x, w, noise_mode='none', fused_modconv=True, gain=1):
|
414 |
+
assert noise_mode in ['random', 'const', 'none']
|
415 |
+
in_resolution = self.resolution // self.up
|
416 |
+
styles = self.affine(w)
|
417 |
+
|
418 |
+
noise = None
|
419 |
+
if self.use_noise and noise_mode == 'random':
|
420 |
+
noise = torch.randn([x.shape[0], 1, self.resolution, self.resolution],
|
421 |
+
device=x.device) * self.noise_strength
|
422 |
+
if self.use_noise and noise_mode == 'const':
|
423 |
+
noise = self.noise_const * self.noise_strength
|
424 |
+
|
425 |
+
flip_weight = (self.up == 1) # slightly faster
|
426 |
+
x = modulated_conv2d(x=x, weight=self.weight, styles=styles, noise=noise, up=self.up,
|
427 |
+
padding=self.padding, resample_filter=self.resample_filter, flip_weight=flip_weight,
|
428 |
+
fused_modconv=fused_modconv)
|
429 |
+
|
430 |
+
act_gain = self.act_gain * gain
|
431 |
+
act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
|
432 |
+
x = F.leaky_relu(x, negative_slope=0.2, inplace=False)
|
433 |
+
if act_gain != 1:
|
434 |
+
x = x * act_gain
|
435 |
+
if act_clamp is not None:
|
436 |
+
x = x.clamp(-act_clamp, act_clamp)
|
437 |
+
return x
|
438 |
+
|
439 |
+
|
440 |
+
class ToRGBLayer(torch.nn.Module):
|
441 |
+
def __init__(self, in_channels, out_channels, w_dim, kernel_size=1, conv_clamp=None, channels_last=False):
|
442 |
+
super().__init__()
|
443 |
+
self.conv_clamp = conv_clamp
|
444 |
+
self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
|
445 |
+
memory_format = torch.channels_last if channels_last else torch.contiguous_format
|
446 |
+
self.weight = torch.nn.Parameter(
|
447 |
+
torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format))
|
448 |
+
self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
|
449 |
+
self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2))
|
450 |
+
|
451 |
+
def forward(self, x, w, fused_modconv=True):
|
452 |
+
styles = self.affine(w) * self.weight_gain
|
453 |
+
x = modulated_conv2d(x=x, weight=self.weight, styles=styles, demodulate=False, fused_modconv=fused_modconv)
|
454 |
+
x = bias_act(x, self.bias.to(x.dtype), clamp=self.conv_clamp)
|
455 |
+
return x
|
456 |
+
|
457 |
+
|
458 |
+
class SynthesisForeword(torch.nn.Module):
|
459 |
+
def __init__(self,
|
460 |
+
z_dim, # Output Latent (Z) dimensionality.
|
461 |
+
resolution, # Resolution of this block.
|
462 |
+
in_channels,
|
463 |
+
img_channels, # Number of input color channels.
|
464 |
+
architecture='skip', # Architecture: 'orig', 'skip', 'resnet'.
|
465 |
+
activation='lrelu', # Activation function: 'relu', 'lrelu', etc.
|
466 |
+
|
467 |
+
):
|
468 |
+
super().__init__()
|
469 |
+
self.in_channels = in_channels
|
470 |
+
self.z_dim = z_dim
|
471 |
+
self.resolution = resolution
|
472 |
+
self.img_channels = img_channels
|
473 |
+
self.architecture = architecture
|
474 |
+
|
475 |
+
self.fc = FullyConnectedLayer(self.z_dim, (self.z_dim // 2) * 4 * 4, activation=activation)
|
476 |
+
self.conv = SynthesisLayer(self.in_channels, self.in_channels, w_dim=(z_dim // 2) * 3, resolution=4)
|
477 |
+
|
478 |
+
if architecture == 'skip':
|
479 |
+
self.torgb = ToRGBLayer(self.in_channels, self.img_channels, kernel_size=1, w_dim=(z_dim // 2) * 3)
|
480 |
+
|
481 |
+
def forward(self, x, ws, feats, img, force_fp32=False):
|
482 |
+
_ = force_fp32 # unused
|
483 |
+
dtype = torch.float32
|
484 |
+
memory_format = torch.contiguous_format
|
485 |
+
|
486 |
+
x_global = x.clone()
|
487 |
+
# ToRGB.
|
488 |
+
x = self.fc(x)
|
489 |
+
x = x.view(-1, self.z_dim // 2, 4, 4)
|
490 |
+
x = x.to(dtype=dtype, memory_format=memory_format)
|
491 |
+
|
492 |
+
# Main layers.
|
493 |
+
x_skip = feats[4].clone()
|
494 |
+
x = x + x_skip
|
495 |
+
|
496 |
+
mod_vector = []
|
497 |
+
mod_vector.append(ws[:, 0])
|
498 |
+
mod_vector.append(x_global.clone())
|
499 |
+
mod_vector = torch.cat(mod_vector, dim=1)
|
500 |
+
|
501 |
+
x = self.conv(x, mod_vector)
|
502 |
+
|
503 |
+
mod_vector = []
|
504 |
+
mod_vector.append(ws[:, 2 * 2 - 3])
|
505 |
+
mod_vector.append(x_global.clone())
|
506 |
+
mod_vector = torch.cat(mod_vector, dim=1)
|
507 |
+
|
508 |
+
if self.architecture == 'skip':
|
509 |
+
img = self.torgb(x, mod_vector)
|
510 |
+
img = img.to(dtype=torch.float32, memory_format=torch.contiguous_format)
|
511 |
+
|
512 |
+
assert x.dtype == dtype
|
513 |
+
return x, img
|
514 |
+
|
515 |
+
|
516 |
+
class SELayer(nn.Module):
|
517 |
+
def __init__(self, channel, reduction=16):
|
518 |
+
super(SELayer, self).__init__()
|
519 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
520 |
+
self.fc = nn.Sequential(
|
521 |
+
nn.Linear(channel, channel // reduction, bias=False),
|
522 |
+
nn.ReLU(inplace=False),
|
523 |
+
nn.Linear(channel // reduction, channel, bias=False),
|
524 |
+
nn.Sigmoid()
|
525 |
+
)
|
526 |
+
|
527 |
+
def forward(self, x):
|
528 |
+
b, c, _, _ = x.size()
|
529 |
+
y = self.avg_pool(x).view(b, c)
|
530 |
+
y = self.fc(y).view(b, c, 1, 1)
|
531 |
+
res = x * y.expand_as(x)
|
532 |
+
return res
|
533 |
+
|
534 |
+
|
535 |
+
class FourierUnit(nn.Module):
|
536 |
+
|
537 |
+
def __init__(self, in_channels, out_channels, groups=1, spatial_scale_factor=None, spatial_scale_mode='bilinear',
|
538 |
+
spectral_pos_encoding=False, use_se=False, se_kwargs=None, ffc3d=False, fft_norm='ortho'):
|
539 |
+
# bn_layer not used
|
540 |
+
super(FourierUnit, self).__init__()
|
541 |
+
self.groups = groups
|
542 |
+
|
543 |
+
self.conv_layer = torch.nn.Conv2d(in_channels=in_channels * 2 + (2 if spectral_pos_encoding else 0),
|
544 |
+
out_channels=out_channels * 2,
|
545 |
+
kernel_size=1, stride=1, padding=0, groups=self.groups, bias=False)
|
546 |
+
self.relu = torch.nn.ReLU(inplace=False)
|
547 |
+
|
548 |
+
# squeeze and excitation block
|
549 |
+
self.use_se = use_se
|
550 |
+
if use_se:
|
551 |
+
if se_kwargs is None:
|
552 |
+
se_kwargs = {}
|
553 |
+
self.se = SELayer(self.conv_layer.in_channels, **se_kwargs)
|
554 |
+
|
555 |
+
self.spatial_scale_factor = spatial_scale_factor
|
556 |
+
self.spatial_scale_mode = spatial_scale_mode
|
557 |
+
self.spectral_pos_encoding = spectral_pos_encoding
|
558 |
+
self.ffc3d = ffc3d
|
559 |
+
self.fft_norm = fft_norm
|
560 |
+
|
561 |
+
def forward(self, x):
|
562 |
+
batch = x.shape[0]
|
563 |
+
|
564 |
+
if self.spatial_scale_factor is not None:
|
565 |
+
orig_size = x.shape[-2:]
|
566 |
+
x = F.interpolate(x, scale_factor=self.spatial_scale_factor, mode=self.spatial_scale_mode,
|
567 |
+
align_corners=False)
|
568 |
+
|
569 |
+
r_size = x.size()
|
570 |
+
# (batch, c, h, w/2+1, 2)
|
571 |
+
fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1)
|
572 |
+
ffted = fft.rfftn(x, dim=fft_dim, norm=self.fft_norm)
|
573 |
+
ffted = torch.stack((ffted.real, ffted.imag), dim=-1)
|
574 |
+
ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1)
|
575 |
+
ffted = ffted.view((batch, -1,) + ffted.size()[3:])
|
576 |
+
|
577 |
+
if self.spectral_pos_encoding:
|
578 |
+
height, width = ffted.shape[-2:]
|
579 |
+
coords_vert = torch.linspace(0, 1, height)[None, None, :, None].expand(batch, 1, height, width).to(ffted)
|
580 |
+
coords_hor = torch.linspace(0, 1, width)[None, None, None, :].expand(batch, 1, height, width).to(ffted)
|
581 |
+
ffted = torch.cat((coords_vert, coords_hor, ffted), dim=1)
|
582 |
+
|
583 |
+
if self.use_se:
|
584 |
+
ffted = self.se(ffted)
|
585 |
+
|
586 |
+
ffted = self.conv_layer(ffted) # (batch, c*2, h, w/2+1)
|
587 |
+
ffted = self.relu(ffted)
|
588 |
+
|
589 |
+
ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute(
|
590 |
+
0, 1, 3, 4, 2).contiguous() # (batch,c, t, h, w/2+1, 2)
|
591 |
+
ffted = torch.complex(ffted[..., 0], ffted[..., 1])
|
592 |
+
|
593 |
+
ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:]
|
594 |
+
output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm)
|
595 |
+
|
596 |
+
if self.spatial_scale_factor is not None:
|
597 |
+
output = F.interpolate(output, size=orig_size, mode=self.spatial_scale_mode, align_corners=False)
|
598 |
+
|
599 |
+
return output
|
600 |
+
|
601 |
+
|
602 |
+
class SpectralTransform(nn.Module):
|
603 |
+
|
604 |
+
def __init__(self, in_channels, out_channels, stride=1, groups=1, enable_lfu=True, **fu_kwargs):
|
605 |
+
# bn_layer not used
|
606 |
+
super(SpectralTransform, self).__init__()
|
607 |
+
self.enable_lfu = enable_lfu
|
608 |
+
if stride == 2:
|
609 |
+
self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
|
610 |
+
else:
|
611 |
+
self.downsample = nn.Identity()
|
612 |
+
|
613 |
+
self.stride = stride
|
614 |
+
self.conv1 = nn.Sequential(
|
615 |
+
nn.Conv2d(in_channels, out_channels //
|
616 |
+
2, kernel_size=1, groups=groups, bias=False),
|
617 |
+
# nn.BatchNorm2d(out_channels // 2),
|
618 |
+
nn.ReLU(inplace=True)
|
619 |
+
)
|
620 |
+
self.fu = FourierUnit(
|
621 |
+
out_channels // 2, out_channels // 2, groups, **fu_kwargs)
|
622 |
+
if self.enable_lfu:
|
623 |
+
self.lfu = FourierUnit(
|
624 |
+
out_channels // 2, out_channels // 2, groups)
|
625 |
+
self.conv2 = torch.nn.Conv2d(
|
626 |
+
out_channels // 2, out_channels, kernel_size=1, groups=groups, bias=False)
|
627 |
+
|
628 |
+
def forward(self, x):
|
629 |
+
|
630 |
+
x = self.downsample(x)
|
631 |
+
x = self.conv1(x)
|
632 |
+
output = self.fu(x)
|
633 |
+
|
634 |
+
if self.enable_lfu:
|
635 |
+
n, c, h, w = x.shape
|
636 |
+
split_no = 2
|
637 |
+
split_s = h // split_no
|
638 |
+
xs = torch.cat(torch.split(
|
639 |
+
x[:, :c // 4], split_s, dim=-2), dim=1).contiguous()
|
640 |
+
xs = torch.cat(torch.split(xs, split_s, dim=-1),
|
641 |
+
dim=1).contiguous()
|
642 |
+
xs = self.lfu(xs)
|
643 |
+
xs = xs.repeat(1, 1, split_no, split_no).contiguous()
|
644 |
+
else:
|
645 |
+
xs = 0
|
646 |
+
|
647 |
+
output = self.conv2(x + output + xs)
|
648 |
+
|
649 |
+
return output
|
650 |
+
|
651 |
+
|
652 |
+
class FFC(nn.Module):
|
653 |
+
|
654 |
+
def __init__(self, in_channels, out_channels, kernel_size,
|
655 |
+
ratio_gin, ratio_gout, stride=1, padding=0,
|
656 |
+
dilation=1, groups=1, bias=False, enable_lfu=True,
|
657 |
+
padding_type='reflect', gated=False, **spectral_kwargs):
|
658 |
+
super(FFC, self).__init__()
|
659 |
+
|
660 |
+
assert stride == 1 or stride == 2, "Stride should be 1 or 2."
|
661 |
+
self.stride = stride
|
662 |
+
|
663 |
+
in_cg = int(in_channels * ratio_gin)
|
664 |
+
in_cl = in_channels - in_cg
|
665 |
+
out_cg = int(out_channels * ratio_gout)
|
666 |
+
out_cl = out_channels - out_cg
|
667 |
+
# groups_g = 1 if groups == 1 else int(groups * ratio_gout)
|
668 |
+
# groups_l = 1 if groups == 1 else groups - groups_g
|
669 |
+
|
670 |
+
self.ratio_gin = ratio_gin
|
671 |
+
self.ratio_gout = ratio_gout
|
672 |
+
self.global_in_num = in_cg
|
673 |
+
|
674 |
+
module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d
|
675 |
+
self.convl2l = module(in_cl, out_cl, kernel_size,
|
676 |
+
stride, padding, dilation, groups, bias, padding_mode=padding_type)
|
677 |
+
module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d
|
678 |
+
self.convl2g = module(in_cl, out_cg, kernel_size,
|
679 |
+
stride, padding, dilation, groups, bias, padding_mode=padding_type)
|
680 |
+
module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d
|
681 |
+
self.convg2l = module(in_cg, out_cl, kernel_size,
|
682 |
+
stride, padding, dilation, groups, bias, padding_mode=padding_type)
|
683 |
+
module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform
|
684 |
+
self.convg2g = module(
|
685 |
+
in_cg, out_cg, stride, 1 if groups == 1 else groups // 2, enable_lfu, **spectral_kwargs)
|
686 |
+
|
687 |
+
self.gated = gated
|
688 |
+
module = nn.Identity if in_cg == 0 or out_cl == 0 or not self.gated else nn.Conv2d
|
689 |
+
self.gate = module(in_channels, 2, 1)
|
690 |
+
|
691 |
+
def forward(self, x, fname=None):
|
692 |
+
x_l, x_g = x if type(x) is tuple else (x, 0)
|
693 |
+
out_xl, out_xg = 0, 0
|
694 |
+
|
695 |
+
if self.gated:
|
696 |
+
total_input_parts = [x_l]
|
697 |
+
if torch.is_tensor(x_g):
|
698 |
+
total_input_parts.append(x_g)
|
699 |
+
total_input = torch.cat(total_input_parts, dim=1)
|
700 |
+
|
701 |
+
gates = torch.sigmoid(self.gate(total_input))
|
702 |
+
g2l_gate, l2g_gate = gates.chunk(2, dim=1)
|
703 |
+
else:
|
704 |
+
g2l_gate, l2g_gate = 1, 1
|
705 |
+
|
706 |
+
spec_x = self.convg2g(x_g)
|
707 |
+
|
708 |
+
if self.ratio_gout != 1:
|
709 |
+
out_xl = self.convl2l(x_l) + self.convg2l(x_g) * g2l_gate
|
710 |
+
if self.ratio_gout != 0:
|
711 |
+
out_xg = self.convl2g(x_l) * l2g_gate + spec_x
|
712 |
+
|
713 |
+
return out_xl, out_xg
|
714 |
+
|
715 |
+
|
716 |
+
class FFC_BN_ACT(nn.Module):
|
717 |
+
|
718 |
+
def __init__(self, in_channels, out_channels,
|
719 |
+
kernel_size, ratio_gin, ratio_gout,
|
720 |
+
stride=1, padding=0, dilation=1, groups=1, bias=False,
|
721 |
+
norm_layer=nn.SyncBatchNorm, activation_layer=nn.Identity,
|
722 |
+
padding_type='reflect',
|
723 |
+
enable_lfu=True, **kwargs):
|
724 |
+
super(FFC_BN_ACT, self).__init__()
|
725 |
+
self.ffc = FFC(in_channels, out_channels, kernel_size,
|
726 |
+
ratio_gin, ratio_gout, stride, padding, dilation,
|
727 |
+
groups, bias, enable_lfu, padding_type=padding_type, **kwargs)
|
728 |
+
lnorm = nn.Identity if ratio_gout == 1 else norm_layer
|
729 |
+
gnorm = nn.Identity if ratio_gout == 0 else norm_layer
|
730 |
+
global_channels = int(out_channels * ratio_gout)
|
731 |
+
# self.bn_l = lnorm(out_channels - global_channels)
|
732 |
+
# self.bn_g = gnorm(global_channels)
|
733 |
+
|
734 |
+
lact = nn.Identity if ratio_gout == 1 else activation_layer
|
735 |
+
gact = nn.Identity if ratio_gout == 0 else activation_layer
|
736 |
+
self.act_l = lact(inplace=True)
|
737 |
+
self.act_g = gact(inplace=True)
|
738 |
+
|
739 |
+
def forward(self, x, fname=None):
|
740 |
+
x_l, x_g = self.ffc(x, fname=fname, )
|
741 |
+
x_l = self.act_l(x_l)
|
742 |
+
x_g = self.act_g(x_g)
|
743 |
+
return x_l, x_g
|
744 |
+
|
745 |
+
|
746 |
+
class FFCResnetBlock(nn.Module):
|
747 |
+
def __init__(self, dim, padding_type, norm_layer, activation_layer=nn.ReLU, dilation=1,
|
748 |
+
spatial_transform_kwargs=None, inline=False, ratio_gin=0.75, ratio_gout=0.75):
|
749 |
+
super().__init__()
|
750 |
+
self.conv1 = FFC_BN_ACT(dim, dim, kernel_size=3, padding=dilation, dilation=dilation,
|
751 |
+
norm_layer=norm_layer,
|
752 |
+
activation_layer=activation_layer,
|
753 |
+
padding_type=padding_type,
|
754 |
+
ratio_gin=ratio_gin, ratio_gout=ratio_gout)
|
755 |
+
self.conv2 = FFC_BN_ACT(dim, dim, kernel_size=3, padding=dilation, dilation=dilation,
|
756 |
+
norm_layer=norm_layer,
|
757 |
+
activation_layer=activation_layer,
|
758 |
+
padding_type=padding_type,
|
759 |
+
ratio_gin=ratio_gin, ratio_gout=ratio_gout)
|
760 |
+
self.inline = inline
|
761 |
+
|
762 |
+
def forward(self, x, fname=None):
|
763 |
+
if self.inline:
|
764 |
+
x_l, x_g = x[:, :-self.conv1.ffc.global_in_num], x[:, -self.conv1.ffc.global_in_num:]
|
765 |
+
else:
|
766 |
+
x_l, x_g = x if type(x) is tuple else (x, 0)
|
767 |
+
|
768 |
+
id_l, id_g = x_l, x_g
|
769 |
+
|
770 |
+
x_l, x_g = self.conv1((x_l, x_g), fname=fname)
|
771 |
+
x_l, x_g = self.conv2((x_l, x_g), fname=fname)
|
772 |
+
|
773 |
+
x_l, x_g = id_l + x_l, id_g + x_g
|
774 |
+
out = x_l, x_g
|
775 |
+
if self.inline:
|
776 |
+
out = torch.cat(out, dim=1)
|
777 |
+
return out
|
778 |
+
|
779 |
+
|
780 |
+
class ConcatTupleLayer(nn.Module):
|
781 |
+
def forward(self, x):
|
782 |
+
assert isinstance(x, tuple)
|
783 |
+
x_l, x_g = x
|
784 |
+
assert torch.is_tensor(x_l) or torch.is_tensor(x_g)
|
785 |
+
if not torch.is_tensor(x_g):
|
786 |
+
return x_l
|
787 |
+
return torch.cat(x, dim=1)
|
788 |
+
|
789 |
+
|
790 |
+
class FFCBlock(torch.nn.Module):
|
791 |
+
def __init__(self,
|
792 |
+
dim, # Number of output/input channels.
|
793 |
+
kernel_size, # Width and height of the convolution kernel.
|
794 |
+
padding,
|
795 |
+
ratio_gin=0.75,
|
796 |
+
ratio_gout=0.75,
|
797 |
+
activation='linear', # Activation function: 'relu', 'lrelu', etc.
|
798 |
+
):
|
799 |
+
super().__init__()
|
800 |
+
if activation == 'linear':
|
801 |
+
self.activation = nn.Identity
|
802 |
+
else:
|
803 |
+
self.activation = nn.ReLU
|
804 |
+
self.padding = padding
|
805 |
+
self.kernel_size = kernel_size
|
806 |
+
self.ffc_block = FFCResnetBlock(dim=dim,
|
807 |
+
padding_type='reflect',
|
808 |
+
norm_layer=nn.SyncBatchNorm,
|
809 |
+
activation_layer=self.activation,
|
810 |
+
dilation=1,
|
811 |
+
ratio_gin=ratio_gin,
|
812 |
+
ratio_gout=ratio_gout)
|
813 |
+
|
814 |
+
self.concat_layer = ConcatTupleLayer()
|
815 |
+
|
816 |
+
def forward(self, gen_ft, mask, fname=None):
|
817 |
+
x = gen_ft.float()
|
818 |
+
|
819 |
+
x_l, x_g = x[:, :-self.ffc_block.conv1.ffc.global_in_num], x[:, -self.ffc_block.conv1.ffc.global_in_num:]
|
820 |
+
id_l, id_g = x_l, x_g
|
821 |
+
|
822 |
+
x_l, x_g = self.ffc_block((x_l, x_g), fname=fname)
|
823 |
+
x_l, x_g = id_l + x_l, id_g + x_g
|
824 |
+
x = self.concat_layer((x_l, x_g))
|
825 |
+
|
826 |
+
return x + gen_ft.float()
|
827 |
+
|
828 |
+
|
829 |
+
class FFCSkipLayer(torch.nn.Module):
|
830 |
+
def __init__(self,
|
831 |
+
dim, # Number of input/output channels.
|
832 |
+
kernel_size=3, # Convolution kernel size.
|
833 |
+
ratio_gin=0.75,
|
834 |
+
ratio_gout=0.75,
|
835 |
+
):
|
836 |
+
super().__init__()
|
837 |
+
self.padding = kernel_size // 2
|
838 |
+
|
839 |
+
self.ffc_act = FFCBlock(dim=dim, kernel_size=kernel_size, activation=nn.ReLU,
|
840 |
+
padding=self.padding, ratio_gin=ratio_gin, ratio_gout=ratio_gout)
|
841 |
+
|
842 |
+
def forward(self, gen_ft, mask, fname=None):
|
843 |
+
x = self.ffc_act(gen_ft, mask, fname=fname)
|
844 |
+
return x
|
845 |
+
|
846 |
+
|
847 |
+
class SynthesisBlock(torch.nn.Module):
|
848 |
+
def __init__(self,
|
849 |
+
in_channels, # Number of input channels, 0 = first block.
|
850 |
+
out_channels, # Number of output channels.
|
851 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
852 |
+
resolution, # Resolution of this block.
|
853 |
+
img_channels, # Number of output color channels.
|
854 |
+
is_last, # Is this the last block?
|
855 |
+
architecture='skip', # Architecture: 'orig', 'skip', 'resnet'.
|
856 |
+
resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations.
|
857 |
+
conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
|
858 |
+
use_fp16=False, # Use FP16 for this block?
|
859 |
+
fp16_channels_last=False, # Use channels-last memory format with FP16?
|
860 |
+
**layer_kwargs, # Arguments for SynthesisLayer.
|
861 |
+
):
|
862 |
+
assert architecture in ['orig', 'skip', 'resnet']
|
863 |
+
super().__init__()
|
864 |
+
self.in_channels = in_channels
|
865 |
+
self.w_dim = w_dim
|
866 |
+
self.resolution = resolution
|
867 |
+
self.img_channels = img_channels
|
868 |
+
self.is_last = is_last
|
869 |
+
self.architecture = architecture
|
870 |
+
self.use_fp16 = use_fp16
|
871 |
+
self.channels_last = (use_fp16 and fp16_channels_last)
|
872 |
+
self.register_buffer('resample_filter', setup_filter(resample_filter))
|
873 |
+
self.num_conv = 0
|
874 |
+
self.num_torgb = 0
|
875 |
+
self.res_ffc = {4: 0, 8: 0, 16: 0, 32: 1, 64: 1, 128: 1, 256: 1, 512: 1}
|
876 |
+
|
877 |
+
if in_channels != 0 and resolution >= 8:
|
878 |
+
self.ffc_skip = nn.ModuleList()
|
879 |
+
for _ in range(self.res_ffc[resolution]):
|
880 |
+
self.ffc_skip.append(FFCSkipLayer(dim=out_channels))
|
881 |
+
|
882 |
+
if in_channels == 0:
|
883 |
+
self.const = torch.nn.Parameter(torch.randn([out_channels, resolution, resolution]))
|
884 |
+
|
885 |
+
if in_channels != 0:
|
886 |
+
self.conv0 = SynthesisLayer(in_channels, out_channels, w_dim=w_dim * 3, resolution=resolution, up=2,
|
887 |
+
resample_filter=resample_filter, conv_clamp=conv_clamp,
|
888 |
+
channels_last=self.channels_last, **layer_kwargs)
|
889 |
+
self.num_conv += 1
|
890 |
+
|
891 |
+
self.conv1 = SynthesisLayer(out_channels, out_channels, w_dim=w_dim * 3, resolution=resolution,
|
892 |
+
conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs)
|
893 |
+
self.num_conv += 1
|
894 |
+
|
895 |
+
if is_last or architecture == 'skip':
|
896 |
+
self.torgb = ToRGBLayer(out_channels, img_channels, w_dim=w_dim * 3,
|
897 |
+
conv_clamp=conv_clamp, channels_last=self.channels_last)
|
898 |
+
self.num_torgb += 1
|
899 |
+
|
900 |
+
if in_channels != 0 and architecture == 'resnet':
|
901 |
+
self.skip = Conv2dLayer(in_channels, out_channels, kernel_size=1, bias=False, up=2,
|
902 |
+
resample_filter=resample_filter, channels_last=self.channels_last)
|
903 |
+
|
904 |
+
def forward(self, x, mask, feats, img, ws, fname=None, force_fp32=False, fused_modconv=None, **layer_kwargs):
|
905 |
+
dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
|
906 |
+
dtype = torch.float32
|
907 |
+
memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
|
908 |
+
if fused_modconv is None:
|
909 |
+
fused_modconv = (not self.training) and (dtype == torch.float32 or int(x.shape[0]) == 1)
|
910 |
+
|
911 |
+
x = x.to(dtype=dtype, memory_format=memory_format)
|
912 |
+
x_skip = feats[self.resolution].clone().to(dtype=dtype, memory_format=memory_format)
|
913 |
+
|
914 |
+
# Main layers.
|
915 |
+
if self.in_channels == 0:
|
916 |
+
x = self.conv1(x, ws[1], fused_modconv=fused_modconv, **layer_kwargs)
|
917 |
+
elif self.architecture == 'resnet':
|
918 |
+
y = self.skip(x, gain=np.sqrt(0.5))
|
919 |
+
x = self.conv0(x, ws[0].clone(), fused_modconv=fused_modconv, **layer_kwargs)
|
920 |
+
if len(self.ffc_skip) > 0:
|
921 |
+
mask = F.interpolate(mask, size=x_skip.shape[2:], )
|
922 |
+
z = x + x_skip
|
923 |
+
for fres in self.ffc_skip:
|
924 |
+
z = fres(z, mask)
|
925 |
+
x = x + z
|
926 |
+
else:
|
927 |
+
x = x + x_skip
|
928 |
+
x = self.conv1(x, ws[1].clone(), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs)
|
929 |
+
x = y.add_(x)
|
930 |
+
else:
|
931 |
+
x = self.conv0(x, ws[0].clone(), fused_modconv=fused_modconv, **layer_kwargs)
|
932 |
+
if len(self.ffc_skip) > 0:
|
933 |
+
mask = F.interpolate(mask, size=x_skip.shape[2:], )
|
934 |
+
z = x + x_skip
|
935 |
+
for fres in self.ffc_skip:
|
936 |
+
z = fres(z, mask)
|
937 |
+
x = x + z
|
938 |
+
else:
|
939 |
+
x = x + x_skip
|
940 |
+
x = self.conv1(x, ws[1].clone(), fused_modconv=fused_modconv, **layer_kwargs)
|
941 |
+
# ToRGB.
|
942 |
+
if img is not None:
|
943 |
+
img = upsample2d(img, self.resample_filter)
|
944 |
+
if self.is_last or self.architecture == 'skip':
|
945 |
+
y = self.torgb(x, ws[2].clone(), fused_modconv=fused_modconv)
|
946 |
+
y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format)
|
947 |
+
img = img.add_(y) if img is not None else y
|
948 |
+
|
949 |
+
x = x.to(dtype=dtype)
|
950 |
+
assert x.dtype == dtype
|
951 |
+
assert img is None or img.dtype == torch.float32
|
952 |
+
return x, img
|
953 |
+
|
954 |
+
|
955 |
+
class SynthesisNetwork(torch.nn.Module):
|
956 |
+
def __init__(self,
|
957 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
958 |
+
z_dim, # Output Latent (Z) dimensionality.
|
959 |
+
img_resolution, # Output image resolution.
|
960 |
+
img_channels, # Number of color channels.
|
961 |
+
channel_base=16384, # Overall multiplier for the number of channels.
|
962 |
+
channel_max=512, # Maximum number of channels in any layer.
|
963 |
+
num_fp16_res=0, # Use FP16 for the N highest resolutions.
|
964 |
+
**block_kwargs, # Arguments for SynthesisBlock.
|
965 |
+
):
|
966 |
+
assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0
|
967 |
+
super().__init__()
|
968 |
+
self.w_dim = w_dim
|
969 |
+
self.img_resolution = img_resolution
|
970 |
+
self.img_resolution_log2 = int(np.log2(img_resolution))
|
971 |
+
self.img_channels = img_channels
|
972 |
+
self.block_resolutions = [2 ** i for i in range(3, self.img_resolution_log2 + 1)]
|
973 |
+
channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions}
|
974 |
+
fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
|
975 |
+
|
976 |
+
self.foreword = SynthesisForeword(img_channels=img_channels, in_channels=min(channel_base // 4, channel_max),
|
977 |
+
z_dim=z_dim * 2, resolution=4)
|
978 |
+
|
979 |
+
self.num_ws = self.img_resolution_log2 * 2 - 2
|
980 |
+
for res in self.block_resolutions:
|
981 |
+
if res // 2 in channels_dict.keys():
|
982 |
+
in_channels = channels_dict[res // 2] if res > 4 else 0
|
983 |
+
else:
|
984 |
+
in_channels = min(channel_base // (res // 2), channel_max)
|
985 |
+
out_channels = channels_dict[res]
|
986 |
+
use_fp16 = (res >= fp16_resolution)
|
987 |
+
use_fp16 = False
|
988 |
+
is_last = (res == self.img_resolution)
|
989 |
+
block = SynthesisBlock(in_channels, out_channels, w_dim=w_dim, resolution=res,
|
990 |
+
img_channels=img_channels, is_last=is_last, use_fp16=use_fp16, **block_kwargs)
|
991 |
+
setattr(self, f'b{res}', block)
|
992 |
+
|
993 |
+
def forward(self, x_global, mask, feats, ws, fname=None, **block_kwargs):
|
994 |
+
|
995 |
+
img = None
|
996 |
+
|
997 |
+
x, img = self.foreword(x_global, ws, feats, img)
|
998 |
+
|
999 |
+
for res in self.block_resolutions:
|
1000 |
+
block = getattr(self, f'b{res}')
|
1001 |
+
mod_vector0 = []
|
1002 |
+
mod_vector0.append(ws[:, int(np.log2(res)) * 2 - 5])
|
1003 |
+
mod_vector0.append(x_global.clone())
|
1004 |
+
mod_vector0 = torch.cat(mod_vector0, dim=1)
|
1005 |
+
|
1006 |
+
mod_vector1 = []
|
1007 |
+
mod_vector1.append(ws[:, int(np.log2(res)) * 2 - 4])
|
1008 |
+
mod_vector1.append(x_global.clone())
|
1009 |
+
mod_vector1 = torch.cat(mod_vector1, dim=1)
|
1010 |
+
|
1011 |
+
mod_vector_rgb = []
|
1012 |
+
mod_vector_rgb.append(ws[:, int(np.log2(res)) * 2 - 3])
|
1013 |
+
mod_vector_rgb.append(x_global.clone())
|
1014 |
+
mod_vector_rgb = torch.cat(mod_vector_rgb, dim=1)
|
1015 |
+
x, img = block(x, mask, feats, img, (mod_vector0, mod_vector1, mod_vector_rgb), fname=fname, **block_kwargs)
|
1016 |
+
return img
|
1017 |
+
|
1018 |
+
|
1019 |
+
class MappingNetwork(torch.nn.Module):
|
1020 |
+
def __init__(self,
|
1021 |
+
z_dim, # Input latent (Z) dimensionality, 0 = no latent.
|
1022 |
+
c_dim, # Conditioning label (C) dimensionality, 0 = no label.
|
1023 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
1024 |
+
num_ws, # Number of intermediate latents to output, None = do not broadcast.
|
1025 |
+
num_layers=8, # Number of mapping layers.
|
1026 |
+
embed_features=None, # Label embedding dimensionality, None = same as w_dim.
|
1027 |
+
layer_features=None, # Number of intermediate features in the mapping layers, None = same as w_dim.
|
1028 |
+
activation='lrelu', # Activation function: 'relu', 'lrelu', etc.
|
1029 |
+
lr_multiplier=0.01, # Learning rate multiplier for the mapping layers.
|
1030 |
+
w_avg_beta=0.995, # Decay for tracking the moving average of W during training, None = do not track.
|
1031 |
+
):
|
1032 |
+
super().__init__()
|
1033 |
+
self.z_dim = z_dim
|
1034 |
+
self.c_dim = c_dim
|
1035 |
+
self.w_dim = w_dim
|
1036 |
+
self.num_ws = num_ws
|
1037 |
+
self.num_layers = num_layers
|
1038 |
+
self.w_avg_beta = w_avg_beta
|
1039 |
+
|
1040 |
+
if embed_features is None:
|
1041 |
+
embed_features = w_dim
|
1042 |
+
if c_dim == 0:
|
1043 |
+
embed_features = 0
|
1044 |
+
if layer_features is None:
|
1045 |
+
layer_features = w_dim
|
1046 |
+
features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim]
|
1047 |
+
|
1048 |
+
if c_dim > 0:
|
1049 |
+
self.embed = FullyConnectedLayer(c_dim, embed_features)
|
1050 |
+
for idx in range(num_layers):
|
1051 |
+
in_features = features_list[idx]
|
1052 |
+
out_features = features_list[idx + 1]
|
1053 |
+
layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier)
|
1054 |
+
setattr(self, f'fc{idx}', layer)
|
1055 |
+
|
1056 |
+
if num_ws is not None and w_avg_beta is not None:
|
1057 |
+
self.register_buffer('w_avg', torch.zeros([w_dim]))
|
1058 |
+
|
1059 |
+
def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False):
|
1060 |
+
# Embed, normalize, and concat inputs.
|
1061 |
+
x = None
|
1062 |
+
with torch.autograd.profiler.record_function('input'):
|
1063 |
+
if self.z_dim > 0:
|
1064 |
+
x = normalize_2nd_moment(z.to(torch.float32))
|
1065 |
+
if self.c_dim > 0:
|
1066 |
+
y = normalize_2nd_moment(self.embed(c.to(torch.float32)))
|
1067 |
+
x = torch.cat([x, y], dim=1) if x is not None else y
|
1068 |
+
|
1069 |
+
# Main layers.
|
1070 |
+
for idx in range(self.num_layers):
|
1071 |
+
layer = getattr(self, f'fc{idx}')
|
1072 |
+
x = layer(x)
|
1073 |
+
|
1074 |
+
# Update moving average of W.
|
1075 |
+
if self.w_avg_beta is not None and self.training and not skip_w_avg_update:
|
1076 |
+
with torch.autograd.profiler.record_function('update_w_avg'):
|
1077 |
+
self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))
|
1078 |
+
|
1079 |
+
# Broadcast.
|
1080 |
+
if self.num_ws is not None:
|
1081 |
+
with torch.autograd.profiler.record_function('broadcast'):
|
1082 |
+
x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
|
1083 |
+
|
1084 |
+
# Apply truncation.
|
1085 |
+
if truncation_psi != 1:
|
1086 |
+
with torch.autograd.profiler.record_function('truncate'):
|
1087 |
+
assert self.w_avg_beta is not None
|
1088 |
+
if self.num_ws is None or truncation_cutoff is None:
|
1089 |
+
x = self.w_avg.lerp(x, truncation_psi)
|
1090 |
+
else:
|
1091 |
+
x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi)
|
1092 |
+
return x
|
1093 |
+
|
1094 |
+
|
1095 |
+
class Generator(torch.nn.Module):
|
1096 |
+
def __init__(self,
|
1097 |
+
z_dim, # Input latent (Z) dimensionality.
|
1098 |
+
c_dim, # Conditioning label (C) dimensionality.
|
1099 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
1100 |
+
img_resolution, # Output resolution.
|
1101 |
+
img_channels, # Number of output color channels.
|
1102 |
+
encoder_kwargs={}, # Arguments for EncoderNetwork.
|
1103 |
+
mapping_kwargs={}, # Arguments for MappingNetwork.
|
1104 |
+
synthesis_kwargs={}, # Arguments for SynthesisNetwork.
|
1105 |
+
):
|
1106 |
+
super().__init__()
|
1107 |
+
self.z_dim = z_dim
|
1108 |
+
self.c_dim = c_dim
|
1109 |
+
self.w_dim = w_dim
|
1110 |
+
self.img_resolution = img_resolution
|
1111 |
+
self.img_channels = img_channels
|
1112 |
+
self.encoder = EncoderNetwork(c_dim=c_dim, z_dim=z_dim, img_resolution=img_resolution,
|
1113 |
+
img_channels=img_channels, **encoder_kwargs)
|
1114 |
+
self.synthesis = SynthesisNetwork(z_dim=z_dim, w_dim=w_dim, img_resolution=img_resolution,
|
1115 |
+
img_channels=img_channels, **synthesis_kwargs)
|
1116 |
+
self.num_ws = self.synthesis.num_ws
|
1117 |
+
self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs)
|
1118 |
+
|
1119 |
+
def forward(self, img, c, fname=None, truncation_psi=1, truncation_cutoff=None, **synthesis_kwargs):
|
1120 |
+
mask = img[:, -1].unsqueeze(1)
|
1121 |
+
x_global, z, feats = self.encoder(img, c)
|
1122 |
+
ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff)
|
1123 |
+
img = self.synthesis(x_global, mask, feats, ws, fname=fname, **synthesis_kwargs)
|
1124 |
+
return img
|
1125 |
+
|
1126 |
+
|
1127 |
+
FCF_MODEL_URL = os.environ.get(
|
1128 |
+
"FCF_MODEL_URL",
|
1129 |
+
"https://github.com/Sanster/models/releases/download/add_fcf/places_512_G.pth",
|
1130 |
+
)
|
1131 |
+
|
1132 |
+
|
1133 |
+
class FcF(InpaintModel):
|
1134 |
+
min_size = 512
|
1135 |
+
pad_mod = 512
|
1136 |
+
pad_to_square = True
|
1137 |
+
|
1138 |
+
def init_model(self, device, **kwargs):
|
1139 |
+
seed = 0
|
1140 |
+
random.seed(seed)
|
1141 |
+
np.random.seed(seed)
|
1142 |
+
torch.manual_seed(seed)
|
1143 |
+
torch.cuda.manual_seed_all(seed)
|
1144 |
+
torch.backends.cudnn.deterministic = True
|
1145 |
+
torch.backends.cudnn.benchmark = False
|
1146 |
+
|
1147 |
+
kwargs = {'channel_base': 1 * 32768, 'channel_max': 512, 'num_fp16_res': 4, 'conv_clamp': 256}
|
1148 |
+
G = Generator(z_dim=512, c_dim=0, w_dim=512, img_resolution=512, img_channels=3,
|
1149 |
+
synthesis_kwargs=kwargs, encoder_kwargs=kwargs, mapping_kwargs={'num_layers': 2})
|
1150 |
+
self.model = load_model(G, FCF_MODEL_URL, device)
|
1151 |
+
self.label = torch.zeros([1, self.model.c_dim], device=device)
|
1152 |
+
|
1153 |
+
@staticmethod
|
1154 |
+
def is_downloaded() -> bool:
|
1155 |
+
return os.path.exists(get_cache_path_by_url(FCF_MODEL_URL))
|
1156 |
+
|
1157 |
+
@torch.no_grad()
|
1158 |
+
def __call__(self, image, mask, config: Config):
|
1159 |
+
"""
|
1160 |
+
images: [H, W, C] RGB, not normalized
|
1161 |
+
masks: [H, W]
|
1162 |
+
return: BGR IMAGE
|
1163 |
+
"""
|
1164 |
+
if image.shape[0] == 512 and image.shape[1] == 512:
|
1165 |
+
return self._pad_forward(image, mask, config)
|
1166 |
+
|
1167 |
+
boxes = boxes_from_mask(mask)
|
1168 |
+
crop_result = []
|
1169 |
+
config.hd_strategy_crop_margin = 128
|
1170 |
+
for box in boxes:
|
1171 |
+
crop_image, crop_mask, crop_box = self._crop_box(image, mask, box, config)
|
1172 |
+
origin_size = crop_image.shape[:2]
|
1173 |
+
resize_image = resize_max_size(crop_image, size_limit=512)
|
1174 |
+
resize_mask = resize_max_size(crop_mask, size_limit=512)
|
1175 |
+
inpaint_result = self._pad_forward(resize_image, resize_mask, config)
|
1176 |
+
|
1177 |
+
# only paste masked area result
|
1178 |
+
inpaint_result = cv2.resize(inpaint_result, (origin_size[1], origin_size[0]), interpolation=cv2.INTER_CUBIC)
|
1179 |
+
|
1180 |
+
original_pixel_indices = crop_mask < 127
|
1181 |
+
inpaint_result[original_pixel_indices] = crop_image[:, :, ::-1][original_pixel_indices]
|
1182 |
+
|
1183 |
+
crop_result.append((inpaint_result, crop_box))
|
1184 |
+
|
1185 |
+
inpaint_result = image[:, :, ::-1]
|
1186 |
+
for crop_image, crop_box in crop_result:
|
1187 |
+
x1, y1, x2, y2 = crop_box
|
1188 |
+
inpaint_result[y1:y2, x1:x2, :] = crop_image
|
1189 |
+
|
1190 |
+
return inpaint_result
|
1191 |
+
|
1192 |
+
def forward(self, image, mask, config: Config):
|
1193 |
+
"""Input images and output images have same size
|
1194 |
+
images: [H, W, C] RGB
|
1195 |
+
masks: [H, W] mask area == 255
|
1196 |
+
return: BGR IMAGE
|
1197 |
+
"""
|
1198 |
+
|
1199 |
+
image = norm_img(image) # [0, 1]
|
1200 |
+
image = image * 2 - 1 # [0, 1] -> [-1, 1]
|
1201 |
+
mask = (mask > 120) * 255
|
1202 |
+
mask = norm_img(mask)
|
1203 |
+
|
1204 |
+
image = torch.from_numpy(image).unsqueeze(0).to(self.device)
|
1205 |
+
mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
|
1206 |
+
|
1207 |
+
erased_img = image * (1 - mask)
|
1208 |
+
input_image = torch.cat([0.5 - mask, erased_img], dim=1)
|
1209 |
+
|
1210 |
+
output = self.model(input_image, self.label, truncation_psi=0.1, noise_mode='none')
|
1211 |
+
output = (output.permute(0, 2, 3, 1) * 127.5 + 127.5).round().clamp(0, 255).to(torch.uint8)
|
1212 |
+
output = output[0].cpu().numpy()
|
1213 |
+
cur_res = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
1214 |
+
return cur_res
|
lama_cleaner/model/lama.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from loguru import logger
|
7 |
+
|
8 |
+
from lama_cleaner.helper import pad_img_to_modulo, download_model, norm_img, get_cache_path_by_url
|
9 |
+
from lama_cleaner.model.base import InpaintModel
|
10 |
+
from lama_cleaner.schema import Config
|
11 |
+
|
12 |
+
LAMA_MODEL_URL = os.environ.get(
|
13 |
+
"LAMA_MODEL_URL",
|
14 |
+
"https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
15 |
+
)
|
16 |
+
|
17 |
+
#"https://drive.google.com/file/d/1bMD06F9hkkS1oi8cEmb4cSjXz54Pxs6A/view?usp=sharing" #big-lama.pt file
|
18 |
+
|
19 |
+
|
20 |
+
class LaMa(InpaintModel):
|
21 |
+
pad_mod = 8
|
22 |
+
|
23 |
+
def init_model(self, device, **kwargs):
|
24 |
+
if os.environ.get("LAMA_MODEL"):
|
25 |
+
model_path = os.environ.get("LAMA_MODEL")
|
26 |
+
if not os.path.exists(model_path):
|
27 |
+
raise FileNotFoundError(
|
28 |
+
f"lama torchscript model not found: {model_path}"
|
29 |
+
)
|
30 |
+
else:
|
31 |
+
model_path = download_model(LAMA_MODEL_URL)
|
32 |
+
logger.info(f"Load LaMa model from: {model_path}")
|
33 |
+
model = torch.jit.load(model_path, map_location="cpu")
|
34 |
+
model = model.to(device)
|
35 |
+
model.eval()
|
36 |
+
self.model = model
|
37 |
+
self.model_path = model_path
|
38 |
+
|
39 |
+
@staticmethod
|
40 |
+
def is_downloaded() -> bool:
|
41 |
+
return os.path.exists(get_cache_path_by_url(LAMA_MODEL_URL))
|
42 |
+
|
43 |
+
def forward(self, image, mask, config: Config):
|
44 |
+
"""Input image and output image have same size
|
45 |
+
image: [H, W, C] RGB
|
46 |
+
mask: [H, W]
|
47 |
+
return: BGR IMAGE
|
48 |
+
"""
|
49 |
+
image = norm_img(image)
|
50 |
+
mask = norm_img(mask)
|
51 |
+
|
52 |
+
mask = (mask > 0) * 1
|
53 |
+
image = torch.from_numpy(image).unsqueeze(0).to(self.device)
|
54 |
+
mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
|
55 |
+
|
56 |
+
inpainted_image = self.model(image, mask)
|
57 |
+
|
58 |
+
cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
|
59 |
+
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
|
60 |
+
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)
|
61 |
+
return cur_res
|
lama_cleaner/model/ldm.py
ADDED
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from loguru import logger
|
6 |
+
|
7 |
+
from lama_cleaner.model.base import InpaintModel
|
8 |
+
from lama_cleaner.model.ddim_sampler import DDIMSampler
|
9 |
+
from lama_cleaner.model.plms_sampler import PLMSSampler
|
10 |
+
from lama_cleaner.schema import Config, LDMSampler
|
11 |
+
|
12 |
+
torch.manual_seed(42)
|
13 |
+
import torch.nn as nn
|
14 |
+
from lama_cleaner.helper import (
|
15 |
+
download_model,
|
16 |
+
norm_img,
|
17 |
+
get_cache_path_by_url,
|
18 |
+
load_jit_model,
|
19 |
+
)
|
20 |
+
from lama_cleaner.model.utils import (
|
21 |
+
make_beta_schedule,
|
22 |
+
timestep_embedding,
|
23 |
+
)
|
24 |
+
|
25 |
+
LDM_ENCODE_MODEL_URL = os.environ.get(
|
26 |
+
"LDM_ENCODE_MODEL_URL",
|
27 |
+
"https://github.com/Sanster/models/releases/download/add_ldm/cond_stage_model_encode.pt",
|
28 |
+
)
|
29 |
+
|
30 |
+
LDM_DECODE_MODEL_URL = os.environ.get(
|
31 |
+
"LDM_DECODE_MODEL_URL",
|
32 |
+
"https://github.com/Sanster/models/releases/download/add_ldm/cond_stage_model_decode.pt",
|
33 |
+
)
|
34 |
+
|
35 |
+
LDM_DIFFUSION_MODEL_URL = os.environ.get(
|
36 |
+
"LDM_DIFFUSION_MODEL_URL",
|
37 |
+
"https://github.com/Sanster/models/releases/download/add_ldm/diffusion.pt",
|
38 |
+
)
|
39 |
+
|
40 |
+
|
41 |
+
class DDPM(nn.Module):
|
42 |
+
# classic DDPM with Gaussian diffusion, in image space
|
43 |
+
def __init__(
|
44 |
+
self,
|
45 |
+
device,
|
46 |
+
timesteps=1000,
|
47 |
+
beta_schedule="linear",
|
48 |
+
linear_start=0.0015,
|
49 |
+
linear_end=0.0205,
|
50 |
+
cosine_s=0.008,
|
51 |
+
original_elbo_weight=0.0,
|
52 |
+
v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
|
53 |
+
l_simple_weight=1.0,
|
54 |
+
parameterization="eps", # all assuming fixed variance schedules
|
55 |
+
use_positional_encodings=False,
|
56 |
+
):
|
57 |
+
super().__init__()
|
58 |
+
self.device = device
|
59 |
+
self.parameterization = parameterization
|
60 |
+
self.use_positional_encodings = use_positional_encodings
|
61 |
+
|
62 |
+
self.v_posterior = v_posterior
|
63 |
+
self.original_elbo_weight = original_elbo_weight
|
64 |
+
self.l_simple_weight = l_simple_weight
|
65 |
+
|
66 |
+
self.register_schedule(
|
67 |
+
beta_schedule=beta_schedule,
|
68 |
+
timesteps=timesteps,
|
69 |
+
linear_start=linear_start,
|
70 |
+
linear_end=linear_end,
|
71 |
+
cosine_s=cosine_s,
|
72 |
+
)
|
73 |
+
|
74 |
+
def register_schedule(
|
75 |
+
self,
|
76 |
+
given_betas=None,
|
77 |
+
beta_schedule="linear",
|
78 |
+
timesteps=1000,
|
79 |
+
linear_start=1e-4,
|
80 |
+
linear_end=2e-2,
|
81 |
+
cosine_s=8e-3,
|
82 |
+
):
|
83 |
+
betas = make_beta_schedule(
|
84 |
+
self.device,
|
85 |
+
beta_schedule,
|
86 |
+
timesteps,
|
87 |
+
linear_start=linear_start,
|
88 |
+
linear_end=linear_end,
|
89 |
+
cosine_s=cosine_s,
|
90 |
+
)
|
91 |
+
alphas = 1.0 - betas
|
92 |
+
alphas_cumprod = np.cumprod(alphas, axis=0)
|
93 |
+
alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
|
94 |
+
|
95 |
+
(timesteps,) = betas.shape
|
96 |
+
self.num_timesteps = int(timesteps)
|
97 |
+
self.linear_start = linear_start
|
98 |
+
self.linear_end = linear_end
|
99 |
+
assert (
|
100 |
+
alphas_cumprod.shape[0] == self.num_timesteps
|
101 |
+
), "alphas have to be defined for each timestep"
|
102 |
+
|
103 |
+
to_torch = lambda x: torch.tensor(x, dtype=torch.float32).to(self.device)
|
104 |
+
|
105 |
+
self.register_buffer("betas", to_torch(betas))
|
106 |
+
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
|
107 |
+
self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
|
108 |
+
|
109 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
110 |
+
self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
|
111 |
+
self.register_buffer(
|
112 |
+
"sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
|
113 |
+
)
|
114 |
+
self.register_buffer(
|
115 |
+
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
|
116 |
+
)
|
117 |
+
self.register_buffer(
|
118 |
+
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
|
119 |
+
)
|
120 |
+
self.register_buffer(
|
121 |
+
"sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
|
122 |
+
)
|
123 |
+
|
124 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
125 |
+
posterior_variance = (1 - self.v_posterior) * betas * (
|
126 |
+
1.0 - alphas_cumprod_prev
|
127 |
+
) / (1.0 - alphas_cumprod) + self.v_posterior * betas
|
128 |
+
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
129 |
+
self.register_buffer("posterior_variance", to_torch(posterior_variance))
|
130 |
+
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
131 |
+
self.register_buffer(
|
132 |
+
"posterior_log_variance_clipped",
|
133 |
+
to_torch(np.log(np.maximum(posterior_variance, 1e-20))),
|
134 |
+
)
|
135 |
+
self.register_buffer(
|
136 |
+
"posterior_mean_coef1",
|
137 |
+
to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)),
|
138 |
+
)
|
139 |
+
self.register_buffer(
|
140 |
+
"posterior_mean_coef2",
|
141 |
+
to_torch(
|
142 |
+
(1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
|
143 |
+
),
|
144 |
+
)
|
145 |
+
|
146 |
+
if self.parameterization == "eps":
|
147 |
+
lvlb_weights = self.betas**2 / (
|
148 |
+
2
|
149 |
+
* self.posterior_variance
|
150 |
+
* to_torch(alphas)
|
151 |
+
* (1 - self.alphas_cumprod)
|
152 |
+
)
|
153 |
+
elif self.parameterization == "x0":
|
154 |
+
lvlb_weights = (
|
155 |
+
0.5
|
156 |
+
* np.sqrt(torch.Tensor(alphas_cumprod))
|
157 |
+
/ (2.0 * 1 - torch.Tensor(alphas_cumprod))
|
158 |
+
)
|
159 |
+
else:
|
160 |
+
raise NotImplementedError("mu not supported")
|
161 |
+
# TODO how to choose this term
|
162 |
+
lvlb_weights[0] = lvlb_weights[1]
|
163 |
+
self.register_buffer("lvlb_weights", lvlb_weights, persistent=False)
|
164 |
+
assert not torch.isnan(self.lvlb_weights).all()
|
165 |
+
|
166 |
+
|
167 |
+
class LatentDiffusion(DDPM):
|
168 |
+
def __init__(
|
169 |
+
self,
|
170 |
+
diffusion_model,
|
171 |
+
device,
|
172 |
+
cond_stage_key="image",
|
173 |
+
cond_stage_trainable=False,
|
174 |
+
concat_mode=True,
|
175 |
+
scale_factor=1.0,
|
176 |
+
scale_by_std=False,
|
177 |
+
*args,
|
178 |
+
**kwargs,
|
179 |
+
):
|
180 |
+
self.num_timesteps_cond = 1
|
181 |
+
self.scale_by_std = scale_by_std
|
182 |
+
super().__init__(device, *args, **kwargs)
|
183 |
+
self.diffusion_model = diffusion_model
|
184 |
+
self.concat_mode = concat_mode
|
185 |
+
self.cond_stage_trainable = cond_stage_trainable
|
186 |
+
self.cond_stage_key = cond_stage_key
|
187 |
+
self.num_downs = 2
|
188 |
+
self.scale_factor = scale_factor
|
189 |
+
|
190 |
+
def make_cond_schedule(
|
191 |
+
self,
|
192 |
+
):
|
193 |
+
self.cond_ids = torch.full(
|
194 |
+
size=(self.num_timesteps,),
|
195 |
+
fill_value=self.num_timesteps - 1,
|
196 |
+
dtype=torch.long,
|
197 |
+
)
|
198 |
+
ids = torch.round(
|
199 |
+
torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)
|
200 |
+
).long()
|
201 |
+
self.cond_ids[: self.num_timesteps_cond] = ids
|
202 |
+
|
203 |
+
def register_schedule(
|
204 |
+
self,
|
205 |
+
given_betas=None,
|
206 |
+
beta_schedule="linear",
|
207 |
+
timesteps=1000,
|
208 |
+
linear_start=1e-4,
|
209 |
+
linear_end=2e-2,
|
210 |
+
cosine_s=8e-3,
|
211 |
+
):
|
212 |
+
super().register_schedule(
|
213 |
+
given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s
|
214 |
+
)
|
215 |
+
|
216 |
+
self.shorten_cond_schedule = self.num_timesteps_cond > 1
|
217 |
+
if self.shorten_cond_schedule:
|
218 |
+
self.make_cond_schedule()
|
219 |
+
|
220 |
+
def apply_model(self, x_noisy, t, cond):
|
221 |
+
# x_recon = self.model(x_noisy, t, cond['c_concat'][0]) # cond['c_concat'][0].shape 1,4,128,128
|
222 |
+
t_emb = timestep_embedding(x_noisy.device, t, 256, repeat_only=False)
|
223 |
+
x_recon = self.diffusion_model(x_noisy, t_emb, cond)
|
224 |
+
return x_recon
|
225 |
+
|
226 |
+
|
227 |
+
class LDM(InpaintModel):
|
228 |
+
pad_mod = 32
|
229 |
+
|
230 |
+
def __init__(self, device, fp16: bool = True, **kwargs):
|
231 |
+
self.fp16 = fp16
|
232 |
+
super().__init__(device)
|
233 |
+
self.device = device
|
234 |
+
|
235 |
+
def init_model(self, device, **kwargs):
|
236 |
+
self.diffusion_model = load_jit_model(LDM_DIFFUSION_MODEL_URL, device)
|
237 |
+
self.cond_stage_model_decode = load_jit_model(LDM_DECODE_MODEL_URL, device)
|
238 |
+
self.cond_stage_model_encode = load_jit_model(LDM_ENCODE_MODEL_URL, device)
|
239 |
+
if self.fp16 and "cuda" in str(device):
|
240 |
+
self.diffusion_model = self.diffusion_model.half()
|
241 |
+
self.cond_stage_model_decode = self.cond_stage_model_decode.half()
|
242 |
+
self.cond_stage_model_encode = self.cond_stage_model_encode.half()
|
243 |
+
|
244 |
+
self.model = LatentDiffusion(self.diffusion_model, device)
|
245 |
+
|
246 |
+
@staticmethod
|
247 |
+
def is_downloaded() -> bool:
|
248 |
+
model_paths = [
|
249 |
+
get_cache_path_by_url(LDM_DIFFUSION_MODEL_URL),
|
250 |
+
get_cache_path_by_url(LDM_DECODE_MODEL_URL),
|
251 |
+
get_cache_path_by_url(LDM_ENCODE_MODEL_URL),
|
252 |
+
]
|
253 |
+
return all([os.path.exists(it) for it in model_paths])
|
254 |
+
|
255 |
+
@torch.cuda.amp.autocast()
|
256 |
+
def forward(self, image, mask, config: Config):
|
257 |
+
"""
|
258 |
+
image: [H, W, C] RGB
|
259 |
+
mask: [H, W, 1]
|
260 |
+
return: BGR IMAGE
|
261 |
+
"""
|
262 |
+
# image [1,3,512,512] float32
|
263 |
+
# mask: [1,1,512,512] float32
|
264 |
+
# masked_image: [1,3,512,512] float32
|
265 |
+
if config.ldm_sampler == LDMSampler.ddim:
|
266 |
+
sampler = DDIMSampler(self.model)
|
267 |
+
elif config.ldm_sampler == LDMSampler.plms:
|
268 |
+
sampler = PLMSSampler(self.model)
|
269 |
+
else:
|
270 |
+
raise ValueError()
|
271 |
+
|
272 |
+
steps = config.ldm_steps
|
273 |
+
image = norm_img(image)
|
274 |
+
mask = norm_img(mask)
|
275 |
+
|
276 |
+
mask[mask < 0.5] = 0
|
277 |
+
mask[mask >= 0.5] = 1
|
278 |
+
|
279 |
+
image = torch.from_numpy(image).unsqueeze(0).to(self.device)
|
280 |
+
mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
|
281 |
+
masked_image = (1 - mask) * image
|
282 |
+
|
283 |
+
mask = self._norm(mask)
|
284 |
+
masked_image = self._norm(masked_image)
|
285 |
+
|
286 |
+
c = self.cond_stage_model_encode(masked_image)
|
287 |
+
torch.cuda.empty_cache()
|
288 |
+
|
289 |
+
cc = torch.nn.functional.interpolate(mask, size=c.shape[-2:]) # 1,1,128,128
|
290 |
+
c = torch.cat((c, cc), dim=1) # 1,4,128,128
|
291 |
+
|
292 |
+
shape = (c.shape[1] - 1,) + c.shape[2:]
|
293 |
+
samples_ddim = sampler.sample(
|
294 |
+
steps=steps, conditioning=c, batch_size=c.shape[0], shape=shape
|
295 |
+
)
|
296 |
+
torch.cuda.empty_cache()
|
297 |
+
x_samples_ddim = self.cond_stage_model_decode(
|
298 |
+
samples_ddim
|
299 |
+
) # samples_ddim: 1, 3, 128, 128 float32
|
300 |
+
torch.cuda.empty_cache()
|
301 |
+
|
302 |
+
# image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
303 |
+
# mask = torch.clamp((mask + 1.0) / 2.0, min=0.0, max=1.0)
|
304 |
+
inpainted_image = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
305 |
+
|
306 |
+
# inpainted = (1 - mask) * image + mask * predicted_image
|
307 |
+
inpainted_image = inpainted_image.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255
|
308 |
+
inpainted_image = inpainted_image.astype(np.uint8)[:, :, ::-1]
|
309 |
+
return inpainted_image
|
310 |
+
|
311 |
+
def _norm(self, tensor):
|
312 |
+
return tensor * 2.0 - 1.0
|
lama_cleaner/model/mat.py
ADDED
@@ -0,0 +1,1444 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import torch.utils.checkpoint as checkpoint
|
10 |
+
|
11 |
+
from lama_cleaner.helper import load_model, get_cache_path_by_url, norm_img
|
12 |
+
from lama_cleaner.model.base import InpaintModel
|
13 |
+
from lama_cleaner.model.utils import setup_filter, Conv2dLayer, FullyConnectedLayer, conv2d_resample, bias_act, \
|
14 |
+
upsample2d, activation_funcs, MinibatchStdLayer, to_2tuple, normalize_2nd_moment
|
15 |
+
from lama_cleaner.schema import Config
|
16 |
+
|
17 |
+
|
18 |
+
class ModulatedConv2d(nn.Module):
|
19 |
+
def __init__(self,
|
20 |
+
in_channels, # Number of input channels.
|
21 |
+
out_channels, # Number of output channels.
|
22 |
+
kernel_size, # Width and height of the convolution kernel.
|
23 |
+
style_dim, # dimension of the style code
|
24 |
+
demodulate=True, # perfrom demodulation
|
25 |
+
up=1, # Integer upsampling factor.
|
26 |
+
down=1, # Integer downsampling factor.
|
27 |
+
resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations.
|
28 |
+
conv_clamp=None, # Clamp the output to +-X, None = disable clamping.
|
29 |
+
):
|
30 |
+
super().__init__()
|
31 |
+
self.demodulate = demodulate
|
32 |
+
|
33 |
+
self.weight = torch.nn.Parameter(torch.randn([1, out_channels, in_channels, kernel_size, kernel_size]))
|
34 |
+
self.out_channels = out_channels
|
35 |
+
self.kernel_size = kernel_size
|
36 |
+
self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2))
|
37 |
+
self.padding = self.kernel_size // 2
|
38 |
+
self.up = up
|
39 |
+
self.down = down
|
40 |
+
self.register_buffer('resample_filter', setup_filter(resample_filter))
|
41 |
+
self.conv_clamp = conv_clamp
|
42 |
+
|
43 |
+
self.affine = FullyConnectedLayer(style_dim, in_channels, bias_init=1)
|
44 |
+
|
45 |
+
def forward(self, x, style):
|
46 |
+
batch, in_channels, height, width = x.shape
|
47 |
+
style = self.affine(style).view(batch, 1, in_channels, 1, 1)
|
48 |
+
weight = self.weight * self.weight_gain * style
|
49 |
+
|
50 |
+
if self.demodulate:
|
51 |
+
decoefs = (weight.pow(2).sum(dim=[2, 3, 4]) + 1e-8).rsqrt()
|
52 |
+
weight = weight * decoefs.view(batch, self.out_channels, 1, 1, 1)
|
53 |
+
|
54 |
+
weight = weight.view(batch * self.out_channels, in_channels, self.kernel_size, self.kernel_size)
|
55 |
+
x = x.view(1, batch * in_channels, height, width)
|
56 |
+
x = conv2d_resample(x=x, w=weight, f=self.resample_filter, up=self.up, down=self.down,
|
57 |
+
padding=self.padding, groups=batch)
|
58 |
+
out = x.view(batch, self.out_channels, *x.shape[2:])
|
59 |
+
|
60 |
+
return out
|
61 |
+
|
62 |
+
|
63 |
+
class StyleConv(torch.nn.Module):
|
64 |
+
def __init__(self,
|
65 |
+
in_channels, # Number of input channels.
|
66 |
+
out_channels, # Number of output channels.
|
67 |
+
style_dim, # Intermediate latent (W) dimensionality.
|
68 |
+
resolution, # Resolution of this layer.
|
69 |
+
kernel_size=3, # Convolution kernel size.
|
70 |
+
up=1, # Integer upsampling factor.
|
71 |
+
use_noise=False, # Enable noise input?
|
72 |
+
activation='lrelu', # Activation function: 'relu', 'lrelu', etc.
|
73 |
+
resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations.
|
74 |
+
conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
|
75 |
+
demodulate=True, # perform demodulation
|
76 |
+
):
|
77 |
+
super().__init__()
|
78 |
+
|
79 |
+
self.conv = ModulatedConv2d(in_channels=in_channels,
|
80 |
+
out_channels=out_channels,
|
81 |
+
kernel_size=kernel_size,
|
82 |
+
style_dim=style_dim,
|
83 |
+
demodulate=demodulate,
|
84 |
+
up=up,
|
85 |
+
resample_filter=resample_filter,
|
86 |
+
conv_clamp=conv_clamp)
|
87 |
+
|
88 |
+
self.use_noise = use_noise
|
89 |
+
self.resolution = resolution
|
90 |
+
if use_noise:
|
91 |
+
self.register_buffer('noise_const', torch.randn([resolution, resolution]))
|
92 |
+
self.noise_strength = torch.nn.Parameter(torch.zeros([]))
|
93 |
+
|
94 |
+
self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
|
95 |
+
self.activation = activation
|
96 |
+
self.act_gain = activation_funcs[activation].def_gain
|
97 |
+
self.conv_clamp = conv_clamp
|
98 |
+
|
99 |
+
def forward(self, x, style, noise_mode='random', gain=1):
|
100 |
+
x = self.conv(x, style)
|
101 |
+
|
102 |
+
assert noise_mode in ['random', 'const', 'none']
|
103 |
+
|
104 |
+
if self.use_noise:
|
105 |
+
if noise_mode == 'random':
|
106 |
+
xh, xw = x.size()[-2:]
|
107 |
+
noise = torch.randn([x.shape[0], 1, xh, xw], device=x.device) \
|
108 |
+
* self.noise_strength
|
109 |
+
if noise_mode == 'const':
|
110 |
+
noise = self.noise_const * self.noise_strength
|
111 |
+
x = x + noise
|
112 |
+
|
113 |
+
act_gain = self.act_gain * gain
|
114 |
+
act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
|
115 |
+
out = bias_act(x, self.bias, act=self.activation, gain=act_gain, clamp=act_clamp)
|
116 |
+
|
117 |
+
return out
|
118 |
+
|
119 |
+
|
120 |
+
class ToRGB(torch.nn.Module):
|
121 |
+
def __init__(self,
|
122 |
+
in_channels,
|
123 |
+
out_channels,
|
124 |
+
style_dim,
|
125 |
+
kernel_size=1,
|
126 |
+
resample_filter=[1, 3, 3, 1],
|
127 |
+
conv_clamp=None,
|
128 |
+
demodulate=False):
|
129 |
+
super().__init__()
|
130 |
+
|
131 |
+
self.conv = ModulatedConv2d(in_channels=in_channels,
|
132 |
+
out_channels=out_channels,
|
133 |
+
kernel_size=kernel_size,
|
134 |
+
style_dim=style_dim,
|
135 |
+
demodulate=demodulate,
|
136 |
+
resample_filter=resample_filter,
|
137 |
+
conv_clamp=conv_clamp)
|
138 |
+
self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
|
139 |
+
self.register_buffer('resample_filter', setup_filter(resample_filter))
|
140 |
+
self.conv_clamp = conv_clamp
|
141 |
+
|
142 |
+
def forward(self, x, style, skip=None):
|
143 |
+
x = self.conv(x, style)
|
144 |
+
out = bias_act(x, self.bias, clamp=self.conv_clamp)
|
145 |
+
|
146 |
+
if skip is not None:
|
147 |
+
if skip.shape != out.shape:
|
148 |
+
skip = upsample2d(skip, self.resample_filter)
|
149 |
+
out = out + skip
|
150 |
+
|
151 |
+
return out
|
152 |
+
|
153 |
+
|
154 |
+
def get_style_code(a, b):
|
155 |
+
return torch.cat([a, b], dim=1)
|
156 |
+
|
157 |
+
|
158 |
+
class DecBlockFirst(nn.Module):
|
159 |
+
def __init__(self, in_channels, out_channels, activation, style_dim, use_noise, demodulate, img_channels):
|
160 |
+
super().__init__()
|
161 |
+
self.fc = FullyConnectedLayer(in_features=in_channels * 2,
|
162 |
+
out_features=in_channels * 4 ** 2,
|
163 |
+
activation=activation)
|
164 |
+
self.conv = StyleConv(in_channels=in_channels,
|
165 |
+
out_channels=out_channels,
|
166 |
+
style_dim=style_dim,
|
167 |
+
resolution=4,
|
168 |
+
kernel_size=3,
|
169 |
+
use_noise=use_noise,
|
170 |
+
activation=activation,
|
171 |
+
demodulate=demodulate,
|
172 |
+
)
|
173 |
+
self.toRGB = ToRGB(in_channels=out_channels,
|
174 |
+
out_channels=img_channels,
|
175 |
+
style_dim=style_dim,
|
176 |
+
kernel_size=1,
|
177 |
+
demodulate=False,
|
178 |
+
)
|
179 |
+
|
180 |
+
def forward(self, x, ws, gs, E_features, noise_mode='random'):
|
181 |
+
x = self.fc(x).view(x.shape[0], -1, 4, 4)
|
182 |
+
x = x + E_features[2]
|
183 |
+
style = get_style_code(ws[:, 0], gs)
|
184 |
+
x = self.conv(x, style, noise_mode=noise_mode)
|
185 |
+
style = get_style_code(ws[:, 1], gs)
|
186 |
+
img = self.toRGB(x, style, skip=None)
|
187 |
+
|
188 |
+
return x, img
|
189 |
+
|
190 |
+
|
191 |
+
class DecBlockFirstV2(nn.Module):
|
192 |
+
def __init__(self, in_channels, out_channels, activation, style_dim, use_noise, demodulate, img_channels):
|
193 |
+
super().__init__()
|
194 |
+
self.conv0 = Conv2dLayer(in_channels=in_channels,
|
195 |
+
out_channels=in_channels,
|
196 |
+
kernel_size=3,
|
197 |
+
activation=activation,
|
198 |
+
)
|
199 |
+
self.conv1 = StyleConv(in_channels=in_channels,
|
200 |
+
out_channels=out_channels,
|
201 |
+
style_dim=style_dim,
|
202 |
+
resolution=4,
|
203 |
+
kernel_size=3,
|
204 |
+
use_noise=use_noise,
|
205 |
+
activation=activation,
|
206 |
+
demodulate=demodulate,
|
207 |
+
)
|
208 |
+
self.toRGB = ToRGB(in_channels=out_channels,
|
209 |
+
out_channels=img_channels,
|
210 |
+
style_dim=style_dim,
|
211 |
+
kernel_size=1,
|
212 |
+
demodulate=False,
|
213 |
+
)
|
214 |
+
|
215 |
+
def forward(self, x, ws, gs, E_features, noise_mode='random'):
|
216 |
+
# x = self.fc(x).view(x.shape[0], -1, 4, 4)
|
217 |
+
x = self.conv0(x)
|
218 |
+
x = x + E_features[2]
|
219 |
+
style = get_style_code(ws[:, 0], gs)
|
220 |
+
x = self.conv1(x, style, noise_mode=noise_mode)
|
221 |
+
style = get_style_code(ws[:, 1], gs)
|
222 |
+
img = self.toRGB(x, style, skip=None)
|
223 |
+
|
224 |
+
return x, img
|
225 |
+
|
226 |
+
|
227 |
+
class DecBlock(nn.Module):
|
228 |
+
def __init__(self, res, in_channels, out_channels, activation, style_dim, use_noise, demodulate,
|
229 |
+
img_channels): # res = 2, ..., resolution_log2
|
230 |
+
super().__init__()
|
231 |
+
self.res = res
|
232 |
+
|
233 |
+
self.conv0 = StyleConv(in_channels=in_channels,
|
234 |
+
out_channels=out_channels,
|
235 |
+
style_dim=style_dim,
|
236 |
+
resolution=2 ** res,
|
237 |
+
kernel_size=3,
|
238 |
+
up=2,
|
239 |
+
use_noise=use_noise,
|
240 |
+
activation=activation,
|
241 |
+
demodulate=demodulate,
|
242 |
+
)
|
243 |
+
self.conv1 = StyleConv(in_channels=out_channels,
|
244 |
+
out_channels=out_channels,
|
245 |
+
style_dim=style_dim,
|
246 |
+
resolution=2 ** res,
|
247 |
+
kernel_size=3,
|
248 |
+
use_noise=use_noise,
|
249 |
+
activation=activation,
|
250 |
+
demodulate=demodulate,
|
251 |
+
)
|
252 |
+
self.toRGB = ToRGB(in_channels=out_channels,
|
253 |
+
out_channels=img_channels,
|
254 |
+
style_dim=style_dim,
|
255 |
+
kernel_size=1,
|
256 |
+
demodulate=False,
|
257 |
+
)
|
258 |
+
|
259 |
+
def forward(self, x, img, ws, gs, E_features, noise_mode='random'):
|
260 |
+
style = get_style_code(ws[:, self.res * 2 - 5], gs)
|
261 |
+
x = self.conv0(x, style, noise_mode=noise_mode)
|
262 |
+
x = x + E_features[self.res]
|
263 |
+
style = get_style_code(ws[:, self.res * 2 - 4], gs)
|
264 |
+
x = self.conv1(x, style, noise_mode=noise_mode)
|
265 |
+
style = get_style_code(ws[:, self.res * 2 - 3], gs)
|
266 |
+
img = self.toRGB(x, style, skip=img)
|
267 |
+
|
268 |
+
return x, img
|
269 |
+
|
270 |
+
|
271 |
+
class MappingNet(torch.nn.Module):
|
272 |
+
def __init__(self,
|
273 |
+
z_dim, # Input latent (Z) dimensionality, 0 = no latent.
|
274 |
+
c_dim, # Conditioning label (C) dimensionality, 0 = no label.
|
275 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
276 |
+
num_ws, # Number of intermediate latents to output, None = do not broadcast.
|
277 |
+
num_layers=8, # Number of mapping layers.
|
278 |
+
embed_features=None, # Label embedding dimensionality, None = same as w_dim.
|
279 |
+
layer_features=None, # Number of intermediate features in the mapping layers, None = same as w_dim.
|
280 |
+
activation='lrelu', # Activation function: 'relu', 'lrelu', etc.
|
281 |
+
lr_multiplier=0.01, # Learning rate multiplier for the mapping layers.
|
282 |
+
w_avg_beta=0.995, # Decay for tracking the moving average of W during training, None = do not track.
|
283 |
+
):
|
284 |
+
super().__init__()
|
285 |
+
self.z_dim = z_dim
|
286 |
+
self.c_dim = c_dim
|
287 |
+
self.w_dim = w_dim
|
288 |
+
self.num_ws = num_ws
|
289 |
+
self.num_layers = num_layers
|
290 |
+
self.w_avg_beta = w_avg_beta
|
291 |
+
|
292 |
+
if embed_features is None:
|
293 |
+
embed_features = w_dim
|
294 |
+
if c_dim == 0:
|
295 |
+
embed_features = 0
|
296 |
+
if layer_features is None:
|
297 |
+
layer_features = w_dim
|
298 |
+
features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim]
|
299 |
+
|
300 |
+
if c_dim > 0:
|
301 |
+
self.embed = FullyConnectedLayer(c_dim, embed_features)
|
302 |
+
for idx in range(num_layers):
|
303 |
+
in_features = features_list[idx]
|
304 |
+
out_features = features_list[idx + 1]
|
305 |
+
layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier)
|
306 |
+
setattr(self, f'fc{idx}', layer)
|
307 |
+
|
308 |
+
if num_ws is not None and w_avg_beta is not None:
|
309 |
+
self.register_buffer('w_avg', torch.zeros([w_dim]))
|
310 |
+
|
311 |
+
def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False):
|
312 |
+
# Embed, normalize, and concat inputs.
|
313 |
+
x = None
|
314 |
+
with torch.autograd.profiler.record_function('input'):
|
315 |
+
if self.z_dim > 0:
|
316 |
+
x = normalize_2nd_moment(z.to(torch.float32))
|
317 |
+
if self.c_dim > 0:
|
318 |
+
y = normalize_2nd_moment(self.embed(c.to(torch.float32)))
|
319 |
+
x = torch.cat([x, y], dim=1) if x is not None else y
|
320 |
+
|
321 |
+
# Main layers.
|
322 |
+
for idx in range(self.num_layers):
|
323 |
+
layer = getattr(self, f'fc{idx}')
|
324 |
+
x = layer(x)
|
325 |
+
|
326 |
+
# Update moving average of W.
|
327 |
+
if self.w_avg_beta is not None and self.training and not skip_w_avg_update:
|
328 |
+
with torch.autograd.profiler.record_function('update_w_avg'):
|
329 |
+
self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))
|
330 |
+
|
331 |
+
# Broadcast.
|
332 |
+
if self.num_ws is not None:
|
333 |
+
with torch.autograd.profiler.record_function('broadcast'):
|
334 |
+
x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
|
335 |
+
|
336 |
+
# Apply truncation.
|
337 |
+
if truncation_psi != 1:
|
338 |
+
with torch.autograd.profiler.record_function('truncate'):
|
339 |
+
assert self.w_avg_beta is not None
|
340 |
+
if self.num_ws is None or truncation_cutoff is None:
|
341 |
+
x = self.w_avg.lerp(x, truncation_psi)
|
342 |
+
else:
|
343 |
+
x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi)
|
344 |
+
|
345 |
+
return x
|
346 |
+
|
347 |
+
|
348 |
+
class DisFromRGB(nn.Module):
|
349 |
+
def __init__(self, in_channels, out_channels, activation): # res = 2, ..., resolution_log2
|
350 |
+
super().__init__()
|
351 |
+
self.conv = Conv2dLayer(in_channels=in_channels,
|
352 |
+
out_channels=out_channels,
|
353 |
+
kernel_size=1,
|
354 |
+
activation=activation,
|
355 |
+
)
|
356 |
+
|
357 |
+
def forward(self, x):
|
358 |
+
return self.conv(x)
|
359 |
+
|
360 |
+
|
361 |
+
class DisBlock(nn.Module):
|
362 |
+
def __init__(self, in_channels, out_channels, activation): # res = 2, ..., resolution_log2
|
363 |
+
super().__init__()
|
364 |
+
self.conv0 = Conv2dLayer(in_channels=in_channels,
|
365 |
+
out_channels=in_channels,
|
366 |
+
kernel_size=3,
|
367 |
+
activation=activation,
|
368 |
+
)
|
369 |
+
self.conv1 = Conv2dLayer(in_channels=in_channels,
|
370 |
+
out_channels=out_channels,
|
371 |
+
kernel_size=3,
|
372 |
+
down=2,
|
373 |
+
activation=activation,
|
374 |
+
)
|
375 |
+
self.skip = Conv2dLayer(in_channels=in_channels,
|
376 |
+
out_channels=out_channels,
|
377 |
+
kernel_size=1,
|
378 |
+
down=2,
|
379 |
+
bias=False,
|
380 |
+
)
|
381 |
+
|
382 |
+
def forward(self, x):
|
383 |
+
skip = self.skip(x, gain=np.sqrt(0.5))
|
384 |
+
x = self.conv0(x)
|
385 |
+
x = self.conv1(x, gain=np.sqrt(0.5))
|
386 |
+
out = skip + x
|
387 |
+
|
388 |
+
return out
|
389 |
+
|
390 |
+
|
391 |
+
class Discriminator(torch.nn.Module):
|
392 |
+
def __init__(self,
|
393 |
+
c_dim, # Conditioning label (C) dimensionality.
|
394 |
+
img_resolution, # Input resolution.
|
395 |
+
img_channels, # Number of input color channels.
|
396 |
+
channel_base=32768, # Overall multiplier for the number of channels.
|
397 |
+
channel_max=512, # Maximum number of channels in any layer.
|
398 |
+
channel_decay=1,
|
399 |
+
cmap_dim=None, # Dimensionality of mapped conditioning label, None = default.
|
400 |
+
activation='lrelu',
|
401 |
+
mbstd_group_size=4, # Group size for the minibatch standard deviation layer, None = entire minibatch.
|
402 |
+
mbstd_num_channels=1, # Number of features for the minibatch standard deviation layer, 0 = disable.
|
403 |
+
):
|
404 |
+
super().__init__()
|
405 |
+
self.c_dim = c_dim
|
406 |
+
self.img_resolution = img_resolution
|
407 |
+
self.img_channels = img_channels
|
408 |
+
|
409 |
+
resolution_log2 = int(np.log2(img_resolution))
|
410 |
+
assert img_resolution == 2 ** resolution_log2 and img_resolution >= 4
|
411 |
+
self.resolution_log2 = resolution_log2
|
412 |
+
|
413 |
+
def nf(stage):
|
414 |
+
return np.clip(int(channel_base / 2 ** (stage * channel_decay)), 1, channel_max)
|
415 |
+
|
416 |
+
if cmap_dim == None:
|
417 |
+
cmap_dim = nf(2)
|
418 |
+
if c_dim == 0:
|
419 |
+
cmap_dim = 0
|
420 |
+
self.cmap_dim = cmap_dim
|
421 |
+
|
422 |
+
if c_dim > 0:
|
423 |
+
self.mapping = MappingNet(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None)
|
424 |
+
|
425 |
+
Dis = [DisFromRGB(img_channels + 1, nf(resolution_log2), activation)]
|
426 |
+
for res in range(resolution_log2, 2, -1):
|
427 |
+
Dis.append(DisBlock(nf(res), nf(res - 1), activation))
|
428 |
+
|
429 |
+
if mbstd_num_channels > 0:
|
430 |
+
Dis.append(MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels))
|
431 |
+
Dis.append(Conv2dLayer(nf(2) + mbstd_num_channels, nf(2), kernel_size=3, activation=activation))
|
432 |
+
self.Dis = nn.Sequential(*Dis)
|
433 |
+
|
434 |
+
self.fc0 = FullyConnectedLayer(nf(2) * 4 ** 2, nf(2), activation=activation)
|
435 |
+
self.fc1 = FullyConnectedLayer(nf(2), 1 if cmap_dim == 0 else cmap_dim)
|
436 |
+
|
437 |
+
def forward(self, images_in, masks_in, c):
|
438 |
+
x = torch.cat([masks_in - 0.5, images_in], dim=1)
|
439 |
+
x = self.Dis(x)
|
440 |
+
x = self.fc1(self.fc0(x.flatten(start_dim=1)))
|
441 |
+
|
442 |
+
if self.c_dim > 0:
|
443 |
+
cmap = self.mapping(None, c)
|
444 |
+
|
445 |
+
if self.cmap_dim > 0:
|
446 |
+
x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
|
447 |
+
|
448 |
+
return x
|
449 |
+
|
450 |
+
|
451 |
+
def nf(stage, channel_base=32768, channel_decay=1.0, channel_max=512):
|
452 |
+
NF = {512: 64, 256: 128, 128: 256, 64: 512, 32: 512, 16: 512, 8: 512, 4: 512}
|
453 |
+
return NF[2 ** stage]
|
454 |
+
|
455 |
+
|
456 |
+
class Mlp(nn.Module):
|
457 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
458 |
+
super().__init__()
|
459 |
+
out_features = out_features or in_features
|
460 |
+
hidden_features = hidden_features or in_features
|
461 |
+
self.fc1 = FullyConnectedLayer(in_features=in_features, out_features=hidden_features, activation='lrelu')
|
462 |
+
self.fc2 = FullyConnectedLayer(in_features=hidden_features, out_features=out_features)
|
463 |
+
|
464 |
+
def forward(self, x):
|
465 |
+
x = self.fc1(x)
|
466 |
+
x = self.fc2(x)
|
467 |
+
return x
|
468 |
+
|
469 |
+
|
470 |
+
def window_partition(x, window_size):
|
471 |
+
"""
|
472 |
+
Args:
|
473 |
+
x: (B, H, W, C)
|
474 |
+
window_size (int): window size
|
475 |
+
Returns:
|
476 |
+
windows: (num_windows*B, window_size, window_size, C)
|
477 |
+
"""
|
478 |
+
B, H, W, C = x.shape
|
479 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
480 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
481 |
+
return windows
|
482 |
+
|
483 |
+
|
484 |
+
def window_reverse(windows, window_size: int, H: int, W: int):
|
485 |
+
"""
|
486 |
+
Args:
|
487 |
+
windows: (num_windows*B, window_size, window_size, C)
|
488 |
+
window_size (int): Window size
|
489 |
+
H (int): Height of image
|
490 |
+
W (int): Width of image
|
491 |
+
Returns:
|
492 |
+
x: (B, H, W, C)
|
493 |
+
"""
|
494 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
495 |
+
# B = windows.shape[0] / (H * W / window_size / window_size)
|
496 |
+
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
497 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
498 |
+
return x
|
499 |
+
|
500 |
+
|
501 |
+
class Conv2dLayerPartial(nn.Module):
|
502 |
+
def __init__(self,
|
503 |
+
in_channels, # Number of input channels.
|
504 |
+
out_channels, # Number of output channels.
|
505 |
+
kernel_size, # Width and height of the convolution kernel.
|
506 |
+
bias=True, # Apply additive bias before the activation function?
|
507 |
+
activation='linear', # Activation function: 'relu', 'lrelu', etc.
|
508 |
+
up=1, # Integer upsampling factor.
|
509 |
+
down=1, # Integer downsampling factor.
|
510 |
+
resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations.
|
511 |
+
conv_clamp=None, # Clamp the output to +-X, None = disable clamping.
|
512 |
+
trainable=True, # Update the weights of this layer during training?
|
513 |
+
):
|
514 |
+
super().__init__()
|
515 |
+
self.conv = Conv2dLayer(in_channels, out_channels, kernel_size, bias, activation, up, down, resample_filter,
|
516 |
+
conv_clamp, trainable)
|
517 |
+
|
518 |
+
self.weight_maskUpdater = torch.ones(1, 1, kernel_size, kernel_size)
|
519 |
+
self.slide_winsize = kernel_size ** 2
|
520 |
+
self.stride = down
|
521 |
+
self.padding = kernel_size // 2 if kernel_size % 2 == 1 else 0
|
522 |
+
|
523 |
+
def forward(self, x, mask=None):
|
524 |
+
if mask is not None:
|
525 |
+
with torch.no_grad():
|
526 |
+
if self.weight_maskUpdater.type() != x.type():
|
527 |
+
self.weight_maskUpdater = self.weight_maskUpdater.to(x)
|
528 |
+
update_mask = F.conv2d(mask, self.weight_maskUpdater, bias=None, stride=self.stride,
|
529 |
+
padding=self.padding)
|
530 |
+
mask_ratio = self.slide_winsize / (update_mask + 1e-8)
|
531 |
+
update_mask = torch.clamp(update_mask, 0, 1) # 0 or 1
|
532 |
+
mask_ratio = torch.mul(mask_ratio, update_mask)
|
533 |
+
x = self.conv(x)
|
534 |
+
x = torch.mul(x, mask_ratio)
|
535 |
+
return x, update_mask
|
536 |
+
else:
|
537 |
+
x = self.conv(x)
|
538 |
+
return x, None
|
539 |
+
|
540 |
+
|
541 |
+
class WindowAttention(nn.Module):
|
542 |
+
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
543 |
+
It supports both of shifted and non-shifted window.
|
544 |
+
Args:
|
545 |
+
dim (int): Number of input channels.
|
546 |
+
window_size (tuple[int]): The height and width of the window.
|
547 |
+
num_heads (int): Number of attention heads.
|
548 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
549 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
550 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
551 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
552 |
+
"""
|
553 |
+
|
554 |
+
def __init__(self, dim, window_size, num_heads, down_ratio=1, qkv_bias=True, qk_scale=None, attn_drop=0.,
|
555 |
+
proj_drop=0.):
|
556 |
+
|
557 |
+
super().__init__()
|
558 |
+
self.dim = dim
|
559 |
+
self.window_size = window_size # Wh, Ww
|
560 |
+
self.num_heads = num_heads
|
561 |
+
head_dim = dim // num_heads
|
562 |
+
self.scale = qk_scale or head_dim ** -0.5
|
563 |
+
|
564 |
+
self.q = FullyConnectedLayer(in_features=dim, out_features=dim)
|
565 |
+
self.k = FullyConnectedLayer(in_features=dim, out_features=dim)
|
566 |
+
self.v = FullyConnectedLayer(in_features=dim, out_features=dim)
|
567 |
+
self.proj = FullyConnectedLayer(in_features=dim, out_features=dim)
|
568 |
+
|
569 |
+
self.softmax = nn.Softmax(dim=-1)
|
570 |
+
|
571 |
+
def forward(self, x, mask_windows=None, mask=None):
|
572 |
+
"""
|
573 |
+
Args:
|
574 |
+
x: input features with shape of (num_windows*B, N, C)
|
575 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
576 |
+
"""
|
577 |
+
B_, N, C = x.shape
|
578 |
+
norm_x = F.normalize(x, p=2.0, dim=-1)
|
579 |
+
q = self.q(norm_x).reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
580 |
+
k = self.k(norm_x).view(B_, -1, self.num_heads, C // self.num_heads).permute(0, 2, 3, 1)
|
581 |
+
v = self.v(x).view(B_, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
582 |
+
|
583 |
+
attn = (q @ k) * self.scale
|
584 |
+
|
585 |
+
if mask is not None:
|
586 |
+
nW = mask.shape[0]
|
587 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
588 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
589 |
+
|
590 |
+
if mask_windows is not None:
|
591 |
+
attn_mask_windows = mask_windows.squeeze(-1).unsqueeze(1).unsqueeze(1)
|
592 |
+
attn = attn + attn_mask_windows.masked_fill(attn_mask_windows == 0, float(-100.0)).masked_fill(
|
593 |
+
attn_mask_windows == 1, float(0.0))
|
594 |
+
with torch.no_grad():
|
595 |
+
mask_windows = torch.clamp(torch.sum(mask_windows, dim=1, keepdim=True), 0, 1).repeat(1, N, 1)
|
596 |
+
|
597 |
+
attn = self.softmax(attn)
|
598 |
+
|
599 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
600 |
+
x = self.proj(x)
|
601 |
+
return x, mask_windows
|
602 |
+
|
603 |
+
|
604 |
+
class SwinTransformerBlock(nn.Module):
|
605 |
+
r""" Swin Transformer Block.
|
606 |
+
Args:
|
607 |
+
dim (int): Number of input channels.
|
608 |
+
input_resolution (tuple[int]): Input resulotion.
|
609 |
+
num_heads (int): Number of attention heads.
|
610 |
+
window_size (int): Window size.
|
611 |
+
shift_size (int): Shift size for SW-MSA.
|
612 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
613 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
614 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
615 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
616 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
617 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
618 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
619 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
620 |
+
"""
|
621 |
+
|
622 |
+
def __init__(self, dim, input_resolution, num_heads, down_ratio=1, window_size=7, shift_size=0,
|
623 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
|
624 |
+
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
625 |
+
super().__init__()
|
626 |
+
self.dim = dim
|
627 |
+
self.input_resolution = input_resolution
|
628 |
+
self.num_heads = num_heads
|
629 |
+
self.window_size = window_size
|
630 |
+
self.shift_size = shift_size
|
631 |
+
self.mlp_ratio = mlp_ratio
|
632 |
+
if min(self.input_resolution) <= self.window_size:
|
633 |
+
# if window size is larger than input resolution, we don't partition windows
|
634 |
+
self.shift_size = 0
|
635 |
+
self.window_size = min(self.input_resolution)
|
636 |
+
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
637 |
+
|
638 |
+
if self.shift_size > 0:
|
639 |
+
down_ratio = 1
|
640 |
+
self.attn = WindowAttention(dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
|
641 |
+
down_ratio=down_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
|
642 |
+
proj_drop=drop)
|
643 |
+
|
644 |
+
self.fuse = FullyConnectedLayer(in_features=dim * 2, out_features=dim, activation='lrelu')
|
645 |
+
|
646 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
647 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
648 |
+
|
649 |
+
if self.shift_size > 0:
|
650 |
+
attn_mask = self.calculate_mask(self.input_resolution)
|
651 |
+
else:
|
652 |
+
attn_mask = None
|
653 |
+
|
654 |
+
self.register_buffer("attn_mask", attn_mask)
|
655 |
+
|
656 |
+
def calculate_mask(self, x_size):
|
657 |
+
# calculate attention mask for SW-MSA
|
658 |
+
H, W = x_size
|
659 |
+
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
660 |
+
h_slices = (slice(0, -self.window_size),
|
661 |
+
slice(-self.window_size, -self.shift_size),
|
662 |
+
slice(-self.shift_size, None))
|
663 |
+
w_slices = (slice(0, -self.window_size),
|
664 |
+
slice(-self.window_size, -self.shift_size),
|
665 |
+
slice(-self.shift_size, None))
|
666 |
+
cnt = 0
|
667 |
+
for h in h_slices:
|
668 |
+
for w in w_slices:
|
669 |
+
img_mask[:, h, w, :] = cnt
|
670 |
+
cnt += 1
|
671 |
+
|
672 |
+
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
673 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
674 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
675 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
676 |
+
|
677 |
+
return attn_mask
|
678 |
+
|
679 |
+
def forward(self, x, x_size, mask=None):
|
680 |
+
# H, W = self.input_resolution
|
681 |
+
H, W = x_size
|
682 |
+
B, L, C = x.shape
|
683 |
+
# assert L == H * W, "input feature has wrong size"
|
684 |
+
|
685 |
+
shortcut = x
|
686 |
+
x = x.view(B, H, W, C)
|
687 |
+
if mask is not None:
|
688 |
+
mask = mask.view(B, H, W, 1)
|
689 |
+
|
690 |
+
# cyclic shift
|
691 |
+
if self.shift_size > 0:
|
692 |
+
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
693 |
+
if mask is not None:
|
694 |
+
shifted_mask = torch.roll(mask, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
695 |
+
else:
|
696 |
+
shifted_x = x
|
697 |
+
if mask is not None:
|
698 |
+
shifted_mask = mask
|
699 |
+
|
700 |
+
# partition windows
|
701 |
+
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
702 |
+
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
|
703 |
+
if mask is not None:
|
704 |
+
mask_windows = window_partition(shifted_mask, self.window_size)
|
705 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size, 1)
|
706 |
+
else:
|
707 |
+
mask_windows = None
|
708 |
+
|
709 |
+
# W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
|
710 |
+
if self.input_resolution == x_size:
|
711 |
+
attn_windows, mask_windows = self.attn(x_windows, mask_windows,
|
712 |
+
mask=self.attn_mask) # nW*B, window_size*window_size, C
|
713 |
+
else:
|
714 |
+
attn_windows, mask_windows = self.attn(x_windows, mask_windows, mask=self.calculate_mask(x_size).to(
|
715 |
+
x.device)) # nW*B, window_size*window_size, C
|
716 |
+
|
717 |
+
# merge windows
|
718 |
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
719 |
+
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
720 |
+
if mask is not None:
|
721 |
+
mask_windows = mask_windows.view(-1, self.window_size, self.window_size, 1)
|
722 |
+
shifted_mask = window_reverse(mask_windows, self.window_size, H, W)
|
723 |
+
|
724 |
+
# reverse cyclic shift
|
725 |
+
if self.shift_size > 0:
|
726 |
+
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
727 |
+
if mask is not None:
|
728 |
+
mask = torch.roll(shifted_mask, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
729 |
+
else:
|
730 |
+
x = shifted_x
|
731 |
+
if mask is not None:
|
732 |
+
mask = shifted_mask
|
733 |
+
x = x.view(B, H * W, C)
|
734 |
+
if mask is not None:
|
735 |
+
mask = mask.view(B, H * W, 1)
|
736 |
+
|
737 |
+
# FFN
|
738 |
+
x = self.fuse(torch.cat([shortcut, x], dim=-1))
|
739 |
+
x = self.mlp(x)
|
740 |
+
|
741 |
+
return x, mask
|
742 |
+
|
743 |
+
|
744 |
+
class PatchMerging(nn.Module):
|
745 |
+
def __init__(self, in_channels, out_channels, down=2):
|
746 |
+
super().__init__()
|
747 |
+
self.conv = Conv2dLayerPartial(in_channels=in_channels,
|
748 |
+
out_channels=out_channels,
|
749 |
+
kernel_size=3,
|
750 |
+
activation='lrelu',
|
751 |
+
down=down,
|
752 |
+
)
|
753 |
+
self.down = down
|
754 |
+
|
755 |
+
def forward(self, x, x_size, mask=None):
|
756 |
+
x = token2feature(x, x_size)
|
757 |
+
if mask is not None:
|
758 |
+
mask = token2feature(mask, x_size)
|
759 |
+
x, mask = self.conv(x, mask)
|
760 |
+
if self.down != 1:
|
761 |
+
ratio = 1 / self.down
|
762 |
+
x_size = (int(x_size[0] * ratio), int(x_size[1] * ratio))
|
763 |
+
x = feature2token(x)
|
764 |
+
if mask is not None:
|
765 |
+
mask = feature2token(mask)
|
766 |
+
return x, x_size, mask
|
767 |
+
|
768 |
+
|
769 |
+
class PatchUpsampling(nn.Module):
|
770 |
+
def __init__(self, in_channels, out_channels, up=2):
|
771 |
+
super().__init__()
|
772 |
+
self.conv = Conv2dLayerPartial(in_channels=in_channels,
|
773 |
+
out_channels=out_channels,
|
774 |
+
kernel_size=3,
|
775 |
+
activation='lrelu',
|
776 |
+
up=up,
|
777 |
+
)
|
778 |
+
self.up = up
|
779 |
+
|
780 |
+
def forward(self, x, x_size, mask=None):
|
781 |
+
x = token2feature(x, x_size)
|
782 |
+
if mask is not None:
|
783 |
+
mask = token2feature(mask, x_size)
|
784 |
+
x, mask = self.conv(x, mask)
|
785 |
+
if self.up != 1:
|
786 |
+
x_size = (int(x_size[0] * self.up), int(x_size[1] * self.up))
|
787 |
+
x = feature2token(x)
|
788 |
+
if mask is not None:
|
789 |
+
mask = feature2token(mask)
|
790 |
+
return x, x_size, mask
|
791 |
+
|
792 |
+
|
793 |
+
class BasicLayer(nn.Module):
|
794 |
+
""" A basic Swin Transformer layer for one stage.
|
795 |
+
Args:
|
796 |
+
dim (int): Number of input channels.
|
797 |
+
input_resolution (tuple[int]): Input resolution.
|
798 |
+
depth (int): Number of blocks.
|
799 |
+
num_heads (int): Number of attention heads.
|
800 |
+
window_size (int): Local window size.
|
801 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
802 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
803 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
804 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
805 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
806 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
807 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
808 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
809 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
810 |
+
"""
|
811 |
+
|
812 |
+
def __init__(self, dim, input_resolution, depth, num_heads, window_size, down_ratio=1,
|
813 |
+
mlp_ratio=2., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
|
814 |
+
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
|
815 |
+
|
816 |
+
super().__init__()
|
817 |
+
self.dim = dim
|
818 |
+
self.input_resolution = input_resolution
|
819 |
+
self.depth = depth
|
820 |
+
self.use_checkpoint = use_checkpoint
|
821 |
+
|
822 |
+
# patch merging layer
|
823 |
+
if downsample is not None:
|
824 |
+
# self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
|
825 |
+
self.downsample = downsample
|
826 |
+
else:
|
827 |
+
self.downsample = None
|
828 |
+
|
829 |
+
# build blocks
|
830 |
+
self.blocks = nn.ModuleList([
|
831 |
+
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
|
832 |
+
num_heads=num_heads, down_ratio=down_ratio, window_size=window_size,
|
833 |
+
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
834 |
+
mlp_ratio=mlp_ratio,
|
835 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
836 |
+
drop=drop, attn_drop=attn_drop,
|
837 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
838 |
+
norm_layer=norm_layer)
|
839 |
+
for i in range(depth)])
|
840 |
+
|
841 |
+
self.conv = Conv2dLayerPartial(in_channels=dim, out_channels=dim, kernel_size=3, activation='lrelu')
|
842 |
+
|
843 |
+
def forward(self, x, x_size, mask=None):
|
844 |
+
if self.downsample is not None:
|
845 |
+
x, x_size, mask = self.downsample(x, x_size, mask)
|
846 |
+
identity = x
|
847 |
+
for blk in self.blocks:
|
848 |
+
if self.use_checkpoint:
|
849 |
+
x, mask = checkpoint.checkpoint(blk, x, x_size, mask)
|
850 |
+
else:
|
851 |
+
x, mask = blk(x, x_size, mask)
|
852 |
+
if mask is not None:
|
853 |
+
mask = token2feature(mask, x_size)
|
854 |
+
x, mask = self.conv(token2feature(x, x_size), mask)
|
855 |
+
x = feature2token(x) + identity
|
856 |
+
if mask is not None:
|
857 |
+
mask = feature2token(mask)
|
858 |
+
return x, x_size, mask
|
859 |
+
|
860 |
+
|
861 |
+
class ToToken(nn.Module):
|
862 |
+
def __init__(self, in_channels=3, dim=128, kernel_size=5, stride=1):
|
863 |
+
super().__init__()
|
864 |
+
|
865 |
+
self.proj = Conv2dLayerPartial(in_channels=in_channels, out_channels=dim, kernel_size=kernel_size,
|
866 |
+
activation='lrelu')
|
867 |
+
|
868 |
+
def forward(self, x, mask):
|
869 |
+
x, mask = self.proj(x, mask)
|
870 |
+
|
871 |
+
return x, mask
|
872 |
+
|
873 |
+
|
874 |
+
class EncFromRGB(nn.Module):
|
875 |
+
def __init__(self, in_channels, out_channels, activation): # res = 2, ..., resolution_log2
|
876 |
+
super().__init__()
|
877 |
+
self.conv0 = Conv2dLayer(in_channels=in_channels,
|
878 |
+
out_channels=out_channels,
|
879 |
+
kernel_size=1,
|
880 |
+
activation=activation,
|
881 |
+
)
|
882 |
+
self.conv1 = Conv2dLayer(in_channels=out_channels,
|
883 |
+
out_channels=out_channels,
|
884 |
+
kernel_size=3,
|
885 |
+
activation=activation,
|
886 |
+
)
|
887 |
+
|
888 |
+
def forward(self, x):
|
889 |
+
x = self.conv0(x)
|
890 |
+
x = self.conv1(x)
|
891 |
+
|
892 |
+
return x
|
893 |
+
|
894 |
+
|
895 |
+
class ConvBlockDown(nn.Module):
|
896 |
+
def __init__(self, in_channels, out_channels, activation): # res = 2, ..., resolution_log
|
897 |
+
super().__init__()
|
898 |
+
|
899 |
+
self.conv0 = Conv2dLayer(in_channels=in_channels,
|
900 |
+
out_channels=out_channels,
|
901 |
+
kernel_size=3,
|
902 |
+
activation=activation,
|
903 |
+
down=2,
|
904 |
+
)
|
905 |
+
self.conv1 = Conv2dLayer(in_channels=out_channels,
|
906 |
+
out_channels=out_channels,
|
907 |
+
kernel_size=3,
|
908 |
+
activation=activation,
|
909 |
+
)
|
910 |
+
|
911 |
+
def forward(self, x):
|
912 |
+
x = self.conv0(x)
|
913 |
+
x = self.conv1(x)
|
914 |
+
|
915 |
+
return x
|
916 |
+
|
917 |
+
|
918 |
+
def token2feature(x, x_size):
|
919 |
+
B, N, C = x.shape
|
920 |
+
h, w = x_size
|
921 |
+
x = x.permute(0, 2, 1).reshape(B, C, h, w)
|
922 |
+
return x
|
923 |
+
|
924 |
+
|
925 |
+
def feature2token(x):
|
926 |
+
B, C, H, W = x.shape
|
927 |
+
x = x.view(B, C, -1).transpose(1, 2)
|
928 |
+
return x
|
929 |
+
|
930 |
+
|
931 |
+
class Encoder(nn.Module):
|
932 |
+
def __init__(self, res_log2, img_channels, activation, patch_size=5, channels=16, drop_path_rate=0.1):
|
933 |
+
super().__init__()
|
934 |
+
|
935 |
+
self.resolution = []
|
936 |
+
|
937 |
+
for idx, i in enumerate(range(res_log2, 3, -1)): # from input size to 16x16
|
938 |
+
res = 2 ** i
|
939 |
+
self.resolution.append(res)
|
940 |
+
if i == res_log2:
|
941 |
+
block = EncFromRGB(img_channels * 2 + 1, nf(i), activation)
|
942 |
+
else:
|
943 |
+
block = ConvBlockDown(nf(i + 1), nf(i), activation)
|
944 |
+
setattr(self, 'EncConv_Block_%dx%d' % (res, res), block)
|
945 |
+
|
946 |
+
def forward(self, x):
|
947 |
+
out = {}
|
948 |
+
for res in self.resolution:
|
949 |
+
res_log2 = int(np.log2(res))
|
950 |
+
x = getattr(self, 'EncConv_Block_%dx%d' % (res, res))(x)
|
951 |
+
out[res_log2] = x
|
952 |
+
|
953 |
+
return out
|
954 |
+
|
955 |
+
|
956 |
+
class ToStyle(nn.Module):
|
957 |
+
def __init__(self, in_channels, out_channels, activation, drop_rate):
|
958 |
+
super().__init__()
|
959 |
+
self.conv = nn.Sequential(
|
960 |
+
Conv2dLayer(in_channels=in_channels, out_channels=in_channels, kernel_size=3, activation=activation,
|
961 |
+
down=2),
|
962 |
+
Conv2dLayer(in_channels=in_channels, out_channels=in_channels, kernel_size=3, activation=activation,
|
963 |
+
down=2),
|
964 |
+
Conv2dLayer(in_channels=in_channels, out_channels=in_channels, kernel_size=3, activation=activation,
|
965 |
+
down=2),
|
966 |
+
)
|
967 |
+
|
968 |
+
self.pool = nn.AdaptiveAvgPool2d(1)
|
969 |
+
self.fc = FullyConnectedLayer(in_features=in_channels,
|
970 |
+
out_features=out_channels,
|
971 |
+
activation=activation)
|
972 |
+
# self.dropout = nn.Dropout(drop_rate)
|
973 |
+
|
974 |
+
def forward(self, x):
|
975 |
+
x = self.conv(x)
|
976 |
+
x = self.pool(x)
|
977 |
+
x = self.fc(x.flatten(start_dim=1))
|
978 |
+
# x = self.dropout(x)
|
979 |
+
|
980 |
+
return x
|
981 |
+
|
982 |
+
|
983 |
+
class DecBlockFirstV2(nn.Module):
|
984 |
+
def __init__(self, res, in_channels, out_channels, activation, style_dim, use_noise, demodulate, img_channels):
|
985 |
+
super().__init__()
|
986 |
+
self.res = res
|
987 |
+
|
988 |
+
self.conv0 = Conv2dLayer(in_channels=in_channels,
|
989 |
+
out_channels=in_channels,
|
990 |
+
kernel_size=3,
|
991 |
+
activation=activation,
|
992 |
+
)
|
993 |
+
self.conv1 = StyleConv(in_channels=in_channels,
|
994 |
+
out_channels=out_channels,
|
995 |
+
style_dim=style_dim,
|
996 |
+
resolution=2 ** res,
|
997 |
+
kernel_size=3,
|
998 |
+
use_noise=use_noise,
|
999 |
+
activation=activation,
|
1000 |
+
demodulate=demodulate,
|
1001 |
+
)
|
1002 |
+
self.toRGB = ToRGB(in_channels=out_channels,
|
1003 |
+
out_channels=img_channels,
|
1004 |
+
style_dim=style_dim,
|
1005 |
+
kernel_size=1,
|
1006 |
+
demodulate=False,
|
1007 |
+
)
|
1008 |
+
|
1009 |
+
def forward(self, x, ws, gs, E_features, noise_mode='random'):
|
1010 |
+
# x = self.fc(x).view(x.shape[0], -1, 4, 4)
|
1011 |
+
x = self.conv0(x)
|
1012 |
+
x = x + E_features[self.res]
|
1013 |
+
style = get_style_code(ws[:, 0], gs)
|
1014 |
+
x = self.conv1(x, style, noise_mode=noise_mode)
|
1015 |
+
style = get_style_code(ws[:, 1], gs)
|
1016 |
+
img = self.toRGB(x, style, skip=None)
|
1017 |
+
|
1018 |
+
return x, img
|
1019 |
+
|
1020 |
+
|
1021 |
+
class DecBlock(nn.Module):
|
1022 |
+
def __init__(self, res, in_channels, out_channels, activation, style_dim, use_noise, demodulate,
|
1023 |
+
img_channels): # res = 4, ..., resolution_log2
|
1024 |
+
super().__init__()
|
1025 |
+
self.res = res
|
1026 |
+
|
1027 |
+
self.conv0 = StyleConv(in_channels=in_channels,
|
1028 |
+
out_channels=out_channels,
|
1029 |
+
style_dim=style_dim,
|
1030 |
+
resolution=2 ** res,
|
1031 |
+
kernel_size=3,
|
1032 |
+
up=2,
|
1033 |
+
use_noise=use_noise,
|
1034 |
+
activation=activation,
|
1035 |
+
demodulate=demodulate,
|
1036 |
+
)
|
1037 |
+
self.conv1 = StyleConv(in_channels=out_channels,
|
1038 |
+
out_channels=out_channels,
|
1039 |
+
style_dim=style_dim,
|
1040 |
+
resolution=2 ** res,
|
1041 |
+
kernel_size=3,
|
1042 |
+
use_noise=use_noise,
|
1043 |
+
activation=activation,
|
1044 |
+
demodulate=demodulate,
|
1045 |
+
)
|
1046 |
+
self.toRGB = ToRGB(in_channels=out_channels,
|
1047 |
+
out_channels=img_channels,
|
1048 |
+
style_dim=style_dim,
|
1049 |
+
kernel_size=1,
|
1050 |
+
demodulate=False,
|
1051 |
+
)
|
1052 |
+
|
1053 |
+
def forward(self, x, img, ws, gs, E_features, noise_mode='random'):
|
1054 |
+
style = get_style_code(ws[:, self.res * 2 - 9], gs)
|
1055 |
+
x = self.conv0(x, style, noise_mode=noise_mode)
|
1056 |
+
x = x + E_features[self.res]
|
1057 |
+
style = get_style_code(ws[:, self.res * 2 - 8], gs)
|
1058 |
+
x = self.conv1(x, style, noise_mode=noise_mode)
|
1059 |
+
style = get_style_code(ws[:, self.res * 2 - 7], gs)
|
1060 |
+
img = self.toRGB(x, style, skip=img)
|
1061 |
+
|
1062 |
+
return x, img
|
1063 |
+
|
1064 |
+
|
1065 |
+
class Decoder(nn.Module):
|
1066 |
+
def __init__(self, res_log2, activation, style_dim, use_noise, demodulate, img_channels):
|
1067 |
+
super().__init__()
|
1068 |
+
self.Dec_16x16 = DecBlockFirstV2(4, nf(4), nf(4), activation, style_dim, use_noise, demodulate, img_channels)
|
1069 |
+
for res in range(5, res_log2 + 1):
|
1070 |
+
setattr(self, 'Dec_%dx%d' % (2 ** res, 2 ** res),
|
1071 |
+
DecBlock(res, nf(res - 1), nf(res), activation, style_dim, use_noise, demodulate, img_channels))
|
1072 |
+
self.res_log2 = res_log2
|
1073 |
+
|
1074 |
+
def forward(self, x, ws, gs, E_features, noise_mode='random'):
|
1075 |
+
x, img = self.Dec_16x16(x, ws, gs, E_features, noise_mode=noise_mode)
|
1076 |
+
for res in range(5, self.res_log2 + 1):
|
1077 |
+
block = getattr(self, 'Dec_%dx%d' % (2 ** res, 2 ** res))
|
1078 |
+
x, img = block(x, img, ws, gs, E_features, noise_mode=noise_mode)
|
1079 |
+
|
1080 |
+
return img
|
1081 |
+
|
1082 |
+
|
1083 |
+
class DecStyleBlock(nn.Module):
|
1084 |
+
def __init__(self, res, in_channels, out_channels, activation, style_dim, use_noise, demodulate, img_channels):
|
1085 |
+
super().__init__()
|
1086 |
+
self.res = res
|
1087 |
+
|
1088 |
+
self.conv0 = StyleConv(in_channels=in_channels,
|
1089 |
+
out_channels=out_channels,
|
1090 |
+
style_dim=style_dim,
|
1091 |
+
resolution=2 ** res,
|
1092 |
+
kernel_size=3,
|
1093 |
+
up=2,
|
1094 |
+
use_noise=use_noise,
|
1095 |
+
activation=activation,
|
1096 |
+
demodulate=demodulate,
|
1097 |
+
)
|
1098 |
+
self.conv1 = StyleConv(in_channels=out_channels,
|
1099 |
+
out_channels=out_channels,
|
1100 |
+
style_dim=style_dim,
|
1101 |
+
resolution=2 ** res,
|
1102 |
+
kernel_size=3,
|
1103 |
+
use_noise=use_noise,
|
1104 |
+
activation=activation,
|
1105 |
+
demodulate=demodulate,
|
1106 |
+
)
|
1107 |
+
self.toRGB = ToRGB(in_channels=out_channels,
|
1108 |
+
out_channels=img_channels,
|
1109 |
+
style_dim=style_dim,
|
1110 |
+
kernel_size=1,
|
1111 |
+
demodulate=False,
|
1112 |
+
)
|
1113 |
+
|
1114 |
+
def forward(self, x, img, style, skip, noise_mode='random'):
|
1115 |
+
x = self.conv0(x, style, noise_mode=noise_mode)
|
1116 |
+
x = x + skip
|
1117 |
+
x = self.conv1(x, style, noise_mode=noise_mode)
|
1118 |
+
img = self.toRGB(x, style, skip=img)
|
1119 |
+
|
1120 |
+
return x, img
|
1121 |
+
|
1122 |
+
|
1123 |
+
class FirstStage(nn.Module):
|
1124 |
+
def __init__(self, img_channels, img_resolution=256, dim=180, w_dim=512, use_noise=False, demodulate=True,
|
1125 |
+
activation='lrelu'):
|
1126 |
+
super().__init__()
|
1127 |
+
res = 64
|
1128 |
+
|
1129 |
+
self.conv_first = Conv2dLayerPartial(in_channels=img_channels + 1, out_channels=dim, kernel_size=3,
|
1130 |
+
activation=activation)
|
1131 |
+
self.enc_conv = nn.ModuleList()
|
1132 |
+
down_time = int(np.log2(img_resolution // res))
|
1133 |
+
# 根据图片尺寸构建 swim transformer 的层数
|
1134 |
+
for i in range(down_time): # from input size to 64
|
1135 |
+
self.enc_conv.append(
|
1136 |
+
Conv2dLayerPartial(in_channels=dim, out_channels=dim, kernel_size=3, down=2, activation=activation)
|
1137 |
+
)
|
1138 |
+
|
1139 |
+
# from 64 -> 16 -> 64
|
1140 |
+
depths = [2, 3, 4, 3, 2]
|
1141 |
+
ratios = [1, 1 / 2, 1 / 2, 2, 2]
|
1142 |
+
num_heads = 6
|
1143 |
+
window_sizes = [8, 16, 16, 16, 8]
|
1144 |
+
drop_path_rate = 0.1
|
1145 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
1146 |
+
|
1147 |
+
self.tran = nn.ModuleList()
|
1148 |
+
for i, depth in enumerate(depths):
|
1149 |
+
res = int(res * ratios[i])
|
1150 |
+
if ratios[i] < 1:
|
1151 |
+
merge = PatchMerging(dim, dim, down=int(1 / ratios[i]))
|
1152 |
+
elif ratios[i] > 1:
|
1153 |
+
merge = PatchUpsampling(dim, dim, up=ratios[i])
|
1154 |
+
else:
|
1155 |
+
merge = None
|
1156 |
+
self.tran.append(
|
1157 |
+
BasicLayer(dim=dim, input_resolution=[res, res], depth=depth, num_heads=num_heads,
|
1158 |
+
window_size=window_sizes[i], drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])],
|
1159 |
+
downsample=merge)
|
1160 |
+
)
|
1161 |
+
|
1162 |
+
# global style
|
1163 |
+
down_conv = []
|
1164 |
+
for i in range(int(np.log2(16))):
|
1165 |
+
down_conv.append(
|
1166 |
+
Conv2dLayer(in_channels=dim, out_channels=dim, kernel_size=3, down=2, activation=activation))
|
1167 |
+
down_conv.append(nn.AdaptiveAvgPool2d((1, 1)))
|
1168 |
+
self.down_conv = nn.Sequential(*down_conv)
|
1169 |
+
self.to_style = FullyConnectedLayer(in_features=dim, out_features=dim * 2, activation=activation)
|
1170 |
+
self.ws_style = FullyConnectedLayer(in_features=w_dim, out_features=dim, activation=activation)
|
1171 |
+
self.to_square = FullyConnectedLayer(in_features=dim, out_features=16 * 16, activation=activation)
|
1172 |
+
|
1173 |
+
style_dim = dim * 3
|
1174 |
+
self.dec_conv = nn.ModuleList()
|
1175 |
+
for i in range(down_time): # from 64 to input size
|
1176 |
+
res = res * 2
|
1177 |
+
self.dec_conv.append(
|
1178 |
+
DecStyleBlock(res, dim, dim, activation, style_dim, use_noise, demodulate, img_channels))
|
1179 |
+
|
1180 |
+
def forward(self, images_in, masks_in, ws, noise_mode='random'):
|
1181 |
+
x = torch.cat([masks_in - 0.5, images_in * masks_in], dim=1)
|
1182 |
+
|
1183 |
+
skips = []
|
1184 |
+
x, mask = self.conv_first(x, masks_in) # input size
|
1185 |
+
skips.append(x)
|
1186 |
+
for i, block in enumerate(self.enc_conv): # input size to 64
|
1187 |
+
x, mask = block(x, mask)
|
1188 |
+
if i != len(self.enc_conv) - 1:
|
1189 |
+
skips.append(x)
|
1190 |
+
|
1191 |
+
x_size = x.size()[-2:]
|
1192 |
+
x = feature2token(x)
|
1193 |
+
mask = feature2token(mask)
|
1194 |
+
mid = len(self.tran) // 2
|
1195 |
+
for i, block in enumerate(self.tran): # 64 to 16
|
1196 |
+
if i < mid:
|
1197 |
+
x, x_size, mask = block(x, x_size, mask)
|
1198 |
+
skips.append(x)
|
1199 |
+
elif i > mid:
|
1200 |
+
x, x_size, mask = block(x, x_size, None)
|
1201 |
+
x = x + skips[mid - i]
|
1202 |
+
else:
|
1203 |
+
x, x_size, mask = block(x, x_size, None)
|
1204 |
+
|
1205 |
+
mul_map = torch.ones_like(x) * 0.5
|
1206 |
+
mul_map = F.dropout(mul_map, training=True)
|
1207 |
+
ws = self.ws_style(ws[:, -1])
|
1208 |
+
add_n = self.to_square(ws).unsqueeze(1)
|
1209 |
+
add_n = F.interpolate(add_n, size=x.size(1), mode='linear', align_corners=False).squeeze(1).unsqueeze(
|
1210 |
+
-1)
|
1211 |
+
x = x * mul_map + add_n * (1 - mul_map)
|
1212 |
+
gs = self.to_style(self.down_conv(token2feature(x, x_size)).flatten(start_dim=1))
|
1213 |
+
style = torch.cat([gs, ws], dim=1)
|
1214 |
+
|
1215 |
+
x = token2feature(x, x_size).contiguous()
|
1216 |
+
img = None
|
1217 |
+
for i, block in enumerate(self.dec_conv):
|
1218 |
+
x, img = block(x, img, style, skips[len(self.dec_conv) - i - 1], noise_mode=noise_mode)
|
1219 |
+
|
1220 |
+
# ensemble
|
1221 |
+
img = img * (1 - masks_in) + images_in * masks_in
|
1222 |
+
|
1223 |
+
return img
|
1224 |
+
|
1225 |
+
|
1226 |
+
class SynthesisNet(nn.Module):
|
1227 |
+
def __init__(self,
|
1228 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
1229 |
+
img_resolution, # Output image resolution.
|
1230 |
+
img_channels=3, # Number of color channels.
|
1231 |
+
channel_base=32768, # Overall multiplier for the number of channels.
|
1232 |
+
channel_decay=1.0,
|
1233 |
+
channel_max=512, # Maximum number of channels in any layer.
|
1234 |
+
activation='lrelu', # Activation function: 'relu', 'lrelu', etc.
|
1235 |
+
drop_rate=0.5,
|
1236 |
+
use_noise=False,
|
1237 |
+
demodulate=True,
|
1238 |
+
):
|
1239 |
+
super().__init__()
|
1240 |
+
resolution_log2 = int(np.log2(img_resolution))
|
1241 |
+
assert img_resolution == 2 ** resolution_log2 and img_resolution >= 4
|
1242 |
+
|
1243 |
+
self.num_layers = resolution_log2 * 2 - 3 * 2
|
1244 |
+
self.img_resolution = img_resolution
|
1245 |
+
self.resolution_log2 = resolution_log2
|
1246 |
+
|
1247 |
+
# first stage
|
1248 |
+
self.first_stage = FirstStage(img_channels, img_resolution=img_resolution, w_dim=w_dim, use_noise=False,
|
1249 |
+
demodulate=demodulate)
|
1250 |
+
|
1251 |
+
# second stage
|
1252 |
+
self.enc = Encoder(resolution_log2, img_channels, activation, patch_size=5, channels=16)
|
1253 |
+
self.to_square = FullyConnectedLayer(in_features=w_dim, out_features=16 * 16, activation=activation)
|
1254 |
+
self.to_style = ToStyle(in_channels=nf(4), out_channels=nf(2) * 2, activation=activation, drop_rate=drop_rate)
|
1255 |
+
style_dim = w_dim + nf(2) * 2
|
1256 |
+
self.dec = Decoder(resolution_log2, activation, style_dim, use_noise, demodulate, img_channels)
|
1257 |
+
|
1258 |
+
def forward(self, images_in, masks_in, ws, noise_mode='random', return_stg1=False):
|
1259 |
+
out_stg1 = self.first_stage(images_in, masks_in, ws, noise_mode=noise_mode)
|
1260 |
+
|
1261 |
+
# encoder
|
1262 |
+
x = images_in * masks_in + out_stg1 * (1 - masks_in)
|
1263 |
+
x = torch.cat([masks_in - 0.5, x, images_in * masks_in], dim=1)
|
1264 |
+
E_features = self.enc(x)
|
1265 |
+
|
1266 |
+
fea_16 = E_features[4]
|
1267 |
+
mul_map = torch.ones_like(fea_16) * 0.5
|
1268 |
+
mul_map = F.dropout(mul_map, training=True)
|
1269 |
+
add_n = self.to_square(ws[:, 0]).view(-1, 16, 16).unsqueeze(1)
|
1270 |
+
add_n = F.interpolate(add_n, size=fea_16.size()[-2:], mode='bilinear', align_corners=False)
|
1271 |
+
fea_16 = fea_16 * mul_map + add_n * (1 - mul_map)
|
1272 |
+
E_features[4] = fea_16
|
1273 |
+
|
1274 |
+
# style
|
1275 |
+
gs = self.to_style(fea_16)
|
1276 |
+
|
1277 |
+
# decoder
|
1278 |
+
img = self.dec(fea_16, ws, gs, E_features, noise_mode=noise_mode)
|
1279 |
+
|
1280 |
+
# ensemble
|
1281 |
+
img = img * (1 - masks_in) + images_in * masks_in
|
1282 |
+
|
1283 |
+
if not return_stg1:
|
1284 |
+
return img
|
1285 |
+
else:
|
1286 |
+
return img, out_stg1
|
1287 |
+
|
1288 |
+
|
1289 |
+
class Generator(nn.Module):
|
1290 |
+
def __init__(self,
|
1291 |
+
z_dim, # Input latent (Z) dimensionality, 0 = no latent.
|
1292 |
+
c_dim, # Conditioning label (C) dimensionality, 0 = no label.
|
1293 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
1294 |
+
img_resolution, # resolution of generated image
|
1295 |
+
img_channels, # Number of input color channels.
|
1296 |
+
synthesis_kwargs={}, # Arguments for SynthesisNetwork.
|
1297 |
+
mapping_kwargs={}, # Arguments for MappingNetwork.
|
1298 |
+
):
|
1299 |
+
super().__init__()
|
1300 |
+
self.z_dim = z_dim
|
1301 |
+
self.c_dim = c_dim
|
1302 |
+
self.w_dim = w_dim
|
1303 |
+
self.img_resolution = img_resolution
|
1304 |
+
self.img_channels = img_channels
|
1305 |
+
|
1306 |
+
self.synthesis = SynthesisNet(w_dim=w_dim,
|
1307 |
+
img_resolution=img_resolution,
|
1308 |
+
img_channels=img_channels,
|
1309 |
+
**synthesis_kwargs)
|
1310 |
+
self.mapping = MappingNet(z_dim=z_dim,
|
1311 |
+
c_dim=c_dim,
|
1312 |
+
w_dim=w_dim,
|
1313 |
+
num_ws=self.synthesis.num_layers,
|
1314 |
+
**mapping_kwargs)
|
1315 |
+
|
1316 |
+
def forward(self, images_in, masks_in, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False,
|
1317 |
+
noise_mode='none', return_stg1=False):
|
1318 |
+
ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff,
|
1319 |
+
skip_w_avg_update=skip_w_avg_update)
|
1320 |
+
img = self.synthesis(images_in, masks_in, ws, noise_mode=noise_mode)
|
1321 |
+
return img
|
1322 |
+
|
1323 |
+
|
1324 |
+
class Discriminator(torch.nn.Module):
|
1325 |
+
def __init__(self,
|
1326 |
+
c_dim, # Conditioning label (C) dimensionality.
|
1327 |
+
img_resolution, # Input resolution.
|
1328 |
+
img_channels, # Number of input color channels.
|
1329 |
+
channel_base=32768, # Overall multiplier for the number of channels.
|
1330 |
+
channel_max=512, # Maximum number of channels in any layer.
|
1331 |
+
channel_decay=1,
|
1332 |
+
cmap_dim=None, # Dimensionality of mapped conditioning label, None = default.
|
1333 |
+
activation='lrelu',
|
1334 |
+
mbstd_group_size=4, # Group size for the minibatch standard deviation layer, None = entire minibatch.
|
1335 |
+
mbstd_num_channels=1, # Number of features for the minibatch standard deviation layer, 0 = disable.
|
1336 |
+
):
|
1337 |
+
super().__init__()
|
1338 |
+
self.c_dim = c_dim
|
1339 |
+
self.img_resolution = img_resolution
|
1340 |
+
self.img_channels = img_channels
|
1341 |
+
|
1342 |
+
resolution_log2 = int(np.log2(img_resolution))
|
1343 |
+
assert img_resolution == 2 ** resolution_log2 and img_resolution >= 4
|
1344 |
+
self.resolution_log2 = resolution_log2
|
1345 |
+
|
1346 |
+
if cmap_dim == None:
|
1347 |
+
cmap_dim = nf(2)
|
1348 |
+
if c_dim == 0:
|
1349 |
+
cmap_dim = 0
|
1350 |
+
self.cmap_dim = cmap_dim
|
1351 |
+
|
1352 |
+
if c_dim > 0:
|
1353 |
+
self.mapping = MappingNet(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None)
|
1354 |
+
|
1355 |
+
Dis = [DisFromRGB(img_channels + 1, nf(resolution_log2), activation)]
|
1356 |
+
for res in range(resolution_log2, 2, -1):
|
1357 |
+
Dis.append(DisBlock(nf(res), nf(res - 1), activation))
|
1358 |
+
|
1359 |
+
if mbstd_num_channels > 0:
|
1360 |
+
Dis.append(MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels))
|
1361 |
+
Dis.append(Conv2dLayer(nf(2) + mbstd_num_channels, nf(2), kernel_size=3, activation=activation))
|
1362 |
+
self.Dis = nn.Sequential(*Dis)
|
1363 |
+
|
1364 |
+
self.fc0 = FullyConnectedLayer(nf(2) * 4 ** 2, nf(2), activation=activation)
|
1365 |
+
self.fc1 = FullyConnectedLayer(nf(2), 1 if cmap_dim == 0 else cmap_dim)
|
1366 |
+
|
1367 |
+
# for 64x64
|
1368 |
+
Dis_stg1 = [DisFromRGB(img_channels + 1, nf(resolution_log2) // 2, activation)]
|
1369 |
+
for res in range(resolution_log2, 2, -1):
|
1370 |
+
Dis_stg1.append(DisBlock(nf(res) // 2, nf(res - 1) // 2, activation))
|
1371 |
+
|
1372 |
+
if mbstd_num_channels > 0:
|
1373 |
+
Dis_stg1.append(MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels))
|
1374 |
+
Dis_stg1.append(Conv2dLayer(nf(2) // 2 + mbstd_num_channels, nf(2) // 2, kernel_size=3, activation=activation))
|
1375 |
+
self.Dis_stg1 = nn.Sequential(*Dis_stg1)
|
1376 |
+
|
1377 |
+
self.fc0_stg1 = FullyConnectedLayer(nf(2) // 2 * 4 ** 2, nf(2) // 2, activation=activation)
|
1378 |
+
self.fc1_stg1 = FullyConnectedLayer(nf(2) // 2, 1 if cmap_dim == 0 else cmap_dim)
|
1379 |
+
|
1380 |
+
def forward(self, images_in, masks_in, images_stg1, c):
|
1381 |
+
x = self.Dis(torch.cat([masks_in - 0.5, images_in], dim=1))
|
1382 |
+
x = self.fc1(self.fc0(x.flatten(start_dim=1)))
|
1383 |
+
|
1384 |
+
x_stg1 = self.Dis_stg1(torch.cat([masks_in - 0.5, images_stg1], dim=1))
|
1385 |
+
x_stg1 = self.fc1_stg1(self.fc0_stg1(x_stg1.flatten(start_dim=1)))
|
1386 |
+
|
1387 |
+
if self.c_dim > 0:
|
1388 |
+
cmap = self.mapping(None, c)
|
1389 |
+
|
1390 |
+
if self.cmap_dim > 0:
|
1391 |
+
x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
|
1392 |
+
x_stg1 = (x_stg1 * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
|
1393 |
+
|
1394 |
+
return x, x_stg1
|
1395 |
+
|
1396 |
+
|
1397 |
+
MAT_MODEL_URL = os.environ.get(
|
1398 |
+
"MAT_MODEL_URL",
|
1399 |
+
"https://github.com/Sanster/models/releases/download/add_mat/Places_512_FullData_G.pth",
|
1400 |
+
)
|
1401 |
+
|
1402 |
+
|
1403 |
+
class MAT(InpaintModel):
|
1404 |
+
min_size = 512
|
1405 |
+
pad_mod = 512
|
1406 |
+
pad_to_square = True
|
1407 |
+
|
1408 |
+
def init_model(self, device, **kwargs):
|
1409 |
+
seed = 240 # pick up a random number
|
1410 |
+
random.seed(seed)
|
1411 |
+
np.random.seed(seed)
|
1412 |
+
torch.manual_seed(seed)
|
1413 |
+
|
1414 |
+
G = Generator(z_dim=512, c_dim=0, w_dim=512, img_resolution=512, img_channels=3)
|
1415 |
+
self.model = load_model(G, MAT_MODEL_URL, device)
|
1416 |
+
self.z = torch.from_numpy(np.random.randn(1, G.z_dim)).to(device) # [1., 512]
|
1417 |
+
self.label = torch.zeros([1, self.model.c_dim], device=device)
|
1418 |
+
|
1419 |
+
@staticmethod
|
1420 |
+
def is_downloaded() -> bool:
|
1421 |
+
return os.path.exists(get_cache_path_by_url(MAT_MODEL_URL))
|
1422 |
+
|
1423 |
+
def forward(self, image, mask, config: Config):
|
1424 |
+
"""Input images and output images have same size
|
1425 |
+
images: [H, W, C] RGB
|
1426 |
+
masks: [H, W] mask area == 255
|
1427 |
+
return: BGR IMAGE
|
1428 |
+
"""
|
1429 |
+
|
1430 |
+
image = norm_img(image) # [0, 1]
|
1431 |
+
image = image * 2 - 1 # [0, 1] -> [-1, 1]
|
1432 |
+
|
1433 |
+
mask = (mask > 127) * 255
|
1434 |
+
mask = 255 - mask
|
1435 |
+
mask = norm_img(mask)
|
1436 |
+
|
1437 |
+
image = torch.from_numpy(image).unsqueeze(0).to(self.device)
|
1438 |
+
mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
|
1439 |
+
|
1440 |
+
output = self.model(image, mask, self.z, self.label, truncation_psi=1, noise_mode='none')
|
1441 |
+
output = (output.permute(0, 2, 3, 1) * 127.5 + 127.5).round().clamp(0, 255).to(torch.uint8)
|
1442 |
+
output = output[0].cpu().numpy()
|
1443 |
+
cur_res = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
1444 |
+
return cur_res
|
lama_cleaner/model/opencv2.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
from lama_cleaner.model.base import InpaintModel
|
3 |
+
from lama_cleaner.schema import Config
|
4 |
+
|
5 |
+
flag_map = {
|
6 |
+
"INPAINT_NS": cv2.INPAINT_NS,
|
7 |
+
"INPAINT_TELEA": cv2.INPAINT_TELEA
|
8 |
+
}
|
9 |
+
|
10 |
+
class OpenCV2(InpaintModel):
|
11 |
+
pad_mod = 1
|
12 |
+
|
13 |
+
@staticmethod
|
14 |
+
def is_downloaded() -> bool:
|
15 |
+
return True
|
16 |
+
|
17 |
+
def forward(self, image, mask, config: Config):
|
18 |
+
"""Input image and output image have same size
|
19 |
+
image: [H, W, C] RGB
|
20 |
+
mask: [H, W, 1]
|
21 |
+
return: BGR IMAGE
|
22 |
+
"""
|
23 |
+
cur_res = cv2.inpaint(image[:,:,::-1], mask, inpaintRadius=config.cv2_radius, flags=flag_map[config.cv2_flag])
|
24 |
+
return cur_res
|
lama_cleaner/model/plms_sampler.py
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# From: https://github.com/CompVis/latent-diffusion/blob/main/ldm/models/diffusion/plms.py
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from lama_cleaner.model.utils import make_ddim_timesteps, make_ddim_sampling_parameters, noise_like
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
|
8 |
+
class PLMSSampler(object):
|
9 |
+
def __init__(self, model, schedule="linear", **kwargs):
|
10 |
+
super().__init__()
|
11 |
+
self.model = model
|
12 |
+
self.ddpm_num_timesteps = model.num_timesteps
|
13 |
+
self.schedule = schedule
|
14 |
+
|
15 |
+
def register_buffer(self, name, attr):
|
16 |
+
setattr(self, name, attr)
|
17 |
+
|
18 |
+
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
19 |
+
if ddim_eta != 0:
|
20 |
+
raise ValueError('ddim_eta must be 0 for PLMS')
|
21 |
+
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
|
22 |
+
num_ddpm_timesteps=self.ddpm_num_timesteps, verbose=verbose)
|
23 |
+
alphas_cumprod = self.model.alphas_cumprod
|
24 |
+
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
25 |
+
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
26 |
+
|
27 |
+
self.register_buffer('betas', to_torch(self.model.betas))
|
28 |
+
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
29 |
+
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
|
30 |
+
|
31 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
32 |
+
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
|
33 |
+
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
34 |
+
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
|
35 |
+
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
36 |
+
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
37 |
+
|
38 |
+
# ddim sampling parameters
|
39 |
+
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
|
40 |
+
ddim_timesteps=self.ddim_timesteps,
|
41 |
+
eta=ddim_eta, verbose=verbose)
|
42 |
+
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
43 |
+
self.register_buffer('ddim_alphas', ddim_alphas)
|
44 |
+
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
45 |
+
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
|
46 |
+
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
47 |
+
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
|
48 |
+
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
49 |
+
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
|
50 |
+
|
51 |
+
@torch.no_grad()
|
52 |
+
def sample(self,
|
53 |
+
steps,
|
54 |
+
batch_size,
|
55 |
+
shape,
|
56 |
+
conditioning=None,
|
57 |
+
callback=None,
|
58 |
+
normals_sequence=None,
|
59 |
+
img_callback=None,
|
60 |
+
quantize_x0=False,
|
61 |
+
eta=0.,
|
62 |
+
mask=None,
|
63 |
+
x0=None,
|
64 |
+
temperature=1.,
|
65 |
+
noise_dropout=0.,
|
66 |
+
score_corrector=None,
|
67 |
+
corrector_kwargs=None,
|
68 |
+
verbose=False,
|
69 |
+
x_T=None,
|
70 |
+
log_every_t=100,
|
71 |
+
unconditional_guidance_scale=1.,
|
72 |
+
unconditional_conditioning=None,
|
73 |
+
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
74 |
+
**kwargs
|
75 |
+
):
|
76 |
+
if conditioning is not None:
|
77 |
+
if isinstance(conditioning, dict):
|
78 |
+
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
79 |
+
if cbs != batch_size:
|
80 |
+
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
81 |
+
else:
|
82 |
+
if conditioning.shape[0] != batch_size:
|
83 |
+
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
84 |
+
|
85 |
+
self.make_schedule(ddim_num_steps=steps, ddim_eta=eta, verbose=verbose)
|
86 |
+
# sampling
|
87 |
+
C, H, W = shape
|
88 |
+
size = (batch_size, C, H, W)
|
89 |
+
print(f'Data shape for PLMS sampling is {size}')
|
90 |
+
|
91 |
+
samples = self.plms_sampling(conditioning, size,
|
92 |
+
callback=callback,
|
93 |
+
img_callback=img_callback,
|
94 |
+
quantize_denoised=quantize_x0,
|
95 |
+
mask=mask, x0=x0,
|
96 |
+
ddim_use_original_steps=False,
|
97 |
+
noise_dropout=noise_dropout,
|
98 |
+
temperature=temperature,
|
99 |
+
score_corrector=score_corrector,
|
100 |
+
corrector_kwargs=corrector_kwargs,
|
101 |
+
x_T=x_T,
|
102 |
+
log_every_t=log_every_t,
|
103 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
104 |
+
unconditional_conditioning=unconditional_conditioning,
|
105 |
+
)
|
106 |
+
return samples
|
107 |
+
|
108 |
+
@torch.no_grad()
|
109 |
+
def plms_sampling(self, cond, shape,
|
110 |
+
x_T=None, ddim_use_original_steps=False,
|
111 |
+
callback=None, timesteps=None, quantize_denoised=False,
|
112 |
+
mask=None, x0=None, img_callback=None, log_every_t=100,
|
113 |
+
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
114 |
+
unconditional_guidance_scale=1., unconditional_conditioning=None, ):
|
115 |
+
device = self.model.betas.device
|
116 |
+
b = shape[0]
|
117 |
+
if x_T is None:
|
118 |
+
img = torch.randn(shape, device=device)
|
119 |
+
else:
|
120 |
+
img = x_T
|
121 |
+
|
122 |
+
if timesteps is None:
|
123 |
+
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
124 |
+
elif timesteps is not None and not ddim_use_original_steps:
|
125 |
+
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
|
126 |
+
timesteps = self.ddim_timesteps[:subset_end]
|
127 |
+
|
128 |
+
time_range = list(reversed(range(0, timesteps))) if ddim_use_original_steps else np.flip(timesteps)
|
129 |
+
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
130 |
+
print(f"Running PLMS Sampling with {total_steps} timesteps")
|
131 |
+
|
132 |
+
iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
|
133 |
+
old_eps = []
|
134 |
+
|
135 |
+
for i, step in enumerate(iterator):
|
136 |
+
index = total_steps - i - 1
|
137 |
+
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
138 |
+
ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
|
139 |
+
|
140 |
+
if mask is not None:
|
141 |
+
assert x0 is not None
|
142 |
+
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
|
143 |
+
img = img_orig * mask + (1. - mask) * img
|
144 |
+
|
145 |
+
outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
|
146 |
+
quantize_denoised=quantize_denoised, temperature=temperature,
|
147 |
+
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
148 |
+
corrector_kwargs=corrector_kwargs,
|
149 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
150 |
+
unconditional_conditioning=unconditional_conditioning,
|
151 |
+
old_eps=old_eps, t_next=ts_next)
|
152 |
+
img, pred_x0, e_t = outs
|
153 |
+
old_eps.append(e_t)
|
154 |
+
if len(old_eps) >= 4:
|
155 |
+
old_eps.pop(0)
|
156 |
+
if callback: callback(i)
|
157 |
+
if img_callback: img_callback(pred_x0, i)
|
158 |
+
|
159 |
+
return img
|
160 |
+
|
161 |
+
@torch.no_grad()
|
162 |
+
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
163 |
+
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
164 |
+
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
|
165 |
+
b, *_, device = *x.shape, x.device
|
166 |
+
|
167 |
+
def get_model_output(x, t):
|
168 |
+
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
169 |
+
e_t = self.model.apply_model(x, t, c)
|
170 |
+
else:
|
171 |
+
x_in = torch.cat([x] * 2)
|
172 |
+
t_in = torch.cat([t] * 2)
|
173 |
+
c_in = torch.cat([unconditional_conditioning, c])
|
174 |
+
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
175 |
+
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
176 |
+
|
177 |
+
if score_corrector is not None:
|
178 |
+
assert self.model.parameterization == "eps"
|
179 |
+
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
180 |
+
|
181 |
+
return e_t
|
182 |
+
|
183 |
+
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
184 |
+
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
185 |
+
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
186 |
+
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
187 |
+
|
188 |
+
def get_x_prev_and_pred_x0(e_t, index):
|
189 |
+
# select parameters corresponding to the currently considered timestep
|
190 |
+
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
191 |
+
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
192 |
+
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
193 |
+
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device)
|
194 |
+
|
195 |
+
# current prediction for x_0
|
196 |
+
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
197 |
+
if quantize_denoised:
|
198 |
+
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
199 |
+
# direction pointing to x_t
|
200 |
+
dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
|
201 |
+
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
202 |
+
if noise_dropout > 0.:
|
203 |
+
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
204 |
+
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
205 |
+
return x_prev, pred_x0
|
206 |
+
|
207 |
+
e_t = get_model_output(x, t)
|
208 |
+
if len(old_eps) == 0:
|
209 |
+
# Pseudo Improved Euler (2nd order)
|
210 |
+
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
|
211 |
+
e_t_next = get_model_output(x_prev, t_next)
|
212 |
+
e_t_prime = (e_t + e_t_next) / 2
|
213 |
+
elif len(old_eps) == 1:
|
214 |
+
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
215 |
+
e_t_prime = (3 * e_t - old_eps[-1]) / 2
|
216 |
+
elif len(old_eps) == 2:
|
217 |
+
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
218 |
+
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
219 |
+
elif len(old_eps) >= 3:
|
220 |
+
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
221 |
+
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
|
222 |
+
|
223 |
+
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
224 |
+
|
225 |
+
return x_prev, pred_x0, e_t
|
lama_cleaner/model/sd.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
import PIL.Image
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from diffusers import PNDMScheduler, DDIMScheduler
|
8 |
+
from loguru import logger
|
9 |
+
from transformers import FeatureExtractionMixin, ImageFeatureExtractionMixin
|
10 |
+
|
11 |
+
from lama_cleaner.helper import norm_img
|
12 |
+
|
13 |
+
from lama_cleaner.model.base import InpaintModel
|
14 |
+
from lama_cleaner.schema import Config, SDSampler
|
15 |
+
|
16 |
+
|
17 |
+
#
|
18 |
+
#
|
19 |
+
# def preprocess_image(image):
|
20 |
+
# w, h = image.size
|
21 |
+
# w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
22 |
+
# image = image.resize((w, h), resample=PIL.Image.LANCZOS)
|
23 |
+
# image = np.array(image).astype(np.float32) / 255.0
|
24 |
+
# image = image[None].transpose(0, 3, 1, 2)
|
25 |
+
# image = torch.from_numpy(image)
|
26 |
+
# # [-1, 1]
|
27 |
+
# return 2.0 * image - 1.0
|
28 |
+
#
|
29 |
+
#
|
30 |
+
# def preprocess_mask(mask):
|
31 |
+
# mask = mask.convert("L")
|
32 |
+
# w, h = mask.size
|
33 |
+
# w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
34 |
+
# mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST)
|
35 |
+
# mask = np.array(mask).astype(np.float32) / 255.0
|
36 |
+
# mask = np.tile(mask, (4, 1, 1))
|
37 |
+
# mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
|
38 |
+
# mask = 1 - mask # repaint white, keep black
|
39 |
+
# mask = torch.from_numpy(mask)
|
40 |
+
# return mask
|
41 |
+
|
42 |
+
class DummyFeatureExtractorOutput:
|
43 |
+
def __init__(self, pixel_values):
|
44 |
+
self.pixel_values = pixel_values
|
45 |
+
|
46 |
+
def to(self, device):
|
47 |
+
return self
|
48 |
+
|
49 |
+
|
50 |
+
class DummyFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
51 |
+
def __init__(self, **kwargs):
|
52 |
+
super().__init__(**kwargs)
|
53 |
+
|
54 |
+
def __call__(self, *args, **kwargs):
|
55 |
+
return DummyFeatureExtractorOutput(torch.empty(0, 3))
|
56 |
+
|
57 |
+
|
58 |
+
class DummySafetyChecker:
|
59 |
+
def __init__(self, *args, **kwargs):
|
60 |
+
pass
|
61 |
+
|
62 |
+
def __call__(self, clip_input, images):
|
63 |
+
return images, False
|
64 |
+
|
65 |
+
|
66 |
+
class SD(InpaintModel):
|
67 |
+
pad_mod = 64 # current diffusers only support 64 https://github.com/huggingface/diffusers/pull/505
|
68 |
+
min_size = 512
|
69 |
+
|
70 |
+
def init_model(self, device: torch.device, **kwargs):
|
71 |
+
from .sd_pipeline import StableDiffusionInpaintPipeline
|
72 |
+
|
73 |
+
model_kwargs = {"local_files_only": kwargs['sd_run_local']}
|
74 |
+
if kwargs['sd_disable_nsfw']:
|
75 |
+
logger.info("Disable Stable Diffusion Model NSFW checker")
|
76 |
+
model_kwargs.update(dict(
|
77 |
+
feature_extractor=DummyFeatureExtractor(),
|
78 |
+
safety_checker=DummySafetyChecker(),
|
79 |
+
))
|
80 |
+
|
81 |
+
self.model = StableDiffusionInpaintPipeline.from_pretrained(
|
82 |
+
self.model_id_or_path,
|
83 |
+
revision="fp16" if torch.cuda.is_available() else "main",
|
84 |
+
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
85 |
+
use_auth_token=kwargs["hf_access_token"],
|
86 |
+
**model_kwargs
|
87 |
+
)
|
88 |
+
# https://huggingface.co/docs/diffusers/v0.3.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing
|
89 |
+
self.model.enable_attention_slicing()
|
90 |
+
self.model = self.model.to(device)
|
91 |
+
|
92 |
+
if kwargs['sd_cpu_textencoder']:
|
93 |
+
logger.info("Run Stable Diffusion TextEncoder on CPU")
|
94 |
+
self.model.text_encoder = self.model.text_encoder.to(torch.device('cpu'), non_blocking=True)
|
95 |
+
self.model.text_encoder = self.model.text_encoder.to(torch.float32, non_blocking=True )
|
96 |
+
|
97 |
+
self.callbacks = kwargs.pop("callbacks", None)
|
98 |
+
|
99 |
+
@torch.cuda.amp.autocast()
|
100 |
+
def forward(self, image, mask, config: Config):
|
101 |
+
"""Input image and output image have same size
|
102 |
+
image: [H, W, C] RGB
|
103 |
+
mask: [H, W, 1] 255 means area to repaint
|
104 |
+
return: BGR IMAGE
|
105 |
+
"""
|
106 |
+
|
107 |
+
# image = norm_img(image) # [0, 1]
|
108 |
+
# image = image * 2 - 1 # [0, 1] -> [-1, 1]
|
109 |
+
|
110 |
+
# resize to latent feature map size
|
111 |
+
# h, w = mask.shape[:2]
|
112 |
+
# mask = cv2.resize(mask, (h // 8, w // 8), interpolation=cv2.INTER_AREA)
|
113 |
+
# mask = norm_img(mask)
|
114 |
+
#
|
115 |
+
# image = torch.from_numpy(image).unsqueeze(0).to(self.device)
|
116 |
+
# mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
|
117 |
+
|
118 |
+
if config.sd_sampler == SDSampler.ddim:
|
119 |
+
scheduler = DDIMScheduler(
|
120 |
+
beta_start=0.00085,
|
121 |
+
beta_end=0.012,
|
122 |
+
beta_schedule="scaled_linear",
|
123 |
+
clip_sample=False,
|
124 |
+
set_alpha_to_one=False,
|
125 |
+
)
|
126 |
+
elif config.sd_sampler == SDSampler.pndm:
|
127 |
+
PNDM_kwargs = {
|
128 |
+
"tensor_format": "pt",
|
129 |
+
"beta_schedule": "scaled_linear",
|
130 |
+
"beta_start": 0.00085,
|
131 |
+
"beta_end": 0.012,
|
132 |
+
"num_train_timesteps": 1000,
|
133 |
+
"skip_prk_steps": True,
|
134 |
+
}
|
135 |
+
scheduler = PNDMScheduler(**PNDM_kwargs)
|
136 |
+
else:
|
137 |
+
raise ValueError(config.sd_sampler)
|
138 |
+
|
139 |
+
self.model.scheduler = scheduler
|
140 |
+
|
141 |
+
seed = config.sd_seed
|
142 |
+
random.seed(seed)
|
143 |
+
np.random.seed(seed)
|
144 |
+
torch.manual_seed(seed)
|
145 |
+
torch.cuda.manual_seed_all(seed)
|
146 |
+
|
147 |
+
if config.sd_mask_blur != 0:
|
148 |
+
k = 2 * config.sd_mask_blur + 1
|
149 |
+
mask = cv2.GaussianBlur(mask, (k, k), 0)[:, :, np.newaxis]
|
150 |
+
|
151 |
+
output = self.model(
|
152 |
+
prompt=config.prompt,
|
153 |
+
init_image=PIL.Image.fromarray(image),
|
154 |
+
mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"),
|
155 |
+
strength=config.sd_strength,
|
156 |
+
num_inference_steps=config.sd_steps,
|
157 |
+
guidance_scale=config.sd_guidance_scale,
|
158 |
+
output_type="np.array",
|
159 |
+
callbacks=self.callbacks,
|
160 |
+
).images[0]
|
161 |
+
|
162 |
+
output = (output * 255).round().astype("uint8")
|
163 |
+
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
164 |
+
return output
|
165 |
+
|
166 |
+
@torch.no_grad()
|
167 |
+
def __call__(self, image, mask, config: Config):
|
168 |
+
"""
|
169 |
+
images: [H, W, C] RGB, not normalized
|
170 |
+
masks: [H, W]
|
171 |
+
return: BGR IMAGE
|
172 |
+
"""
|
173 |
+
img_h, img_w = image.shape[:2]
|
174 |
+
|
175 |
+
# boxes = boxes_from_mask(mask)
|
176 |
+
if config.use_croper:
|
177 |
+
logger.info("use croper")
|
178 |
+
l, t, w, h = (
|
179 |
+
config.croper_x,
|
180 |
+
config.croper_y,
|
181 |
+
config.croper_width,
|
182 |
+
config.croper_height,
|
183 |
+
)
|
184 |
+
r = l + w
|
185 |
+
b = t + h
|
186 |
+
|
187 |
+
l = max(l, 0)
|
188 |
+
r = min(r, img_w)
|
189 |
+
t = max(t, 0)
|
190 |
+
b = min(b, img_h)
|
191 |
+
|
192 |
+
crop_img = image[t:b, l:r, :]
|
193 |
+
crop_mask = mask[t:b, l:r]
|
194 |
+
|
195 |
+
crop_image = self._pad_forward(crop_img, crop_mask, config)
|
196 |
+
|
197 |
+
inpaint_result = image[:, :, ::-1]
|
198 |
+
inpaint_result[t:b, l:r, :] = crop_image
|
199 |
+
else:
|
200 |
+
inpaint_result = self._pad_forward(image, mask, config)
|
201 |
+
|
202 |
+
return inpaint_result
|
203 |
+
|
204 |
+
@staticmethod
|
205 |
+
def is_downloaded() -> bool:
|
206 |
+
# model will be downloaded when app start, and can't switch in frontend settings
|
207 |
+
return True
|
208 |
+
|
209 |
+
|
210 |
+
class SD14(SD):
|
211 |
+
model_id_or_path = "CompVis/stable-diffusion-v1-4"
|
212 |
+
|
213 |
+
|
214 |
+
class SD15(SD):
|
215 |
+
model_id_or_path = "CompVis/stable-diffusion-v1-5"
|
lama_cleaner/model/sd_pipeline.py
ADDED
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
from typing import List, Optional, Union, Callable
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
|
7 |
+
import PIL
|
8 |
+
from diffusers import DiffusionPipeline, AutoencoderKL, UNet2DConditionModel, DDIMScheduler, PNDMScheduler
|
9 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker, StableDiffusionPipelineOutput
|
10 |
+
from diffusers.utils import logging
|
11 |
+
from tqdm.auto import tqdm
|
12 |
+
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
13 |
+
|
14 |
+
logger = logging.get_logger(__name__)
|
15 |
+
|
16 |
+
|
17 |
+
def preprocess_image(image):
|
18 |
+
w, h = image.size
|
19 |
+
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
20 |
+
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
|
21 |
+
image = np.array(image).astype(np.float32) / 255.0
|
22 |
+
image = image[None].transpose(0, 3, 1, 2)
|
23 |
+
image = torch.from_numpy(image)
|
24 |
+
return 2.0 * image - 1.0
|
25 |
+
|
26 |
+
|
27 |
+
def preprocess_mask(mask):
|
28 |
+
mask = mask.convert("L")
|
29 |
+
w, h = mask.size
|
30 |
+
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
31 |
+
mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST)
|
32 |
+
mask = np.array(mask).astype(np.float32) / 255.0
|
33 |
+
mask = np.tile(mask, (4, 1, 1))
|
34 |
+
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
|
35 |
+
mask = 1 - mask # repaint white, keep black
|
36 |
+
mask = torch.from_numpy(mask)
|
37 |
+
return mask
|
38 |
+
|
39 |
+
|
40 |
+
class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
41 |
+
r"""
|
42 |
+
Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*.
|
43 |
+
|
44 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
45 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
46 |
+
|
47 |
+
Args:
|
48 |
+
vae ([`AutoencoderKL`]):
|
49 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
50 |
+
text_encoder ([`CLIPTextModel`]):
|
51 |
+
Frozen text-encoder. Stable Diffusion uses the text portion of
|
52 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
53 |
+
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
54 |
+
tokenizer (`CLIPTokenizer`):
|
55 |
+
Tokenizer of class
|
56 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
57 |
+
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
58 |
+
scheduler ([`SchedulerMixin`]):
|
59 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
|
60 |
+
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
61 |
+
safety_checker ([`StableDiffusionSafetyChecker`]):
|
62 |
+
Classification module that estimates whether generated images could be considered offsensive or harmful.
|
63 |
+
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
|
64 |
+
feature_extractor ([`CLIPFeatureExtractor`]):
|
65 |
+
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
66 |
+
"""
|
67 |
+
|
68 |
+
def __init__(
|
69 |
+
self,
|
70 |
+
vae: AutoencoderKL,
|
71 |
+
text_encoder: CLIPTextModel,
|
72 |
+
tokenizer: CLIPTokenizer,
|
73 |
+
unet: UNet2DConditionModel,
|
74 |
+
scheduler: Union[DDIMScheduler, PNDMScheduler],
|
75 |
+
safety_checker: StableDiffusionSafetyChecker,
|
76 |
+
feature_extractor: CLIPFeatureExtractor,
|
77 |
+
):
|
78 |
+
super().__init__()
|
79 |
+
scheduler = scheduler.set_format("pt")
|
80 |
+
logger.info("`StableDiffusionInpaintPipeline` is experimental and will very likely change in the future.")
|
81 |
+
self.register_modules(
|
82 |
+
vae=vae,
|
83 |
+
text_encoder=text_encoder,
|
84 |
+
tokenizer=tokenizer,
|
85 |
+
unet=unet,
|
86 |
+
scheduler=scheduler,
|
87 |
+
safety_checker=safety_checker,
|
88 |
+
feature_extractor=feature_extractor,
|
89 |
+
)
|
90 |
+
|
91 |
+
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
92 |
+
r"""
|
93 |
+
Enable sliced attention computation.
|
94 |
+
|
95 |
+
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
96 |
+
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
97 |
+
|
98 |
+
Args:
|
99 |
+
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
|
100 |
+
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
101 |
+
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
|
102 |
+
`attention_head_dim` must be a multiple of `slice_size`.
|
103 |
+
"""
|
104 |
+
if slice_size == "auto":
|
105 |
+
# half the attention head size is usually a good trade-off between
|
106 |
+
# speed and memory
|
107 |
+
slice_size = self.unet.config.attention_head_dim // 2
|
108 |
+
self.unet.set_attention_slice(slice_size)
|
109 |
+
|
110 |
+
def disable_attention_slicing(self):
|
111 |
+
r"""
|
112 |
+
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
|
113 |
+
back to computing attention in one step.
|
114 |
+
"""
|
115 |
+
# set slice_size = `None` to disable `set_attention_slice`
|
116 |
+
self.enable_attention_slice(None)
|
117 |
+
|
118 |
+
@torch.no_grad()
|
119 |
+
def __call__(
|
120 |
+
self,
|
121 |
+
prompt: Union[str, List[str]],
|
122 |
+
init_image: Union[torch.FloatTensor, PIL.Image.Image],
|
123 |
+
mask_image: Union[torch.FloatTensor, PIL.Image.Image],
|
124 |
+
strength: float = 0.8,
|
125 |
+
num_inference_steps: Optional[int] = 50,
|
126 |
+
guidance_scale: Optional[float] = 7.5,
|
127 |
+
eta: Optional[float] = 0.0,
|
128 |
+
generator: Optional[torch.Generator] = None,
|
129 |
+
output_type: Optional[str] = "pil",
|
130 |
+
return_dict: bool = True,
|
131 |
+
callbacks: List[Callable[[int], None]] = None
|
132 |
+
):
|
133 |
+
r"""
|
134 |
+
Function invoked when calling the pipeline for generation.
|
135 |
+
|
136 |
+
Args:
|
137 |
+
prompt (`str` or `List[str]`):
|
138 |
+
The prompt or prompts to guide the image generation.
|
139 |
+
init_image (`torch.FloatTensor` or `PIL.Image.Image`):
|
140 |
+
`Image`, or tensor representing an image batch, that will be used as the starting point for the
|
141 |
+
process. This is the image whose masked region will be inpainted.
|
142 |
+
mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
|
143 |
+
`Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be
|
144 |
+
replaced by noise and therefore repainted, while black pixels will be preserved. The mask image will be
|
145 |
+
converted to a single channel (luminance) before use.
|
146 |
+
strength (`float`, *optional*, defaults to 0.8):
|
147 |
+
Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
|
148 |
+
is 1, the denoising process will be run on the masked area for the full number of iterations specified
|
149 |
+
in `num_inference_steps`. `init_image` will be used as a reference for the masked area, adding more
|
150 |
+
noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
|
151 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
152 |
+
The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
|
153 |
+
the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
|
154 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
155 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
156 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
157 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
158 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
159 |
+
usually at the expense of lower image quality.
|
160 |
+
eta (`float`, *optional*, defaults to 0.0):
|
161 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
162 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
163 |
+
generator (`torch.Generator`, *optional*):
|
164 |
+
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
165 |
+
deterministic.
|
166 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
167 |
+
The output format of the generate image. Choose between
|
168 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
|
169 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
170 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
171 |
+
plain tuple.
|
172 |
+
|
173 |
+
Returns:
|
174 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
175 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
176 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
177 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
178 |
+
(nsfw) content, according to the `safety_checker`.
|
179 |
+
"""
|
180 |
+
if isinstance(prompt, str):
|
181 |
+
batch_size = 1
|
182 |
+
elif isinstance(prompt, list):
|
183 |
+
batch_size = len(prompt)
|
184 |
+
else:
|
185 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
186 |
+
|
187 |
+
if strength < 0 or strength > 1:
|
188 |
+
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
189 |
+
|
190 |
+
# set timesteps
|
191 |
+
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
|
192 |
+
extra_set_kwargs = {}
|
193 |
+
offset = 0
|
194 |
+
if accepts_offset:
|
195 |
+
offset = 1
|
196 |
+
extra_set_kwargs["offset"] = 1
|
197 |
+
|
198 |
+
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
|
199 |
+
|
200 |
+
# preprocess image
|
201 |
+
init_image = preprocess_image(init_image).to(self.device)
|
202 |
+
|
203 |
+
# encode the init image into latents and scale the latents
|
204 |
+
init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist
|
205 |
+
init_latents = init_latent_dist.sample(generator=generator)
|
206 |
+
|
207 |
+
init_latents = 0.18215 * init_latents
|
208 |
+
|
209 |
+
# Expand init_latents for batch_size
|
210 |
+
init_latents = torch.cat([init_latents] * batch_size)
|
211 |
+
init_latents_orig = init_latents
|
212 |
+
|
213 |
+
# preprocess mask
|
214 |
+
mask = preprocess_mask(mask_image).to(self.device)
|
215 |
+
mask = torch.cat([mask] * batch_size)
|
216 |
+
|
217 |
+
# check sizes
|
218 |
+
if not mask.shape == init_latents.shape:
|
219 |
+
raise ValueError("The mask and init_image should be the same size!")
|
220 |
+
|
221 |
+
# get the original timestep using init_timestep
|
222 |
+
init_timestep = int(num_inference_steps * strength) + offset
|
223 |
+
init_timestep = min(init_timestep, num_inference_steps)
|
224 |
+
timesteps = self.scheduler.timesteps[-init_timestep]
|
225 |
+
timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)
|
226 |
+
|
227 |
+
# add noise to latents using the timesteps
|
228 |
+
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
|
229 |
+
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
|
230 |
+
|
231 |
+
# get prompt text embeddings
|
232 |
+
text_input = self.tokenizer(
|
233 |
+
prompt,
|
234 |
+
padding="max_length",
|
235 |
+
max_length=self.tokenizer.model_max_length,
|
236 |
+
truncation=True,
|
237 |
+
return_tensors="pt",
|
238 |
+
)
|
239 |
+
text_encoder_device = self.text_encoder.device
|
240 |
+
|
241 |
+
text_embeddings = self.text_encoder(text_input.input_ids.to(text_encoder_device, non_blocking=True))[0].to(self.device, non_blocking=True)
|
242 |
+
|
243 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
244 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
245 |
+
# corresponds to doing no classifier free guidance.
|
246 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
247 |
+
# get unconditional embeddings for classifier free guidance
|
248 |
+
if do_classifier_free_guidance:
|
249 |
+
max_length = text_input.input_ids.shape[-1]
|
250 |
+
uncond_input = self.tokenizer(
|
251 |
+
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
|
252 |
+
)
|
253 |
+
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(text_encoder_device, non_blocking=True))[0].to(self.device, non_blocking=True)
|
254 |
+
|
255 |
+
# For classifier free guidance, we need to do two forward passes.
|
256 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
257 |
+
# to avoid doing two forward passes
|
258 |
+
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
259 |
+
|
260 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
261 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
262 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
263 |
+
# and should be between [0, 1]
|
264 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
265 |
+
extra_step_kwargs = {}
|
266 |
+
if accepts_eta:
|
267 |
+
extra_step_kwargs["eta"] = eta
|
268 |
+
|
269 |
+
latents = init_latents
|
270 |
+
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
271 |
+
for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])):
|
272 |
+
# expand the latents if we are doing classifier free guidance
|
273 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
274 |
+
# predict the noise residual
|
275 |
+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
276 |
+
|
277 |
+
# perform guidance
|
278 |
+
if do_classifier_free_guidance:
|
279 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
280 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
281 |
+
|
282 |
+
# compute the previous noisy sample x_t -> x_t-1
|
283 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
284 |
+
|
285 |
+
# masking
|
286 |
+
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t)
|
287 |
+
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
288 |
+
|
289 |
+
if callbacks is not None:
|
290 |
+
for callback in callbacks:
|
291 |
+
callback(i)
|
292 |
+
|
293 |
+
# scale and decode the image latents with vae
|
294 |
+
latents = 1 / 0.18215 * latents
|
295 |
+
image = self.vae.decode(latents).sample
|
296 |
+
|
297 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
298 |
+
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
299 |
+
|
300 |
+
# run safety checker
|
301 |
+
safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
|
302 |
+
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values)
|
303 |
+
|
304 |
+
if output_type == "pil":
|
305 |
+
image = self.numpy_to_pil(image)
|
306 |
+
|
307 |
+
if not return_dict:
|
308 |
+
return (image, has_nsfw_concept)
|
309 |
+
|
310 |
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
lama_cleaner/model/utils.py
ADDED
@@ -0,0 +1,709 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Any
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
import collections
|
7 |
+
from itertools import repeat
|
8 |
+
|
9 |
+
from torch import conv2d, conv_transpose2d
|
10 |
+
|
11 |
+
|
12 |
+
def make_beta_schedule(device, schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
13 |
+
if schedule == "linear":
|
14 |
+
betas = (
|
15 |
+
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
|
16 |
+
)
|
17 |
+
|
18 |
+
elif schedule == "cosine":
|
19 |
+
timesteps = (torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s).to(device)
|
20 |
+
alphas = timesteps / (1 + cosine_s) * np.pi / 2
|
21 |
+
alphas = torch.cos(alphas).pow(2).to(device)
|
22 |
+
alphas = alphas / alphas[0]
|
23 |
+
betas = 1 - alphas[1:] / alphas[:-1]
|
24 |
+
betas = np.clip(betas, a_min=0, a_max=0.999)
|
25 |
+
|
26 |
+
elif schedule == "sqrt_linear":
|
27 |
+
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
|
28 |
+
elif schedule == "sqrt":
|
29 |
+
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
|
30 |
+
else:
|
31 |
+
raise ValueError(f"schedule '{schedule}' unknown.")
|
32 |
+
return betas.numpy()
|
33 |
+
|
34 |
+
|
35 |
+
def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
|
36 |
+
# select alphas for computing the variance schedule
|
37 |
+
alphas = alphacums[ddim_timesteps]
|
38 |
+
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
|
39 |
+
|
40 |
+
# according the the formula provided in https://arxiv.org/abs/2010.02502
|
41 |
+
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
|
42 |
+
if verbose:
|
43 |
+
print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
|
44 |
+
print(f'For the chosen value of eta, which is {eta}, '
|
45 |
+
f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
|
46 |
+
return sigmas, alphas, alphas_prev
|
47 |
+
|
48 |
+
|
49 |
+
def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
|
50 |
+
if ddim_discr_method == 'uniform':
|
51 |
+
c = num_ddpm_timesteps // num_ddim_timesteps
|
52 |
+
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
|
53 |
+
elif ddim_discr_method == 'quad':
|
54 |
+
ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
|
55 |
+
else:
|
56 |
+
raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
|
57 |
+
|
58 |
+
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
|
59 |
+
# add one to get the final alpha values right (the ones from first scale to data during sampling)
|
60 |
+
steps_out = ddim_timesteps + 1
|
61 |
+
if verbose:
|
62 |
+
print(f'Selected timesteps for ddim sampler: {steps_out}')
|
63 |
+
return steps_out
|
64 |
+
|
65 |
+
|
66 |
+
def noise_like(shape, device, repeat=False):
|
67 |
+
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
|
68 |
+
noise = lambda: torch.randn(shape, device=device)
|
69 |
+
return repeat_noise() if repeat else noise()
|
70 |
+
|
71 |
+
|
72 |
+
def timestep_embedding(device, timesteps, dim, max_period=10000, repeat_only=False):
|
73 |
+
"""
|
74 |
+
Create sinusoidal timestep embeddings.
|
75 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
76 |
+
These may be fractional.
|
77 |
+
:param dim: the dimension of the output.
|
78 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
79 |
+
:return: an [N x dim] Tensor of positional embeddings.
|
80 |
+
"""
|
81 |
+
half = dim // 2
|
82 |
+
freqs = torch.exp(
|
83 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
84 |
+
).to(device=device)
|
85 |
+
|
86 |
+
args = timesteps[:, None].float() * freqs[None]
|
87 |
+
|
88 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
89 |
+
if dim % 2:
|
90 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
91 |
+
return embedding
|
92 |
+
|
93 |
+
|
94 |
+
###### MAT and FcF #######
|
95 |
+
|
96 |
+
|
97 |
+
def normalize_2nd_moment(x, dim=1, eps=1e-8):
|
98 |
+
return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt()
|
99 |
+
|
100 |
+
|
101 |
+
class EasyDict(dict):
|
102 |
+
"""Convenience class that behaves like a dict but allows access with the attribute syntax."""
|
103 |
+
|
104 |
+
def __getattr__(self, name: str) -> Any:
|
105 |
+
try:
|
106 |
+
return self[name]
|
107 |
+
except KeyError:
|
108 |
+
raise AttributeError(name)
|
109 |
+
|
110 |
+
def __setattr__(self, name: str, value: Any) -> None:
|
111 |
+
self[name] = value
|
112 |
+
|
113 |
+
def __delattr__(self, name: str) -> None:
|
114 |
+
del self[name]
|
115 |
+
|
116 |
+
|
117 |
+
def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
|
118 |
+
"""Slow reference implementation of `bias_act()` using standard TensorFlow ops.
|
119 |
+
"""
|
120 |
+
assert isinstance(x, torch.Tensor)
|
121 |
+
assert clamp is None or clamp >= 0
|
122 |
+
spec = activation_funcs[act]
|
123 |
+
alpha = float(alpha if alpha is not None else spec.def_alpha)
|
124 |
+
gain = float(gain if gain is not None else spec.def_gain)
|
125 |
+
clamp = float(clamp if clamp is not None else -1)
|
126 |
+
|
127 |
+
# Add bias.
|
128 |
+
if b is not None:
|
129 |
+
assert isinstance(b, torch.Tensor) and b.ndim == 1
|
130 |
+
assert 0 <= dim < x.ndim
|
131 |
+
assert b.shape[0] == x.shape[dim]
|
132 |
+
x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
|
133 |
+
|
134 |
+
# Evaluate activation function.
|
135 |
+
alpha = float(alpha)
|
136 |
+
x = spec.func(x, alpha=alpha)
|
137 |
+
|
138 |
+
# Scale by gain.
|
139 |
+
gain = float(gain)
|
140 |
+
if gain != 1:
|
141 |
+
x = x * gain
|
142 |
+
|
143 |
+
# Clamp.
|
144 |
+
if clamp >= 0:
|
145 |
+
x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
|
146 |
+
return x
|
147 |
+
|
148 |
+
|
149 |
+
def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='ref'):
|
150 |
+
r"""Fused bias and activation function.
|
151 |
+
|
152 |
+
Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
|
153 |
+
and scales the result by `gain`. Each of the steps is optional. In most cases,
|
154 |
+
the fused op is considerably more efficient than performing the same calculation
|
155 |
+
using standard PyTorch ops. It supports first and second order gradients,
|
156 |
+
but not third order gradients.
|
157 |
+
|
158 |
+
Args:
|
159 |
+
x: Input activation tensor. Can be of any shape.
|
160 |
+
b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
|
161 |
+
as `x`. The shape must be known, and it must match the dimension of `x`
|
162 |
+
corresponding to `dim`.
|
163 |
+
dim: The dimension in `x` corresponding to the elements of `b`.
|
164 |
+
The value of `dim` is ignored if `b` is not specified.
|
165 |
+
act: Name of the activation function to evaluate, or `"linear"` to disable.
|
166 |
+
Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
|
167 |
+
See `activation_funcs` for a full list. `None` is not allowed.
|
168 |
+
alpha: Shape parameter for the activation function, or `None` to use the default.
|
169 |
+
gain: Scaling factor for the output tensor, or `None` to use default.
|
170 |
+
See `activation_funcs` for the default scaling of each activation function.
|
171 |
+
If unsure, consider specifying 1.
|
172 |
+
clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
|
173 |
+
the clamping (default).
|
174 |
+
impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
|
175 |
+
|
176 |
+
Returns:
|
177 |
+
Tensor of the same shape and datatype as `x`.
|
178 |
+
"""
|
179 |
+
assert isinstance(x, torch.Tensor)
|
180 |
+
assert impl in ['ref', 'cuda']
|
181 |
+
return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
|
182 |
+
|
183 |
+
|
184 |
+
def _get_filter_size(f):
|
185 |
+
if f is None:
|
186 |
+
return 1, 1
|
187 |
+
|
188 |
+
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
|
189 |
+
fw = f.shape[-1]
|
190 |
+
fh = f.shape[0]
|
191 |
+
|
192 |
+
fw = int(fw)
|
193 |
+
fh = int(fh)
|
194 |
+
assert fw >= 1 and fh >= 1
|
195 |
+
return fw, fh
|
196 |
+
|
197 |
+
|
198 |
+
def _get_weight_shape(w):
|
199 |
+
shape = [int(sz) for sz in w.shape]
|
200 |
+
return shape
|
201 |
+
|
202 |
+
|
203 |
+
def _parse_scaling(scaling):
|
204 |
+
if isinstance(scaling, int):
|
205 |
+
scaling = [scaling, scaling]
|
206 |
+
assert isinstance(scaling, (list, tuple))
|
207 |
+
assert all(isinstance(x, int) for x in scaling)
|
208 |
+
sx, sy = scaling
|
209 |
+
assert sx >= 1 and sy >= 1
|
210 |
+
return sx, sy
|
211 |
+
|
212 |
+
|
213 |
+
def _parse_padding(padding):
|
214 |
+
if isinstance(padding, int):
|
215 |
+
padding = [padding, padding]
|
216 |
+
assert isinstance(padding, (list, tuple))
|
217 |
+
assert all(isinstance(x, int) for x in padding)
|
218 |
+
if len(padding) == 2:
|
219 |
+
padx, pady = padding
|
220 |
+
padding = [padx, padx, pady, pady]
|
221 |
+
padx0, padx1, pady0, pady1 = padding
|
222 |
+
return padx0, padx1, pady0, pady1
|
223 |
+
|
224 |
+
|
225 |
+
def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None):
|
226 |
+
r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`.
|
227 |
+
|
228 |
+
Args:
|
229 |
+
f: Torch tensor, numpy array, or python list of the shape
|
230 |
+
`[filter_height, filter_width]` (non-separable),
|
231 |
+
`[filter_taps]` (separable),
|
232 |
+
`[]` (impulse), or
|
233 |
+
`None` (identity).
|
234 |
+
device: Result device (default: cpu).
|
235 |
+
normalize: Normalize the filter so that it retains the magnitude
|
236 |
+
for constant input signal (DC)? (default: True).
|
237 |
+
flip_filter: Flip the filter? (default: False).
|
238 |
+
gain: Overall scaling factor for signal magnitude (default: 1).
|
239 |
+
separable: Return a separable filter? (default: select automatically).
|
240 |
+
|
241 |
+
Returns:
|
242 |
+
Float32 tensor of the shape
|
243 |
+
`[filter_height, filter_width]` (non-separable) or
|
244 |
+
`[filter_taps]` (separable).
|
245 |
+
"""
|
246 |
+
# Validate.
|
247 |
+
if f is None:
|
248 |
+
f = 1
|
249 |
+
f = torch.as_tensor(f, dtype=torch.float32)
|
250 |
+
assert f.ndim in [0, 1, 2]
|
251 |
+
assert f.numel() > 0
|
252 |
+
if f.ndim == 0:
|
253 |
+
f = f[np.newaxis]
|
254 |
+
|
255 |
+
# Separable?
|
256 |
+
if separable is None:
|
257 |
+
separable = (f.ndim == 1 and f.numel() >= 8)
|
258 |
+
if f.ndim == 1 and not separable:
|
259 |
+
f = f.ger(f)
|
260 |
+
assert f.ndim == (1 if separable else 2)
|
261 |
+
|
262 |
+
# Apply normalize, flip, gain, and device.
|
263 |
+
if normalize:
|
264 |
+
f /= f.sum()
|
265 |
+
if flip_filter:
|
266 |
+
f = f.flip(list(range(f.ndim)))
|
267 |
+
f = f * (gain ** (f.ndim / 2))
|
268 |
+
f = f.to(device=device)
|
269 |
+
return f
|
270 |
+
|
271 |
+
|
272 |
+
def _ntuple(n):
|
273 |
+
def parse(x):
|
274 |
+
if isinstance(x, collections.abc.Iterable):
|
275 |
+
return x
|
276 |
+
return tuple(repeat(x, n))
|
277 |
+
|
278 |
+
return parse
|
279 |
+
|
280 |
+
|
281 |
+
to_2tuple = _ntuple(2)
|
282 |
+
|
283 |
+
activation_funcs = {
|
284 |
+
'linear': EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
|
285 |
+
'relu': EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2,
|
286 |
+
ref='y', has_2nd_grad=False),
|
287 |
+
'lrelu': EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2,
|
288 |
+
def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
|
289 |
+
'tanh': EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y',
|
290 |
+
has_2nd_grad=True),
|
291 |
+
'sigmoid': EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y',
|
292 |
+
has_2nd_grad=True),
|
293 |
+
'elu': EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y',
|
294 |
+
has_2nd_grad=True),
|
295 |
+
'selu': EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y',
|
296 |
+
has_2nd_grad=True),
|
297 |
+
'softplus': EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8,
|
298 |
+
ref='y', has_2nd_grad=True),
|
299 |
+
'swish': EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x',
|
300 |
+
has_2nd_grad=True),
|
301 |
+
}
|
302 |
+
|
303 |
+
|
304 |
+
def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
305 |
+
r"""Pad, upsample, filter, and downsample a batch of 2D images.
|
306 |
+
|
307 |
+
Performs the following sequence of operations for each channel:
|
308 |
+
|
309 |
+
1. Upsample the image by inserting N-1 zeros after each pixel (`up`).
|
310 |
+
|
311 |
+
2. Pad the image with the specified number of zeros on each side (`padding`).
|
312 |
+
Negative padding corresponds to cropping the image.
|
313 |
+
|
314 |
+
3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it
|
315 |
+
so that the footprint of all output pixels lies within the input image.
|
316 |
+
|
317 |
+
4. Downsample the image by keeping every Nth pixel (`down`).
|
318 |
+
|
319 |
+
This sequence of operations bears close resemblance to scipy.signal.upfirdn().
|
320 |
+
The fused op is considerably more efficient than performing the same calculation
|
321 |
+
using standard PyTorch ops. It supports gradients of arbitrary order.
|
322 |
+
|
323 |
+
Args:
|
324 |
+
x: Float32/float64/float16 input tensor of the shape
|
325 |
+
`[batch_size, num_channels, in_height, in_width]`.
|
326 |
+
f: Float32 FIR filter of the shape
|
327 |
+
`[filter_height, filter_width]` (non-separable),
|
328 |
+
`[filter_taps]` (separable), or
|
329 |
+
`None` (identity).
|
330 |
+
up: Integer upsampling factor. Can be a single int or a list/tuple
|
331 |
+
`[x, y]` (default: 1).
|
332 |
+
down: Integer downsampling factor. Can be a single int or a list/tuple
|
333 |
+
`[x, y]` (default: 1).
|
334 |
+
padding: Padding with respect to the upsampled image. Can be a single number
|
335 |
+
or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
336 |
+
(default: 0).
|
337 |
+
flip_filter: False = convolution, True = correlation (default: False).
|
338 |
+
gain: Overall scaling factor for signal magnitude (default: 1).
|
339 |
+
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
340 |
+
|
341 |
+
Returns:
|
342 |
+
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
343 |
+
"""
|
344 |
+
# assert isinstance(x, torch.Tensor)
|
345 |
+
# assert impl in ['ref', 'cuda']
|
346 |
+
return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain)
|
347 |
+
|
348 |
+
|
349 |
+
def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
|
350 |
+
"""Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.
|
351 |
+
"""
|
352 |
+
# Validate arguments.
|
353 |
+
assert isinstance(x, torch.Tensor) and x.ndim == 4
|
354 |
+
if f is None:
|
355 |
+
f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
|
356 |
+
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
|
357 |
+
assert f.dtype == torch.float32 and not f.requires_grad
|
358 |
+
batch_size, num_channels, in_height, in_width = x.shape
|
359 |
+
# upx, upy = _parse_scaling(up)
|
360 |
+
# downx, downy = _parse_scaling(down)
|
361 |
+
|
362 |
+
upx, upy = up, up
|
363 |
+
downx, downy = down, down
|
364 |
+
|
365 |
+
# padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
366 |
+
padx0, padx1, pady0, pady1 = padding[0], padding[1], padding[2], padding[3]
|
367 |
+
|
368 |
+
# Upsample by inserting zeros.
|
369 |
+
x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
|
370 |
+
x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
|
371 |
+
x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
|
372 |
+
|
373 |
+
# Pad or crop.
|
374 |
+
x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)])
|
375 |
+
x = x[:, :, max(-pady0, 0): x.shape[2] - max(-pady1, 0), max(-padx0, 0): x.shape[3] - max(-padx1, 0)]
|
376 |
+
|
377 |
+
# Setup filter.
|
378 |
+
f = f * (gain ** (f.ndim / 2))
|
379 |
+
f = f.to(x.dtype)
|
380 |
+
if not flip_filter:
|
381 |
+
f = f.flip(list(range(f.ndim)))
|
382 |
+
|
383 |
+
# Convolve with the filter.
|
384 |
+
f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
|
385 |
+
if f.ndim == 4:
|
386 |
+
x = conv2d(input=x, weight=f, groups=num_channels)
|
387 |
+
else:
|
388 |
+
x = conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
|
389 |
+
x = conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
|
390 |
+
|
391 |
+
# Downsample by throwing away pixels.
|
392 |
+
x = x[:, :, ::downy, ::downx]
|
393 |
+
return x
|
394 |
+
|
395 |
+
|
396 |
+
def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
397 |
+
r"""Downsample a batch of 2D images using the given 2D FIR filter.
|
398 |
+
|
399 |
+
By default, the result is padded so that its shape is a fraction of the input.
|
400 |
+
User-specified padding is applied on top of that, with negative values
|
401 |
+
indicating cropping. Pixels outside the image are assumed to be zero.
|
402 |
+
|
403 |
+
Args:
|
404 |
+
x: Float32/float64/float16 input tensor of the shape
|
405 |
+
`[batch_size, num_channels, in_height, in_width]`.
|
406 |
+
f: Float32 FIR filter of the shape
|
407 |
+
`[filter_height, filter_width]` (non-separable),
|
408 |
+
`[filter_taps]` (separable), or
|
409 |
+
`None` (identity).
|
410 |
+
down: Integer downsampling factor. Can be a single int or a list/tuple
|
411 |
+
`[x, y]` (default: 1).
|
412 |
+
padding: Padding with respect to the input. Can be a single number or a
|
413 |
+
list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
414 |
+
(default: 0).
|
415 |
+
flip_filter: False = convolution, True = correlation (default: False).
|
416 |
+
gain: Overall scaling factor for signal magnitude (default: 1).
|
417 |
+
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
418 |
+
|
419 |
+
Returns:
|
420 |
+
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
421 |
+
"""
|
422 |
+
downx, downy = _parse_scaling(down)
|
423 |
+
# padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
424 |
+
padx0, padx1, pady0, pady1 = padding, padding, padding, padding
|
425 |
+
|
426 |
+
fw, fh = _get_filter_size(f)
|
427 |
+
p = [
|
428 |
+
padx0 + (fw - downx + 1) // 2,
|
429 |
+
padx1 + (fw - downx) // 2,
|
430 |
+
pady0 + (fh - downy + 1) // 2,
|
431 |
+
pady1 + (fh - downy) // 2,
|
432 |
+
]
|
433 |
+
return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
|
434 |
+
|
435 |
+
|
436 |
+
def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
437 |
+
r"""Upsample a batch of 2D images using the given 2D FIR filter.
|
438 |
+
|
439 |
+
By default, the result is padded so that its shape is a multiple of the input.
|
440 |
+
User-specified padding is applied on top of that, with negative values
|
441 |
+
indicating cropping. Pixels outside the image are assumed to be zero.
|
442 |
+
|
443 |
+
Args:
|
444 |
+
x: Float32/float64/float16 input tensor of the shape
|
445 |
+
`[batch_size, num_channels, in_height, in_width]`.
|
446 |
+
f: Float32 FIR filter of the shape
|
447 |
+
`[filter_height, filter_width]` (non-separable),
|
448 |
+
`[filter_taps]` (separable), or
|
449 |
+
`None` (identity).
|
450 |
+
up: Integer upsampling factor. Can be a single int or a list/tuple
|
451 |
+
`[x, y]` (default: 1).
|
452 |
+
padding: Padding with respect to the output. Can be a single number or a
|
453 |
+
list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
454 |
+
(default: 0).
|
455 |
+
flip_filter: False = convolution, True = correlation (default: False).
|
456 |
+
gain: Overall scaling factor for signal magnitude (default: 1).
|
457 |
+
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
458 |
+
|
459 |
+
Returns:
|
460 |
+
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
461 |
+
"""
|
462 |
+
upx, upy = _parse_scaling(up)
|
463 |
+
# upx, upy = up, up
|
464 |
+
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
465 |
+
# padx0, padx1, pady0, pady1 = padding, padding, padding, padding
|
466 |
+
fw, fh = _get_filter_size(f)
|
467 |
+
p = [
|
468 |
+
padx0 + (fw + upx - 1) // 2,
|
469 |
+
padx1 + (fw - upx) // 2,
|
470 |
+
pady0 + (fh + upy - 1) // 2,
|
471 |
+
pady1 + (fh - upy) // 2,
|
472 |
+
]
|
473 |
+
return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain * upx * upy, impl=impl)
|
474 |
+
|
475 |
+
|
476 |
+
class MinibatchStdLayer(torch.nn.Module):
|
477 |
+
def __init__(self, group_size, num_channels=1):
|
478 |
+
super().__init__()
|
479 |
+
self.group_size = group_size
|
480 |
+
self.num_channels = num_channels
|
481 |
+
|
482 |
+
def forward(self, x):
|
483 |
+
N, C, H, W = x.shape
|
484 |
+
G = torch.min(torch.as_tensor(self.group_size),
|
485 |
+
torch.as_tensor(N)) if self.group_size is not None else N
|
486 |
+
F = self.num_channels
|
487 |
+
c = C // F
|
488 |
+
|
489 |
+
y = x.reshape(G, -1, F, c, H,
|
490 |
+
W) # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c.
|
491 |
+
y = y - y.mean(dim=0) # [GnFcHW] Subtract mean over group.
|
492 |
+
y = y.square().mean(dim=0) # [nFcHW] Calc variance over group.
|
493 |
+
y = (y + 1e-8).sqrt() # [nFcHW] Calc stddev over group.
|
494 |
+
y = y.mean(dim=[2, 3, 4]) # [nF] Take average over channels and pixels.
|
495 |
+
y = y.reshape(-1, F, 1, 1) # [nF11] Add missing dimensions.
|
496 |
+
y = y.repeat(G, 1, H, W) # [NFHW] Replicate over group and pixels.
|
497 |
+
x = torch.cat([x, y], dim=1) # [NCHW] Append to input as new channels.
|
498 |
+
return x
|
499 |
+
|
500 |
+
|
501 |
+
class FullyConnectedLayer(torch.nn.Module):
|
502 |
+
def __init__(self,
|
503 |
+
in_features, # Number of input features.
|
504 |
+
out_features, # Number of output features.
|
505 |
+
bias=True, # Apply additive bias before the activation function?
|
506 |
+
activation='linear', # Activation function: 'relu', 'lrelu', etc.
|
507 |
+
lr_multiplier=1, # Learning rate multiplier.
|
508 |
+
bias_init=0, # Initial value for the additive bias.
|
509 |
+
):
|
510 |
+
super().__init__()
|
511 |
+
self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier)
|
512 |
+
self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None
|
513 |
+
self.activation = activation
|
514 |
+
|
515 |
+
self.weight_gain = lr_multiplier / np.sqrt(in_features)
|
516 |
+
self.bias_gain = lr_multiplier
|
517 |
+
|
518 |
+
def forward(self, x):
|
519 |
+
w = self.weight * self.weight_gain
|
520 |
+
b = self.bias
|
521 |
+
if b is not None and self.bias_gain != 1:
|
522 |
+
b = b * self.bias_gain
|
523 |
+
|
524 |
+
if self.activation == 'linear' and b is not None:
|
525 |
+
# out = torch.addmm(b.unsqueeze(0), x, w.t())
|
526 |
+
x = x.matmul(w.t())
|
527 |
+
out = x + b.reshape([-1 if i == x.ndim - 1 else 1 for i in range(x.ndim)])
|
528 |
+
else:
|
529 |
+
x = x.matmul(w.t())
|
530 |
+
out = bias_act(x, b, act=self.activation, dim=x.ndim - 1)
|
531 |
+
return out
|
532 |
+
|
533 |
+
|
534 |
+
def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):
|
535 |
+
"""Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
|
536 |
+
"""
|
537 |
+
out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
|
538 |
+
|
539 |
+
# Flip weight if requested.
|
540 |
+
if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
|
541 |
+
w = w.flip([2, 3])
|
542 |
+
|
543 |
+
# Workaround performance pitfall in cuDNN 8.0.5, triggered when using
|
544 |
+
# 1x1 kernel + memory_format=channels_last + less than 64 channels.
|
545 |
+
if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose:
|
546 |
+
if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64:
|
547 |
+
if out_channels <= 4 and groups == 1:
|
548 |
+
in_shape = x.shape
|
549 |
+
x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1])
|
550 |
+
x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]])
|
551 |
+
else:
|
552 |
+
x = x.to(memory_format=torch.contiguous_format)
|
553 |
+
w = w.to(memory_format=torch.contiguous_format)
|
554 |
+
x = conv2d(x, w, groups=groups)
|
555 |
+
return x.to(memory_format=torch.channels_last)
|
556 |
+
|
557 |
+
# Otherwise => execute using conv2d_gradfix.
|
558 |
+
op = conv_transpose2d if transpose else conv2d
|
559 |
+
return op(x, w, stride=stride, padding=padding, groups=groups)
|
560 |
+
|
561 |
+
|
562 |
+
def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):
|
563 |
+
r"""2D convolution with optional up/downsampling.
|
564 |
+
|
565 |
+
Padding is performed only once at the beginning, not between the operations.
|
566 |
+
|
567 |
+
Args:
|
568 |
+
x: Input tensor of shape
|
569 |
+
`[batch_size, in_channels, in_height, in_width]`.
|
570 |
+
w: Weight tensor of shape
|
571 |
+
`[out_channels, in_channels//groups, kernel_height, kernel_width]`.
|
572 |
+
f: Low-pass filter for up/downsampling. Must be prepared beforehand by
|
573 |
+
calling setup_filter(). None = identity (default).
|
574 |
+
up: Integer upsampling factor (default: 1).
|
575 |
+
down: Integer downsampling factor (default: 1).
|
576 |
+
padding: Padding with respect to the upsampled image. Can be a single number
|
577 |
+
or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
578 |
+
(default: 0).
|
579 |
+
groups: Split input channels into N groups (default: 1).
|
580 |
+
flip_weight: False = convolution, True = correlation (default: True).
|
581 |
+
flip_filter: False = convolution, True = correlation (default: False).
|
582 |
+
|
583 |
+
Returns:
|
584 |
+
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
585 |
+
"""
|
586 |
+
# Validate arguments.
|
587 |
+
assert isinstance(x, torch.Tensor) and (x.ndim == 4)
|
588 |
+
assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
|
589 |
+
assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
|
590 |
+
assert isinstance(up, int) and (up >= 1)
|
591 |
+
assert isinstance(down, int) and (down >= 1)
|
592 |
+
# assert isinstance(groups, int) and (groups >= 1), f"!!!!!! groups: {groups} isinstance(groups, int) {isinstance(groups, int)} {type(groups)}"
|
593 |
+
out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
|
594 |
+
fw, fh = _get_filter_size(f)
|
595 |
+
# px0, px1, py0, py1 = _parse_padding(padding)
|
596 |
+
px0, px1, py0, py1 = padding, padding, padding, padding
|
597 |
+
|
598 |
+
# Adjust padding to account for up/downsampling.
|
599 |
+
if up > 1:
|
600 |
+
px0 += (fw + up - 1) // 2
|
601 |
+
px1 += (fw - up) // 2
|
602 |
+
py0 += (fh + up - 1) // 2
|
603 |
+
py1 += (fh - up) // 2
|
604 |
+
if down > 1:
|
605 |
+
px0 += (fw - down + 1) // 2
|
606 |
+
px1 += (fw - down) // 2
|
607 |
+
py0 += (fh - down + 1) // 2
|
608 |
+
py1 += (fh - down) // 2
|
609 |
+
|
610 |
+
# Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
|
611 |
+
if kw == 1 and kh == 1 and (down > 1 and up == 1):
|
612 |
+
x = upfirdn2d(x=x, f=f, down=down, padding=[px0, px1, py0, py1], flip_filter=flip_filter)
|
613 |
+
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
|
614 |
+
return x
|
615 |
+
|
616 |
+
# Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
|
617 |
+
if kw == 1 and kh == 1 and (up > 1 and down == 1):
|
618 |
+
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
|
619 |
+
x = upfirdn2d(x=x, f=f, up=up, padding=[px0, px1, py0, py1], gain=up ** 2, flip_filter=flip_filter)
|
620 |
+
return x
|
621 |
+
|
622 |
+
# Fast path: downsampling only => use strided convolution.
|
623 |
+
if down > 1 and up == 1:
|
624 |
+
x = upfirdn2d(x=x, f=f, padding=[px0, px1, py0, py1], flip_filter=flip_filter)
|
625 |
+
x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)
|
626 |
+
return x
|
627 |
+
|
628 |
+
# Fast path: upsampling with optional downsampling => use transpose strided convolution.
|
629 |
+
if up > 1:
|
630 |
+
if groups == 1:
|
631 |
+
w = w.transpose(0, 1)
|
632 |
+
else:
|
633 |
+
w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
|
634 |
+
w = w.transpose(1, 2)
|
635 |
+
w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
|
636 |
+
px0 -= kw - 1
|
637 |
+
px1 -= kw - up
|
638 |
+
py0 -= kh - 1
|
639 |
+
py1 -= kh - up
|
640 |
+
pxt = max(min(-px0, -px1), 0)
|
641 |
+
pyt = max(min(-py0, -py1), 0)
|
642 |
+
x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt, pxt], groups=groups, transpose=True,
|
643 |
+
flip_weight=(not flip_weight))
|
644 |
+
x = upfirdn2d(x=x, f=f, padding=[px0 + pxt, px1 + pxt, py0 + pyt, py1 + pyt], gain=up ** 2,
|
645 |
+
flip_filter=flip_filter)
|
646 |
+
if down > 1:
|
647 |
+
x = upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
|
648 |
+
return x
|
649 |
+
|
650 |
+
# Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
|
651 |
+
if up == 1 and down == 1:
|
652 |
+
if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
|
653 |
+
return _conv2d_wrapper(x=x, w=w, padding=[py0, px0], groups=groups, flip_weight=flip_weight)
|
654 |
+
|
655 |
+
# Fallback: Generic reference implementation.
|
656 |
+
x = upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0, px1, py0, py1], gain=up ** 2,
|
657 |
+
flip_filter=flip_filter)
|
658 |
+
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
|
659 |
+
if down > 1:
|
660 |
+
x = upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
|
661 |
+
return x
|
662 |
+
|
663 |
+
|
664 |
+
class Conv2dLayer(torch.nn.Module):
|
665 |
+
def __init__(self,
|
666 |
+
in_channels, # Number of input channels.
|
667 |
+
out_channels, # Number of output channels.
|
668 |
+
kernel_size, # Width and height of the convolution kernel.
|
669 |
+
bias=True, # Apply additive bias before the activation function?
|
670 |
+
activation='linear', # Activation function: 'relu', 'lrelu', etc.
|
671 |
+
up=1, # Integer upsampling factor.
|
672 |
+
down=1, # Integer downsampling factor.
|
673 |
+
resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations.
|
674 |
+
conv_clamp=None, # Clamp the output to +-X, None = disable clamping.
|
675 |
+
channels_last=False, # Expect the input to have memory_format=channels_last?
|
676 |
+
trainable=True, # Update the weights of this layer during training?
|
677 |
+
):
|
678 |
+
super().__init__()
|
679 |
+
self.activation = activation
|
680 |
+
self.up = up
|
681 |
+
self.down = down
|
682 |
+
self.register_buffer('resample_filter', setup_filter(resample_filter))
|
683 |
+
self.conv_clamp = conv_clamp
|
684 |
+
self.padding = kernel_size // 2
|
685 |
+
self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2))
|
686 |
+
self.act_gain = activation_funcs[activation].def_gain
|
687 |
+
|
688 |
+
memory_format = torch.channels_last if channels_last else torch.contiguous_format
|
689 |
+
weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format)
|
690 |
+
bias = torch.zeros([out_channels]) if bias else None
|
691 |
+
if trainable:
|
692 |
+
self.weight = torch.nn.Parameter(weight)
|
693 |
+
self.bias = torch.nn.Parameter(bias) if bias is not None else None
|
694 |
+
else:
|
695 |
+
self.register_buffer('weight', weight)
|
696 |
+
if bias is not None:
|
697 |
+
self.register_buffer('bias', bias)
|
698 |
+
else:
|
699 |
+
self.bias = None
|
700 |
+
|
701 |
+
def forward(self, x, gain=1):
|
702 |
+
w = self.weight * self.weight_gain
|
703 |
+
x = conv2d_resample(x=x, w=w, f=self.resample_filter, up=self.up, down=self.down,
|
704 |
+
padding=self.padding)
|
705 |
+
|
706 |
+
act_gain = self.act_gain * gain
|
707 |
+
act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
|
708 |
+
out = bias_act(x, self.bias, act=self.activation, gain=act_gain, clamp=act_clamp)
|
709 |
+
return out
|
lama_cleaner/model/zits.py
ADDED
@@ -0,0 +1,427 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import skimage
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
from lama_cleaner.helper import get_cache_path_by_url, load_jit_model
|
10 |
+
from lama_cleaner.schema import Config
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
from lama_cleaner.model.base import InpaintModel
|
14 |
+
|
15 |
+
ZITS_INPAINT_MODEL_URL = os.environ.get(
|
16 |
+
"ZITS_INPAINT_MODEL_URL",
|
17 |
+
"https://github.com/Sanster/models/releases/download/add_zits/zits-inpaint-0717.pt",
|
18 |
+
)
|
19 |
+
|
20 |
+
ZITS_EDGE_LINE_MODEL_URL = os.environ.get(
|
21 |
+
"ZITS_EDGE_LINE_MODEL_URL",
|
22 |
+
"https://github.com/Sanster/models/releases/download/add_zits/zits-edge-line-0717.pt",
|
23 |
+
)
|
24 |
+
|
25 |
+
ZITS_STRUCTURE_UPSAMPLE_MODEL_URL = os.environ.get(
|
26 |
+
"ZITS_STRUCTURE_UPSAMPLE_MODEL_URL",
|
27 |
+
"https://github.com/Sanster/models/releases/download/add_zits/zits-structure-upsample-0717.pt",
|
28 |
+
)
|
29 |
+
|
30 |
+
ZITS_WIRE_FRAME_MODEL_URL = os.environ.get(
|
31 |
+
"ZITS_WIRE_FRAME_MODEL_URL",
|
32 |
+
"https://github.com/Sanster/models/releases/download/add_zits/zits-wireframe-0717.pt",
|
33 |
+
)
|
34 |
+
|
35 |
+
|
36 |
+
def resize(img, height, width, center_crop=False):
|
37 |
+
imgh, imgw = img.shape[0:2]
|
38 |
+
|
39 |
+
if center_crop and imgh != imgw:
|
40 |
+
# center crop
|
41 |
+
side = np.minimum(imgh, imgw)
|
42 |
+
j = (imgh - side) // 2
|
43 |
+
i = (imgw - side) // 2
|
44 |
+
img = img[j : j + side, i : i + side, ...]
|
45 |
+
|
46 |
+
if imgh > height and imgw > width:
|
47 |
+
inter = cv2.INTER_AREA
|
48 |
+
else:
|
49 |
+
inter = cv2.INTER_LINEAR
|
50 |
+
img = cv2.resize(img, (height, width), interpolation=inter)
|
51 |
+
|
52 |
+
return img
|
53 |
+
|
54 |
+
|
55 |
+
def to_tensor(img, scale=True, norm=False):
|
56 |
+
if img.ndim == 2:
|
57 |
+
img = img[:, :, np.newaxis]
|
58 |
+
c = img.shape[-1]
|
59 |
+
|
60 |
+
if scale:
|
61 |
+
img_t = torch.from_numpy(img).permute(2, 0, 1).float().div(255)
|
62 |
+
else:
|
63 |
+
img_t = torch.from_numpy(img).permute(2, 0, 1).float()
|
64 |
+
|
65 |
+
if norm:
|
66 |
+
mean = torch.tensor([0.5, 0.5, 0.5]).reshape(c, 1, 1)
|
67 |
+
std = torch.tensor([0.5, 0.5, 0.5]).reshape(c, 1, 1)
|
68 |
+
img_t = (img_t - mean) / std
|
69 |
+
return img_t
|
70 |
+
|
71 |
+
|
72 |
+
def load_masked_position_encoding(mask):
|
73 |
+
ones_filter = np.ones((3, 3), dtype=np.float32)
|
74 |
+
d_filter1 = np.array([[1, 1, 0], [1, 1, 0], [0, 0, 0]], dtype=np.float32)
|
75 |
+
d_filter2 = np.array([[0, 0, 0], [1, 1, 0], [1, 1, 0]], dtype=np.float32)
|
76 |
+
d_filter3 = np.array([[0, 1, 1], [0, 1, 1], [0, 0, 0]], dtype=np.float32)
|
77 |
+
d_filter4 = np.array([[0, 0, 0], [0, 1, 1], [0, 1, 1]], dtype=np.float32)
|
78 |
+
str_size = 256
|
79 |
+
pos_num = 128
|
80 |
+
|
81 |
+
ori_mask = mask.copy()
|
82 |
+
ori_h, ori_w = ori_mask.shape[0:2]
|
83 |
+
ori_mask = ori_mask / 255
|
84 |
+
mask = cv2.resize(mask, (str_size, str_size), interpolation=cv2.INTER_AREA)
|
85 |
+
mask[mask > 0] = 255
|
86 |
+
h, w = mask.shape[0:2]
|
87 |
+
mask3 = mask.copy()
|
88 |
+
mask3 = 1.0 - (mask3 / 255.0)
|
89 |
+
pos = np.zeros((h, w), dtype=np.int32)
|
90 |
+
direct = np.zeros((h, w, 4), dtype=np.int32)
|
91 |
+
i = 0
|
92 |
+
while np.sum(1 - mask3) > 0:
|
93 |
+
i += 1
|
94 |
+
mask3_ = cv2.filter2D(mask3, -1, ones_filter)
|
95 |
+
mask3_[mask3_ > 0] = 1
|
96 |
+
sub_mask = mask3_ - mask3
|
97 |
+
pos[sub_mask == 1] = i
|
98 |
+
|
99 |
+
m = cv2.filter2D(mask3, -1, d_filter1)
|
100 |
+
m[m > 0] = 1
|
101 |
+
m = m - mask3
|
102 |
+
direct[m == 1, 0] = 1
|
103 |
+
|
104 |
+
m = cv2.filter2D(mask3, -1, d_filter2)
|
105 |
+
m[m > 0] = 1
|
106 |
+
m = m - mask3
|
107 |
+
direct[m == 1, 1] = 1
|
108 |
+
|
109 |
+
m = cv2.filter2D(mask3, -1, d_filter3)
|
110 |
+
m[m > 0] = 1
|
111 |
+
m = m - mask3
|
112 |
+
direct[m == 1, 2] = 1
|
113 |
+
|
114 |
+
m = cv2.filter2D(mask3, -1, d_filter4)
|
115 |
+
m[m > 0] = 1
|
116 |
+
m = m - mask3
|
117 |
+
direct[m == 1, 3] = 1
|
118 |
+
|
119 |
+
mask3 = mask3_
|
120 |
+
|
121 |
+
abs_pos = pos.copy()
|
122 |
+
rel_pos = pos / (str_size / 2) # to 0~1 maybe larger than 1
|
123 |
+
rel_pos = (rel_pos * pos_num).astype(np.int32)
|
124 |
+
rel_pos = np.clip(rel_pos, 0, pos_num - 1)
|
125 |
+
|
126 |
+
if ori_w != w or ori_h != h:
|
127 |
+
rel_pos = cv2.resize(rel_pos, (ori_w, ori_h), interpolation=cv2.INTER_NEAREST)
|
128 |
+
rel_pos[ori_mask == 0] = 0
|
129 |
+
direct = cv2.resize(direct, (ori_w, ori_h), interpolation=cv2.INTER_NEAREST)
|
130 |
+
direct[ori_mask == 0, :] = 0
|
131 |
+
|
132 |
+
return rel_pos, abs_pos, direct
|
133 |
+
|
134 |
+
|
135 |
+
def load_image(img, mask, device, sigma256=3.0):
|
136 |
+
"""
|
137 |
+
Args:
|
138 |
+
img: [H, W, C] RGB
|
139 |
+
mask: [H, W] 255 为 masks 区域
|
140 |
+
sigma256:
|
141 |
+
|
142 |
+
Returns:
|
143 |
+
|
144 |
+
"""
|
145 |
+
h, w, _ = img.shape
|
146 |
+
imgh, imgw = img.shape[0:2]
|
147 |
+
img_256 = resize(img, 256, 256)
|
148 |
+
|
149 |
+
mask = (mask > 127).astype(np.uint8) * 255
|
150 |
+
mask_256 = cv2.resize(mask, (256, 256), interpolation=cv2.INTER_AREA)
|
151 |
+
mask_256[mask_256 > 0] = 255
|
152 |
+
|
153 |
+
mask_512 = cv2.resize(mask, (512, 512), interpolation=cv2.INTER_AREA)
|
154 |
+
mask_512[mask_512 > 0] = 255
|
155 |
+
|
156 |
+
# original skimage implemention
|
157 |
+
# https://scikit-image.org/docs/stable/api/skimage.feature.html#skimage.feature.canny
|
158 |
+
# low_threshold: Lower bound for hysteresis thresholding (linking edges). If None, low_threshold is set to 10% of dtype’s max.
|
159 |
+
# high_threshold: Upper bound for hysteresis thresholding (linking edges). If None, high_threshold is set to 20% of dtype’s max.
|
160 |
+
gray_256 = skimage.color.rgb2gray(img_256)
|
161 |
+
edge_256 = skimage.feature.canny(gray_256, sigma=sigma256, mask=None).astype(float)
|
162 |
+
# cv2.imwrite("skimage_gray.jpg", (_gray_256*255).astype(np.uint8))
|
163 |
+
# cv2.imwrite("skimage_edge.jpg", (_edge_256*255).astype(np.uint8))
|
164 |
+
|
165 |
+
# gray_256 = cv2.cvtColor(img_256, cv2.COLOR_RGB2GRAY)
|
166 |
+
# gray_256_blured = cv2.GaussianBlur(gray_256, ksize=(3,3), sigmaX=sigma256, sigmaY=sigma256)
|
167 |
+
# edge_256 = cv2.Canny(gray_256_blured, threshold1=int(255*0.1), threshold2=int(255*0.2))
|
168 |
+
# cv2.imwrite("edge.jpg", edge_256)
|
169 |
+
|
170 |
+
# line
|
171 |
+
img_512 = resize(img, 512, 512)
|
172 |
+
|
173 |
+
rel_pos, abs_pos, direct = load_masked_position_encoding(mask)
|
174 |
+
|
175 |
+
batch = dict()
|
176 |
+
batch["images"] = to_tensor(img.copy()).unsqueeze(0).to(device)
|
177 |
+
batch["img_256"] = to_tensor(img_256, norm=True).unsqueeze(0).to(device)
|
178 |
+
batch["masks"] = to_tensor(mask).unsqueeze(0).to(device)
|
179 |
+
batch["mask_256"] = to_tensor(mask_256).unsqueeze(0).to(device)
|
180 |
+
batch["mask_512"] = to_tensor(mask_512).unsqueeze(0).to(device)
|
181 |
+
batch["edge_256"] = to_tensor(edge_256, scale=False).unsqueeze(0).to(device)
|
182 |
+
batch["img_512"] = to_tensor(img_512).unsqueeze(0).to(device)
|
183 |
+
batch["rel_pos"] = torch.LongTensor(rel_pos).unsqueeze(0).to(device)
|
184 |
+
batch["abs_pos"] = torch.LongTensor(abs_pos).unsqueeze(0).to(device)
|
185 |
+
batch["direct"] = torch.LongTensor(direct).unsqueeze(0).to(device)
|
186 |
+
batch["h"] = imgh
|
187 |
+
batch["w"] = imgw
|
188 |
+
|
189 |
+
return batch
|
190 |
+
|
191 |
+
|
192 |
+
def to_device(data, device):
|
193 |
+
if isinstance(data, torch.Tensor):
|
194 |
+
return data.to(device)
|
195 |
+
if isinstance(data, dict):
|
196 |
+
for key in data:
|
197 |
+
if isinstance(data[key], torch.Tensor):
|
198 |
+
data[key] = data[key].to(device)
|
199 |
+
return data
|
200 |
+
if isinstance(data, list):
|
201 |
+
return [to_device(d, device) for d in data]
|
202 |
+
|
203 |
+
|
204 |
+
class ZITS(InpaintModel):
|
205 |
+
min_size = 256
|
206 |
+
pad_mod = 32
|
207 |
+
pad_to_square = True
|
208 |
+
|
209 |
+
def __init__(self, device, **kwargs):
|
210 |
+
"""
|
211 |
+
|
212 |
+
Args:
|
213 |
+
device:
|
214 |
+
"""
|
215 |
+
super().__init__(device)
|
216 |
+
self.device = device
|
217 |
+
self.sample_edge_line_iterations = 1
|
218 |
+
|
219 |
+
def init_model(self, device, **kwargs):
|
220 |
+
self.wireframe = load_jit_model(ZITS_WIRE_FRAME_MODEL_URL, device)
|
221 |
+
self.edge_line = load_jit_model(ZITS_EDGE_LINE_MODEL_URL, device)
|
222 |
+
self.structure_upsample = load_jit_model(
|
223 |
+
ZITS_STRUCTURE_UPSAMPLE_MODEL_URL, device
|
224 |
+
)
|
225 |
+
self.inpaint = load_jit_model(ZITS_INPAINT_MODEL_URL, device)
|
226 |
+
|
227 |
+
@staticmethod
|
228 |
+
def is_downloaded() -> bool:
|
229 |
+
model_paths = [
|
230 |
+
get_cache_path_by_url(ZITS_WIRE_FRAME_MODEL_URL),
|
231 |
+
get_cache_path_by_url(ZITS_EDGE_LINE_MODEL_URL),
|
232 |
+
get_cache_path_by_url(ZITS_STRUCTURE_UPSAMPLE_MODEL_URL),
|
233 |
+
get_cache_path_by_url(ZITS_INPAINT_MODEL_URL),
|
234 |
+
]
|
235 |
+
return all([os.path.exists(it) for it in model_paths])
|
236 |
+
|
237 |
+
def wireframe_edge_and_line(self, items, enable: bool):
|
238 |
+
# 最终向 items 中添加 edge 和 line key
|
239 |
+
if not enable:
|
240 |
+
items["edge"] = torch.zeros_like(items["masks"])
|
241 |
+
items["line"] = torch.zeros_like(items["masks"])
|
242 |
+
return
|
243 |
+
|
244 |
+
start = time.time()
|
245 |
+
try:
|
246 |
+
line_256 = self.wireframe_forward(
|
247 |
+
items["img_512"],
|
248 |
+
h=256,
|
249 |
+
w=256,
|
250 |
+
masks=items["mask_512"],
|
251 |
+
mask_th=0.85,
|
252 |
+
)
|
253 |
+
except:
|
254 |
+
line_256 = torch.zeros_like(items["mask_256"])
|
255 |
+
|
256 |
+
print(f"wireframe_forward time: {(time.time() - start) * 1000:.2f}ms")
|
257 |
+
|
258 |
+
# np_line = (line[0][0].numpy() * 255).astype(np.uint8)
|
259 |
+
# cv2.imwrite("line.jpg", np_line)
|
260 |
+
|
261 |
+
start = time.time()
|
262 |
+
edge_pred, line_pred = self.sample_edge_line_logits(
|
263 |
+
context=[items["img_256"], items["edge_256"], line_256],
|
264 |
+
mask=items["mask_256"].clone(),
|
265 |
+
iterations=self.sample_edge_line_iterations,
|
266 |
+
add_v=0.05,
|
267 |
+
mul_v=4,
|
268 |
+
)
|
269 |
+
print(f"sample_edge_line_logits time: {(time.time() - start) * 1000:.2f}ms")
|
270 |
+
|
271 |
+
# np_edge_pred = (edge_pred[0][0].numpy() * 255).astype(np.uint8)
|
272 |
+
# cv2.imwrite("edge_pred.jpg", np_edge_pred)
|
273 |
+
# np_line_pred = (line_pred[0][0].numpy() * 255).astype(np.uint8)
|
274 |
+
# cv2.imwrite("line_pred.jpg", np_line_pred)
|
275 |
+
# exit()
|
276 |
+
|
277 |
+
input_size = min(items["h"], items["w"])
|
278 |
+
if input_size != 256 and input_size > 256:
|
279 |
+
while edge_pred.shape[2] < input_size:
|
280 |
+
edge_pred = self.structure_upsample(edge_pred)
|
281 |
+
edge_pred = torch.sigmoid((edge_pred + 2) * 2)
|
282 |
+
|
283 |
+
line_pred = self.structure_upsample(line_pred)
|
284 |
+
line_pred = torch.sigmoid((line_pred + 2) * 2)
|
285 |
+
|
286 |
+
edge_pred = F.interpolate(
|
287 |
+
edge_pred,
|
288 |
+
size=(input_size, input_size),
|
289 |
+
mode="bilinear",
|
290 |
+
align_corners=False,
|
291 |
+
)
|
292 |
+
line_pred = F.interpolate(
|
293 |
+
line_pred,
|
294 |
+
size=(input_size, input_size),
|
295 |
+
mode="bilinear",
|
296 |
+
align_corners=False,
|
297 |
+
)
|
298 |
+
|
299 |
+
# np_edge_pred = (edge_pred[0][0].numpy() * 255).astype(np.uint8)
|
300 |
+
# cv2.imwrite("edge_pred_upsample.jpg", np_edge_pred)
|
301 |
+
# np_line_pred = (line_pred[0][0].numpy() * 255).astype(np.uint8)
|
302 |
+
# cv2.imwrite("line_pred_upsample.jpg", np_line_pred)
|
303 |
+
# exit()
|
304 |
+
|
305 |
+
items["edge"] = edge_pred.detach()
|
306 |
+
items["line"] = line_pred.detach()
|
307 |
+
|
308 |
+
@torch.no_grad()
|
309 |
+
def forward(self, image, mask, config: Config):
|
310 |
+
"""Input images and output images have same size
|
311 |
+
images: [H, W, C] RGB
|
312 |
+
masks: [H, W]
|
313 |
+
return: BGR IMAGE
|
314 |
+
"""
|
315 |
+
mask = mask[:, :, 0]
|
316 |
+
items = load_image(image, mask, device=self.device)
|
317 |
+
|
318 |
+
self.wireframe_edge_and_line(items, config.zits_wireframe)
|
319 |
+
|
320 |
+
inpainted_image = self.inpaint(
|
321 |
+
items["images"],
|
322 |
+
items["masks"],
|
323 |
+
items["edge"],
|
324 |
+
items["line"],
|
325 |
+
items["rel_pos"],
|
326 |
+
items["direct"],
|
327 |
+
)
|
328 |
+
|
329 |
+
inpainted_image = inpainted_image * 255.0
|
330 |
+
inpainted_image = (
|
331 |
+
inpainted_image.cpu().permute(0, 2, 3, 1)[0].numpy().astype(np.uint8)
|
332 |
+
)
|
333 |
+
inpainted_image = inpainted_image[:, :, ::-1]
|
334 |
+
|
335 |
+
# cv2.imwrite("inpainted.jpg", inpainted_image)
|
336 |
+
# exit()
|
337 |
+
|
338 |
+
return inpainted_image
|
339 |
+
|
340 |
+
def wireframe_forward(self, images, h, w, masks, mask_th=0.925):
|
341 |
+
lcnn_mean = torch.tensor([109.730, 103.832, 98.681]).reshape(1, 3, 1, 1)
|
342 |
+
lcnn_std = torch.tensor([22.275, 22.124, 23.229]).reshape(1, 3, 1, 1)
|
343 |
+
images = images * 255.0
|
344 |
+
# the masks value of lcnn is 127.5
|
345 |
+
masked_images = images * (1 - masks) + torch.ones_like(images) * masks * 127.5
|
346 |
+
masked_images = (masked_images - lcnn_mean) / lcnn_std
|
347 |
+
|
348 |
+
def to_int(x):
|
349 |
+
return tuple(map(int, x))
|
350 |
+
|
351 |
+
lines_tensor = []
|
352 |
+
lmap = np.zeros((h, w))
|
353 |
+
|
354 |
+
output_masked = self.wireframe(masked_images)
|
355 |
+
|
356 |
+
output_masked = to_device(output_masked, "cpu")
|
357 |
+
if output_masked["num_proposals"] == 0:
|
358 |
+
lines_masked = []
|
359 |
+
scores_masked = []
|
360 |
+
else:
|
361 |
+
lines_masked = output_masked["lines_pred"].numpy()
|
362 |
+
lines_masked = [
|
363 |
+
[line[1] * h, line[0] * w, line[3] * h, line[2] * w]
|
364 |
+
for line in lines_masked
|
365 |
+
]
|
366 |
+
scores_masked = output_masked["lines_score"].numpy()
|
367 |
+
|
368 |
+
for line, score in zip(lines_masked, scores_masked):
|
369 |
+
if score > mask_th:
|
370 |
+
rr, cc, value = skimage.draw.line_aa(
|
371 |
+
*to_int(line[0:2]), *to_int(line[2:4])
|
372 |
+
)
|
373 |
+
lmap[rr, cc] = np.maximum(lmap[rr, cc], value)
|
374 |
+
|
375 |
+
lmap = np.clip(lmap * 255, 0, 255).astype(np.uint8)
|
376 |
+
lines_tensor.append(to_tensor(lmap).unsqueeze(0))
|
377 |
+
|
378 |
+
lines_tensor = torch.cat(lines_tensor, dim=0)
|
379 |
+
return lines_tensor.detach().to(self.device)
|
380 |
+
|
381 |
+
def sample_edge_line_logits(
|
382 |
+
self, context, mask=None, iterations=1, add_v=0, mul_v=4
|
383 |
+
):
|
384 |
+
[img, edge, line] = context
|
385 |
+
|
386 |
+
img = img * (1 - mask)
|
387 |
+
edge = edge * (1 - mask)
|
388 |
+
line = line * (1 - mask)
|
389 |
+
|
390 |
+
for i in range(iterations):
|
391 |
+
edge_logits, line_logits = self.edge_line(img, edge, line, masks=mask)
|
392 |
+
|
393 |
+
edge_pred = torch.sigmoid(edge_logits)
|
394 |
+
line_pred = torch.sigmoid((line_logits + add_v) * mul_v)
|
395 |
+
edge = edge + edge_pred * mask
|
396 |
+
edge[edge >= 0.25] = 1
|
397 |
+
edge[edge < 0.25] = 0
|
398 |
+
line = line + line_pred * mask
|
399 |
+
|
400 |
+
b, _, h, w = edge_pred.shape
|
401 |
+
edge_pred = edge_pred.reshape(b, -1, 1)
|
402 |
+
line_pred = line_pred.reshape(b, -1, 1)
|
403 |
+
mask = mask.reshape(b, -1)
|
404 |
+
|
405 |
+
edge_probs = torch.cat([1 - edge_pred, edge_pred], dim=-1)
|
406 |
+
line_probs = torch.cat([1 - line_pred, line_pred], dim=-1)
|
407 |
+
edge_probs[:, :, 1] += 0.5
|
408 |
+
line_probs[:, :, 1] += 0.5
|
409 |
+
edge_max_probs = edge_probs.max(dim=-1)[0] + (1 - mask) * (-100)
|
410 |
+
line_max_probs = line_probs.max(dim=-1)[0] + (1 - mask) * (-100)
|
411 |
+
|
412 |
+
indices = torch.sort(
|
413 |
+
edge_max_probs + line_max_probs, dim=-1, descending=True
|
414 |
+
)[1]
|
415 |
+
|
416 |
+
for ii in range(b):
|
417 |
+
keep = int((i + 1) / iterations * torch.sum(mask[ii, ...]))
|
418 |
+
|
419 |
+
assert torch.sum(mask[ii][indices[ii, :keep]]) == keep, "Error!!!"
|
420 |
+
mask[ii][indices[ii, :keep]] = 0
|
421 |
+
|
422 |
+
mask = mask.reshape(b, 1, h, w)
|
423 |
+
edge = edge * (1 - mask)
|
424 |
+
line = line * (1 - mask)
|
425 |
+
|
426 |
+
edge, line = edge.to(torch.float32), line.to(torch.float32)
|
427 |
+
return edge, line
|
lama_cleaner/model_manager.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from lama_cleaner.model.fcf import FcF
|
2 |
+
from lama_cleaner.model.lama import LaMa
|
3 |
+
from lama_cleaner.model.ldm import LDM
|
4 |
+
from lama_cleaner.model.mat import MAT
|
5 |
+
from lama_cleaner.model.sd import SD14
|
6 |
+
from lama_cleaner.model.zits import ZITS
|
7 |
+
from lama_cleaner.model.opencv2 import OpenCV2
|
8 |
+
from lama_cleaner.schema import Config
|
9 |
+
|
10 |
+
models = {"lama": LaMa, "ldm": LDM, "zits": ZITS, "mat": MAT, "fcf": FcF, "sd1.4": SD14, "cv2": OpenCV2}
|
11 |
+
|
12 |
+
|
13 |
+
class ModelManager:
|
14 |
+
def __init__(self, name: str, device, **kwargs):
|
15 |
+
self.name = name
|
16 |
+
self.device = device
|
17 |
+
self.kwargs = kwargs
|
18 |
+
self.model = self.init_model(name, device, **kwargs)
|
19 |
+
|
20 |
+
def init_model(self, name: str, device, **kwargs):
|
21 |
+
if name in models:
|
22 |
+
model = models[name](device, **kwargs)
|
23 |
+
else:
|
24 |
+
raise NotImplementedError(f"Not supported model: {name}")
|
25 |
+
return model
|
26 |
+
|
27 |
+
def is_downloaded(self, name: str) -> bool:
|
28 |
+
if name in models:
|
29 |
+
return models[name].is_downloaded()
|
30 |
+
else:
|
31 |
+
raise NotImplementedError(f"Not supported model: {name}")
|
32 |
+
|
33 |
+
def __call__(self, image, mask, config: Config):
|
34 |
+
return self.model(image, mask, config)
|
35 |
+
|
36 |
+
def switch(self, new_name: str):
|
37 |
+
if new_name == self.name:
|
38 |
+
return
|
39 |
+
try:
|
40 |
+
self.model = self.init_model(new_name, self.device, **self.kwargs)
|
41 |
+
self.name = new_name
|
42 |
+
except NotImplementedError as e:
|
43 |
+
raise e
|
lama_cleaner/schema.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum
|
2 |
+
|
3 |
+
from pydantic import BaseModel
|
4 |
+
|
5 |
+
|
6 |
+
class HDStrategy(str, Enum):
|
7 |
+
ORIGINAL = "Original"
|
8 |
+
RESIZE = "Resize"
|
9 |
+
CROP = "Crop"
|
10 |
+
|
11 |
+
|
12 |
+
class LDMSampler(str, Enum):
|
13 |
+
ddim = "ddim"
|
14 |
+
plms = "plms"
|
15 |
+
|
16 |
+
|
17 |
+
class SDSampler(str, Enum):
|
18 |
+
ddim = "ddim"
|
19 |
+
pndm = "pndm"
|
20 |
+
|
21 |
+
|
22 |
+
class Config(BaseModel):
|
23 |
+
ldm_steps: int
|
24 |
+
ldm_sampler: str = LDMSampler.plms
|
25 |
+
zits_wireframe: bool = True
|
26 |
+
hd_strategy: str
|
27 |
+
hd_strategy_crop_margin: int
|
28 |
+
hd_strategy_crop_trigger_size: int
|
29 |
+
hd_strategy_resize_limit: int
|
30 |
+
|
31 |
+
prompt: str = ""
|
32 |
+
# 始终是在原图尺度上的值
|
33 |
+
use_croper: bool = False
|
34 |
+
croper_x: int = None
|
35 |
+
croper_y: int = None
|
36 |
+
croper_height: int = None
|
37 |
+
croper_width: int = None
|
38 |
+
|
39 |
+
# sd
|
40 |
+
sd_mask_blur: int = 0
|
41 |
+
sd_strength: float = 0.75
|
42 |
+
sd_steps: int = 50
|
43 |
+
sd_guidance_scale: float = 7.5
|
44 |
+
sd_sampler: str = SDSampler.ddim
|
45 |
+
# -1 mean random seed
|
46 |
+
sd_seed: int = 42
|
47 |
+
|
48 |
+
# cv2
|
49 |
+
cv2_flag: str = 'INPAINT_NS'
|
50 |
+
cv2_radius: int = 4
|
lama_cleaner/settings.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Django settings for lama_cleaner project.
|
3 |
+
|
4 |
+
Generated by 'django-admin startproject' using Django 4.1.2.
|
5 |
+
|
6 |
+
For more information on this file, see
|
7 |
+
https://docs.djangoproject.com/en/4.1/topics/settings/
|
8 |
+
|
9 |
+
For the full list of settings and their values, see
|
10 |
+
https://docs.djangoproject.com/en/4.1/ref/settings/
|
11 |
+
"""
|
12 |
+
|
13 |
+
from pathlib import Path
|
14 |
+
|
15 |
+
# Build paths inside the project like this: BASE_DIR / 'subdir'.
|
16 |
+
BASE_DIR = Path(__file__).resolve().parent.parent
|
17 |
+
|
18 |
+
|
19 |
+
# Quick-start development settings - unsuitable for production
|
20 |
+
# See https://docs.djangoproject.com/en/4.1/howto/deployment/checklist/
|
21 |
+
|
22 |
+
# SECURITY WARNING: keep the secret key used in production secret!
|
23 |
+
SECRET_KEY = 'django-insecure-=x2n@zasb2nkq$)frp(&h*tsozyka+jb5(&3^7@u5@ven@-sdu'
|
24 |
+
|
25 |
+
# SECURITY WARNING: don't run with debug turned on in production!
|
26 |
+
DEBUG = True
|
27 |
+
|
28 |
+
ALLOWED_HOSTS = []
|
29 |
+
|
30 |
+
|
31 |
+
# Application definition
|
32 |
+
|
33 |
+
INSTALLED_APPS = [
|
34 |
+
'django.contrib.admin',
|
35 |
+
'django.contrib.auth',
|
36 |
+
'django.contrib.contenttypes',
|
37 |
+
'django.contrib.sessions',
|
38 |
+
'django.contrib.messages',
|
39 |
+
'django.contrib.staticfiles',
|
40 |
+
'inpainting',
|
41 |
+
]
|
42 |
+
|
43 |
+
MIDDLEWARE = [
|
44 |
+
'django.middleware.security.SecurityMiddleware',
|
45 |
+
'django.contrib.sessions.middleware.SessionMiddleware',
|
46 |
+
'django.middleware.common.CommonMiddleware',
|
47 |
+
'django.middleware.csrf.CsrfViewMiddleware',
|
48 |
+
'django.contrib.auth.middleware.AuthenticationMiddleware',
|
49 |
+
'django.contrib.messages.middleware.MessageMiddleware',
|
50 |
+
'django.middleware.clickjacking.XFrameOptionsMiddleware',
|
51 |
+
]
|
52 |
+
|
53 |
+
ROOT_URLCONF = 'lama_cleaner.urls'
|
54 |
+
|
55 |
+
TEMPLATES = [
|
56 |
+
{
|
57 |
+
'BACKEND': 'django.template.backends.django.DjangoTemplates',
|
58 |
+
'DIRS': [],
|
59 |
+
'APP_DIRS': True,
|
60 |
+
'OPTIONS': {
|
61 |
+
'context_processors': [
|
62 |
+
'django.template.context_processors.debug',
|
63 |
+
'django.template.context_processors.request',
|
64 |
+
'django.contrib.auth.context_processors.auth',
|
65 |
+
'django.contrib.messages.context_processors.messages',
|
66 |
+
],
|
67 |
+
},
|
68 |
+
},
|
69 |
+
]
|
70 |
+
|
71 |
+
WSGI_APPLICATION = 'lama_cleaner.wsgi.application'
|
72 |
+
|
73 |
+
|
74 |
+
# Database
|
75 |
+
# https://docs.djangoproject.com/en/4.1/ref/settings/#databases
|
76 |
+
|
77 |
+
DATABASES = {
|
78 |
+
'default': {
|
79 |
+
'ENGINE': 'django.db.backends.sqlite3',
|
80 |
+
'NAME': BASE_DIR / 'db.sqlite3',
|
81 |
+
}
|
82 |
+
}
|
83 |
+
|
84 |
+
|
85 |
+
# Password validation
|
86 |
+
# https://docs.djangoproject.com/en/4.1/ref/settings/#auth-password-validators
|
87 |
+
|
88 |
+
AUTH_PASSWORD_VALIDATORS = [
|
89 |
+
{
|
90 |
+
'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator',
|
91 |
+
},
|
92 |
+
{
|
93 |
+
'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator',
|
94 |
+
},
|
95 |
+
{
|
96 |
+
'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator',
|
97 |
+
},
|
98 |
+
{
|
99 |
+
'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator',
|
100 |
+
},
|
101 |
+
]
|
102 |
+
|
103 |
+
|
104 |
+
# Internationalization
|
105 |
+
# https://docs.djangoproject.com/en/4.1/topics/i18n/
|
106 |
+
|
107 |
+
LANGUAGE_CODE = 'en-us'
|
108 |
+
|
109 |
+
TIME_ZONE = 'UTC'
|
110 |
+
|
111 |
+
USE_I18N = True
|
112 |
+
|
113 |
+
USE_TZ = True
|
114 |
+
|
115 |
+
|
116 |
+
# Static files (CSS, JavaScript, Images)
|
117 |
+
# https://docs.djangoproject.com/en/4.1/howto/static-files/
|
118 |
+
|
119 |
+
STATIC_URL = 'static/'
|
120 |
+
|
121 |
+
# Default primary key field type
|
122 |
+
# https://docs.djangoproject.com/en/4.1/ref/settings/#default-auto-field
|
123 |
+
|
124 |
+
DEFAULT_AUTO_FIELD = 'django.db.models.BigAutoField'
|
lama_cleaner/urls.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""lama_cleaner URL Configuration
|
2 |
+
|
3 |
+
The `urlpatterns` list routes URLs to views. For more information please see:
|
4 |
+
https://docs.djangoproject.com/en/4.1/topics/http/urls/
|
5 |
+
Examples:
|
6 |
+
Function views
|
7 |
+
1. Add an import: from my_app import views
|
8 |
+
2. Add a URL to urlpatterns: path('', views.home, name='home')
|
9 |
+
Class-based views
|
10 |
+
1. Add an import: from other_app.views import Home
|
11 |
+
2. Add a URL to urlpatterns: path('', Home.as_view(), name='home')
|
12 |
+
Including another URLconf
|
13 |
+
1. Import the include() function: from django.urls import include, path
|
14 |
+
2. Add a URL to urlpatterns: path('blog/', include('blog.urls'))
|
15 |
+
"""
|
16 |
+
from django.contrib import admin
|
17 |
+
from django.urls import path,include
|
18 |
+
|
19 |
+
urlpatterns = [
|
20 |
+
path('admin/', admin.site.urls),
|
21 |
+
path('inpainting/',include('inpainting.urls')),
|
22 |
+
]
|
lama_cleaner/wsgi.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
WSGI config for lama_cleaner project.
|
3 |
+
|
4 |
+
It exposes the WSGI callable as a module-level variable named ``application``.
|
5 |
+
|
6 |
+
For more information on this file, see
|
7 |
+
https://docs.djangoproject.com/en/4.1/howto/deployment/wsgi/
|
8 |
+
"""
|
9 |
+
|
10 |
+
import os
|
11 |
+
|
12 |
+
from django.core.wsgi import get_wsgi_application
|
13 |
+
|
14 |
+
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'lama_cleaner.settings')
|
15 |
+
|
16 |
+
application = get_wsgi_application()
|
requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
opencv-python==4.6.0.66
|
2 |
+
pytest==7.1.3
|
3 |
+
torch==2.2.0
|
4 |
+
pydantic==1.10.2
|
5 |
+
loguru==0.6.0
|
6 |
+
tqdm==4.64.1
|
7 |
+
Pillow==9.2.0
|
8 |
+
diffusers==0.4.2
|
9 |
+
transformers
|
10 |
+
scikit-image==0.19.3
|
11 |
+
gradio
|
12 |
+
timm
|