|
|
|
"""Tests for async interface.""" |
|
|
|
import asyncio |
|
import os |
|
import sys |
|
|
|
import asynctest |
|
import mmcv |
|
import torch |
|
|
|
from mmdet.apis import async_inference_detector, init_detector |
|
|
|
if sys.version_info >= (3, 7): |
|
from mmdet.utils.contextmanagers import concurrent |
|
|
|
|
|
class AsyncTestCase(asynctest.TestCase): |
|
use_default_loop = False |
|
forbid_get_event_loop = True |
|
|
|
TEST_TIMEOUT = int(os.getenv('ASYNCIO_TEST_TIMEOUT', '30')) |
|
|
|
def _run_test_method(self, method): |
|
result = method() |
|
if asyncio.iscoroutine(result): |
|
self.loop.run_until_complete( |
|
asyncio.wait_for(result, timeout=self.TEST_TIMEOUT)) |
|
|
|
|
|
class MaskRCNNDetector: |
|
|
|
def __init__(self, |
|
model_config, |
|
checkpoint=None, |
|
streamqueue_size=3, |
|
device='cuda:0'): |
|
|
|
self.streamqueue_size = streamqueue_size |
|
self.device = device |
|
|
|
self.model = init_detector( |
|
model_config, checkpoint=None, device=self.device) |
|
self.streamqueue = None |
|
|
|
async def init(self): |
|
self.streamqueue = asyncio.Queue() |
|
for _ in range(self.streamqueue_size): |
|
stream = torch.cuda.Stream(device=self.device) |
|
self.streamqueue.put_nowait(stream) |
|
|
|
if sys.version_info >= (3, 7): |
|
|
|
async def apredict(self, img): |
|
if isinstance(img, str): |
|
img = mmcv.imread(img) |
|
async with concurrent(self.streamqueue): |
|
result = await async_inference_detector(self.model, img) |
|
return result |
|
|
|
|
|
class AsyncInferenceTestCase(AsyncTestCase): |
|
|
|
if sys.version_info >= (3, 7): |
|
|
|
async def test_simple_inference(self): |
|
if not torch.cuda.is_available(): |
|
import pytest |
|
|
|
pytest.skip('test requires GPU and torch+cuda') |
|
|
|
ori_grad_enabled = torch.is_grad_enabled() |
|
root_dir = os.path.dirname(os.path.dirname(__name__)) |
|
model_config = os.path.join( |
|
root_dir, 'configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py') |
|
detector = MaskRCNNDetector(model_config) |
|
await detector.init() |
|
img_path = os.path.join(root_dir, 'demo/demo.jpg') |
|
bboxes, _ = await detector.apredict(img_path) |
|
self.assertTrue(bboxes) |
|
|
|
|
|
torch.set_grad_enabled(ori_grad_enabled) |
|
|