Pedro Cuenca commited on
Commit
704ee93
1 Parent(s): 1b9d7ec

Update demo to use Suraj's backend server.

Browse files

Former-commit-id: 6a2df0b8bb5ea2d2e88a376e02d0d5f4b1f033db

Files changed (2) hide show
  1. README.md +1 -1
  2. app/app_gradio_ngrok.py +105 -0
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🎨
4
  colorFrom: red
5
  colorTo: blue
6
  sdk: gradio
7
- app_file: app/app_gradio.py
8
  pinned: false
9
  ---
10
 
 
4
  colorFrom: red
5
  colorTo: blue
6
  sdk: gradio
7
+ app_file: app/app_gradio_ngrok.py
8
  pinned: false
9
  ---
10
 
app/app_gradio_ngrok.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ import requests
5
+ from PIL import Image
6
+ import numpy as np
7
+ import matplotlib.pyplot as plt
8
+ from io import BytesIO
9
+ import base64
10
+
11
+ import gradio as gr
12
+
13
+
14
+ def compose_predictions(images, caption=None):
15
+ increased_h = 0 if caption is None else 48
16
+ w, h = images[0].size[0], images[0].size[1]
17
+ img = Image.new("RGB", (len(images)*w, h + increased_h))
18
+ for i, img_ in enumerate(images):
19
+ img.paste(img_, (i*w, increased_h))
20
+
21
+ if caption is not None:
22
+ draw = ImageDraw.Draw(img)
23
+ font = ImageFont.truetype("/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40)
24
+ draw.text((20, 3), caption, (255,255,255), font=font)
25
+ return img
26
+
27
+ def top_k_predictions(prompt, num_candidates=32, k=8):
28
+ images = hallucinate(prompt, num_images=num_candidates)
29
+ images = clip_top_k(prompt, images, k=k)
30
+ return images
31
+
32
+ class ServiceError(Exception):
33
+ def __init__(self, status_code):
34
+ self.status_code = status_code
35
+
36
+ def get_images_from_ngrok(prompt):
37
+ r = requests.post(
38
+ "https://dd7123a7e01c.ngrok.io/generate",
39
+ json={"prompt": prompt}
40
+ )
41
+ if r.status_code == 200:
42
+ images = r.json()["images"]
43
+ images = [Image.open(BytesIO(base64.b64decode(img))) for img in images]
44
+ return images
45
+ else:
46
+ raise ServiceError(r.status_code)
47
+
48
+ def run_inference(prompt):
49
+ try:
50
+ images = get_images_from_ngrok(prompt)
51
+ predictions = compose_predictions(images)
52
+ output_title = f"""
53
+ <p style="font-size:22px; font-style:bold">Best predictions</p>
54
+ <p>We asked our model to generate 32 candidates for your prompt:</p>
55
+
56
+ <pre>
57
+
58
+ <b>{prompt}</b>
59
+ </pre>
60
+ <p>We then used a pre-trained <a href="https://huggingface.co/openai/clip-vit-base-patch32">CLIP model</a> to score them according to the
61
+ similarity of the text and the image representations.</p>
62
+
63
+ <p>This is the result:</p>
64
+ """
65
+
66
+ output_description = """
67
+ <p>Read more about the process <a href="https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA">in our report</a>.<p>
68
+ <p style='text-align: center'>Created with <a href="https://github.com/borisdayma/dalle-mini">DALLE·mini</a></p>
69
+ """
70
+
71
+ except ServiceError:
72
+ output_title = f"""
73
+ Sorry, there was an error retrieving the images. Please, try again later or <a href="mailto:[email protected]">contact us here</a>.
74
+ """
75
+ predictions = None
76
+ output_description = ""
77
+
78
+ return (output_title, predictions, output_description)
79
+
80
+ outputs = [
81
+ gr.outputs.HTML(label=""), # To be used as title
82
+ gr.outputs.Image(label=''),
83
+ gr.outputs.HTML(label=""), # Additional text that appears in the screenshot
84
+ ]
85
+
86
+ description = """
87
+ Welcome to our demo of DALL·E-mini. This project was created on TPU v3-8s during the 🤗 Flax / JAX Community Week.
88
+ It reproduces the essential characteristics of OpenAI's DALL·E, at a fraction of the size.
89
+
90
+ Please, write what you would like the model to generate, or select one of the examples below.
91
+ """
92
+ gr.Interface(run_inference,
93
+ inputs=[gr.inputs.Textbox(label='Prompt')], #, gr.inputs.Slider(1,64,1,8, label='Candidates to generate'), gr.inputs.Slider(1,8,1,1, label='Best predictions to show')],
94
+ outputs=outputs,
95
+ title='DALL·E mini',
96
+ description=description,
97
+ article="<p style='text-align: center'> DALLE·mini by Boris Dayma et al. | <a href='https://github.com/borisdayma/dalle-mini'>GitHub</a></p>",
98
+ layout='vertical',
99
+ theme='huggingface',
100
+ examples=[['an armchair in the shape of an avocado'], ['snowy mountains by the sea']],
101
+ allow_flagging=False,
102
+ live=False,
103
+ server_name="0.0.0.0", # Bind to all interfaces (I think)
104
+ # server_port=8999
105
+ ).launch(share=True)