Spaces:
Running
Running
from rtmlib import YOLOX, RTMPose, draw_bbox, draw_skeleton | |
import functools | |
from typing import Callable | |
from pathlib import Path | |
import gradio as gr | |
import numpy as np | |
import PIL.Image | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from PIL import Image | |
import torchvision.transforms as T | |
TITLE = 'Face Parsing' | |
def get_palette(num_cls): | |
""" Returns the color map for visualizing the segmentation mask. | |
Args: | |
num_cls: Number of classes | |
Returns: | |
The color map | |
""" | |
n = num_cls | |
palette = [0] * (n * 3) | |
for j in range(0, n): | |
lab = j | |
palette[j * 3 + 0] = 0 | |
palette[j * 3 + 1] = 0 | |
palette[j * 3 + 2] = 0 | |
i = 0 | |
while lab: | |
palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i)) | |
palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i)) | |
palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i)) | |
i += 1 | |
lab >>= 3 | |
return palette | |
def predict(image: PIL.Image.Image, model, transform: Callable, | |
device: torch.device,palette) -> np.ndarray: | |
img_show = np.array(image.copy()) | |
bboxes = model[1](np. array(image)) | |
img_show = draw_bbox(img_show, bboxes) | |
keypoints,scores = model[2](np. array(image),bboxes=bboxes) | |
img_show = draw_skeleton(img_show,keypoints,scores) | |
data = transform(image) | |
data = data.unsqueeze(0).to(device) | |
out = model[0](data) | |
out = F.interpolate(out, [image.size[1],image.size[0]], mode="bilinear") | |
output = out[0].permute(1,2,0) | |
parsing = torch.argmax(output,dim=2).cpu().numpy() | |
output_im = Image.fromarray(np.asarray(parsing, dtype=np.uint8)) | |
image = Image.fromarray(np.asarray(img_show, dtype=np.uint8)) | |
output_im.putpalette(palette) | |
output_im = output_im.convert('RGB') | |
# output_im.save('output.png') | |
res = Image.blend(image.convert('RGB'), output_im, 0.5) | |
return output_im, res | |
def load_parsing_model(): | |
model = torch.jit.load(Path("models/faceparsing_512_512.pt")) | |
model.eval() | |
return model | |
def main(): | |
device = torch.device('cpu') | |
model_ls =[] | |
model = load_parsing_model() | |
transform = T.Compose([ | |
T.Resize((512, 512), interpolation=PIL.Image.NEAREST), | |
T.ToTensor(), | |
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), | |
]) | |
palette = get_palette(20) | |
det_model = YOLOX('models/det.onnx',model_input_size=(640,640),backend='onnxruntime', device='cpu') | |
pose_model = RTMPose('models/pose.onnx', model_input_size=(192, 256),to_openpose=False, backend='onnxruntime', device='cpu') | |
model_ls.append(model) | |
model_ls.append(det_model) | |
model_ls.append(pose_model) | |
func = functools.partial(predict, | |
model=model_ls, | |
transform=transform, | |
device=device,palette=palette) | |
gr.Interface( | |
fn=func, | |
inputs=gr.Image(label='Input', type='pil'), | |
outputs=[ | |
gr.Image(label='Predicted Labels', type='pil'), | |
gr.Image(label='Masked', type='pil'), | |
], | |
title=TITLE, | |
).queue().launch(show_api=False) | |
if __name__ == "__main__": | |
main() | |