File size: 12,222 Bytes
c19ca42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
 # pylint: disable=global-statement
import os
import io
import math
import base64
import numpy as np
import mediapipe as mp
from PIL import Image, ImageOps
from pi_heif import register_heif_opener
from skimage.metrics import structural_similarity as ssim
from scipy.stats import beta

import util
import sdapi
import options

face_model = None
body_model = None
segmentation_model = None
all_images = []
all_images_by_type = {}


class Result():
    def __init__(self, typ: str, fn: str, tag: str = None, requested: list = []):
        self.type = typ
        self.input = fn
        self.output = ''
        self.basename = ''
        self.message = ''
        self.image = None
        self.caption = ''
        self.tag = tag
        self.tags = []
        self.ops = []
        self.steps = requested


def detect_blur(image: Image):
    # based on <https://github.com/karthik9319/Blur-Detection/>
    bw = ImageOps.grayscale(image)
    cx, cy = image.size[0] // 2, image.size[1] // 2
    fft = np.fft.fft2(bw)
    fftShift = np.fft.fftshift(fft)
    fftShift[cy - options.process.blur_samplesize: cy + options.process.blur_samplesize, cx - options.process.blur_samplesize: cx + options.process.blur_samplesize] = 0
    fftShift = np.fft.ifftshift(fftShift)
    recon = np.fft.ifft2(fftShift)
    magnitude = np.log(np.abs(recon))
    mean = round(np.mean(magnitude), 2)
    return mean


def detect_dynamicrange(image: Image):
    # based on <https://towardsdatascience.com/measuring-enhancing-image-quality-attributes-234b0f250e10>
    data = np.asarray(image)
    image = np.float32(data)
    RGB = [0.299, 0.587, 0.114]
    height, width = image.shape[:2] # pylint: disable=unsubscriptable-object
    brightness_image = np.sqrt(image[..., 0] ** 2 * RGB[0] + image[..., 1] ** 2 * RGB[1] + image[..., 2] ** 2 * RGB[2]) # pylint: disable=unsubscriptable-object
    hist, _ = np.histogram(brightness_image, bins=256, range=(0, 255))
    img_brightness_pmf = hist / (height * width)
    dist = beta(2, 2)
    ys = dist.pdf(np.linspace(0, 1, 256))
    ref_pmf = ys / np.sum(ys)
    dot_product = np.dot(ref_pmf, img_brightness_pmf)
    squared_dist_a = np.sum(ref_pmf ** 2)
    squared_dist_b = np.sum(img_brightness_pmf ** 2)
    res = dot_product / math.sqrt(squared_dist_a * squared_dist_b)
    return round(res, 2)


def detect_simmilar(image: Image):
    img = image.resize((options.process.similarity_size, options.process.similarity_size))
    img = ImageOps.grayscale(img)
    data = np.array(img)
    similarity = 0
    for i in all_images:
        val = ssim(data, i, data_range=255, channel_axis=None, gradient=False, full=False)
        if val > similarity:
            similarity = val
    all_images.append(data)
    return similarity


def segmentation(res: Result):
    global segmentation_model
    if segmentation_model is None:
        segmentation_model = mp.solutions.selfie_segmentation.SelfieSegmentation(model_selection=options.process.segmentation_model)
    data = np.array(res.image)
    results = segmentation_model.process(data)
    condition = np.stack((results.segmentation_mask,) * 3, axis=-1) > 0.1
    background = np.zeros(data.shape, dtype=np.uint8)
    background[:] = options.process.segmentation_background
    data = np.where(condition, data, background) # consider using a joint bilateral filter instead of pure combine
    segmented = Image.fromarray(data)
    res.image = segmented
    res.ops.append('segmentation')
    return res


def unload():
    global face_model
    if face_model is not None:
        face_model = None
    global body_model
    if body_model is not None:
        body_model = None
    global segmentation_model
    if segmentation_model is not None:
        segmentation_model = None


