File size: 1,951 Bytes
45099b6 0f18a0c 45099b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
import logging
from typing import Tuple, Union
import gdown
import numpy as np
import numpy.typing as npt
import torch
from torch import device as torch_device
from torch.nn import Module
from .module.yolov5 import YOLO_DIR
DETECTOR_WEIGHT_ID = "1YHH7pLoZEdyxw2AoLz9G4lrq6uuxweYB"
REMOVER_WEIGHT_ID = "1Hd79M8DhCwjFuT198R-QB7ozQbHRGcGM"
def select_device(device: str = "") -> torch_device:
"""Return a torch.device instance"""
cpu = device.lower() == "cpu"
cuda = not cpu and torch.cuda.is_available()
return torch_device("cuda:0" if cuda else "cpu")
def load_yolo_model(weight_path: str, device: str) -> Tuple[Module, int]:
"""Load yolov5 model from specified path using torch hub"""
model = torch.hub.load(str(YOLO_DIR), "custom", path=weight_path, source="local", force_reload=True)
print(weight_path)
# model = torch.load(weight_path, map_location=device)["model"]
# model.to(device)
return model, model.stride
def download_weight(file_id: str, output: Union[str, None] = None, quiet: bool = False) -> None:
"""Download model weight from Google Drive given the file ID"""
url = f"https://drive.google.com/uc?id={file_id}"
try:
gdown.cached_download(url=url, path=output, quiet=quiet)
except Exception as e:
print(e)
print("Something went wrong when downloading the weight")
print(
"Check your internet connection or manually download the weight "
f"at https://drive.google.com/file/d/{file_id}/view?usp=sharing"
)
def check_image_shape(image: npt.NDArray) -> None:
"""Check if input image is valid"""
if not isinstance(image, np.ndarray):
raise TypeError("Invalid Type: List value must be of type np.ndarray")
else:
if len(image.shape) != 3:
raise ValueError("Invalid image shape")
if image.shape[-1] != 3:
raise ValueError("Image must be 3 dimensional")
|