import gradio as gr from PIL import Image import torch import numpy as np from os.path import exists as path_exists from git.repo.base import Repo from einops import rearrange import torchvision.transforms as transforms from torchvision.utils import make_grid if not (path_exists(f"rudalle-aspect-ratio")): Repo.clone_from("https://github.com/shonenkov-AI/rudalle-aspect-ratio", "rudalle-aspect-ratio") import sys sys.path.append('./rudalle-aspect-ratio') from rudalle_aspect_ratio import RuDalleAspectRatio, get_rudalle_model from rudalle import get_vae, get_tokenizer from rudalle.pipelines import show #model_path_e = hf_hub_download(repo_id="multimodalart/compvis-latent-diffusion-text2img-large", filename="txt2img-f8-large.ckpt") device = 'cuda' dalle_surreal = get_rudalle_model('Surrealist_XL', fp16=True, device=device) dalle_real = get_rudalle_model('Malevich',fp16=True,device=device) dalle_emoji = get_rudalle_model('Emojich',fp16=True,device=device) vae, tokenizer = get_vae().to(device), get_tokenizer() def np_gallery(array, ncols=3): nindex, height, width, intensity = array.shape nrows = nindex//ncols assert nindex == nrows*ncols # want result.shape = (height*nrows, width*ncols, intensity) result = (array.reshape(nrows, ncols, height, width, intensity) .swapaxes(1,2) .reshape(height*nrows, width*ncols, intensity)) return result def image_to_np(image): return np.asarray(image) def run(prompt, aspect_ratio, model): if(model=='Surrealism'): dalle = dalle_surreal elif(model=='Realism'): dalle = dalle_real elif(model=='Emoji'): dalle = dalle_emoji if(aspect_ratio == 'Square'): aspect_ratio_value = 1 top_k = 512 elif(aspect_ratio == 'Horizontal'): aspect_ratio_value = 32/9 top_k = 1024 elif(aspect_ratio == 'Vertical'): aspect_ratio_value = 9/32 top_k = 512 rudalle_ar = RuDalleAspectRatio( dalle=dalle, vae=vae, tokenizer=tokenizer, aspect_ratio=aspect_ratio_value, bs=1, device=device ) _, result_pil_images = rudalle_ar.generate_images(prompt, top_k, 0.975, 1) #np_images = map(image_to_np,result_pil_images) #np_grid = np_gallery(np.array(list(np_images)),2) #result_grid = Image.fromarray(np_grid) return(result_pil_images[0]) image = gr.outputs.Image(type="pil", label="Your result") iface = gr.Interface(fn=run, inputs=[ gr.inputs.Textbox(label="Prompt (if not in Russian, it will be automatically translated to Russian)",default="chalk pastel drawing of a dog wearing a funny hat"), #gr.inputs.Slider(label="Steps - more steps can increase quality but will take longer to generate",default=45,maximum=50,minimum=1,step=1), gr.inputs.Radio(label="Aspect Ratio", choices=["Square", "Horizontal", "Vertical"],default="Horizontal"), gr.inputs.Dropdown(label="Model", choices=["Surrealism","Realism", "Emoji"], default="Surrealism") #gr.inputs.Radio(label="Height", choices=[32,64,128,256,512],default=256), #gr.inputs.Slider(label="Images - How many images you wish to generate", default=2, step=1, minimum=1, maximum=4), #gr.inputs.Slider(label="Diversity scale - How different from one another you wish the images to be",default=5.0, minimum=1.0, maximum=15.0), #gr.inputs.Slider(label="ETA - between 0 and 1. Lower values can provide better quality, higher values can be more diverse",default=0.0,minimum=0.0, maximum=1.0,step=0.1), ], outputs=image, #css=css, title="Generate images from text with ruDALLE", description="