pix2pix-instruct-IE / handler.py
sergeipetrov's picture
Update handler.py
bfbe48c verified
raw
history blame
1.26 kB
from typing import Dict, List, Any
from PIL import Image
from io import BytesIO
import torch
import base64
from diffusers import StableDiffusionInstructPix2PixPipeline, EulerAncestralDiscreteScheduler
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class EndpointHandler():
def __init__(self, path=""):
model_id = "timbrooks/instruct-pix2pix"
self.pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, safety_checker=None)
self.pipe.to(device)
self.pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipe.scheduler.config)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj:`string`)
parameters (:obj:)
Return:
A :obj:`string`:. Base64 encoded image string
"""
inputs = data.pop("inputs", data)
# decode base64 image to PIL
image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
prompt = inputs['prompt']
images = self.pipe(prompt, image=image, num_inference_steps=10, image_guidance_scale=1).images
return images[0]