File size: 1,951 Bytes
45099b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f18a0c
 
 
45099b6
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import logging
from typing import Tuple, Union

import gdown
import numpy as np
import numpy.typing as npt
import torch
from torch import device as torch_device
from torch.nn import Module

from .module.yolov5 import YOLO_DIR


DETECTOR_WEIGHT_ID = "1YHH7pLoZEdyxw2AoLz9G4lrq6uuxweYB"
REMOVER_WEIGHT_ID = "1Hd79M8DhCwjFuT198R-QB7ozQbHRGcGM"


def select_device(device: str = "") -> torch_device:
    """Return a torch.device instance"""
    cpu = device.lower() == "cpu"
    cuda = not cpu and torch.cuda.is_available()
    return torch_device("cuda:0" if cuda else "cpu")


def load_yolo_model(weight_path: str, device: str) -> Tuple[Module, int]:
    """Load yolov5 model from specified path using torch hub"""
    model = torch.hub.load(str(YOLO_DIR), "custom", path=weight_path, source="local", force_reload=True)
    print(weight_path)
    # model = torch.load(weight_path, map_location=device)["model"]
    # model.to(device)
    return model, model.stride


def download_weight(file_id: str, output: Union[str, None] = None, quiet: bool = False) -> None:
    """Download model weight from Google Drive given the file ID"""
    url = f"https://drive.google.com/uc?id={file_id}"
    try:
        gdown.cached_download(url=url, path=output, quiet=quiet)
    except Exception as e:
        print(e)
        print("Something went wrong when downloading the weight")
        print(
            "Check your internet connection or manually download the weight "
            f"at https://drive.google.com/file/d/{file_id}/view?usp=sharing"
        )


def check_image_shape(image: npt.NDArray) -> None:
    """Check if input image is valid"""
    if not isinstance(image, np.ndarray):
        raise TypeError("Invalid Type: List value must be of type np.ndarray")
    else:
        if len(image.shape) != 3:
            raise ValueError("Invalid image shape")
        if image.shape[-1] != 3:
            raise ValueError("Image must be 3 dimensional")