File size: 10,367 Bytes
3b8e6ee
 
 
 
 
 
 
 
de33098
3b8e6ee
de33098
3b8e6ee
 
 
 
 
 
 
 
 
 
 
 
8f63650
 
 
3b8e6ee
 
 
 
 
fd8bc49
83a6d20
 
3b8e6ee
 
 
 
 
 
 
 
 
 
7c0ff63
3b8e6ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de33098
3b8e6ee
 
 
 
 
 
 
 
 
 
 
 
de33098
3b8e6ee
b5a9780
3b8e6ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d42f4c5
 
ca97dd1
d42f4c5
ca97dd1
 
 
4b7632d
 
 
30c1c6f
732cd32
 
 
a648d42
c043471
3b8e6ee
 
44d732d
daaa781
3b8e6ee
 
d39ac3a
ef949f6
732cd32
4b7632d
be45df9
4b7632d
 
ffd211a
3409d90
766c020
be45df9
04c5803
4b7632d
774f569
ca97dd1
 
4a5727c
 
 
ca97dd1
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
## **** below codelines are borrowed from multimodalart space
from pydoc import describe
import gradio as gr
import torch
from omegaconf import OmegaConf
import sys 
sys.path.append(".")
sys.path.append('./taming-transformers')
#sys.path.append('./latent-diffusion')
from taming.models import vqgan 
from util import instantiate_from_config
from huggingface_hub import hf_hub_download

model_path_e = hf_hub_download(repo_id="multimodalart/compvis-latent-diffusion-text2img-large", filename="txt2img-f8-large.ckpt")

#@title Import stuff
import argparse, os, sys, glob
import numpy as np
from PIL import Image
from einops import rearrange
from torchvision.utils import make_grid
import transformers
import gc
from util import instantiate_from_config
from ddim import DDIMSampler
from plms import PLMSSampler
from open_clip import tokenizer
import open_clip

def load_model_from_config(config, ckpt, verbose=False):
    print(f"Loading model from {ckpt}")
    #pl_sd = torch.load(ckpt, map_location="cuda")
    #please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
    pl_sd = torch.load(ckpt, map_location=torch.device('cpu'))
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)

    #model = model.half() #.cuda()
    model.eval()
    return model

def load_safety_model(clip_model):
    """load the safety model"""
    import autokeras as ak  # pylint: disable=import-outside-toplevel
    from tensorflow.keras.models import load_model  # pylint: disable=import-outside-toplevel
    from os.path import expanduser  # pylint: disable=import-outside-toplevel

    home = expanduser("~")

    cache_folder = home + "/.cache/clip_retrieval/" + clip_model.replace("/", "_")
    if clip_model == "ViT-L/14":
        model_dir = cache_folder + "/clip_autokeras_binary_nsfw"
        dim = 768
    elif clip_model == "ViT-B/32":
        model_dir = cache_folder + "/clip_autokeras_nsfw_b32"
        dim = 512
    else:
        raise ValueError("Unknown clip model")
    if not os.path.exists(model_dir):
        os.makedirs(cache_folder, exist_ok=True)

        from urllib.request import urlretrieve  # pylint: disable=import-outside-toplevel

        path_to_zip_file = cache_folder + "/clip_autokeras_binary_nsfw.zip"
        if clip_model == "ViT-L/14":
            url_model = "https://raw.githubusercontent.com/LAION-AI/CLIP-based-NSFW-Detector/main/clip_autokeras_binary_nsfw.zip"
        elif clip_model == "ViT-B/32":
            url_model = (
                "https://raw.githubusercontent.com/LAION-AI/CLIP-based-NSFW-Detector/main/clip_autokeras_nsfw_b32.zip"
            )
        else:
            raise ValueError("Unknown model {}".format(clip_model))
        urlretrieve(url_model, path_to_zip_file)
        import zipfile  # pylint: disable=import-outside-toplevel

        with zipfile.ZipFile(path_to_zip_file, "r") as zip_ref:
            zip_ref.extractall(cache_folder)

    loaded_model = load_model(model_dir, custom_objects=ak.CUSTOM_OBJECTS)
    loaded_model.predict(np.random.rand(10 ** 3, dim).astype("float32"), batch_size=10 ** 3)

    return loaded_model
    
def is_unsafe(safety_model, embeddings, threshold=0.5):
    """find unsafe embeddings"""
    nsfw_values = safety_model.predict(embeddings, batch_size=embeddings.shape[0])
    x = np.array([e[0] for e in nsfw_values])
    return True if x > threshold else False

config = OmegaConf.load("./txt2img-1p4B-eval.yaml")
model = load_model_from_config(config,model_path_e)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)

#NSFW CLIP Filter
safety_model = load_safety_model("ViT-B/32")
clip_model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='openai')


