# Copyright (c) OpenMMLab. All rights reserved. """Tests for async interface.""" import asyncio import os import sys import asynctest import mmcv import torch from mmdet.apis import async_inference_detector, init_detector if sys.version_info >= (3, 7): from mmdet.utils.contextmanagers import concurrent class AsyncTestCase(asynctest.TestCase): use_default_loop = False forbid_get_event_loop = True TEST_TIMEOUT = int(os.getenv('ASYNCIO_TEST_TIMEOUT', '30')) def _run_test_method(self, method): result = method() if asyncio.iscoroutine(result): self.loop.run_until_complete( asyncio.wait_for(result, timeout=self.TEST_TIMEOUT)) class MaskRCNNDetector: def __init__(self, model_config, checkpoint=None, streamqueue_size=3, device='cuda:0'): self.streamqueue_size = streamqueue_size self.device = device # build the model and load checkpoint self.model = init_detector( model_config, checkpoint=None, device=self.device) self.streamqueue = None async def init(self): self.streamqueue = asyncio.Queue() for _ in range(self.streamqueue_size): stream = torch.cuda.Stream(device=self.device) self.streamqueue.put_nowait(stream) if sys.version_info >= (3, 7): async def apredict(self, img): if isinstance(img, str): img = mmcv.imread(img) async with concurrent(self.streamqueue): result = await async_inference_detector(self.model, img) return result class AsyncInferenceTestCase(AsyncTestCase): if sys.version_info >= (3, 7): async def test_simple_inference(self): if not torch.cuda.is_available(): import pytest pytest.skip('test requires GPU and torch+cuda') ori_grad_enabled = torch.is_grad_enabled() root_dir = os.path.dirname(os.path.dirname(__name__)) model_config = os.path.join( root_dir, 'configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py') detector = MaskRCNNDetector(model_config) await detector.init() img_path = os.path.join(root_dir, 'demo/demo.jpg') bboxes, _ = await detector.apredict(img_path) self.assertTrue(bboxes) # asy inference detector will hack grad_enabled, # so restore here to avoid it to influence other tests torch.set_grad_enabled(ori_grad_enabled)