from rtmlib import YOLOX, RTMPose, draw_bbox, draw_skeleton import functools from typing import Callable from pathlib import Path import gradio as gr import numpy as np import PIL.Image import torch import torch.nn as nn import torch.nn.functional as F from PIL import Image import torchvision.transforms as T TITLE = 'Face Parsing' def get_palette(num_cls): """ Returns the color map for visualizing the segmentation mask. Args: num_cls: Number of classes Returns: The color map """ n = num_cls palette = [0] * (n * 3) for j in range(0, n): lab = j palette[j * 3 + 0] = 0 palette[j * 3 + 1] = 0 palette[j * 3 + 2] = 0 i = 0 while lab: palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i)) palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i)) palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i)) i += 1 lab >>= 3 return palette @torch.inference_mode() def predict(image: PIL.Image.Image, model, transform: Callable, device: torch.device,palette) -> np.ndarray: img_show = np.array(image.copy()) bboxes = model[1](np. array(image)) img_show = draw_bbox(img_show, bboxes) keypoints,scores = model[2](np. array(image),bboxes=bboxes) img_show = draw_skeleton(img_show,keypoints,scores) data = transform(image) data = data.unsqueeze(0).to(device) out = model[0](data) out = F.interpolate(out, [image.size[1],image.size[0]], mode="bilinear") output = out[0].permute(1,2,0) parsing = torch.argmax(output,dim=2).cpu().numpy() output_im = Image.fromarray(np.asarray(parsing, dtype=np.uint8)) image = Image.fromarray(np.asarray(img_show, dtype=np.uint8)) output_im.putpalette(palette) output_im = output_im.convert('RGB') # output_im.save('output.png') res = Image.blend(image.convert('RGB'), output_im, 0.5) return output_im, res def load_parsing_model(): model = torch.jit.load(Path("models/faceparsing_512_512.pt")) model.eval() return model def main(): device = torch.device('cpu') model_ls =[] model = load_parsing_model() transform = T.Compose([ T.Resize((512, 512), interpolation=PIL.Image.NEAREST), T.ToTensor(), T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ]) palette = get_palette(20) det_model = YOLOX('models/det.onnx',model_input_size=(640,640),backend='onnxruntime', device='cpu') pose_model = RTMPose('models/pose.onnx', model_input_size=(192, 256),to_openpose=False, backend='onnxruntime', device='cpu') model_ls.append(model) model_ls.append(det_model) model_ls.append(pose_model) func = functools.partial(predict, model=model_ls, transform=transform, device=device,palette=palette) gr.Interface( fn=func, inputs=gr.Image(label='Input', type='pil'), outputs=[ gr.Image(label='Predicted Labels', type='pil'), gr.Image(label='Masked', type='pil'), ], title=TITLE, ).queue().launch(show_api=False) if __name__ == "__main__": main()