def run(prompt, steps, width, height, images, scale):
    opt = argparse.Namespace(
        prompt = prompt, 
        outdir='./outputs',
        ddim_steps = int(steps),
        ddim_eta = 1,
        n_iter = 1,
        W=int(width),
        H=int(height),
        n_samples=int(images),
        scale=scale,
        plms=True
    )

    if opt.plms:
        opt.ddim_eta = 0
        sampler = PLMSSampler(model)
    else:
        sampler = DDIMSampler(model)
    
    os.makedirs(opt.outdir, exist_ok=True)
    outpath = opt.outdir

    prompt = opt.prompt


    sample_path = os.path.join(outpath, "samples")
    os.makedirs(sample_path, exist_ok=True)
    base_count = len(os.listdir(sample_path))

    all_samples=list()
    all_samples_images=list()
    with torch.no_grad():
        with torch.cuda.amp.autocast():
            with model.ema_scope():
                uc = None
                if opt.scale > 0:
                    uc = model.get_learned_conditioning(opt.n_samples * [""])
                for n in range(opt.n_iter):
                    c = model.get_learned_conditioning(opt.n_samples * [prompt])
                    shape = [4, opt.H//8, opt.W//8]
                    samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
                                                    conditioning=c,
                                                    batch_size=opt.n_samples,
                                                    shape=shape,
                                                    verbose=False,
                                                    unconditional_guidance_scale=opt.scale,
                                                    unconditional_conditioning=uc,
                                                    eta=opt.ddim_eta)

                    x_samples_ddim = model.decode_first_stage(samples_ddim)
                    x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, min=0.0, max=1.0)

                    for x_sample in x_samples_ddim:
                        x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
                        image_vector = Image.fromarray(x_sample.astype(np.uint8))
                        image_preprocess = preprocess(image_vector).unsqueeze(0)
                        with torch.no_grad():
                          image_features = clip_model.encode_image(image_preprocess)
                        image_features /= image_features.norm(dim=-1, keepdim=True)
                        query = image_features.cpu().detach().numpy().astype("float32")
                        unsafe = is_unsafe(safety_model,query,0.5)
                        if(not unsafe):
                            all_samples_images.append(image_vector)
                        else:
                            return(None,None,"Sorry, potential NSFW content was detected on your outputs by our NSFW detection model. Try again with different prompts. If you feel your prompt was not supposed to give NSFW outputs, this may be due to a bias in the model. Read more about biases in the Biases Acknowledgment section below.")
                        #Image.fromarray(x_sample.astype(np.uint8)).save(os.path.join(sample_path, f"{base_count:04}.png"))
                        base_count += 1
                    all_samples.append(x_samples_ddim)
                    
    
    # additionally, save as grid
    grid = torch.stack(all_samples, 0)
    grid = rearrange(grid, 'n b c h w -> (n b) c h w')
    grid = make_grid(grid, nrow=2)
    # to image
    grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
    
    Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'{prompt.replace(" ", "-")}.png'))
    #return(Image.fromarray(grid.astype(np.uint8)),all_samples_images,None)    
    return Image.fromarray(grid.astype(np.uint8))
        
## **** above codelines are borrowed from multimodalart space

import gradio as gr

fastspeech = gr.Interface.load("huggingface/facebook/fastspeech2-en-ljspeech")

def text2speech(text):
    return fastspeech(text)
    
def engine(text_input):
    ner = gr.Interface.load("huggingface/flair/ner-english-ontonotes-large")
    entities = ner(text_input)
    
    entities = [tupl for tupl in entities if None not in tupl]
    entities_num = len(entities)
    
    img = run(entities[0],'50','256','256','1',10)
    #img_intfc = gr.Interface.load("spaces/multimodalart/latentdiffusion")
    #img_intfc = gr.Interface.load("spaces/multimodalart/latentdiffusion", inputs=[gr.inputs.Textbox(lines=1, label="Input Text"), gr.inputs.Textbox(lines=1, label="Input Text"), gr.inputs.Textbox(lines=1, label="Input Text"), gr.inputs.Textbox(lines=1, label="Input Text"), gr.inputs.Textbox(lines=1, label="Input Text"), gr.inputs.Textbox(lines=1, label="Input Text")],
    #outputs=[gr.outputs.Image(type="pil", label="output image"),gr.outputs.Carousel(label="Individual images",components=["image"]),gr.outputs.Textbox(label="Error")], )
    #title="Convert text to image")
    #img = img_intfc[0]
    #img = img_intfc('George','50','256','256','1','10')
    #img = img[0]
    #inputs=['George',50,256,256,1,10]
    #run(prompt, steps, width, height, images, scale)
    
    speech = text2speech(text_input)
    return entities, speech, img
    
app = gr.Interface(fn=engine, 
                   inputs=gr.inputs.Textbox(lines=5, label="Input Text"),
                   #live=True,
                   description="Takes a text as input and reads it out to you.", 
                   outputs=[gr.outputs.Textbox(type="auto", label="Text"), gr.outputs.Audio(type="file", label="Speech Answer"), #],
                   gr.outputs.Image(type="file", label="output image")],
                   examples=["On April 17th Sunday George celebrated Easter. He is staying at Empire State building with his parents. He is a citizen of Canada and speaks English and French fluently. His role model is former president Obama. He got 1000 dollar from his mother to visit Disney World and to buy new iPhone mobile.  George likes watching Game of Thrones."]
                   ).launch(debug=True) 
                   
 
 #get_audio = gr.Button("generate audio")
 #get_audio.click(text2speech, inputs=text, outputs=speech)
 
#def greet(name):
#    return "Hello " + name + "!!"

#iface = gr.Interface(fn=greet, inputs="text", outputs="text")
#iface.launch()