File size: 2,608 Bytes
3bbb319
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# Copyright (c) OpenMMLab. All rights reserved.
"""Tests for async interface."""

import asyncio
import os
import sys

import asynctest
import mmcv
import torch

from mmdet.apis import async_inference_detector, init_detector

if sys.version_info >= (3, 7):
    from mmdet.utils.contextmanagers import concurrent


class AsyncTestCase(asynctest.TestCase):
    use_default_loop = False
    forbid_get_event_loop = True

    TEST_TIMEOUT = int(os.getenv('ASYNCIO_TEST_TIMEOUT', '30'))

    def _run_test_method(self, method):
        result = method()
        if asyncio.iscoroutine(result):
            self.loop.run_until_complete(
                asyncio.wait_for(result, timeout=self.TEST_TIMEOUT))


class MaskRCNNDetector:

    def __init__(self,
                 model_config,
                 checkpoint=None,
                 streamqueue_size=3,
                 device='cuda:0'):

        self.streamqueue_size = streamqueue_size
        self.device = device
        # build the model and load checkpoint
        self.model = init_detector(
            model_config, checkpoint=None, device=self.device)
        self.streamqueue = None

    async def init(self):
        self.streamqueue = asyncio.Queue()
        for _ in range(self.streamqueue_size):
            stream = torch.cuda.Stream(device=self.device)
            self.streamqueue.put_nowait(stream)

    if sys.version_info >= (3, 7):

        async def apredict(self, img):
            if isinstance(img, str):
                img = mmcv.imread(img)
            async with concurrent(self.streamqueue):
                result = await async_inference_detector(self.model, img)
            return result


class AsyncInferenceTestCase(AsyncTestCase):

    if sys.version_info >= (3, 7):

        async def test_simple_inference(self):
            if not torch.cuda.is_available():
                import pytest

                pytest.skip('test requires GPU and torch+cuda')

            ori_grad_enabled = torch.is_grad_enabled()
            root_dir = os.path.dirname(os.path.dirname(__name__))
            model_config = os.path.join(
                root_dir, 'configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py')
            detector = MaskRCNNDetector(model_config)
            await detector.init()
            img_path = os.path.join(root_dir, 'demo/demo.jpg')
            bboxes, _ = await detector.apredict(img_path)
            self.assertTrue(bboxes)
            # asy inference detector will hack grad_enabled,
            # so restore here to avoid it to influence other tests
            torch.set_grad_enabled(ori_grad_enabled)