import os import yaml import torch import argparse import numpy as np import gradio as gr from PIL import Image from copy import deepcopy from torch.nn.parallel import DataParallel, DistributedDataParallel from huggingface_hub import hf_hub_download from gradio_imageslider import ImageSlider ## local code from models import seemore def dict2namespace(config): namespace = argparse.Namespace() for key, value in config.items(): if isinstance(value, dict): new_value = dict2namespace(value) else: new_value = value setattr(namespace, key, new_value) return namespace def load_img (filename, norm=True,): img = np.array(Image.open(filename).convert("RGB")) h, w = img.shape[:2] if w > 1920 or h > 1080: new_h, new_w = h // 4, w // 4 img = np.array(Image.fromarray(img).resize((new_w, new_h), Image.BICUBIC)) if norm: img = img / 255. img = img.astype(np.float32) return img def process_img (image): img = np.array(image) img = img / 255. img = img.astype(np.float32) y = torch.tensor(img).permute(2,0,1).unsqueeze(0).to(device) with torch.no_grad(): x_hat = model(y) restored_img = x_hat.squeeze().permute(1,2,0).clamp_(0, 1).cpu().detach().numpy() restored_img = np.clip(restored_img, 0. , 1.) restored_img = (restored_img * 255.0).round().astype(np.uint8) # float32 to uint8 #return Image.fromarray(restored_img) # return (image, Image.fromarray(restored_img)) def load_network(net, load_path, strict=True, param_key='params'): if isinstance(net, (DataParallel, DistributedDataParallel)): net = net.module load_net = torch.load(load_path, map_location=lambda storage, loc: storage) if param_key is not None: if param_key not in load_net and 'params' in load_net: param_key = 'params' load_net = load_net[param_key] # remove unnecessary 'module.' for k, v in deepcopy(load_net).items(): if k.startswith('module.'): load_net[k[7:]] = v load_net.pop(k) net.load_state_dict(load_net, strict=strict) CONFIG = "configs/eval_seemore_t_x4.yml" hf_hub_download(repo_id="eduardzamfir/SeemoRe-T", filename="SeemoRe_T_X4.pth", local_dir="./") MODEL_NAME = "SeemoRe_T_X4.pth" # parse config file with open(os.path.join(CONFIG), "r") as f: config = yaml.safe_load(f) cfg = dict2namespace(config) device = torch.device("cpu") model = seemore.SeemoRe(scale=cfg.model.scale, in_chans=cfg.model.in_chans, num_experts=cfg.model.num_experts, num_layers=cfg.model.num_layers, embedding_dim=cfg.model.embedding_dim, img_range=cfg.model.img_range, use_shuffle=cfg.model.use_shuffle, global_kernel_size=cfg.model.global_kernel_size, recursive=cfg.model.recursive, lr_space=cfg.model.lr_space, topk=cfg.model.topk) model = model.to(device) print ("IMAGE MODEL CKPT:", MODEL_NAME) load_network(model, MODEL_NAME, strict=True, param_key='params') title = "Online Image Enhancer - Free Upscale & Enhancement" description = ''' Experience the power of our advanced image enhancement model! Upscale and enhance your images online for free. Simply upload an image to see the transformation. ''' article = "
See More Details: Efficient Image Super-Resolution by Experts Mining
" #### Image,Prompts examples examples = [ ['images/0801x4.png'], ['images/0840x4.png'], ['images/0841x4.png'], ['images/0870x4.png'], ['images/0878x4.png'], ['images/0884x4.png'], ['images/0900x4.png'], ['images/img002x4.png'], ['images/img003x4.png'], ['images/img004x4.png'], ['images/img035x4.png'], ['images/img053x4.png'], ['images/img064x4.png'], ['images/img083x4.png'], ['images/img092x4.png'], ] css = """ .image-frame img, .image-container img { width: auto; height: auto; max-width: none; } """ demo = gr.Interface( fn=process_img, inputs=[gr.Image(type="pil", label="Input", value="images/0878x4.png"),], outputs=ImageSlider(label="Super-Resolved Image", type="pil", show_download_button=True, ), #[gr.Image(type="pil", label="Ouput", min_width=500)], title=title, description=description, article=article, examples=examples, css=css, ) if __name__ == "__main__": demo.launch()