yamildiego's picture
test c
4a1e480
raw
history blame contribute delete
No virus
1.89 kB
from typing import List, Any
import torch
from diffusers import StableCascadePriorPipeline, StableCascadeDecoderPipeline
# Configurar el dispositivo para ejecutar el modelo
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device.type != 'cuda':
raise ValueError("Se requiere ejecutar en GPU")
# Configurar el tipo de dato mixto basado en la capacidad de la GPU
dtype = torch.bfloat16 if torch.cuda.get_device_capability(device.index)[0] >= 8 else torch.float16
class EndpointHandler():
def __init__(self):
# Inicializar aquí si es necesario
pass
def __call__(self, data: Any) -> List[Any]:
# Configurar el número de imágenes por prompt
num_images_per_prompt = 1
# Cargar los modelos con el tipo de dato y dispositivo correctos
prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=dtype).to(device)
decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=dtype).to(device)
prompt = data.get("inputs", "Una imagen interesante") # Asegúrate de pasar un prompt adecuado
negative_prompt = data.get("negative_prompt", "")
prior_output = prior(
prompt=prompt,
height=512,
width=512,
negative_prompt=negative_prompt,
guidance_scale=7.5,
num_inference_steps=50,
num_images_per_prompt=num_images_per_prompt,
)
decoder_output = decoder(
image_embeddings=prior_output["image_embeddings"].half(),
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=7.5,
output_type="pil",
num_inference_steps=20
)
# Asumiendo que quieres retornar la primera imagen
return [decoder_output.images[0]]