Spaces:
Runtime error
Runtime error
import tkinter as tk | |
from tkinter import ttk, scrolledtext, filedialog | |
from PIL import Image, ImageTk | |
import asyncio | |
import random | |
from threading import RLock | |
import requests | |
import io | |
import base64 | |
from all_models import models | |
class ImageGeneratorApp: | |
def __init__(self, root): | |
self.root = root | |
self.root.title("Image Generator") | |
self.lock = RLock() | |
self.models = models # Use the imported models | |
self.num_models = 6 # Number of models to display | |
self.create_widgets() | |
def create_widgets(self): | |
self.notebook = ttk.Notebook(self.root) | |
self.notebook.pack(fill=tk.BOTH, expand=True) | |
self.create_local_tab() | |
self.create_api_tab() | |
def create_local_tab(self): | |
local_frame = ttk.Frame(self.notebook) | |
self.notebook.add(local_frame, text="Local Generation") | |
# Prompt input | |
prompt_label = ttk.Label(local_frame, text="Your prompt:") | |
prompt_label.pack(pady=5) | |
self.local_prompt_input = scrolledtext.ScrolledText(local_frame, height=4) | |
self.local_prompt_input.pack(pady=5, padx=10, fill=tk.X) | |
# Generate button | |
generate_button = ttk.Button(local_frame, text="Generate Images", command=self.generate_images) | |
generate_button.pack(pady=10) | |
# Image display area | |
image_frame = ttk.Frame(local_frame) | |
image_frame.pack(pady=10, padx=10) | |
self.image_labels = [] | |
for i in range(self.num_models): | |
label = ttk.Label(image_frame) | |
label.grid(row=i//3, column=i%3, padx=5, pady=5) | |
self.image_labels.append(label) | |
# Model selection | |
model_frame = ttk.LabelFrame(local_frame, text="Model Selection") | |
model_frame.pack(pady=10, padx=10, fill=tk.X) | |
self.model_vars = [] | |
for model in self.models[:self.num_models]: | |
var = tk.BooleanVar(value=True) | |
cb = ttk.Checkbutton(model_frame, text=model, variable=var) | |
cb.pack(anchor=tk.W) | |
self.model_vars.append(var) | |
def create_api_tab(self): | |
api_frame = ttk.Frame(self.notebook) | |
self.notebook.add(api_frame, text="API Generation") | |
# Model selection | |
model_label = ttk.Label(api_frame, text="Model:") | |
model_label.pack(pady=5) | |
self.api_model_var = tk.StringVar(value=self.models[0]) | |
model_combobox = ttk.Combobox(api_frame, textvariable=self.api_model_var, values=self.models) | |
model_combobox.pack(pady=5, padx=10, fill=tk.X) | |
# Prompt input | |
prompt_label = ttk.Label(api_frame, text="Your prompt:") | |
prompt_label.pack(pady=5) | |
self.api_prompt_input = scrolledtext.ScrolledText(api_frame, height=4) | |
self.api_prompt_input.pack(pady=5, padx=10, fill=tk.X) | |
# Generate button | |
generate_button = ttk.Button(api_frame, text="Generate Image", command=self.generate_api_image) | |
generate_button.pack(pady=10) | |
# Image display | |
self.api_image_label = ttk.Label(api_frame) | |
self.api_image_label.pack(pady=10) | |
# Add to gallery button | |
add_gallery_button = ttk.Button(api_frame, text="Add to Gallery", command=self.add_to_gallery) | |
add_gallery_button.pack(pady=10) | |
# Gallery display | |
self.gallery_frame = ttk.Frame(api_frame) | |
self.gallery_frame.pack(pady=10, padx=10, fill=tk.BOTH, expand=True) | |
async def generate_image(self, model, prompt): | |
# This is a placeholder for actual model inference | |
# In a real application, you would call the actual model here | |
await asyncio.sleep(random.uniform(1, 3)) # Random delay to simulate processing time | |
return Image.new('RGB', (256, 256), color=random.choice(['red', 'green', 'blue'])) | |
async def generate_all_images(self, prompt): | |
tasks = [] | |
for model, var in zip(self.models[:self.num_models], self.model_vars): | |
if var.get(): | |
task = asyncio.create_task(self.generate_image(model, prompt)) | |
tasks.append(task) | |
results = await asyncio.gather(*tasks) | |
for label, image in zip(self.image_labels, results): | |
if image: | |
photo = ImageTk.PhotoImage(image) | |
label.configure(image=photo) | |
label.image = photo | |
else: | |
label.configure(image='') | |
def generate_images(self): | |
prompt = self.local_prompt_input.get("1.0", tk.END).strip() | |
if not prompt: | |
return | |
for label in self.image_labels: | |
label.configure(image='') | |
asyncio.run(self.generate_all_images(prompt)) | |
def generate_api_image(self): | |
model_str = self.api_model_var.get() | |
prompt = self.api_prompt_input.get("1.0", tk.END).strip() | |
if not prompt: | |
return | |
# Make API call | |
url = "https://k00b404-huggingfacediffusion-custom.hf.space/run/gen_fn_4" | |
payload = { | |
"data": [ | |
model_str, | |
prompt | |
] | |
} | |
response = requests.post(url, json=payload) | |
if response.status_code == 200: | |
result = response.json() | |
image_data = base64.b64decode(result['data'][0].split(',')[1]) | |
image = Image.open(io.BytesIO(image_data)) | |
photo = ImageTk.PhotoImage(image) | |
self.api_image_label.configure(image=photo) | |
self.api_image_label.image = photo | |
self.current_api_image = image | |
else: | |
print(f"Error: {response.status_code}") | |
def add_to_gallery(self): | |
if not hasattr(self, 'current_api_image'): | |
return | |
# Save the image to a temporary file | |
temp_file = 'temp_image.png' | |
self.current_api_image.save(temp_file) | |
# Make API call to add to gallery | |
url = "https://k00b404-huggingfacediffusion-custom.hf.space/run/add_gallery_4" | |
files = { | |
'data': ('image.png', open(temp_file, 'rb'), 'image/png') | |
} | |
data = { | |
'data': ['', self.api_model_var.get(), '[]'] | |
} | |
response = requests.post(url, files=files, data=data) | |
if response.status_code == 200: | |
result = response.json() | |
self.update_gallery(result['data'][0]) | |
else: | |
print(f"Error: {response.status_code}") | |
def update_gallery(self, gallery_data): | |
for widget in self.gallery_frame.winfo_children(): | |
widget.destroy() | |
for i, item in enumerate(gallery_data): | |
image_data = base64.b64decode(item['image'].split(',')[1]) | |
image = Image.open(io.BytesIO(image_data)) | |
photo = ImageTk.PhotoImage(image) | |
label = ttk.Label(self.gallery_frame, image=photo) | |
label.image = photo | |
label.grid(row=i//3, column=i%3, padx=5, pady=5) | |
if __name__ == "__main__": | |
root = tk.Tk() | |
app = ImageGeneratorApp(root) | |
root.mainloop() |