FaceParsing / app.py
vemodalen's picture
Upload 6 files
10af974 verified
from rtmlib import YOLOX, RTMPose, draw_bbox, draw_skeleton
import functools
from typing import Callable
from pathlib import Path
import gradio as gr
import numpy as np
import PIL.Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import torchvision.transforms as T
TITLE = 'Face Parsing'
def get_palette(num_cls):
""" Returns the color map for visualizing the segmentation mask.
Args:
num_cls: Number of classes
Returns:
The color map
"""
n = num_cls
palette = [0] * (n * 3)
for j in range(0, n):
lab = j
palette[j * 3 + 0] = 0
palette[j * 3 + 1] = 0
palette[j * 3 + 2] = 0
i = 0
while lab:
palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
i += 1
lab >>= 3
return palette
@torch.inference_mode()
def predict(image: PIL.Image.Image, model, transform: Callable,
device: torch.device,palette) -> np.ndarray:
img_show = np.array(image.copy())
bboxes = model[1](np. array(image))
img_show = draw_bbox(img_show, bboxes)
keypoints,scores = model[2](np. array(image),bboxes=bboxes)
img_show = draw_skeleton(img_show,keypoints,scores)
data = transform(image)
data = data.unsqueeze(0).to(device)
out = model[0](data)
out = F.interpolate(out, [image.size[1],image.size[0]], mode="bilinear")
output = out[0].permute(1,2,0)
parsing = torch.argmax(output,dim=2).cpu().numpy()
output_im = Image.fromarray(np.asarray(parsing, dtype=np.uint8))
image = Image.fromarray(np.asarray(img_show, dtype=np.uint8))
output_im.putpalette(palette)
output_im = output_im.convert('RGB')
# output_im.save('output.png')
res = Image.blend(image.convert('RGB'), output_im, 0.5)
return output_im, res
def load_parsing_model():
model = torch.jit.load(Path("models/faceparsing_512_512.pt"))
model.eval()
return model
def main():
device = torch.device('cpu')
model_ls =[]
model = load_parsing_model()
transform = T.Compose([
T.Resize((512, 512), interpolation=PIL.Image.NEAREST),
T.ToTensor(),
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
palette = get_palette(20)
det_model = YOLOX('models/det.onnx',model_input_size=(640,640),backend='onnxruntime', device='cpu')
pose_model = RTMPose('models/pose.onnx', model_input_size=(192, 256),to_openpose=False, backend='onnxruntime', device='cpu')
model_ls.append(model)
model_ls.append(det_model)
model_ls.append(pose_model)
func = functools.partial(predict,
model=model_ls,
transform=transform,
device=device,palette=palette)
gr.Interface(
fn=func,
inputs=gr.Image(label='Input', type='pil'),
outputs=[
gr.Image(label='Predicted Labels', type='pil'),
gr.Image(label='Masked', type='pil'),
],
title=TITLE,
).queue().launch(show_api=False)
if __name__ == "__main__":
main()