File size: 10,857 Bytes
3d4d894
 
 
 
 
 
be0162b
3d4d894
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be0162b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
04b1201
ef697d2
be0162b
 
3d4d894
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a54498b
3d4d894
 
 
 
 
be0162b
 
3d4d894
 
a54498b
3d4d894
 
 
 
 
 
 
 
 
 
 
a54498b
3d4d894
 
 
 
 
 
 
 
 
 
 
 
 
 
be0162b
3d4d894
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be0162b
3d4d894
 
 
 
be0162b
 
 
 
 
 
 
3d4d894
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be0162b
 
 
 
 
 
3d4d894
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
"""This file contains methods for inference and image generation."""
import logging
from typing import List, Tuple, Dict

import streamlit as st
import torch
import time
import numpy as np
from PIL import Image
from time import perf_counter
from contextlib import contextmanager
from scipy.signal import fftconvolve
from PIL import ImageFilter

from transformers import AutoImageProcessor, UperNetForSemanticSegmentation
from diffusers import ControlNetModel, UniPCMultistepScheduler
from diffusers import StableDiffusionInpaintPipeline
from compel import Compel

from config import WIDTH, HEIGHT
from palette import ade_palette
from stable_diffusion_controlnet_inpaint_img2img import StableDiffusionControlNetInpaintImg2ImgPipeline

LOGGING = logging.getLogger(__name__)


class ControlNetPipeline:
    def __init__(self):
        self.in_use = False
        self.controlnet = ControlNetModel.from_pretrained(
        "BertChristiaens/controlnet-seg-room", torch_dtype=torch.float16)

        self.pipe = StableDiffusionControlNetInpaintImg2ImgPipeline.from_pretrained(
            "runwayml/stable-diffusion-inpainting",
            controlnet=self.controlnet,
            safety_checker=None,
            torch_dtype=torch.float16
        )

        self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
        self.pipe.enable_xformers_memory_efficient_attention()
        self.pipe = self.pipe.to("cuda")
        
        self.waiting_queue = []
        self.count = 0
    
    def __call__(self, **kwargs):
        self.count += 1
        number = self.count

        self.waiting_queue.append(number)
        
        # wait until the next number in the queue is the current number
        while self.waiting_queue[0] != number:
            print(f"Wait for your turn {number} in queue {self.waiting_queue}")
            time.sleep(0.5)
            pass

        # it's your turn, so remove the number from the queue
        # and call the function
        print("It's the turn of", self.count)
        self.waiting_queue.pop(0)
        return self.pipe(**kwargs)


@contextmanager
def catchtime(message: str) -> float:
    """Context manager to measure time
    Args:
        message (str): message to log
    Returns:
        float: time in seconds
    Yields:
        Iterator[float]: time in seconds
    """
    start = perf_counter()
    yield lambda: perf_counter() - start
    LOGGING.info('%s: %.3f seconds', message, perf_counter() - start)


def convolution(mask: Image.Image, size=9) -> Image:
    """Method to blur the mask
    Args:
        mask (Image): masking image
        size (int, optional): size of the blur. Defaults to 9.
    Returns:
        Image: blurred mask
    """
    mask = np.array(mask.convert("L"))
    conv = np.ones((size, size)) / size**2
    mask_blended = fftconvolve(mask, conv, 'same')
    mask_blended = mask_blended.astype(np.uint8).copy()

    border = size

    # replace borders with original values
    mask_blended[:border, :] = mask[:border, :]
    mask_blended[-border:, :] = mask[-border:, :]
    mask_blended[:, :border] = mask[:, :border]
    mask_blended[:, -border:] = mask[:, -border:]

    return Image.fromarray(mask_blended).convert("L")


def postprocess_image_masking(inpainted: Image, image: Image, mask: Image) -> Image:
    """Method to postprocess the inpainted image
    Args:
        inpainted (Image): inpainted image
        image (Image): original image
        mask (Image): mask
    Returns:
        Image: inpainted image
    """
    final_inpainted = Image.composite(inpainted.convert("RGBA"), image.convert("RGBA"), mask)
    return final_inpainted.convert("RGB")


@st.experimental_singleton(max_entries=5)
def get_controlnet() -> ControlNetModel:
    """Method to load the controlnet model
    Returns:
        ControlNetModel: controlnet model
    """
    pipe = ControlNetPipeline()
    return pipe


@st.experimental_singleton(max_entries=5)
def get_segmentation_pipeline() -> Tuple[AutoImageProcessor, UperNetForSemanticSegmentation]:
    """Method to load the segmentation pipeline
    Returns:
        Tuple[AutoImageProcessor, UperNetForSemanticSegmentation]: segmentation pipeline
    """
    image_processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-small")
    image_segmentor = UperNetForSemanticSegmentation.from_pretrained(
        "openmmlab/upernet-convnext-small")
    return image_processor, image_segmentor


@st.experimental_singleton(max_entries=5)
def get_inpainting_pipeline() -> StableDiffusionInpaintPipeline:
    """Method to load the inpainting pipeline
    Returns:
        StableDiffusionInpaintPipeline: inpainting pipeline
    """
    pipe = StableDiffusionInpaintPipeline.from_pretrained(
        "stabilityai/stable-diffusion-2-inpainting",
        torch_dtype=torch.float16,
        safety_checker=None,
    )

    pipe.enable_xformers_memory_efficient_attention()
    pipe = pipe.to("cuda")

    return pipe


