Spaces:
Runtime error
Runtime error
Harisreedhar
commited on
Commit
•
7f475d2
1
Parent(s):
27c3130
Add soft erosion and fix face parsing video
Browse files- app.py +40 -24
- face_parsing/__init__.py +1 -1
- face_parsing/swap.py +60 -17
- swapper.py +3 -3
app.py
CHANGED
@@ -17,7 +17,7 @@ from moviepy.editor import VideoFileClip, ImageSequenceClip
|
|
17 |
|
18 |
from face_analyser import detect_conditions, analyse_face
|
19 |
from utils import trim_video, StreamerThread, ProcessBar, open_directory
|
20 |
-
from face_parsing import init_parser, swap_regions, mask_regions, mask_regions_to_list
|
21 |
from swapper import (
|
22 |
swap_face,
|
23 |
swap_face_with_condition,
|
@@ -59,8 +59,9 @@ MASK_INCLUDE = [
|
|
59 |
"L-Lip",
|
60 |
"U-Lip"
|
61 |
]
|
62 |
-
|
63 |
-
|
|
|
64 |
|
65 |
FACE_SWAPPER = None
|
66 |
FACE_ANALYSER = None
|
@@ -84,6 +85,8 @@ else:
|
|
84 |
USE_CUDA = False
|
85 |
print("\n********** Running on CPU **********\n")
|
86 |
|
|
|
|
|
87 |
|
88 |
## ------------------------------ LOAD MODELS ------------------------------
|
89 |
|
@@ -114,7 +117,7 @@ def load_face_parser_model(name="./assets/pretrained_models/79999_iter.pth"):
|
|
114 |
global FACE_PARSER
|
115 |
path = os.path.join(os.path.abspath(os.path.dirname(__file__)), name)
|
116 |
if FACE_PARSER is None:
|
117 |
-
FACE_PARSER = init_parser(name,
|
118 |
|
119 |
|
120 |
load_face_analyser_model()
|
@@ -137,9 +140,10 @@ def process(
|
|
137 |
distance,
|
138 |
face_enhance,
|
139 |
enable_face_parser,
|
140 |
-
|
141 |
-
|
142 |
-
|
|
|
143 |
*specifics,
|
144 |
):
|
145 |
global WORKSPACE
|
@@ -196,14 +200,18 @@ def process(
|
|
196 |
|
197 |
yield "### \n ⌛ Analysing Face...", *ui_before()
|
198 |
|
199 |
-
|
200 |
-
|
|
|
|
|
|
|
|
|
201 |
models = {
|
202 |
"swap": FACE_SWAPPER,
|
203 |
"enhance": FACE_ENHANCER,
|
204 |
"enhance_sett": face_enhance,
|
205 |
"face_parser": FACE_PARSER,
|
206 |
-
"face_parser_sett": (enable_face_parser,
|
207 |
}
|
208 |
|
209 |
## ------------------------------ ANALYSE SOURCE & SPECIFIC ------------------------------
|
@@ -301,9 +309,9 @@ def process(
|
|
301 |
|
302 |
if condition == "Specific Face":
|
303 |
swapped = swap_specific(
|
304 |
-
frame,
|
305 |
-
analysed_target,
|
306 |
analysed_source_specific,
|
|
|
|
|
307 |
models,
|
308 |
threshold=distance,
|
309 |
)
|
@@ -381,9 +389,9 @@ def process(
|
|
381 |
|
382 |
if condition == "Specific Face":
|
383 |
swapped = swap_specific(
|
384 |
-
target,
|
385 |
-
analysed_target,
|
386 |
analysed_source_specific,
|
|
|
|
|
387 |
models,
|
388 |
threshold=distance,
|
389 |
)
|
@@ -636,16 +644,23 @@ with gr.Blocks(css=css) as interface:
|
|
636 |
label="Include",
|
637 |
interactive=True,
|
638 |
)
|
639 |
-
|
640 |
-
|
641 |
-
value=
|
642 |
-
|
643 |
-
label="Exclude",
|
644 |
interactive=True,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
645 |
)
|
646 |
-
|
647 |
-
label="Blur
|
648 |
-
value=
|
649 |
minimum=0,
|
650 |
interactive=True,
|
651 |
)
|
@@ -827,8 +842,9 @@ with gr.Blocks(css=css) as interface:
|
|
827 |
enable_face_enhance,
|
828 |
enable_face_parser_mask,
|
829 |
mask_include,
|
830 |
-
|
831 |
-
|
|
|
832 |
*src_specific_inputs,
|
833 |
]
|
834 |
|
|
|
17 |
|
18 |
from face_analyser import detect_conditions, analyse_face
|
19 |
from utils import trim_video, StreamerThread, ProcessBar, open_directory
|
20 |
+
from face_parsing import init_parser, swap_regions, mask_regions, mask_regions_to_list, SoftErosion
|
21 |
from swapper import (
|
22 |
swap_face,
|
23 |
swap_face_with_condition,
|
|
|
59 |
"L-Lip",
|
60 |
"U-Lip"
|
61 |
]
|
62 |
+
MASK_SOFT_KERNEL = 17
|
63 |
+
MASK_SOFT_ITERATIONS = 7
|
64 |
+
MASK_BLUR_AMOUNT = 20
|
65 |
|
66 |
FACE_SWAPPER = None
|
67 |
FACE_ANALYSER = None
|
|
|
85 |
USE_CUDA = False
|
86 |
print("\n********** Running on CPU **********\n")
|
87 |
|
88 |
+
device = "cuda" if USE_CUDA else "cpu"
|
89 |
+
|
90 |
|
91 |
## ------------------------------ LOAD MODELS ------------------------------
|
92 |
|
|
|
117 |
global FACE_PARSER
|
118 |
path = os.path.join(os.path.abspath(os.path.dirname(__file__)), name)
|
119 |
if FACE_PARSER is None:
|
120 |
+
FACE_PARSER = init_parser(name, mode=device)
|
121 |
|
122 |
|
123 |
load_face_analyser_model()
|
|
|
140 |
distance,
|
141 |
face_enhance,
|
142 |
enable_face_parser,
|
143 |
+
mask_includes,
|
144 |
+
mask_soft_kernel,
|
145 |
+
mask_soft_iterations,
|
146 |
+
blur_amount,
|
147 |
*specifics,
|
148 |
):
|
149 |
global WORKSPACE
|
|
|
200 |
|
201 |
yield "### \n ⌛ Analysing Face...", *ui_before()
|
202 |
|
203 |
+
includes = mask_regions_to_list(mask_includes)
|
204 |
+
if mask_soft_iterations > 0:
|
205 |
+
smooth_mask = SoftErosion(kernel_size=17, threshold=0.9, iterations=int(mask_soft_iterations)).to(device)
|
206 |
+
else:
|
207 |
+
smooth_mask = None
|
208 |
+
|
209 |
models = {
|
210 |
"swap": FACE_SWAPPER,
|
211 |
"enhance": FACE_ENHANCER,
|
212 |
"enhance_sett": face_enhance,
|
213 |
"face_parser": FACE_PARSER,
|
214 |
+
"face_parser_sett": (enable_face_parser, includes, smooth_mask, int(blur_amount))
|
215 |
}
|
216 |
|
217 |
## ------------------------------ ANALYSE SOURCE & SPECIFIC ------------------------------
|
|
|
309 |
|
310 |
if condition == "Specific Face":
|
311 |
swapped = swap_specific(
|
|
|
|
|
312 |
analysed_source_specific,
|
313 |
+
analysed_target,
|
314 |
+
frame,
|
315 |
models,
|
316 |
threshold=distance,
|
317 |
)
|
|
|
389 |
|
390 |
if condition == "Specific Face":
|
391 |
swapped = swap_specific(
|
|
|
|
|
392 |
analysed_source_specific,
|
393 |
+
analysed_target,
|
394 |
+
target,
|
395 |
models,
|
396 |
threshold=distance,
|
397 |
)
|
|
|
644 |
label="Include",
|
645 |
interactive=True,
|
646 |
)
|
647 |
+
mask_soft_kernel = gr.Number(
|
648 |
+
label="Soft Erode Kernel",
|
649 |
+
value=MASK_SOFT_KERNEL,
|
650 |
+
minimum=3,
|
|
|
651 |
interactive=True,
|
652 |
+
visible = False
|
653 |
+
)
|
654 |
+
mask_soft_iterations = gr.Number(
|
655 |
+
label="Soft Erode Iterations",
|
656 |
+
value=MASK_SOFT_ITERATIONS,
|
657 |
+
minimum=0,
|
658 |
+
interactive=True,
|
659 |
+
|
660 |
)
|
661 |
+
blur_amount = gr.Number(
|
662 |
+
label="Mask Blur",
|
663 |
+
value=MASK_BLUR_AMOUNT,
|
664 |
minimum=0,
|
665 |
interactive=True,
|
666 |
)
|
|
|
842 |
enable_face_enhance,
|
843 |
enable_face_parser_mask,
|
844 |
mask_include,
|
845 |
+
mask_soft_kernel,
|
846 |
+
mask_soft_iterations,
|
847 |
+
blur_amount,
|
848 |
*src_specific_inputs,
|
849 |
]
|
850 |
|
face_parsing/__init__.py
CHANGED
@@ -1 +1 @@
|
|
1 |
-
from .swap import init_parser, swap_regions, mask_regions, mask_regions_to_list
|
|
|
1 |
+
from .swap import init_parser, swap_regions, mask_regions, mask_regions_to_list, SoftErosion
|
face_parsing/swap.py
CHANGED
@@ -1,4 +1,6 @@
|
|
1 |
import torch
|
|
|
|
|
2 |
import torchvision.transforms as transforms
|
3 |
import cv2
|
4 |
import numpy as np
|
@@ -27,15 +29,44 @@ mask_regions = {
|
|
27 |
"Hat":18
|
28 |
}
|
29 |
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
n_classes = 19
|
37 |
net = BiSeNet(n_classes=n_classes)
|
38 |
-
if
|
39 |
net.cuda()
|
40 |
net.load_state_dict(torch.load(pth_path))
|
41 |
else:
|
@@ -55,8 +86,7 @@ def image_to_parsing(img, net):
|
|
55 |
img = torch.unsqueeze(img, 0)
|
56 |
|
57 |
with torch.no_grad():
|
58 |
-
|
59 |
-
img = img.cuda()
|
60 |
out = net(img)[0]
|
61 |
parsing = out.squeeze(0).cpu().numpy().argmax(0)
|
62 |
return parsing
|
@@ -68,20 +98,33 @@ def get_mask(parsing, classes):
|
|
68 |
res += parsing == val
|
69 |
return res
|
70 |
|
71 |
-
def swap_regions(source, target, net, includes=[1,2,3,4,5,10,11,12,13],
|
72 |
parsing = image_to_parsing(source, net)
|
|
|
73 |
if len(includes) == 0:
|
74 |
return source, np.zeros_like(source)
|
|
|
75 |
include_mask = get_mask(parsing, includes)
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
def mask_regions_to_list(values):
|
87 |
out_ids = []
|
|
|
1 |
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
import torchvision.transforms as transforms
|
5 |
import cv2
|
6 |
import numpy as np
|
|
|
29 |
"Hat":18
|
30 |
}
|
31 |
|
32 |
+
# Borrowed from simswap
|
33 |
+
# https://github.com/neuralchen/SimSwap/blob/26c84d2901bd56eda4d5e3c5ca6da16e65dc82a6/util/reverse2original.py#L30
|
34 |
+
class SoftErosion(nn.Module):
|
35 |
+
def __init__(self, kernel_size=15, threshold=0.6, iterations=1):
|
36 |
+
super(SoftErosion, self).__init__()
|
37 |
+
r = kernel_size // 2
|
38 |
+
self.padding = r
|
39 |
+
self.iterations = iterations
|
40 |
+
self.threshold = threshold
|
41 |
|
42 |
+
# Create kernel
|
43 |
+
y_indices, x_indices = torch.meshgrid(torch.arange(0., kernel_size), torch.arange(0., kernel_size))
|
44 |
+
dist = torch.sqrt((x_indices - r) ** 2 + (y_indices - r) ** 2)
|
45 |
+
kernel = dist.max() - dist
|
46 |
+
kernel /= kernel.sum()
|
47 |
+
kernel = kernel.view(1, 1, *kernel.shape)
|
48 |
+
self.register_buffer('weight', kernel)
|
49 |
|
50 |
+
def forward(self, x):
|
51 |
+
x = x.float()
|
52 |
+
for i in range(self.iterations - 1):
|
53 |
+
x = torch.min(x, F.conv2d(x, weight=self.weight, groups=x.shape[1], padding=self.padding))
|
54 |
+
x = F.conv2d(x, weight=self.weight, groups=x.shape[1], padding=self.padding)
|
55 |
+
|
56 |
+
mask = x >= self.threshold
|
57 |
+
x[mask] = 1.0
|
58 |
+
x[~mask] /= x[~mask].max()
|
59 |
+
|
60 |
+
return x, mask
|
61 |
+
|
62 |
+
device = "cpu"
|
63 |
+
|
64 |
+
def init_parser(pth_path, mode="cpu"):
|
65 |
+
global device
|
66 |
+
device = mode
|
67 |
n_classes = 19
|
68 |
net = BiSeNet(n_classes=n_classes)
|
69 |
+
if device == "cuda":
|
70 |
net.cuda()
|
71 |
net.load_state_dict(torch.load(pth_path))
|
72 |
else:
|
|
|
86 |
img = torch.unsqueeze(img, 0)
|
87 |
|
88 |
with torch.no_grad():
|
89 |
+
img = img.to(device)
|
|
|
90 |
out = net(img)[0]
|
91 |
parsing = out.squeeze(0).cpu().numpy().argmax(0)
|
92 |
return parsing
|
|
|
98 |
res += parsing == val
|
99 |
return res
|
100 |
|
101 |
+
def swap_regions(source, target, net, smooth_mask, includes=[1,2,3,4,5,10,11,12,13], blur=10):
|
102 |
parsing = image_to_parsing(source, net)
|
103 |
+
|
104 |
if len(includes) == 0:
|
105 |
return source, np.zeros_like(source)
|
106 |
+
|
107 |
include_mask = get_mask(parsing, includes)
|
108 |
+
mask = np.repeat(include_mask[:, :, np.newaxis], 3, axis=2).astype("float32")
|
109 |
+
|
110 |
+
if smooth_mask is not None:
|
111 |
+
mask_tensor = torch.from_numpy(mask.copy().transpose((2, 0, 1))).float().to(device)
|
112 |
+
face_mask_tensor = mask_tensor[0] + mask_tensor[1]
|
113 |
+
soft_face_mask_tensor, _ = smooth_mask(face_mask_tensor.unsqueeze_(0).unsqueeze_(0))
|
114 |
+
soft_face_mask_tensor.squeeze_()
|
115 |
+
mask = np.repeat(soft_face_mask_tensor.cpu().numpy()[:, :, np.newaxis], 3, axis=2)
|
116 |
+
|
117 |
+
if blur > 0:
|
118 |
+
mask = cv2.GaussianBlur(mask, (0, 0), blur)
|
119 |
+
|
120 |
+
resized_source = cv2.resize((source/255).astype("float32"), (512, 512))
|
121 |
+
resized_target = cv2.resize((target/255).astype("float32"), (512, 512))
|
122 |
+
|
123 |
+
result = mask * resized_source + (1 - mask) * resized_target
|
124 |
+
normalized_result = (result - np.min(result)) / (np.max(result) - np.min(result))
|
125 |
+
result = cv2.resize((result*255).astype("uint8"), (source.shape[1], source.shape[0]))
|
126 |
+
|
127 |
+
return result
|
128 |
|
129 |
def mask_regions_to_list(values):
|
130 |
out_ids = []
|
swapper.py
CHANGED
@@ -25,10 +25,10 @@ def swap_face(whole_img, target_face, source_face, models):
|
|
25 |
aimg, _ = face_align.norm_crop2(whole_img, target_face.kps, image_size=image_size)
|
26 |
|
27 |
if face_parser is not None:
|
28 |
-
fp_enable,
|
29 |
if fp_enable:
|
30 |
-
bgr_fake
|
31 |
-
bgr_fake, aimg, face_parser,
|
32 |
)
|
33 |
|
34 |
if fe_enable:
|
|
|
25 |
aimg, _ = face_align.norm_crop2(whole_img, target_face.kps, image_size=image_size)
|
26 |
|
27 |
if face_parser is not None:
|
28 |
+
fp_enable, includes, smooth_mask, blur_amount = models.get("face_parser_sett")
|
29 |
if fp_enable:
|
30 |
+
bgr_fake = swap_regions(
|
31 |
+
bgr_fake, aimg, face_parser, smooth_mask, includes=includes, blur=blur_amount
|
32 |
)
|
33 |
|
34 |
if fe_enable:
|