gokaygokay
commited on
Commit
•
f8cf70e
1
Parent(s):
66156c6
image generation
Browse files- app.py +22 -0
- llm_inference.py +17 -1
app.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
import gradio as gr
|
2 |
from llm_inference import LLMInferenceNode
|
3 |
import random
|
|
|
|
|
4 |
|
5 |
title = """<h1 align="center">Random Prompt Generator</h1>
|
6 |
<p><center>
|
@@ -69,6 +71,7 @@ def create_interface():
|
|
69 |
generate_button = gr.Button("GENERATE")
|
70 |
with gr.Row():
|
71 |
text_output = gr.Textbox(label="LLM Generated Text", lines=10, show_copy_button=True)
|
|
|
72 |
|
73 |
# Updated Models based on provider
|
74 |
def update_model_choices(provider):
|
@@ -157,6 +160,25 @@ def create_interface():
|
|
157 |
api_name="generate_random_prompt_with_llm"
|
158 |
)
|
159 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
return demo
|
161 |
|
162 |
if __name__ == "__main__":
|
|
|
1 |
import gradio as gr
|
2 |
from llm_inference import LLMInferenceNode
|
3 |
import random
|
4 |
+
from PIL import Image
|
5 |
+
import io
|
6 |
|
7 |
title = """<h1 align="center">Random Prompt Generator</h1>
|
8 |
<p><center>
|
|
|
71 |
generate_button = gr.Button("GENERATE")
|
72 |
with gr.Row():
|
73 |
text_output = gr.Textbox(label="LLM Generated Text", lines=10, show_copy_button=True)
|
74 |
+
image_output = gr.Image(label="Generated Image", type="pil")
|
75 |
|
76 |
# Updated Models based on provider
|
77 |
def update_model_choices(provider):
|
|
|
160 |
api_name="generate_random_prompt_with_llm"
|
161 |
)
|
162 |
|
163 |
+
# Add image generation button
|
164 |
+
generate_image_button = gr.Button("Generate Image")
|
165 |
+
|
166 |
+
# Function to generate image
|
167 |
+
def generate_image(text):
|
168 |
+
try:
|
169 |
+
image = llm_node.generate_image(text)
|
170 |
+
return image
|
171 |
+
except Exception as e:
|
172 |
+
print(f"An error occurred while generating the image: {e}")
|
173 |
+
return None
|
174 |
+
|
175 |
+
# Connect the image generation button
|
176 |
+
generate_image_button.click(
|
177 |
+
generate_image,
|
178 |
+
inputs=[text_output],
|
179 |
+
outputs=[image_output]
|
180 |
+
)
|
181 |
+
|
182 |
return demo
|
183 |
|
184 |
if __name__ == "__main__":
|
llm_inference.py
CHANGED
@@ -2,6 +2,9 @@ import os
|
|
2 |
import random # Import the random module
|
3 |
from groq import Groq
|
4 |
from openai import OpenAI
|
|
|
|
|
|
|
5 |
|
6 |
class LLMInferenceNode:
|
7 |
def __init__(self):
|
@@ -237,4 +240,17 @@ Your output is only the caption itself, no comments or extra formatting. The cap
|
|
237 |
|
238 |
except Exception as e:
|
239 |
print(f"An error occurred: {e}")
|
240 |
-
return f"Error occurred while processing the request: {str(e)}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import random # Import the random module
|
3 |
from groq import Groq
|
4 |
from openai import OpenAI
|
5 |
+
import requests
|
6 |
+
from PIL import Image
|
7 |
+
import io
|
8 |
|
9 |
class LLMInferenceNode:
|
10 |
def __init__(self):
|
|
|
240 |
|
241 |
except Exception as e:
|
242 |
print(f"An error occurred: {e}")
|
243 |
+
return f"Error occurred while processing the request: {str(e)}"
|
244 |
+
|
245 |
+
def generate_image(self, prompt):
|
246 |
+
API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-dev"
|
247 |
+
headers = {"Authorization": f"Bearer {self.huggingface_token}"}
|
248 |
+
|
249 |
+
response = requests.post(API_URL, headers=headers, json={"inputs": prompt})
|
250 |
+
|
251 |
+
if response.status_code != 200:
|
252 |
+
raise Exception(f"Error generating image: {response.text}")
|
253 |
+
|
254 |
+
image_bytes = response.content
|
255 |
+
image = Image.open(io.BytesIO(image_bytes))
|
256 |
+
return image
|