Spaces:
PAIR
/
Running on A10G

File size: 6,246 Bytes
bfd34e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29ac50c
bfd34e9
da1e12f
 
bfd34e9
 
 
 
 
 
 
 
 
 
 
 
 
da1e12f
bfd34e9
 
 
 
 
 
 
 
 
 
 
 
 
 
da1e12f
bfd34e9
 
 
 
da1e12f
 
bfd34e9
 
 
 
da1e12f
bfd34e9
 
 
 
 
 
 
 
 
da1e12f
 
bfd34e9
 
 
 
da1e12f
bfd34e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import os
from functools import partial
from glob import glob
from pathlib import Path as PythonPath

import cv2
import torchvision.transforms.functional as TvF
import torch
import torch.nn as nn
import numpy as np
from inspect import isfunction
from PIL import Image

from lib import smplfusion
from lib.smplfusion import share, router, attentionpatch, transformerpatch
from lib.utils.iimage import IImage
from lib.utils import poisson_blend
from lib.models.sd2_sr import predict_eps_from_z_and_v, predict_start_from_z_and_v


def refine_mask(hr_image, hr_mask, lr_image, sam_predictor):
    lr_mask = hr_mask.resize(512)

    x_min, y_min, rect_w, rect_h = cv2.boundingRect(lr_mask.data[0][:, :, 0])
    x_min = max(x_min - 1, 0)
    y_min = max(y_min - 1, 0)
    x_max = x_min + rect_w + 1
    y_max = y_min + rect_h + 1

    input_box = np.array([x_min, y_min, x_max, y_max])

    sam_predictor.set_image(hr_image.resize(512).data[0])
    masks, _, _ = sam_predictor.predict(
        point_coords=None,
        point_labels=None,
        box=input_box[None, :],
        multimask_output=True,
    )
    dilation_kernel = np.ones((13, 13))
    original_object_mask = (np.sum(masks, axis=0) > 0).astype(np.uint8)
    original_object_mask = cv2.dilate(original_object_mask, dilation_kernel)

    sam_predictor.set_image(lr_image.resize(512).data[0])
    masks, _, _ = sam_predictor.predict(
        point_coords=None,
        point_labels=None,
        box=input_box[None, :],
        multimask_output=True,
    )
    dilation_kernel = np.ones((3, 3))
    inpainted_object_mask = (np.sum(masks, axis=0) > 0).astype(np.uint8)
    inpainted_object_mask = cv2.dilate(inpainted_object_mask, dilation_kernel)

    lr_mask_masking = ((original_object_mask + inpainted_object_mask ) > 0).astype(np.uint8)
    new_mask = lr_mask.data[0] * lr_mask_masking[:, :, np.newaxis]
    new_mask = IImage(new_mask).resize(2048, resample = Image.BICUBIC)
    return new_mask


def run(ddim, sam_predictor, lr_image, hr_image, hr_mask, prompt = 'high resolution professional photo', noise_level=20,
blend_output = True, blend_trick = True, no_superres = False,
dt = 50, seed = 1, guidance_scale = 7.5, negative_prompt = '', use_sam_mask = False):
    torch.manual_seed(seed)
    dtype = ddim.vae.encoder.conv_in.weight.dtype
    device = ddim.vae.encoder.conv_in.weight.device

    router.attention_forward = attentionpatch.default.forward_xformers
    router.basic_transformer_forward = transformerpatch.default.forward

    if use_sam_mask:
        with torch.no_grad():
            hr_mask = refine_mask(hr_image, hr_mask, lr_image, sam_predictor)

    orig_h, orig_w = hr_image.torch().shape[2], hr_image.torch().shape[3]
    hr_image = hr_image.padx(256, padding_mode='reflect')
    hr_mask = hr_mask.padx(256, padding_mode='reflect').dilate(19)
    hr_mask_orig = hr_mask
    lr_image = lr_image.padx(64, padding_mode='reflect')
    lr_mask = hr_mask.resize((lr_image.torch().shape[2], lr_image.torch().shape[3]), resample = Image.BICUBIC).alpha().torch(vmin=0).to(device)
    lr_mask = TvF.gaussian_blur(lr_mask, kernel_size=19)

    if no_superres:
        output_tensor = lr_image.resize((hr_image.torch().shape[2], hr_image.torch().shape[3]), resample = Image.BICUBIC).torch().cuda()
        output_tensor = (255*((output_tensor.clip(-1, 1) + 1) / 2)).to(torch.uint8)
        output_tensor = poisson_blend(
            orig_img=hr_image.data[0][:orig_h, :orig_w, :],
            fake_img=output_tensor.cpu().permute(0, 2, 3, 1)[0].numpy()[:orig_h, :orig_w, :],
            mask=hr_mask_orig.alpha().data[0][:orig_h, :orig_w, :]
        )
        return IImage(output_tensor[:orig_h, :orig_w, :])
                                     
    # encode hr image
    with torch.no_grad():
        hr_z0 = ddim.vae.encode(hr_image.torch().cuda().to(dtype=dtype, device=device)).mean * ddim.config.scale_factor

    assert hr_z0.shape[2] == lr_image.torch().shape[2]
    assert hr_z0.shape[3] == lr_image.torch().shape[3]                                  
                                     
    unet_condition = lr_image.cuda().torch().to(memory_format=torch.contiguous_format).to(dtype=dtype, device=device)
    zT = torch.randn((1,4,unet_condition.shape[2], unet_condition.shape[3])).cuda().to(dtype=dtype, device=device)
    
    with torch.no_grad():
        context = ddim.encoder.encode([negative_prompt, prompt])

    noise_level = torch.Tensor(1 * [noise_level]).to(device=device).long()
    unet_condition, noise_level = ddim.low_scale_model(unet_condition, noise_level=noise_level)

    with torch.autocast('cuda'), torch.no_grad():
        zt = zT
        for index,t in enumerate(range(999, 0, -dt)):
           
            _zt = zt if unet_condition is None else torch.cat([zt, unet_condition], 1)

            eps_uncond, eps = ddim.unet(
                torch.cat([_zt, _zt]).to(dtype=dtype, device=device), 
                timesteps = torch.tensor([t, t]).to(device=device), 
                context = context,
                y=torch.cat([noise_level]*2)
            ).chunk(2)

            ts = torch.full((zt.shape[0],), t, device=device, dtype=torch.long)
            model_output = (eps_uncond + guidance_scale * (eps - eps_uncond))
            eps = predict_eps_from_z_and_v(ddim.schedule, zt, ts, model_output).to(dtype)
            z0 = predict_start_from_z_and_v(ddim.schedule, zt, ts, model_output).to(dtype)

            if blend_trick:
                z0 = z0 * lr_mask + hr_z0 * (1-lr_mask)
            
            zt = ddim.schedule.sqrt_alphas[t - dt] * z0 + ddim.schedule.sqrt_one_minus_alphas[t - dt] * eps

    with torch.no_grad():
        output_tensor = ddim.vae.decode(z0.to(dtype) / ddim.config.scale_factor)

    if blend_output:
        output_tensor = (255*((output_tensor + 1) / 2).clip(0, 1)).to(torch.uint8)
        output_tensor = poisson_blend(
            orig_img=hr_image.data[0][:orig_h, :orig_w, :],
            fake_img=output_tensor.cpu().permute(0, 2, 3, 1)[0].numpy()[:orig_h, :orig_w, :],
            mask=hr_mask_orig.alpha().data[0][:orig_h, :orig_w, :]
        )
        return IImage(output_tensor[:orig_h, :orig_w, :])
    else:
        return IImage(output_tensor[:, :, :orig_h, :orig_w])