yangheng's picture
init
9842c28
raw
history blame
3.59 kB
# -*- coding: utf-8 -*-
# file: test.py
# time: 05/12/2022
# author: yangheng <[email protected]>
# github: https://github.com/yangheng95
# huggingface: https://huggingface.co/yangheng
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
# Copyright (C) 2021. All Rights Reserved.
from pathlib import Path
from typing import Union
import autocuda
import findfile
from pyabsa.utils.pyabsa_utils import fprint
from torchvision import transforms
from .utils.prepare_images import *
from .Models import *
class ImageMagnifier:
def __init__(self):
self.device = autocuda.auto_cuda()
self.model_cran_v2 = CARN_V2(
color_channels=3,
mid_channels=64,
conv=nn.Conv2d,
single_conv_size=3,
single_conv_group=1,
scale=2,
activation=nn.LeakyReLU(0.1),
SEBlock=True,
repeat_blocks=3,
atrous=(1, 1, 1),
)
self.model_cran_v2 = network_to_half(self.model_cran_v2)
self.checkpoint = findfile.find_cwd_file("CARN_model_checkpoint.pt")
self.model_cran_v2.load_state_dict(
torch.load(self.checkpoint, map_location="cpu")
)
# if use GPU, then comment out the next line so it can use fp16.
self.model_cran_v2 = self.model_cran_v2.float().to(self.device)
self.model_cran_v2.to(self.device)
def __image_scale(self, img, scale_factor: int = 2):
img_splitter = ImageSplitter(
seg_size=64, scale_factor=scale_factor, boarder_pad_size=3
)
img_patches = img_splitter.split_img_tensor(img, scale_method=None, img_pad=0)
with torch.no_grad():
if self.device != "cpu":
with torch.cuda.amp.autocast():
out = [self.model_cran_v2(i.to(self.device)) for i in img_patches]
else:
with torch.cpu.amp.autocast():
out = [self.model_cran_v2(i) for i in img_patches]
img_upscale = img_splitter.merge_img_tensor(out)
final = torch.cat([img_upscale])
return transforms.ToPILImage()(final[0])
def magnify(self, img, scale_factor: int = 2):
fprint("scale factor reset to:", scale_factor // 2 * 2)
_scale_factor = scale_factor
while _scale_factor // 2 > 0:
img = self.__image_scale(img, scale_factor=2)
_scale_factor = _scale_factor // 2
return img
def magnify_from_file(
self, img_path: Union[str, Path], scale_factor: int = 2, save_img: bool = True
):
if not os.path.exists(img_path):
raise FileNotFoundError("Path is not found.")
if os.path.isfile(img_path):
try:
img = Image.open(img_path)
img = self.magnify(img, scale_factor)
if save_img:
img.save(os.path.join(img_path))
except Exception as e:
fprint(img_path, e)
fprint(img_path, "Done.")
elif os.path.isdir(img_path):
for path in os.listdir(img_path):
try:
img = Image.open(os.path.join(img_path, path))
img = self.magnify(img, scale_factor)
if save_img:
img.save(os.path.join(img_path, path))
except Exception as e:
fprint(path, e)
continue
fprint(path, "Done.")
else:
raise TypeError("Path is not a file or directory.")