import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer import torch import io from PIL import Image import requests import random import dom import os NUM_IMAGES = 2 # Configuración del dispositivo device = "cpu" if torch.cuda.is_available(): device = "cuda" elif torch.backends.mps.is_available(): device = "mps" # Configuración de modelos API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-dev" headers = {"Authorization": f"Bearer {os.getenv('api_token')}"} model_id_image_description = "vikhyatk/moondream2" revision = "2024-08-26" torch_dtype = torch.float32 if torch.cuda.is_available(): torch_dtype = torch.bfloat16 # Optimización en GPU # Carga de modelos persistente print("Cargando modelo de descripción de imágenes...") model_description = AutoModelForCausalLM.from_pretrained(model_id_image_description, trust_remote_code=True, revision=revision) tokenizer_description = AutoTokenizer.from_pretrained(model_id_image_description, revision=revision) def generate_description(image_path): image_test = Image.open(image_path) enc_image = model_description.encode_image(image_test) description = model_description.answer_question(enc_image, "Describe this image to create an avatar", tokenizer_description) return description def query(payload): response = requests.post(API_URL, headers=headers, json=payload) return response.content def generate_image_by_description(description, avatar_style=None): images = [] for _ in range(NUM_IMAGES): prompt = f"Create a pigeon profile avatar. Use the following description: {description}." if avatar_style: prompt += f" Use {avatar_style} style." image_bytes = query({"inputs": prompt, "parameters": {"seed": random.randint(0, 1000)}}) image = Image.open(io.BytesIO(image_bytes)) images.append(image) print(images) return images def process_and_generate(image, avatar_style): description = generate_description(image) return generate_image_by_description(description, avatar_style) with gr.Blocks(js=dom.generate_title) as demo: with gr.Row(): gr.Markdown(dom.generate_markdown) with gr.Row(): with gr.Column(scale=2, min_width=300): selected_image = gr.Image(type="filepath", label="Upload an Image of the Pigeon", height=300) example_image = gr.Examples(["./examples/pigeon.webp"], label="Example Images", inputs=[selected_image]) avatar_style = gr.Radio( ["Realistic", "Pixel Art", "Imaginative", "Cartoon"], label="(optional) Select the avatar style:" ) generate_button = gr.Button("Generate Avatar", variant="primary") with gr.Column(scale=2, min_width=300): generated_image = gr.Gallery(type="pil", label="Generated Avatar", height=300) generate_button.click(process_and_generate, inputs=[selected_image, avatar_style], outputs=generated_image) demo.launch()