from cgitb import enable from ctypes.wintypes import HFONT import os import sys import torch import gradio as gr import numpy as np import torchvision.transforms as transforms from torch.autograd import Variable from network.Transformer import Transformer from huggingface_hub import hf_hub_download from PIL import Image import logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Constants MODEL_PATH = "models" COLOUR_MODEL = "RGB" MODEL_REPO = "NDugar/horse_to_zebra_cycle_GAN" MODEL_FILE = "h2z-85epoch.pth" # Model Initalisation #shinkai_model_hfhub = hf_hub_download(repo_id=MODEL_REPO_SHINKAI, filename=MODEL_FILE_SHINKAI) #hosoda_model_hfhub = hf_hub_download(repo_id=MODEL_REPO_HOSODA, filename=MODEL_FILE_HOSODA) #miyazaki_model_hfhub = hf_hub_download(repo_id=MODEL_REPO_MIYAZAKI, filename=MODEL_FILE_MIYAZAKI) model_hfhub = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE) #shinkai_model = Transformer() #hosoda_model = Transformer() #miyazaki_model = Transformer() model = Transformer() enable_gpu = torch.cuda.is_available() map_location = torch.device("cuda") if enable_gpu else "cpu" model.load_state_dict(torch.load(model_hfhub, map_location=map_location)) shinkai_model.eval() hosoda_model.eval() miyazaki_model.eval() kon_model.eval() # Functions def get_model(): return model def adjust_image_for_model(img): logger.info(f"Image Height: {img.height}, Image Width: {img.width}") return img def inference(img, style): img = adjust_image_for_model(img) input_image = img.convert(COLOUR_MODEL) input_image = np.asarray(input_image) input_image = input_image[:, :, [2, 1, 0]] input_image = transforms.ToTensor()(input_image).unsqueeze(0) input_image = -1 + 2 * input_image if enable_gpu: logger.info(f"CUDA found. Using GPU.") input_image = Variable(input_image).cuda() else: logger.info(f"CUDA not found. Using CPU.") input_image = Variable(input_image).float() model = get_model() output_image = model(input_image) output_image = output_image[0] # BGR -> RGB output_image = output_image[[2, 1, 0], :, :] output_image = output_image.data.cpu().float() * 0.5 + 0.5 return transforms.ToPILImage()(output_image) # Gradio setup title = "Horse 2 Zebra GAN" description = "Gradio Demo for CycleGAN" gr.Interface( fn=inference, inputs=[ gr.inputs.Image( type="pil", label="Input Photo", ), ], outputs=gr.outputs.Image( type="pil", label="Output Image", ), title=title, description=description, article=article, examples=examples, allow_flagging="never", allow_screenshot=False, ).launch(enable_queue=True)