|
from typing import Dict, List, Any |
|
import torch |
|
from torch import autocast |
|
from diffusers import StableDiffusionXLPipeline |
|
import base64 |
|
from io import BytesIO |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else 'cpu') |
|
|
|
if device.type != "cuda": |
|
raise ValueError('need to run on gpu') |
|
|
|
|
|
class EndpointHandler(): |
|
def __init__(self, path="") : |
|
self.pipe = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True) |
|
self.pipe = self.pipe.to(device) |
|
|
|
|
|
def __call__(self, data:Any) -> List[List[Dict[str, float]]]: |
|
print(data) |
|
inputs = data.pop("inputs", data) |
|
print(device) |
|
with autocast(device.type): |
|
image = self.pipe(inputs, guidance_scale=7.5).images[0] |
|
|
|
buffered = BytesIO() |
|
image.save(buffered, format="JPEG") |
|
img_str = base64.b64encode(buffered.getvalue()) |
|
|
|
return { "image" : img_str.decode()} |
|
|
|
|
|
|