import tkinter as tk from tkinter import ttk, scrolledtext, filedialog from PIL import Image, ImageTk import asyncio import random from threading import RLock import requests import io import base64 from all_models import models class ImageGeneratorApp: def __init__(self, root): self.root = root self.root.title("Image Generator") self.lock = RLock() self.models = models # Use the imported models self.num_models = 6 # Number of models to display self.create_widgets() def create_widgets(self): self.notebook = ttk.Notebook(self.root) self.notebook.pack(fill=tk.BOTH, expand=True) self.create_local_tab() self.create_api_tab() def create_local_tab(self): local_frame = ttk.Frame(self.notebook) self.notebook.add(local_frame, text="Local Generation") # Prompt input prompt_label = ttk.Label(local_frame, text="Your prompt:") prompt_label.pack(pady=5) self.local_prompt_input = scrolledtext.ScrolledText(local_frame, height=4) self.local_prompt_input.pack(pady=5, padx=10, fill=tk.X) # Generate button generate_button = ttk.Button(local_frame, text="Generate Images", command=self.generate_images) generate_button.pack(pady=10) # Image display area image_frame = ttk.Frame(local_frame) image_frame.pack(pady=10, padx=10) self.image_labels = [] for i in range(self.num_models): label = ttk.Label(image_frame) label.grid(row=i//3, column=i%3, padx=5, pady=5) self.image_labels.append(label) # Model selection model_frame = ttk.LabelFrame(local_frame, text="Model Selection") model_frame.pack(pady=10, padx=10, fill=tk.X) self.model_vars = [] for model in self.models[:self.num_models]: var = tk.BooleanVar(value=True) cb = ttk.Checkbutton(model_frame, text=model, variable=var) cb.pack(anchor=tk.W) self.model_vars.append(var) def create_api_tab(self): api_frame = ttk.Frame(self.notebook) self.notebook.add(api_frame, text="API Generation") # Model selection model_label = ttk.Label(api_frame, text="Model:") model_label.pack(pady=5) self.api_model_var = tk.StringVar(value=self.models[0]) model_combobox = ttk.Combobox(api_frame, textvariable=self.api_model_var, values=self.models) model_combobox.pack(pady=5, padx=10, fill=tk.X) # Prompt input prompt_label = ttk.Label(api_frame, text="Your prompt:") prompt_label.pack(pady=5) self.api_prompt_input = scrolledtext.ScrolledText(api_frame, height=4) self.api_prompt_input.pack(pady=5, padx=10, fill=tk.X) # Generate button generate_button = ttk.Button(api_frame, text="Generate Image", command=self.generate_api_image) generate_button.pack(pady=10) # Image display self.api_image_label = ttk.Label(api_frame) self.api_image_label.pack(pady=10) # Add to gallery button add_gallery_button = ttk.Button(api_frame, text="Add to Gallery", command=self.add_to_gallery) add_gallery_button.pack(pady=10) # Gallery display self.gallery_frame = ttk.Frame(api_frame) self.gallery_frame.pack(pady=10, padx=10, fill=tk.BOTH, expand=True) async def generate_image(self, model, prompt): # This is a placeholder for actual model inference # In a real application, you would call the actual model here await asyncio.sleep(random.uniform(1, 3)) # Random delay to simulate processing time return Image.new('RGB', (256, 256), color=random.choice(['red', 'green', 'blue'])) async def generate_all_images(self, prompt): tasks = [] for model, var in zip(self.models[:self.num_models], self.model_vars): if var.get(): task = asyncio.create_task(self.generate_image(model, prompt)) tasks.append(task) results = await asyncio.gather(*tasks) for label, image in zip(self.image_labels, results): if image: photo = ImageTk.PhotoImage(image) label.configure(image=photo) label.image = photo else: label.configure(image='') def generate_images(self): prompt = self.local_prompt_input.get("1.0", tk.END).strip() if not prompt: return for label in self.image_labels: label.configure(image='') asyncio.run(self.generate_all_images(prompt)) def generate_api_image(self): model_str = self.api_model_var.get() prompt = self.api_prompt_input.get("1.0", tk.END).strip() if not prompt: return # Make API call url = "https://k00b404-huggingfacediffusion-custom.hf.space/run/gen_fn_4" payload = { "data": [ model_str, prompt ] } response = requests.post(url, json=payload) if response.status_code == 200: result = response.json() image_data = base64.b64decode(result['data'][0].split(',')[1]) image = Image.open(io.BytesIO(image_data)) photo = ImageTk.PhotoImage(image) self.api_image_label.configure(image=photo) self.api_image_label.image = photo self.current_api_image = image else: print(f"Error: {response.status_code}") def add_to_gallery(self): if not hasattr(self, 'current_api_image'): return # Save the image to a temporary file temp_file = 'temp_image.png' self.current_api_image.save(temp_file) # Make API call to add to gallery url = "https://k00b404-huggingfacediffusion-custom.hf.space/run/add_gallery_4" files = { 'data': ('image.png', open(temp_file, 'rb'), 'image/png') } data = { 'data': ['', self.api_model_var.get(), '[]'] } response = requests.post(url, files=files, data=data) if response.status_code == 200: result = response.json() self.update_gallery(result['data'][0]) else: print(f"Error: {response.status_code}") def update_gallery(self, gallery_data): for widget in self.gallery_frame.winfo_children(): widget.destroy() for i, item in enumerate(gallery_data): image_data = base64.b64decode(item['image'].split(',')[1]) image = Image.open(io.BytesIO(image_data)) photo = ImageTk.PhotoImage(image) label = ttk.Label(self.gallery_frame, image=photo) label.image = photo label.grid(row=i//3, column=i%3, padx=5, pady=5) if __name__ == "__main__": root = tk.Tk() app = ImageGeneratorApp(root) root.mainloop()