Spaces:
Running
Running
File size: 3,084 Bytes
4a0cc75 0ce8592 4a0cc75 0ce8592 8e4035c 0ce8592 8e4035c d5e510d 8e4035c d5e510d 8e4035c d5e510d 0ce8592 8e4035c 4a0cc75 d5e510d 4a0cc75 d5e510d 4a0cc75 d5e510d 8e4035c 0ce8592 d5e510d 0ce8592 4a0cc75 0ce8592 88a3ed8 8e4035c 0ce8592 df2c5c8 ff6b31d d5e510d 0ce8592 8e4035c d5e510d ff6b31d 0ce8592 d5e510d 4a0cc75 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import io
from PIL import Image
import requests
import random
import dom
import os
NUM_IMAGES = 2
# Configuración del dispositivo
device = "cpu"
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
# Configuración de modelos
API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-dev"
headers = {"Authorization": f"Bearer {os.getenv('api_token')}"}
model_id_image_description = "vikhyatk/moondream2"
revision = "2024-08-26"
torch_dtype = torch.float32
if torch.cuda.is_available():
torch_dtype = torch.bfloat16 # Optimización en GPU
# Carga de modelos persistente
print("Cargando modelo de descripción de imágenes...")
model_description = AutoModelForCausalLM.from_pretrained(model_id_image_description, trust_remote_code=True, revision=revision)
tokenizer_description = AutoTokenizer.from_pretrained(model_id_image_description, revision=revision)
def generate_description(image_path):
image_test = Image.open(image_path)
enc_image = model_description.encode_image(image_test)
description = model_description.answer_question(enc_image, "Describe this image to create an avatar", tokenizer_description)
return description
def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
return response.content
def generate_image_by_description(description, avatar_style=None):
images = []
for _ in range(NUM_IMAGES):
prompt = f"Create a pigeon profile avatar. Use the following description: {description}."
if avatar_style:
prompt += f" Use {avatar_style} style."
image_bytes = query({"inputs": prompt, "parameters": {"seed": random.randint(0, 1000)}})
image = Image.open(io.BytesIO(image_bytes))
images.append(image)
print(images)
return images
def process_and_generate(image, avatar_style):
description = generate_description(image)
return generate_image_by_description(description, avatar_style)
with gr.Blocks(js=dom.generate_title) as demo:
with gr.Row():
gr.Markdown(dom.generate_markdown)
gr.Markdown(dom.models)
with gr.Row():
with gr.Column(scale=2, min_width=300):
selected_image = gr.Image(type="filepath", label="Upload an Image of the Pigeon", height=300)
example_image = gr.Examples(["./examples/pigeon.webp"], label="Example Images", inputs=[selected_image])
avatar_style = gr.Radio(
["Realistic", "Pixel Art", "Imaginative", "Cartoon"],
label="(optional) Select the avatar style:"
)
generate_button = gr.Button("Generate Avatar", variant="primary")
with gr.Column(scale=2, min_width=300):
generated_image = gr.Gallery(type="pil", label="Generated Avatar", height=300)
generate_button.click(process_and_generate, inputs=[selected_image, avatar_style], outputs=generated_image)
demo.launch()
|