File size: 2,997 Bytes
8da8f47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import sys

import os
import math
import argparse
import random
import logging

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from data.data_sampler import DistIterSampler

import options.options as option
from utils import util
from data.util import read_img 
from data import create_dataloader, create_dataset
from models import create_model
import numpy as np
from PIL import Image
from diffusers import StableDiffusionInpaintPipeline


def init_dist(backend='nccl', **kwargs):
    ''' initialization for distributed training'''
    # if mp.get_start_method(allow_none=True) is None:
    if mp.get_start_method(allow_none=True) != 'spawn':
        mp.set_start_method('spawn')
    rank = int(os.environ['RANK'])
    num_gpus = torch.cuda.device_count()
    torch.cuda.set_device(rank % num_gpus)
    dist.init_process_group(backend=backend, **kwargs)


def load_image(image, message = None):
    # img_GT = read_img(None, image_path)
    img_GT = image / 255
    # print(img_GT)
    img_GT = img_GT[:, :, [2, 1, 0]]
    img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float().unsqueeze(0)
    img_GT = torch.nn.functional.interpolate(img_GT, size=(512, 512), mode='nearest', align_corners=None)
    img_GT = img_GT.unsqueeze(0)

    _, T, C, W, H = img_GT.shape
    list_h = []
    R = 0
    G = 0
    B = 255
    image = Image.new('RGB', (W, H), (R, G, B))
    result = np.array(image) / 255.
    expanded_matrix = np.expand_dims(result, axis=0) 
    expanded_matrix = np.repeat(expanded_matrix, T, axis=0)
    imgs_LQ = torch.from_numpy(np.ascontiguousarray(expanded_matrix)).float()
    imgs_LQ = imgs_LQ.permute(0, 3, 1, 2)
    imgs_LQ = torch.nn.functional.interpolate(imgs_LQ, size=(W, H), mode='nearest', align_corners=None)
    imgs_LQ = imgs_LQ.unsqueeze(0)

    list_h.append(imgs_LQ)

    list_h = torch.stack(list_h, dim=0)

    return {
            'LQ': list_h,
            'GT': img_GT,
            'MES': message
        }


def image_editing(image_numpy, mask_image, prompt):

    pipe = StableDiffusionInpaintPipeline.from_pretrained(
            "stabilityai/stable-diffusion-2-inpainting",
            torch_dtype=torch.float16,
        ).to("cuda")
    
    pil_image = Image.fromarray(image_numpy)
    print(mask_image.shape)
    print("maskmin", mask_image.min(), "maskmax", mask_image.max())
    mask_image = Image.fromarray(mask_image.astype(np.uint8)).convert("L")
    image_init = pil_image.convert("RGB").resize((512, 512))
    
    h, w = mask_image.size
    
    image_inpaint = pipe(prompt=prompt, image=image_init, mask_image=mask_image, height=w, width=h).images[0]
    image_inpaint = np.array(image_inpaint) / 255.
    image = np.array(image_init) / 255.
    mask_image = np.array(mask_image)
    mask_image = np.stack([mask_image] * 3, axis=-1) / 255.
    mask_image = mask_image.astype(np.uint8)
    image_fuse = image * (1 - mask_image) + image_inpaint * mask_image

    return image_fuse