import os from functools import partial from glob import glob from pathlib import Path as PythonPath import cv2 import torchvision.transforms.functional as TvF import torch import torch.nn as nn import numpy as np from inspect import isfunction from PIL import Image from lib import smplfusion from lib.smplfusion import share, router, attentionpatch, transformerpatch from lib.utils.iimage import IImage from lib.utils import poisson_blend from lib.models.sd2_sr import predict_eps_from_z_and_v, predict_start_from_z_and_v def refine_mask(hr_image, hr_mask, lr_image, sam_predictor): lr_mask = hr_mask.resize(512) x_min, y_min, rect_w, rect_h = cv2.boundingRect(lr_mask.data[0][:, :, 0]) x_min = max(x_min - 1, 0) y_min = max(y_min - 1, 0) x_max = x_min + rect_w + 1 y_max = y_min + rect_h + 1 input_box = np.array([x_min, y_min, x_max, y_max]) sam_predictor.set_image(hr_image.resize(512).data[0]) masks, _, _ = sam_predictor.predict( point_coords=None, point_labels=None, box=input_box[None, :], multimask_output=True, ) dilation_kernel = np.ones((13, 13)) original_object_mask = (np.sum(masks, axis=0) > 0).astype(np.uint8) original_object_mask = cv2.dilate(original_object_mask, dilation_kernel) sam_predictor.set_image(lr_image.resize(512).data[0]) masks, _, _ = sam_predictor.predict( point_coords=None, point_labels=None, box=input_box[None, :], multimask_output=True, ) dilation_kernel = np.ones((3, 3)) inpainted_object_mask = (np.sum(masks, axis=0) > 0).astype(np.uint8) inpainted_object_mask = cv2.dilate(inpainted_object_mask, dilation_kernel) lr_mask_masking = ((original_object_mask + inpainted_object_mask ) > 0).astype(np.uint8) new_mask = lr_mask.data[0] * lr_mask_masking[:, :, np.newaxis] new_mask = IImage(new_mask).resize(2048, resample = Image.BICUBIC) return new_mask def run(ddim, sam_predictor, lr_image, hr_image, hr_mask, prompt = 'high resolution professional photo', noise_level=20, blend_output = True, blend_trick = True, no_superres = False, dt = 50, seed = 1, guidance_scale = 7.5, negative_prompt = '', use_sam_mask = False): torch.manual_seed(seed) dtype = ddim.vae.encoder.conv_in.weight.dtype device = ddim.vae.encoder.conv_in.weight.device router.attention_forward = attentionpatch.default.forward_xformers router.basic_transformer_forward = transformerpatch.default.forward if use_sam_mask: with torch.no_grad(): hr_mask = refine_mask(hr_image, hr_mask, lr_image, sam_predictor) orig_h, orig_w = hr_image.torch().shape[2], hr_image.torch().shape[3] hr_image = hr_image.padx(256, padding_mode='reflect') hr_mask = hr_mask.padx(256, padding_mode='reflect').dilate(19) hr_mask_orig = hr_mask lr_image = lr_image.padx(64, padding_mode='reflect') lr_mask = hr_mask.resize((lr_image.torch().shape[2], lr_image.torch().shape[3]), resample = Image.BICUBIC).alpha().torch(vmin=0).to(device) lr_mask = TvF.gaussian_blur(lr_mask, kernel_size=19) if no_superres: output_tensor = lr_image.resize((hr_image.torch().shape[2], hr_image.torch().shape[3]), resample = Image.BICUBIC).torch().cuda() output_tensor = (255*((output_tensor.clip(-1, 1) + 1) / 2)).to(torch.uint8) output_tensor = poisson_blend( orig_img=hr_image.data[0][:orig_h, :orig_w, :], fake_img=output_tensor.cpu().permute(0, 2, 3, 1)[0].numpy()[:orig_h, :orig_w, :], mask=hr_mask_orig.alpha().data[0][:orig_h, :orig_w, :] ) return IImage(output_tensor[:orig_h, :orig_w, :]) # encode hr image with torch.no_grad(): hr_z0 = ddim.vae.encode(hr_image.torch().cuda().to(dtype=dtype, device=device)).mean * ddim.config.scale_factor assert hr_z0.shape[2] == lr_image.torch().shape[2] assert hr_z0.shape[3] == lr_image.torch().shape[3] unet_condition = lr_image.cuda().torch().to(memory_format=torch.contiguous_format).to(dtype=dtype, device=device) zT = torch.randn((1,4,unet_condition.shape[2], unet_condition.shape[3])).cuda().to(dtype=dtype, device=device) with torch.no_grad(): context = ddim.encoder.encode([negative_prompt, prompt]) noise_level = torch.Tensor(1 * [noise_level]).to(device=device).long() unet_condition, noise_level = ddim.low_scale_model(unet_condition, noise_level=noise_level) with torch.autocast('cuda'), torch.no_grad(): zt = zT for index,t in enumerate(range(999, 0, -dt)): _zt = zt if unet_condition is None else torch.cat([zt, unet_condition], 1) eps_uncond, eps = ddim.unet( torch.cat([_zt, _zt]).to(dtype=dtype, device=device), timesteps = torch.tensor([t, t]).to(device=device), context = context, y=torch.cat([noise_level]*2) ).chunk(2) ts = torch.full((zt.shape[0],), t, device=device, dtype=torch.long) model_output = (eps_uncond + guidance_scale * (eps - eps_uncond)) eps = predict_eps_from_z_and_v(ddim.schedule, zt, ts, model_output).to(dtype) z0 = predict_start_from_z_and_v(ddim.schedule, zt, ts, model_output).to(dtype) if blend_trick: z0 = z0 * lr_mask + hr_z0 * (1-lr_mask) zt = ddim.schedule.sqrt_alphas[t - dt] * z0 + ddim.schedule.sqrt_one_minus_alphas[t - dt] * eps with torch.no_grad(): output_tensor = ddim.vae.decode(z0.to(dtype) / ddim.config.scale_factor) if blend_output: output_tensor = (255*((output_tensor + 1) / 2).clip(0, 1)).to(torch.uint8) output_tensor = poisson_blend( orig_img=hr_image.data[0][:orig_h, :orig_w, :], fake_img=output_tensor.cpu().permute(0, 2, 3, 1)[0].numpy()[:orig_h, :orig_w, :], mask=hr_mask_orig.alpha().data[0][:orig_h, :orig_w, :] ) return IImage(output_tensor[:orig_h, :orig_w, :]) else: return IImage(output_tensor[:, :, :orig_h, :orig_w])