def encode(img):
    with io.BytesIO() as stream:
        img.save(stream, 'JPEG')
        values = stream.getvalue()
        encoded = base64.b64encode(values).decode()
        return encoded


def reset():
    unload()
    global all_images_by_type
    all_images_by_type = {}
    global all_images
    all_images = []


def upscale_restore_image(res: Result, upscale: bool = False, restore: bool = False):
    kwargs = util.Map({
        'image': encode(res.image),
        'codeformer_visibility': 0.0,
        'codeformer_weight': 0.0,
    })
    if res.image.width >= options.process.target_size and res.image.height >= options.process.target_size:
        upscale = False
    if upscale:
        kwargs.upscaler_1 = 'SwinIR_4x'
        kwargs.upscaling_resize = 2
        res.ops.append('upscale')
    if restore:
        kwargs.codeformer_visibility = 1.0
        kwargs.codeformer_weight = 0.2
        res.ops.append('restore')
    if upscale or restore:
        result = sdapi.postsync('/sdapi/v1/extra-single-image', kwargs)
        if 'image' not in result:
            res.message = 'failed to upscale/restore image'
        else:
            res.image = Image.open(io.BytesIO(base64.b64decode(result['image'])))
    return res


def interrogate_image(res: Result, tag: str = None):
    caption = ''
    tags = []
    for model in options.process.interrogate_model:
        json = util.Map({ 'image': encode(res.image), 'model': model })
        result = sdapi.postsync('/sdapi/v1/interrogate', json)
        if model == 'clip':
            caption = result.caption if 'caption' in result else ''
            caption = caption.split(',')[0].replace(' a ', ' ').strip()
            if tag is not None:
                caption = res.tag + ', ' + caption
        if model == 'deepdanbooru':
            tag = result.caption if 'caption' in result else ''
            tags = tag.split(',')
            tags = [t.replace('(', '').replace(')', '').replace('\\', '').split(':')[0].strip() for t in tags]
            if tag is not None:
                for t in res.tag.split(',')[::-1]:
                    tags.insert(0, t.strip())
    pos = 0 if len(tags) == 0 else 1
    tags.insert(pos, caption.split(' ')[1])
    tags = [t for t in tags if len(t) > 2]
    if len(tags) > options.process.tag_limit:
        tags = tags[:options.process.tag_limit]
    res.caption = caption
    res.tags = tags
    res.ops.append('interrogate')
    return res


def resize_image(res: Result):
    resized = res.image
    resized.thumbnail((options.process.target_size, options.process.target_size), Image.Resampling.HAMMING)
    res.image = resized
    res.ops.append('resize')
    return res


