vemodalen commited on
Commit
10af974
1 Parent(s): cd1a54d

Upload 6 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,16 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+
37
+ ### Python ###
38
+ # Byte-compiled / optimized / DLL files
39
+ __pycache__/
40
+ *.py[cod]
41
+ *$py.class
42
+
43
+ # PyCharm
44
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
45
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
46
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
47
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
48
+ #.idea/
app.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from rtmlib import YOLOX, RTMPose, draw_bbox, draw_skeleton
2
+ import functools
3
+ from typing import Callable
4
+ from pathlib import Path
5
+ import gradio as gr
6
+ import numpy as np
7
+ import PIL.Image
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from PIL import Image
12
+ import torchvision.transforms as T
13
+
14
+
15
+ TITLE = 'Face Parsing'
16
+
17
+ def get_palette(num_cls):
18
+ """ Returns the color map for visualizing the segmentation mask.
19
+ Args:
20
+ num_cls: Number of classes
21
+ Returns:
22
+ The color map
23
+ """
24
+
25
+ n = num_cls
26
+ palette = [0] * (n * 3)
27
+ for j in range(0, n):
28
+ lab = j
29
+ palette[j * 3 + 0] = 0
30
+ palette[j * 3 + 1] = 0
31
+ palette[j * 3 + 2] = 0
32
+ i = 0
33
+ while lab:
34
+ palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
35
+ palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
36
+ palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
37
+ i += 1
38
+ lab >>= 3
39
+ return palette
40
+
41
+ @torch.inference_mode()
42
+ def predict(image: PIL.Image.Image, model, transform: Callable,
43
+ device: torch.device,palette) -> np.ndarray:
44
+ img_show = np.array(image.copy())
45
+ bboxes = model[1](np. array(image))
46
+ img_show = draw_bbox(img_show, bboxes)
47
+ keypoints,scores = model[2](np. array(image),bboxes=bboxes)
48
+ img_show = draw_skeleton(img_show,keypoints,scores)
49
+
50
+ data = transform(image)
51
+ data = data.unsqueeze(0).to(device)
52
+ out = model[0](data)
53
+ out = F.interpolate(out, [image.size[1],image.size[0]], mode="bilinear")
54
+ output = out[0].permute(1,2,0)
55
+ parsing = torch.argmax(output,dim=2).cpu().numpy()
56
+
57
+ output_im = Image.fromarray(np.asarray(parsing, dtype=np.uint8))
58
+ image = Image.fromarray(np.asarray(img_show, dtype=np.uint8))
59
+ output_im.putpalette(palette)
60
+ output_im = output_im.convert('RGB')
61
+ # output_im.save('output.png')
62
+
63
+ res = Image.blend(image.convert('RGB'), output_im, 0.5)
64
+ return output_im, res
65
+
66
+
67
+ def load_parsing_model():
68
+ model = torch.jit.load(Path("models/faceparsing_512_512.pt"))
69
+ model.eval()
70
+ return model
71
+
72
+
73
+ def main():
74
+ device = torch.device('cpu')
75
+ model_ls =[]
76
+ model = load_parsing_model()
77
+
78
+ transform = T.Compose([
79
+ T.Resize((512, 512), interpolation=PIL.Image.NEAREST),
80
+ T.ToTensor(),
81
+ T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
82
+ ])
83
+ palette = get_palette(20)
84
+ det_model = YOLOX('models/det.onnx',model_input_size=(640,640),backend='onnxruntime', device='cpu')
85
+ pose_model = RTMPose('models/pose.onnx', model_input_size=(192, 256),to_openpose=False, backend='onnxruntime', device='cpu')
86
+
87
+ model_ls.append(model)
88
+ model_ls.append(det_model)
89
+ model_ls.append(pose_model)
90
+
91
+ func = functools.partial(predict,
92
+ model=model_ls,
93
+ transform=transform,
94
+ device=device,palette=palette)
95
+
96
+
97
+ gr.Interface(
98
+ fn=func,
99
+ inputs=gr.Image(label='Input', type='pil'),
100
+ outputs=[
101
+ gr.Image(label='Predicted Labels', type='pil'),
102
+ gr.Image(label='Masked', type='pil'),
103
+ ],
104
+ title=TITLE,
105
+ ).queue().launch(show_api=False)
106
+
107
+ if __name__ == "__main__":
108
+ main()
models/det.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3dea6513388889f0fff4b77bf7a26013600321b9eb9ceb0e9a400a82572f5f23
3
+ size 101400344
models/faceparsing_512_512.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8267ce549f91c0e87f575e80d339b6cebf743c05fb23064cc82e2ac180bfc5e8
3
+ size 257745243
models/pose.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7bca282009dd5e5b8a94bb27e8015f610630643659b862818803febf9107e2e5
3
+ size 368041127
requirements.txt ADDED
Binary file (312 Bytes). View file