Spaces:
Running
on
Zero
Running
on
Zero
# -*- coding: utf-8 -*- | |
# @Organization : insightface.ai | |
# @Author : Jia Guo | |
# @Time : 2021-05-04 | |
# @Function : | |
import os | |
import os.path as osp | |
import glob | |
import onnxruntime | |
from .arcface_onnx import * | |
from .retinaface import * | |
#from .scrfd import * | |
from .landmark import * | |
from .attribute import Attribute | |
from .inswapper import INSwapper | |
from ..utils import download_onnx | |
__all__ = ['get_model'] | |
class PickableInferenceSession(onnxruntime.InferenceSession): | |
# This is a wrapper to make the current InferenceSession class pickable. | |
def __init__(self, model_path, **kwargs): | |
super().__init__(model_path, **kwargs) | |
self.model_path = model_path | |
def __getstate__(self): | |
return {'model_path': self.model_path} | |
def __setstate__(self, values): | |
model_path = values['model_path'] | |
self.__init__(model_path) | |
class ModelRouter: | |
def __init__(self, onnx_file): | |
self.onnx_file = onnx_file | |
def get_model(self, **kwargs): | |
session = PickableInferenceSession(self.onnx_file, **kwargs) | |
# print(f'Applied providers: {session._providers}, with options: {session._provider_options}') | |
inputs = session.get_inputs() | |
input_cfg = inputs[0] | |
input_shape = input_cfg.shape | |
outputs = session.get_outputs() | |
if len(outputs)>=5: | |
return RetinaFace(model_file=self.onnx_file, session=session) | |
elif input_shape[2]==192 and input_shape[3]==192: | |
return Landmark(model_file=self.onnx_file, session=session) | |
elif input_shape[2]==96 and input_shape[3]==96: | |
return Attribute(model_file=self.onnx_file, session=session) | |
elif len(inputs)==2 and input_shape[2]==128 and input_shape[3]==128: | |
return INSwapper(model_file=self.onnx_file, session=session) | |
elif input_shape[2]==input_shape[3] and input_shape[2]>=112 and input_shape[2]%16==0: | |
return ArcFaceONNX(model_file=self.onnx_file, session=session) | |
else: | |
#raise RuntimeError('error on model routing') | |
return None | |
def find_onnx_file(dir_path): | |
if not os.path.exists(dir_path): | |
return None | |
paths = glob.glob("%s/*.onnx" % dir_path) | |
if len(paths) == 0: | |
return None | |
paths = sorted(paths) | |
return paths[-1] | |
def get_default_providers(): | |
return ['CUDAExecutionProvider', 'CPUExecutionProvider'] | |
def get_default_provider_options(): | |
return None | |
def get_model(name, **kwargs): | |
root = kwargs.get('root', '~/.insightface') | |
root = os.path.expanduser(root) | |
model_root = osp.join(root, 'models') | |
allow_download = kwargs.get('download', False) | |
download_zip = kwargs.get('download_zip', False) | |
if not name.endswith('.onnx'): | |
model_dir = os.path.join(model_root, name) | |
model_file = find_onnx_file(model_dir) | |
if model_file is None: | |
return None | |
else: | |
model_file = name | |
if not osp.exists(model_file) and allow_download: | |
model_file = download_onnx('models', model_file, root=root, download_zip=download_zip) | |
assert osp.exists(model_file), 'model_file %s should exist'%model_file | |
assert osp.isfile(model_file), 'model_file %s should be a file'%model_file | |
router = ModelRouter(model_file) | |
providers = kwargs.get('providers', get_default_providers()) | |
provider_options = kwargs.get('provider_options', get_default_provider_options()) | |
model = router.get_model(providers=providers, provider_options=provider_options) | |
return model | |