Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import tempfile | |
from functools import partial | |
import mmcv | |
import numpy as np | |
import pytest | |
import torch | |
from packaging import version | |
from mmocr.core.deployment import (ONNXRuntimeDetector, ONNXRuntimeRecognizer, | |
TensorRTDetector, TensorRTRecognizer) | |
from mmocr.models import build_detector | |
def test_detector_wrapper(): | |
try: | |
import onnxruntime as ort # noqa: F401 | |
import tensorrt as trt | |
from mmcv.tensorrt import onnx2trt, save_trt_engine | |
except ImportError: | |
pytest.skip('ONNXRuntime or TensorRT is not available.') | |
cfg = dict( | |
model=dict( | |
type='DBNet', | |
backbone=dict( | |
type='ResNet', | |
depth=18, | |
num_stages=4, | |
out_indices=(0, 1, 2, 3), | |
frozen_stages=-1, | |
norm_cfg=dict(type='BN', requires_grad=True), | |
init_cfg=dict( | |
type='Pretrained', checkpoint='torchvision://resnet18'), | |
norm_eval=False, | |
style='caffe'), | |
neck=dict( | |
type='FPNC', | |
in_channels=[64, 128, 256, 512], | |
lateral_channels=256), | |
bbox_head=dict( | |
type='DBHead', | |
text_repr_type='quad', | |
in_channels=256, | |
loss=dict(type='DBLoss', alpha=5.0, beta=10.0, | |
bbce_loss=True)), | |
train_cfg=None, | |
test_cfg=None)) | |
cfg = mmcv.Config(cfg) | |
pytorch_model = build_detector(cfg.model, None, None) | |
# prepare data | |
inputs = torch.rand(1, 3, 224, 224) | |
img_metas = [{ | |
'img_shape': [1, 3, 224, 224], | |
'ori_shape': [1, 3, 224, 224], | |
'pad_shape': [1, 3, 224, 224], | |
'filename': None, | |
'scale_factor': np.array([1, 1, 1, 1]) | |
}] | |
pytorch_model.forward = pytorch_model.forward_dummy | |
with tempfile.TemporaryDirectory() as tmpdirname: | |
onnx_path = f'{tmpdirname}/tmp.onnx' | |
with torch.no_grad(): | |
torch.onnx.export( | |
pytorch_model, | |
inputs, | |
onnx_path, | |
input_names=['input'], | |
output_names=['output'], | |
export_params=True, | |
keep_initializers_as_inputs=False, | |
verbose=False, | |
opset_version=11) | |
# TensorRT part | |
def get_GiB(x: int): | |
"""return x GiB.""" | |
return x * (1 << 30) | |
trt_path = onnx_path.replace('.onnx', '.trt') | |
min_shape = [1, 3, 224, 224] | |
max_shape = [1, 3, 224, 224] | |
# create trt engine and wrapper | |
opt_shape_dict = {'input': [min_shape, min_shape, max_shape]} | |
max_workspace_size = get_GiB(1) | |
trt_engine = onnx2trt( | |
onnx_path, | |
opt_shape_dict, | |
log_level=trt.Logger.ERROR, | |
fp16_mode=False, | |
max_workspace_size=max_workspace_size) | |
save_trt_engine(trt_engine, trt_path) | |
print(f'Successfully created TensorRT engine: {trt_path}') | |
wrap_onnx = ONNXRuntimeDetector(onnx_path, cfg, 0) | |
wrap_trt = TensorRTDetector(trt_path, cfg, 0) | |
assert isinstance(wrap_onnx, ONNXRuntimeDetector) | |
assert isinstance(wrap_trt, TensorRTDetector) | |
with torch.no_grad(): | |
onnx_outputs = wrap_onnx.simple_test(inputs, img_metas, rescale=False) | |
trt_outputs = wrap_onnx.simple_test(inputs, img_metas, rescale=False) | |
assert isinstance(onnx_outputs[0], dict) | |
assert isinstance(trt_outputs[0], dict) | |
assert 'boundary_result' in onnx_outputs[0] | |
assert 'boundary_result' in trt_outputs[0] | |
def test_recognizer_wrapper(): | |
try: | |
import onnxruntime as ort # noqa: F401 | |
import tensorrt as trt | |
from mmcv.tensorrt import onnx2trt, save_trt_engine | |
except ImportError: | |
pytest.skip('ONNXRuntime or TensorRT is not available.') | |
cfg = dict( | |
label_convertor=dict( | |
type='CTCConvertor', | |
dict_type='DICT36', | |
with_unknown=False, | |
lower=True), | |
model=dict( | |
type='CRNNNet', | |
preprocessor=None, | |
backbone=dict( | |
type='VeryDeepVgg', leaky_relu=False, input_channels=1), | |
encoder=None, | |
decoder=dict(type='CRNNDecoder', in_channels=512, rnn_flag=True), | |
loss=dict(type='CTCLoss'), | |
label_convertor=dict( | |
type='CTCConvertor', | |
dict_type='DICT36', | |
with_unknown=False, | |
lower=True), | |
pretrained=None), | |
train_cfg=None, | |
test_cfg=None) | |
cfg = mmcv.Config(cfg) | |
pytorch_model = build_detector(cfg.model, None, None) | |
# prepare data | |
inputs = torch.rand(1, 1, 32, 32) | |
img_metas = [{ | |
'img_shape': [1, 1, 32, 32], | |
'ori_shape': [1, 1, 32, 32], | |
'pad_shape': [1, 1, 32, 32], | |
'filename': None, | |
'scale_factor': np.array([1, 1, 1, 1]) | |
}] | |
pytorch_model.forward = partial( | |
pytorch_model.forward, | |
img_metas=img_metas, | |
return_loss=False, | |
rescale=True) | |
with tempfile.TemporaryDirectory() as tmpdirname: | |
onnx_path = f'{tmpdirname}/tmp.onnx' | |
with torch.no_grad(): | |
torch.onnx.export( | |
pytorch_model, | |
inputs, | |
onnx_path, | |
input_names=['input'], | |
output_names=['output'], | |
export_params=True, | |
keep_initializers_as_inputs=False, | |
verbose=False, | |
opset_version=11) | |
# TensorRT part | |
def get_GiB(x: int): | |
"""return x GiB.""" | |
return x * (1 << 30) | |
trt_path = onnx_path.replace('.onnx', '.trt') | |
min_shape = [1, 1, 32, 32] | |
max_shape = [1, 1, 32, 32] | |
# create trt engine and wrapper | |
opt_shape_dict = {'input': [min_shape, min_shape, max_shape]} | |
max_workspace_size = get_GiB(1) | |
trt_engine = onnx2trt( | |
onnx_path, | |
opt_shape_dict, | |
log_level=trt.Logger.ERROR, | |
fp16_mode=False, | |
max_workspace_size=max_workspace_size) | |
save_trt_engine(trt_engine, trt_path) | |
print(f'Successfully created TensorRT engine: {trt_path}') | |
wrap_onnx = ONNXRuntimeRecognizer(onnx_path, cfg, 0) | |
wrap_trt = TensorRTRecognizer(trt_path, cfg, 0) | |
assert isinstance(wrap_onnx, ONNXRuntimeRecognizer) | |
assert isinstance(wrap_trt, TensorRTRecognizer) | |
with torch.no_grad(): | |
onnx_outputs = wrap_onnx.simple_test(inputs, img_metas, rescale=False) | |
trt_outputs = wrap_onnx.simple_test(inputs, img_metas, rescale=False) | |
assert isinstance(onnx_outputs[0], dict) | |
assert isinstance(trt_outputs[0], dict) | |
assert 'text' in onnx_outputs[0] | |
assert 'text' in trt_outputs[0] | |