Spaces:
Running
on
A10G
Running
on
A10G
import os | |
from functools import partial | |
from glob import glob | |
from pathlib import Path as PythonPath | |
import cv2 | |
import torchvision.transforms.functional as TvF | |
import torch | |
import torch.nn as nn | |
import numpy as np | |
from inspect import isfunction | |
from PIL import Image | |
from lib import smplfusion | |
from lib.smplfusion import share, router, attentionpatch, transformerpatch | |
from lib.utils.iimage import IImage | |
from lib.utils import poisson_blend | |
from lib.models.sd2_sr import predict_eps_from_z_and_v, predict_start_from_z_and_v | |
def refine_mask(hr_image, hr_mask, lr_image, sam_predictor): | |
lr_mask = hr_mask.resize(512) | |
x_min, y_min, rect_w, rect_h = cv2.boundingRect(lr_mask.data[0][:, :, 0]) | |
x_min = max(x_min - 1, 0) | |
y_min = max(y_min - 1, 0) | |
x_max = x_min + rect_w + 1 | |
y_max = y_min + rect_h + 1 | |
input_box = np.array([x_min, y_min, x_max, y_max]) | |
sam_predictor.set_image(hr_image.resize(512).data[0]) | |
masks, _, _ = sam_predictor.predict( | |
point_coords=None, | |
point_labels=None, | |
box=input_box[None, :], | |
multimask_output=True, | |
) | |
dilation_kernel = np.ones((13, 13)) | |
original_object_mask = (np.sum(masks, axis=0) > 0).astype(np.uint8) | |
original_object_mask = cv2.dilate(original_object_mask, dilation_kernel) | |
sam_predictor.set_image(lr_image.resize(512).data[0]) | |
masks, _, _ = sam_predictor.predict( | |
point_coords=None, | |
point_labels=None, | |
box=input_box[None, :], | |
multimask_output=True, | |
) | |
dilation_kernel = np.ones((3, 3)) | |
inpainted_object_mask = (np.sum(masks, axis=0) > 0).astype(np.uint8) | |
inpainted_object_mask = cv2.dilate(inpainted_object_mask, dilation_kernel) | |
lr_mask_masking = ((original_object_mask + inpainted_object_mask ) > 0).astype(np.uint8) | |
new_mask = lr_mask.data[0] * lr_mask_masking[:, :, np.newaxis] | |
new_mask = IImage(new_mask).resize(2048, resample = Image.BICUBIC) | |
return new_mask | |
def run(ddim, sam_predictor, lr_image, hr_image, hr_mask, prompt = 'high resolution professional photo', noise_level=20, | |
blend_output = True, blend_trick = True, no_superres = False, | |
dt = 50, seed = 1, guidance_scale = 7.5, negative_prompt = '', use_sam_mask = False): | |
torch.manual_seed(seed) | |
dtype = ddim.vae.encoder.conv_in.weight.dtype | |
device = ddim.vae.encoder.conv_in.weight.device | |
router.attention_forward = attentionpatch.default.forward_xformers | |
router.basic_transformer_forward = transformerpatch.default.forward | |
if use_sam_mask: | |
with torch.no_grad(): | |
hr_mask = refine_mask(hr_image, hr_mask, lr_image, sam_predictor) | |
orig_h, orig_w = hr_image.torch().shape[2], hr_image.torch().shape[3] | |
hr_image = hr_image.padx(256, padding_mode='reflect') | |
hr_mask = hr_mask.padx(256, padding_mode='reflect').dilate(19) | |
hr_mask_orig = hr_mask | |
lr_image = lr_image.padx(64, padding_mode='reflect') | |
lr_mask = hr_mask.resize((lr_image.torch().shape[2], lr_image.torch().shape[3]), resample = Image.BICUBIC).alpha().torch(vmin=0).to(device) | |
lr_mask = TvF.gaussian_blur(lr_mask, kernel_size=19) | |
if no_superres: | |
output_tensor = lr_image.resize((hr_image.torch().shape[2], hr_image.torch().shape[3]), resample = Image.BICUBIC).torch().cuda() | |
output_tensor = (255*((output_tensor.clip(-1, 1) + 1) / 2)).to(torch.uint8) | |
output_tensor = poisson_blend( | |
orig_img=hr_image.data[0][:orig_h, :orig_w, :], | |
fake_img=output_tensor.cpu().permute(0, 2, 3, 1)[0].numpy()[:orig_h, :orig_w, :], | |
mask=hr_mask_orig.alpha().data[0][:orig_h, :orig_w, :] | |
) | |
return IImage(output_tensor[:orig_h, :orig_w, :]) | |
# encode hr image | |
with torch.no_grad(): | |
hr_z0 = ddim.vae.encode(hr_image.torch().cuda().to(dtype=dtype, device=device)).mean * ddim.config.scale_factor | |
assert hr_z0.shape[2] == lr_image.torch().shape[2] | |
assert hr_z0.shape[3] == lr_image.torch().shape[3] | |
unet_condition = lr_image.cuda().torch().to(memory_format=torch.contiguous_format).to(dtype=dtype, device=device) | |
zT = torch.randn((1,4,unet_condition.shape[2], unet_condition.shape[3])).cuda().to(dtype=dtype, device=device) | |
with torch.no_grad(): | |
context = ddim.encoder.encode([negative_prompt, prompt]) | |
noise_level = torch.Tensor(1 * [noise_level]).to(device=device).long() | |
unet_condition, noise_level = ddim.low_scale_model(unet_condition, noise_level=noise_level) | |
with torch.autocast('cuda'), torch.no_grad(): | |
zt = zT | |
for index,t in enumerate(range(999, 0, -dt)): | |
_zt = zt if unet_condition is None else torch.cat([zt, unet_condition], 1) | |
eps_uncond, eps = ddim.unet( | |
torch.cat([_zt, _zt]).to(dtype=dtype, device=device), | |
timesteps = torch.tensor([t, t]).to(device=device), | |
context = context, | |
y=torch.cat([noise_level]*2) | |
).chunk(2) | |
ts = torch.full((zt.shape[0],), t, device=device, dtype=torch.long) | |
model_output = (eps_uncond + guidance_scale * (eps - eps_uncond)) | |
eps = predict_eps_from_z_and_v(ddim.schedule, zt, ts, model_output).to(dtype) | |
z0 = predict_start_from_z_and_v(ddim.schedule, zt, ts, model_output).to(dtype) | |
if blend_trick: | |
z0 = z0 * lr_mask + hr_z0 * (1-lr_mask) | |
zt = ddim.schedule.sqrt_alphas[t - dt] * z0 + ddim.schedule.sqrt_one_minus_alphas[t - dt] * eps | |
with torch.no_grad(): | |
output_tensor = ddim.vae.decode(z0.to(dtype) / ddim.config.scale_factor) | |
if blend_output: | |
output_tensor = (255*((output_tensor + 1) / 2).clip(0, 1)).to(torch.uint8) | |
output_tensor = poisson_blend( | |
orig_img=hr_image.data[0][:orig_h, :orig_w, :], | |
fake_img=output_tensor.cpu().permute(0, 2, 3, 1)[0].numpy()[:orig_h, :orig_w, :], | |
mask=hr_mask_orig.alpha().data[0][:orig_h, :orig_w, :] | |
) | |
return IImage(output_tensor[:orig_h, :orig_w, :]) | |
else: | |
return IImage(output_tensor[:, :, :orig_h, :orig_w]) | |