ehristoforu commited on
Commit
2786a41
1 Parent(s): 9e319ca

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +130 -0
app.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import gradio as gr
3
+ import requests
4
+ import time
5
+ import json
6
+ import base64
7
+ import os
8
+ from PIL import Image
9
+ from io import BytesIO
10
+
11
+ class Prodia:
12
+ def __init__(self, api_key, base=None):
13
+ self.base = base or "https://api.prodia.com/v1"
14
+ self.headers = {
15
+ "X-Prodia-Key": api_key
16
+ }
17
+
18
+ def generate(self, params):
19
+ response = self._post(f"{self.base}/sdxl/generate", params)
20
+ return response.json()
21
+
22
+ def get_job(self, job_id):
23
+ response = self._get(f"{self.base}/job/{job_id}")
24
+ return response.json()
25
+
26
+ def wait(self, job):
27
+ job_result = job
28
+
29
+ while job_result['status'] not in ['succeeded', 'failed']:
30
+ time.sleep(0.25)
31
+ job_result = self.get_job(job['job'])
32
+
33
+ return job_result
34
+
35
+ def list_models(self):
36
+ response = self._get(f"{self.base}/sdxl/models")
37
+ return response.json()
38
+
39
+ def list_samplers(self):
40
+ response = self._get(f"{self.base}/sdxl/samplers")
41
+ return response.json()
42
+
43
+ def _post(self, url, params):
44
+ headers = {
45
+ **self.headers,
46
+ "Content-Type": "application/json"
47
+ }
48
+ response = requests.post(url, headers=headers, data=json.dumps(params))
49
+
50
+ if response.status_code != 200:
51
+ raise Exception(f"Bad Prodia Response: {response.status_code}")
52
+
53
+ return response
54
+
55
+ def _get(self, url):
56
+ response = requests.get(url, headers=self.headers)
57
+
58
+ if response.status_code != 200:
59
+ raise Exception(f"Bad Prodia Response: {response.status_code}")
60
+
61
+ return response
62
+
63
+
64
+ def image_to_base64(image_path):
65
+ # Open the image with PIL
66
+ with Image.open(image_path) as image:
67
+ # Convert the image to bytes
68
+ buffered = BytesIO()
69
+ image.save(buffered, format="PNG") # You can change format to PNG if needed
70
+
71
+ # Encode the bytes to base64
72
+ img_str = base64.b64encode(buffered.getvalue())
73
+
74
+ return img_str.decode('utf-8') # Convert bytes to string
75
+
76
+
77
+
78
+ prodia_client = Prodia(api_key=os.getenv("PRODIA_API_KEY"))
79
+
80
+ def flip_text(prompt, negative_prompt, model, steps, sampler, cfg_scale, width, height, seed):
81
+ result = prodia_client.generate({
82
+ "prompt": prompt,
83
+ "negative_prompt": negative_prompt,
84
+ "model": model,
85
+ "steps": steps,
86
+ "sampler": sampler,
87
+ "cfg_scale": cfg_scale,
88
+ "width": width,
89
+ "height": height,
90
+ "seed": seed
91
+ })
92
+
93
+ job = prodia_client.wait(result)
94
+
95
+ return job["imageUrl"]
96
+
97
+ css = """
98
+ #generate {
99
+ height: 100%;
100
+ }
101
+ """
102
+
103
+ with gr.Blocks(css=css, model="sd_xl_base_1.0.safetensors [be9edd61]", sampler="DPM++ 2M Karras", batch_size=1, batch_count=1) as demo:
104
+ with gr.Row():
105
+ with gr.Column(scale=1):
106
+ gr.HTML(value=""""<h1><center>Fast SDXL on <a href="https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0" target="_blank">stabilityai/stable-diffusion-xl-base-1.0</a>""")
107
+ with gr.Column(scale=6, min_width=600):
108
+ prompt = gr.Textbox(label="Prompt", placeholder="a cute cat, 8k", show_label=true, lines=1)
109
+ text_button = gr.Button("Generate", variant='primary', elem_id="generate")
110
+
111
+ with gr.Row():
112
+ with gr.Accordion("Additionals inputs"):
113
+ with gr.Column(scale=1):
114
+ negative_prompt = gr.Textbox(label="Negative Prompt", value="text, blurry", placeholder="What you don't want to see in the image", show_label=True, lines=1)
115
+ with gr.Column(scale=1):
116
+ steps = gr.Slider(label="Sampling Steps", minimum=1, maximum=30, value=25, step=1)
117
+ cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, value=7, step=1)
118
+ seed = gr.Number(label="Seed", value=-1)
119
+ with gr.Column(scale=1):
120
+ width = gr.Slider(label="↔️ Width", minimum=1024, maximum=1024, value=1024, step=8)
121
+ height = gr.Slider(label="↕️ Height", minimum=1024, maximum=1024, value=1024, step=8)
122
+
123
+
124
+
125
+ with gr.Column(scale=1):
126
+ image_output = gr.Image()
127
+
128
+ text_button.click(flip_text, inputs=[prompt, negative_prompt, steps, cfg_scale, width, height, seed], outputs=image_output)
129
+
130
+ demo.queue(concurrency_count=16, max_size=20, api_open=False).launch(max_threads=64)