File size: 3,590 Bytes
9842c28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# -*- coding: utf-8 -*-
# file: test.py
# time: 05/12/2022
# author: yangheng <[email protected]>
# github: https://github.com/yangheng95
# huggingface: https://huggingface.co/yangheng
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
# Copyright (C) 2021. All Rights Reserved.
from pathlib import Path
from typing import Union

import autocuda
import findfile
from pyabsa.utils.pyabsa_utils import fprint
from torchvision import transforms
from .utils.prepare_images import *
from .Models import *


class ImageMagnifier:
    def __init__(self):
        self.device = autocuda.auto_cuda()
        self.model_cran_v2 = CARN_V2(
            color_channels=3,
            mid_channels=64,
            conv=nn.Conv2d,
            single_conv_size=3,
            single_conv_group=1,
            scale=2,
            activation=nn.LeakyReLU(0.1),
            SEBlock=True,
            repeat_blocks=3,
            atrous=(1, 1, 1),
        )

        self.model_cran_v2 = network_to_half(self.model_cran_v2)
        self.checkpoint = findfile.find_cwd_file("CARN_model_checkpoint.pt")
        self.model_cran_v2.load_state_dict(
            torch.load(self.checkpoint, map_location="cpu")
        )
        # if use GPU, then comment out the next line so it can use fp16.
        self.model_cran_v2 = self.model_cran_v2.float().to(self.device)
        self.model_cran_v2.to(self.device)

    def __image_scale(self, img, scale_factor: int = 2):
        img_splitter = ImageSplitter(
            seg_size=64, scale_factor=scale_factor, boarder_pad_size=3
        )
        img_patches = img_splitter.split_img_tensor(img, scale_method=None, img_pad=0)
        with torch.no_grad():
            if self.device != "cpu":
                with torch.cuda.amp.autocast():
                    out = [self.model_cran_v2(i.to(self.device)) for i in img_patches]
            else:
                with torch.cpu.amp.autocast():
                    out = [self.model_cran_v2(i) for i in img_patches]
        img_upscale = img_splitter.merge_img_tensor(out)

        final = torch.cat([img_upscale])

        return transforms.ToPILImage()(final[0])

    def magnify(self, img, scale_factor: int = 2):
        fprint("scale factor reset to:", scale_factor // 2 * 2)
        _scale_factor = scale_factor
        while _scale_factor // 2 > 0:
            img = self.__image_scale(img, scale_factor=2)
            _scale_factor = _scale_factor // 2
        return img

    def magnify_from_file(
        self, img_path: Union[str, Path], scale_factor: int = 2, save_img: bool = True
    ):

        if not os.path.exists(img_path):
            raise FileNotFoundError("Path is not found.")
        if os.path.isfile(img_path):
            try:
                img = Image.open(img_path)
                img = self.magnify(img, scale_factor)
                if save_img:
                    img.save(os.path.join(img_path))
            except Exception as e:
                fprint(img_path, e)
            fprint(img_path, "Done.")

        elif os.path.isdir(img_path):
            for path in os.listdir(img_path):
                try:
                    img = Image.open(os.path.join(img_path, path))
                    img = self.magnify(img, scale_factor)
                    if save_img:
                        img.save(os.path.join(img_path, path))
                except Exception as e:
                    fprint(path, e)
                    continue
                fprint(path, "Done.")
        else:
            raise TypeError("Path is not a file or directory.")