def make_grid_parameters(grid_search: Dict, params: Dict) -> List[Dict]:
    """Method to make grid parameters
    Args:
        grid_search (Dict): grid search parameters
        params (Dict): fixed parameters
    Returns:
        List[Dict]: grid parameters
    """
    options = []

    for k in range(len(grid_search['generator'])):
        for i in range(len(grid_search['strength'])):
            for j in range(len(grid_search['guidance_scale'])):
                options.append({'strength': grid_search['strength'][i],
                                'guidance_scale': grid_search['guidance_scale'][j],
                                'generator': grid_search['generator'][k],
                                **params
                                })
    return options


def make_captions(options: List[Dict]) -> List[str]:
    """Method to make captions
    Args:
        options (List[Dict]): grid parameters
    Returns:
        List[str]: captions
    """
    captions = []
    for option in options:
        captions.append(
            f"strength {option['strength']}, guidance {option['guidance_scale']}, steps {option['num_inference_steps']}")
    return captions


@torch.inference_mode()
def make_image_controlnet(image: np.ndarray,
                          mask_image: np.ndarray,
                          controlnet_conditioning_image: np.ndarray,
                          positive_prompt: str, negative_prompt: str,
                          seed: int = 2356132) -> List[Image.Image]:
    """Method to make image using controlnet
    Args:
        image (np.ndarray): input image
        mask_image (np.ndarray): mask image
        controlnet_conditioning_image (np.ndarray): conditioning image
        positive_prompt (str): positive prompt string
        negative_prompt (str): negative prompt string
        seed (int, optional): seed. Defaults to 2356132.
    Returns:
        List[Image.Image]: list of generated images
    """

    with catchtime("get controlnet"):
        pipe = get_controlnet()

    torch.cuda.empty_cache()
    images = []

    common_parameters = {'prompt': positive_prompt,
                        'negative_prompt': negative_prompt,
                        'num_inference_steps': 30,
                            'controlnet_conditioning_scale': 1.1,
                            'controlnet_conditioning_scale_decay': 0.96,
                            'controlnet_steps': 28,
                        }

    grid_search = {'strength': [1.00, ],
                   'guidance_scale': [7.0],
                   'generator': [[torch.Generator(device="cuda").manual_seed(seed+i)] for i in range(1)],
                   }

    prompt_settings = make_grid_parameters(grid_search, common_parameters)


    mask_image = Image.fromarray((mask_image * 255).astype(np.uint8)).convert("RGB")
    image = Image.fromarray(image).convert("RGB")
    controlnet_conditioning_image = Image.fromarray(controlnet_conditioning_image).convert("RGB").filter(ImageFilter.GaussianBlur(radius = 9))

    mask_image_postproc = convolution(mask_image)

    with catchtime("Controlnet generation total"):
        for _, setting in enumerate(prompt_settings):
            with catchtime("Controlnet generation"):
                generated_image = pipe(
                    **setting,
                    image=image,
                    mask_image=mask_image,
                    controlnet_conditioning_image=controlnet_conditioning_image,
                ).images[0]
                generated_image = postprocess_image_masking(
                    generated_image, image, mask_image_postproc)
            images.append(generated_image)

    return images


@torch.inference_mode()
def make_inpainting(positive_prompt: str,
                    image: Image,
                    mask_image: np.ndarray,
                    negative_prompt: str = "") -> List[Image.Image]:
    """Method to make inpainting
    Args:
        positive_prompt (str): positive prompt string
        image (Image): input image
        mask_image (np.ndarray): mask image
        negative_prompt (str, optional): negative prompt string. Defaults to "".
    Returns:
        List[Image.Image]: list of generated images
    """

    with catchtime("Get inpainting pipeline"):
        pipe = get_inpainting_pipeline()

    common_parameters = {'prompt': positive_prompt,
                        'negative_prompt': negative_prompt,
                        'num_inference_steps': 20,
                        }

    torch.cuda.empty_cache()
    images = []
    for _ in range(1):
        with catchtime("Inpainting generation"):
            image_ = pipe(image=image,
                          mask_image=Image.fromarray((mask_image * 255).astype(np.uint8)),
                          height=HEIGHT,
                          width=WIDTH,
                          **common_parameters
                          ).images[0]
        images.append(image_)
    return images


@torch.inference_mode()
@torch.autocast('cuda')
def segment_image(image: Image) -> Image:
    """Method to segment image
    Args:
        image (Image): input image
    Returns:
        Image: segmented image
    """
    image_processor, image_segmentor = get_segmentation_pipeline()
    pixel_values = image_processor(image, return_tensors="pt").pixel_values
    with torch.no_grad():
        outputs = image_segmentor(pixel_values)

    seg = image_processor.post_process_semantic_segmentation(
        outputs, target_sizes=[image.size[::-1]])
    seg = seg[0]
    color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)  # height, width, 3
    palette = np.array(ade_palette())
    for label, color in enumerate(palette):
        color_seg[seg == label, :] = color
    color_seg = color_seg.astype(np.uint8)
    seg_image = Image.fromarray(color_seg).convert('RGB')
    return seg_image