File size: 2,602 Bytes
1e5aadc
 
640f9c9
 
 
 
 
 
 
 
b5b3814
640f9c9
 
ace1938
 
640f9c9
 
301a4a3
 
 
 
 
 
 
ace1938
301a4a3
 
 
 
640f9c9
301a4a3
 
640f9c9
 
 
 
af5ee77
301a4a3
af5ee77
640f9c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ace1938
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
os.environ["USE_NATIVE"] = "1"
import math
import torch
import torchvision
import gradio as gr
from PIL import Image
import torchvision
from test_ddgan import load_model, sample
from model_configs import get_model_config
from subprocess import call

def download(filename):
    return "models/" + filename


device = 'cuda' if torch.cuda.is_available() else 'cpu'
cache = {}

def load(name):
    if name in cache:
        return cache[name]
    else:
        model_config, model_path = models[name]
        print(model_config, model_path)
        model = load_model(model_config, model_path, device=device)
        cache[name] = model
        return model

models = {
    "diffusion_db_128ch_1timesteps_openclip_vith14": (get_model_config('ddgan_ddb_v2'), download('diffusion_db_128ch_1timesteps_openclip_vith14.th')),
    "diffusion_db_192ch_2timesteps_openclip_vith14": (get_model_config('ddgan_ddb_v3'), download('diffusion_db_192ch_2timesteps_openclip_vith14.th')),
}
default = "diffusion_db_128ch_1timesteps_openclip_vith14"

def gen(md, model_name, md2, text, seed, nb_samples, width, height):
    print("load ", model_name)
    model = load(model_name)
    print(model)
    torch.manual_seed(int(seed))
    nb_samples = int(nb_samples)
    height = int(height)
    width =  int(width)
    with torch.no_grad():
        cond = model.text_encoder([text]*nb_samples)
        if text == "":
            cond[0].normal_()
            cond[1].normal_()
            cond[0][1:] = cond[0][0:1]
            cond[1][1:] = cond[1][0:1]
            
        x_init = torch.randn(nb_samples, 3, height, width).to(device)
        fake_sample = sample(model, x_init=x_init, cond=cond)
        fake_sample = (fake_sample + 1) / 2
    grid = torchvision.utils.make_grid(fake_sample, nrow=4)
    grid = grid.permute(1, 2, 0).cpu().numpy()
    grid = (grid*255).astype("uint8")
    return Image.fromarray(grid)

text = """
DDGAN
"""
iface = gr.Interface(
    fn=gen,
    inputs=[
        gr.Markdown(text),
        # text caption
        gr.Dropdown(list(models.keys()), value=default), 
        gr.Markdown("If text caption is empty, random CLIP embeddings will be used as input"),
        gr.Textbox(
            lines=1, 
            placeholder="Enter text caption here, or leave empty", 
            value="Painting of a hamster king  with a crown and a cape  in a magical forest."
        ),
        gr.Number(value=0), # seed
        gr.Number(value=4), # nb_samples
        gr.Number(value=256), # width
        gr.Number(value=256),# height
    ],
    outputs="image"
)
iface.launch(debug=True)