from typing import List, Any import torch from diffusers import StableCascadePriorPipeline, StableCascadeDecoderPipeline # Configurar el dispositivo para ejecutar el modelo device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device.type != 'cuda': raise ValueError("Se requiere ejecutar en GPU") # Configurar el tipo de dato mixto basado en la capacidad de la GPU dtype = torch.bfloat16 if torch.cuda.get_device_capability(device.index)[0] >= 8 else torch.float16 class EndpointHandler(): def __init__(self): # Inicializar aquí si es necesario pass def __call__(self, data: Any) -> List[Any]: # Configurar el número de imágenes por prompt num_images_per_prompt = 1 # Cargar los modelos con el tipo de dato y dispositivo correctos prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=dtype).to(device) decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=dtype).to(device) prompt = data.get("inputs", "Una imagen interesante") # Asegúrate de pasar un prompt adecuado negative_prompt = data.get("negative_prompt", "") prior_output = prior( prompt=prompt, height=512, width=512, negative_prompt=negative_prompt, guidance_scale=7.5, num_inference_steps=50, num_images_per_prompt=num_images_per_prompt, ) decoder_output = decoder( image_embeddings=prior_output["image_embeddings"].half(), prompt=prompt, negative_prompt=negative_prompt, guidance_scale=7.5, output_type="pil", num_inference_steps=20 ) # Asumiendo que quieres retornar la primera imagen return [decoder_output.images[0]]