def square_image(res: Result):
    size = max(res.image.width, res.image.height)
    squared = Image.new('RGB', (size, size))
    squared.paste(res.image, ((size - res.image.width) // 2, (size - res.image.height) // 2))
    res.image = squared
    res.ops.append('square')
    return res


def process_face(res: Result):
    res.ops.append('face')
    global face_model
    if face_model is None:
        face_model = mp.solutions.face_detection.FaceDetection(min_detection_confidence=options.process.face_score, model_selection=options.process.face_model)
    results = face_model.process(np.array(res.image))
    if results.detections is None:
        res.message = 'no face detected'
        res.image = None
        return res
    box = results.detections[0].location_data.relative_bounding_box
    if box.xmin < 0 or box.ymin < 0 or (box.width - box.xmin) > 1 or (box.height - box.ymin) > 1:
        res.message = 'face out of frame'
        res.image = None
        return res
    x = max(0, (box.xmin - options.process.face_pad / 2) * res.image.width)
    y = max(0, (box.ymin - options.process.face_pad / 2)* res.image.height)
    w = min(res.image.width, (box.width + options.process.face_pad) * res.image.width)
    h = min(res.image.height, (box.height + options.process.face_pad) * res.image.height)
    x = max(0, x)
    res.image = res.image.crop((x, y, x + w, y + h))
    return res


def process_body(res: Result):
    res.ops.append('body')
    global body_model
    if body_model is None:
        body_model = mp.solutions.pose.Pose(static_image_mode=True, min_detection_confidence=options.process.body_score, model_complexity=options.process.body_model)
    results = body_model.process(np.array(res.image))
    if results.pose_landmarks is None:
        res.message = 'no body detected'
        res.image = None
        return res
    x0 = [res.image.width * (i.x - options.process.body_pad / 2) for i in results.pose_landmarks.landmark if i.visibility > options.process.body_visibility]
    y0 = [res.image.height * (i.y - options.process.body_pad / 2) for i in results.pose_landmarks.landmark if i.visibility > options.process.body_visibility]
    x1 = [res.image.width * (i.x + options.process.body_pad / 2) for i in results.pose_landmarks.landmark if i.visibility > options.process.body_visibility]
    y1 = [res.image.height * (i.y + options.process.body_pad / 2) for i in results.pose_landmarks.landmark if i.visibility > options.process.body_visibility]
    if len(x0) < options.process.body_parts:
        res.message = f'insufficient body parts detected: {len(x0)}'
        res.image = None
        return res
    res.image = res.image.crop((max(0, min(x0)), max(0, min(y0)), min(res.image.width, max(x1)), min(res.image.height, max(y1))))
    return res


def process_original(res: Result):
    res.ops.append('original')
    return res


def save_image(res: Result, folder: str):
    if res.image is None or folder is None:
        return res
    all_images_by_type[res.type] = all_images_by_type.get(res.type, 0) + 1
    res.basename = os.path.basename(res.input).split('.')[0]
    res.basename = str(all_images_by_type[res.type]).rjust(3, '0') + '-' + res.type + '-' + res.basename
    res.basename = os.path.join(folder, res.basename)
    res.output = res.basename + options.process.format
    res.image.save(res.output)
    res.image.close()
    res.ops.append('save')
    return res


def file(filename: str, folder: str, tag = None, requested = []):
    # initialize result dict
    res = Result(fn = filename, typ='unknown', tag=tag, requested = requested)
    # open image
    try:
        register_heif_opener()
        res.image = Image.open(filename)
        if res.image.mode == 'RGBA':
            res.image = res.image.convert('RGB')
        res.image = ImageOps.exif_transpose(res.image) # rotate image according to EXIF orientation
    except Exception as e:
        res.message = f'error opening: {e}'
        return res
    # primary steps
    if 'face' in requested:
        res.type = 'face'
        res = process_face(res)
    elif 'body' in requested:
        res.type = 'body'
        res = process_body(res)
    elif 'original' in requested:
        res.type = 'original'
        res = process_original(res)
    # validation steps
    if res.image is None:
        return res
    if 'blur' in requested:
        res.ops.append('blur')
        val = detect_blur(res.image)
        if val > options.process.blur_score:
            res.message = f'blur check failed: {val}'
            res.image = None
    if 'range' in requested:
        res.ops.append('range')
        val = detect_dynamicrange(res.image)
        if val < options.process.range_score:
            res.message = f'dynamic range check failed: {val}'
            res.image = None
    if 'similarity' in requested:
        res.ops.append('similarity')
        val = detect_simmilar(res.image)
        if val > options.process.similarity_score:
            res.message = f'dynamic range check failed: {val}'
            res.image = None
    if res.image is None:
        return res
    # post processing steps
    res = upscale_restore_image(res, 'upscale' in requested, 'restore' in requested)
    if res.image.width < options.process.target_size or res.image.height < options.process.target_size:
        res.message = f'low resolution: [{res.image.width}, {res.image.height}]'
        res.image = None
        return res
    if 'interrogate' in requested:
        res = interrogate_image(res, tag)
    if 'resize' in requested:
        res = resize_image(res)
    if 'square' in requested:
        res = square_image(res)
    if 'segment' in requested:
        res = segmentation(res)
    # finally save image
    res = save_image(res, folder)
    return res