from typing import Dict, List, Any from PIL import Image from io import BytesIO import torch import base64 from diffusers import StableDiffusionInstructPix2PixPipeline, EulerAncestralDiscreteScheduler import diffusers import transformers import logging import subprocess import sys logger = logging.getLogger() logger.setLevel(logging.DEBUG) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') subprocess.run("nvidia-smi") logger.info(f"torch version: {torch.__version__}") logger.info(f"diffusers version: {diffusers.__version__}") logger.info(f"transformers version: {transformers.__version__}") logger.info(f"device: {device}") class EndpointHandler(): def __init__(self, path=""): model_id = "timbrooks/instruct-pix2pix" self.pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, safety_checker=None) self.pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipe.scheduler.config) self.pipe.to(device) logger.info(f"PIPE LOADED") def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ data dict: inputs: base64 encoded image, parameters: dict: prompt: str returns: base64 encoded image """ image_data = data.pop("inputs", data) logger.info(f"Raw img size: {sys.getsizeof(image_data)}") # decode base64 image to PIL image = Image.open(BytesIO(base64.b64decode(image_data))) logger.info(f"PIL Image img size: {sys.getsizeof(image)}") parameters = data.pop("parameters", data) prompt = parameters['prompt'] images = self.pipe(prompt, image=image, num_inference_steps=10, image_guidance_scale=1).images return images[0]