Spaces:
Running
on
A10G
Running
on
A10G
File size: 6,246 Bytes
bfd34e9 29ac50c bfd34e9 da1e12f bfd34e9 da1e12f bfd34e9 da1e12f bfd34e9 da1e12f bfd34e9 da1e12f bfd34e9 da1e12f bfd34e9 da1e12f bfd34e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import os
from functools import partial
from glob import glob
from pathlib import Path as PythonPath
import cv2
import torchvision.transforms.functional as TvF
import torch
import torch.nn as nn
import numpy as np
from inspect import isfunction
from PIL import Image
from lib import smplfusion
from lib.smplfusion import share, router, attentionpatch, transformerpatch
from lib.utils.iimage import IImage
from lib.utils import poisson_blend
from lib.models.sd2_sr import predict_eps_from_z_and_v, predict_start_from_z_and_v
def refine_mask(hr_image, hr_mask, lr_image, sam_predictor):
lr_mask = hr_mask.resize(512)
x_min, y_min, rect_w, rect_h = cv2.boundingRect(lr_mask.data[0][:, :, 0])
x_min = max(x_min - 1, 0)
y_min = max(y_min - 1, 0)
x_max = x_min + rect_w + 1
y_max = y_min + rect_h + 1
input_box = np.array([x_min, y_min, x_max, y_max])
sam_predictor.set_image(hr_image.resize(512).data[0])
masks, _, _ = sam_predictor.predict(
point_coords=None,
point_labels=None,
box=input_box[None, :],
multimask_output=True,
)
dilation_kernel = np.ones((13, 13))
original_object_mask = (np.sum(masks, axis=0) > 0).astype(np.uint8)
original_object_mask = cv2.dilate(original_object_mask, dilation_kernel)
sam_predictor.set_image(lr_image.resize(512).data[0])
masks, _, _ = sam_predictor.predict(
point_coords=None,
point_labels=None,
box=input_box[None, :],
multimask_output=True,
)
dilation_kernel = np.ones((3, 3))
inpainted_object_mask = (np.sum(masks, axis=0) > 0).astype(np.uint8)
inpainted_object_mask = cv2.dilate(inpainted_object_mask, dilation_kernel)
lr_mask_masking = ((original_object_mask + inpainted_object_mask ) > 0).astype(np.uint8)
new_mask = lr_mask.data[0] * lr_mask_masking[:, :, np.newaxis]
new_mask = IImage(new_mask).resize(2048, resample = Image.BICUBIC)
return new_mask
def run(ddim, sam_predictor, lr_image, hr_image, hr_mask, prompt = 'high resolution professional photo', noise_level=20,
blend_output = True, blend_trick = True, no_superres = False,
dt = 50, seed = 1, guidance_scale = 7.5, negative_prompt = '', use_sam_mask = False):
torch.manual_seed(seed)
dtype = ddim.vae.encoder.conv_in.weight.dtype
device = ddim.vae.encoder.conv_in.weight.device
router.attention_forward = attentionpatch.default.forward_xformers
router.basic_transformer_forward = transformerpatch.default.forward
if use_sam_mask:
with torch.no_grad():
hr_mask = refine_mask(hr_image, hr_mask, lr_image, sam_predictor)
orig_h, orig_w = hr_image.torch().shape[2], hr_image.torch().shape[3]
hr_image = hr_image.padx(256, padding_mode='reflect')
hr_mask = hr_mask.padx(256, padding_mode='reflect').dilate(19)
hr_mask_orig = hr_mask
lr_image = lr_image.padx(64, padding_mode='reflect')
lr_mask = hr_mask.resize((lr_image.torch().shape[2], lr_image.torch().shape[3]), resample = Image.BICUBIC).alpha().torch(vmin=0).to(device)
lr_mask = TvF.gaussian_blur(lr_mask, kernel_size=19)
if no_superres:
output_tensor = lr_image.resize((hr_image.torch().shape[2], hr_image.torch().shape[3]), resample = Image.BICUBIC).torch().cuda()
output_tensor = (255*((output_tensor.clip(-1, 1) + 1) / 2)).to(torch.uint8)
output_tensor = poisson_blend(
orig_img=hr_image.data[0][:orig_h, :orig_w, :],
fake_img=output_tensor.cpu().permute(0, 2, 3, 1)[0].numpy()[:orig_h, :orig_w, :],
mask=hr_mask_orig.alpha().data[0][:orig_h, :orig_w, :]
)
return IImage(output_tensor[:orig_h, :orig_w, :])
# encode hr image
with torch.no_grad():
hr_z0 = ddim.vae.encode(hr_image.torch().cuda().to(dtype=dtype, device=device)).mean * ddim.config.scale_factor
assert hr_z0.shape[2] == lr_image.torch().shape[2]
assert hr_z0.shape[3] == lr_image.torch().shape[3]
unet_condition = lr_image.cuda().torch().to(memory_format=torch.contiguous_format).to(dtype=dtype, device=device)
zT = torch.randn((1,4,unet_condition.shape[2], unet_condition.shape[3])).cuda().to(dtype=dtype, device=device)
with torch.no_grad():
context = ddim.encoder.encode([negative_prompt, prompt])
noise_level = torch.Tensor(1 * [noise_level]).to(device=device).long()
unet_condition, noise_level = ddim.low_scale_model(unet_condition, noise_level=noise_level)
with torch.autocast('cuda'), torch.no_grad():
zt = zT
for index,t in enumerate(range(999, 0, -dt)):
_zt = zt if unet_condition is None else torch.cat([zt, unet_condition], 1)
eps_uncond, eps = ddim.unet(
torch.cat([_zt, _zt]).to(dtype=dtype, device=device),
timesteps = torch.tensor([t, t]).to(device=device),
context = context,
y=torch.cat([noise_level]*2)
).chunk(2)
ts = torch.full((zt.shape[0],), t, device=device, dtype=torch.long)
model_output = (eps_uncond + guidance_scale * (eps - eps_uncond))
eps = predict_eps_from_z_and_v(ddim.schedule, zt, ts, model_output).to(dtype)
z0 = predict_start_from_z_and_v(ddim.schedule, zt, ts, model_output).to(dtype)
if blend_trick:
z0 = z0 * lr_mask + hr_z0 * (1-lr_mask)
zt = ddim.schedule.sqrt_alphas[t - dt] * z0 + ddim.schedule.sqrt_one_minus_alphas[t - dt] * eps
with torch.no_grad():
output_tensor = ddim.vae.decode(z0.to(dtype) / ddim.config.scale_factor)
if blend_output:
output_tensor = (255*((output_tensor + 1) / 2).clip(0, 1)).to(torch.uint8)
output_tensor = poisson_blend(
orig_img=hr_image.data[0][:orig_h, :orig_w, :],
fake_img=output_tensor.cpu().permute(0, 2, 3, 1)[0].numpy()[:orig_h, :orig_w, :],
mask=hr_mask_orig.alpha().data[0][:orig_h, :orig_w, :]
)
return IImage(output_tensor[:orig_h, :orig_w, :])
else:
return IImage(output_tensor[:, :, :orig_h, :orig_w])
|