kandisuperres / KandiSuperRes /sr_pipeline.py
doevent's picture
Upload 14 files
5004324 verified
import torch
import numpy as np
import PIL
import torchvision.transforms as T
import torch.nn.functional as F
from KandiSuperRes.model.unet import UNet
from KandiSuperRes.model.unet_sr import UNet as UNet_sr
from KandiSuperRes.movq import MoVQ
from KandiSuperRes.model.diffusion_sr import DPMSolver
from KandiSuperRes.model.diffusion_refine import BaseDiffusion, get_named_beta_schedule
from KandiSuperRes.model.diffusion_sr_turbo import BaseDiffusion as BaseDiffusion_turbo
class KandiSuperResPipeline:
def __init__(
self,
scale: int,
device: str,
dtype: str,
flash: bool,
sr_model: UNet_sr,
movq: MoVQ = None,
refiner: UNet = None,
):
self.device = device
self.dtype = dtype
self.scale = scale
self.flash = flash
self.to_pil = T.ToPILImage()
self.image_transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.ToTensor(),
T.Lambda(lambda img: 2. * img - 1.),
])
self.sr_model = sr_model
self.movq = movq
self.refiner = refiner
def __call__(
self,
pil_image: PIL.Image.Image = None,
steps: int = 5,
view_batch_size: int = 15,
seed: int = 0,
refine=True
) -> PIL.Image.Image:
if self.flash:
betas_turbo = get_named_beta_schedule('linear', 1000)
base_diffusion_sr = BaseDiffusion_turbo(betas_turbo)
old_height = pil_image.size[1]
old_width = pil_image.size[0]
height = int(old_height-np.mod(old_height,32))
width = int(old_width-np.mod(old_width,32))
pil_image = pil_image.resize((width,height))
lr_image = self.image_transform(pil_image).unsqueeze(0).to(self.device['sr_model'])
sr_image = base_diffusion_sr.p_sample_loop(
self.sr_model, (1, 3, height*self.scale, width*self.scale), self.device['sr_model'], self.dtype['sr_model'], lowres_img=lr_image
)
if refine:
betas = get_named_beta_schedule('cosine', 1000)
base_diffusion = BaseDiffusion(betas, 0.99)
with torch.cuda.amp.autocast(dtype=self.dtype['movq']):
lr_image_latent = self.movq.encode(sr_image)
pil_images = []
context = torch.load('weights/context.pt').to(self.dtype['refiner'])
context_mask = torch.load('weights/context_mask.pt').to(self.dtype['refiner'])
with torch.no_grad():
with torch.cuda.amp.autocast(dtype=self.dtype['refiner']):
refiner_image = base_diffusion.refine_tiled(self.refiner, lr_image_latent, context, context_mask)
with torch.cuda.amp.autocast(dtype=self.dtype['movq']):
refiner_image = self.movq.decode(refiner_image)
refiner_image = torch.clip((refiner_image + 1.) / 2., 0., 1.)
if old_height*self.scale != refiner_image.shape[2] or old_width*self.scale != refiner_image.shape[3]:
refiner_image = F.interpolate(refiner_image, [old_height*self.scale, old_width*self.scale], mode='bilinear', align_corners=True)
refined_pil_image = self.to_pil(refiner_image[0])
return refined_pil_image
sr_image = torch.clip((sr_image + 1.) / 2., 0., 1.)
if old_height*self.scale != sr_image.shape[2] or old_width*self.scale != sr_image.shape[3]:
sr_image = F.interpolate(sr_image, [old_height*self.scale, old_width*self.scale], mode='bilinear', align_corners=True)
pil_sr_image = self.to_pil(sr_image[0])
return pil_sr_image
else:
base_diffusion = DPMSolver(steps)
lr_image = self.image_transform(pil_image).unsqueeze(0).to(self.device)
old_height = pil_image.size[1]
old_width = pil_image.size[0]
height = int(old_height+np.mod(old_height,2))*self.scale
width = int(old_width+np.mod(old_width,2))*self.scale
sr_image = base_diffusion.generate_panorama(height, width, self.device, self.dtype, steps,
self.sr_model, lowres_img=lr_image,
view_batch_size=view_batch_size, eta=0.0, seed=seed)
sr_image = torch.clip((sr_image + 1.) / 2., 0., 1.)
if old_height*self.scale != height or old_width*self.scale != width:
sr_image = F.interpolate(sr_image, [old_height*self.scale, old_width*self.scale], mode='bilinear', align_corners=True)
pil_sr_image = self.to_pil(sr_image[0])
return pil_sr_image