Spaces:
Running
Running
import gradio as gr | |
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer | |
from diffusers import StableDiffusionPipeline | |
import torch | |
from PIL import Image | |
# Configuración del dispositivo | |
device = "cpu" | |
if torch.cuda.is_available(): | |
device = "cuda" | |
elif torch.backends.mps.is_available(): | |
device = "mps" | |
# Configuración de modelos | |
model_id_image = "sd-legacy/stable-diffusion-v1-5" | |
model_id_image_description = "vikhyatk/moondream2" | |
revision = "2024-08-26" | |
torch_dtype = torch.float32 | |
if torch.cuda.is_available(): | |
torch_dtype = torch.bfloat16 # Optimización en GPU | |
# Carga de modelos persistente | |
print("Cargando modelo de descripción de imágenes...") | |
model_description = AutoModelForCausalLM.from_pretrained(model_id_image_description, trust_remote_code=True, revision=revision) | |
tokenizer_description = AutoTokenizer.from_pretrained(model_id_image_description, revision=revision) | |
print("Cargando modelo de Stable Diffusion...") | |
pipe_sd = StableDiffusionPipeline.from_pretrained(model_id_image, torch_dtype=torch_dtype) | |
pipe_sd = pipe_sd.to(device) | |
# Opciones para optimizar memoria | |
pipe_sd.enable_attention_slicing() | |
if device == "cuda": | |
pipe_sd.enable_sequential_cpu_offload() # Liberar memoria gradualmente para GPUs pequeñas | |
def generate_description(image_path): | |
image_test = Image.open(image_path) | |
enc_image = model_description.encode_image(image_test) | |
description = model_description.answer_question(enc_image, "Describe this image to create an avatar", tokenizer_description) | |
return description | |
def generate_image_by_description(description, avatar_style=None): | |
prompt = f"Create a pigeon profile avatar. Use the following description: {description}." | |
if avatar_style: | |
prompt += f" Use {avatar_style} style." | |
result = pipe_sd(prompt) | |
return result.images[0] | |
def process_and_generate(image, avatar_style): | |
description = generate_description(image) | |
return generate_image_by_description(description, avatar_style) | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(scale=2, min_width=300): | |
selected_image = gr.Image(type="filepath", label="Upload an Image of the Pigeon", height=300) | |
avatar_style = gr.Radio( | |
["Realistic", "Pixel Art", "Imaginative", "Cartoon"], | |
label="(optional) Select the avatar style:" | |
) | |
generate_button = gr.Button("Generate Avatar", variant="primary") | |
with gr.Column(scale=2, min_width=300): | |
generated_image = gr.Image(type="numpy", label="Generated Avatar", height=300) | |
generate_button.click(process_and_generate, inputs=[selected_image, avatar_style], outputs=generated_image) | |
demo.launch() | |