nguyenp99's picture
Update stamp_processing/utils.py
091dd84 verified
raw
history blame
1.97 kB
import logging
from typing import Tuple, Union
import gdown
import numpy as np
import numpy.typing as npt
import torch
from torch import device as torch_device
from torch.nn import Module
from .module.yolov5 import YOLO_DIR
DETECTOR_WEIGHT_ID = "1YHH7pLoZEdyxw2AoLz9G4lrq6uuxweYB"
REMOVER_WEIGHT_ID = "1Hd79M8DhCwjFuT198R-QB7ozQbHRGcGM"
def select_device(device: str = "") -> torch_device:
"""Return a torch.device instance"""
cpu = device.lower() == "cpu"
cuda = not cpu and torch.cuda.is_available()
return torch_device("cuda:1" if cuda else "cpu")
def load_yolo_model(weight_path: str, device: str) -> Tuple[Module, int]:
"""Load yolov5 model from specified path using torch hub"""
model = torch.hub.load(str(YOLO_DIR), "custom", path=weight_path, source="local", force_reload=True, device=device)
print(weight_path)
# model = torch.load(weight_path, map_location=device)["model"]
# model.to(device)
return model, model.stride
def download_weight(file_id: str, output: Union[str, None] = None, quiet: bool = False) -> None:
"""Download model weight from Google Drive given the file ID"""
url = f"https://drive.google.com/uc?id={file_id}"
try:
gdown.cached_download(url=url, path=output, quiet=quiet)
except Exception as e:
print(e)
print("Something went wrong when downloading the weight")
print(
"Check your internet connection or manually download the weight "
f"at https://drive.google.com/file/d/{file_id}/view?usp=sharing"
)
def check_image_shape(image: npt.NDArray) -> None:
"""Check if input image is valid"""
if not isinstance(image, np.ndarray):
raise TypeError("Invalid Type: List value must be of type np.ndarray")
else:
if len(image.shape) != 3:
raise ValueError("Invalid image shape")
if image.shape[-1] != 3:
raise ValueError("Image must be 3 dimensional")