pix2pix-instruct-IE / handler.py
sergeipetrov's picture
sergeipetrov HF staff
Update handler.py
63043dc verified
raw
history blame contribute delete
No virus
1.89 kB
from typing import Dict, List, Any
from PIL import Image
from io import BytesIO
import torch
import base64
from diffusers import StableDiffusionInstructPix2PixPipeline, EulerAncestralDiscreteScheduler
import diffusers
import transformers
import logging
import subprocess
import sys
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
subprocess.run("nvidia-smi")
logger.info(f"torch version: {torch.__version__}")
logger.info(f"diffusers version: {diffusers.__version__}")
logger.info(f"transformers version: {transformers.__version__}")
logger.info(f"device: {device}")
class EndpointHandler():
def __init__(self, path=""):
model_id = "timbrooks/instruct-pix2pix"
self.pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, safety_checker=None)
self.pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipe.scheduler.config)
self.pipe.to(device)
logger.info(f"PIPE LOADED")
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data dict:
inputs: base64 encoded image,
parameters: dict:
prompt: str
returns:
base64 encoded image
"""
image_data = data.pop("inputs", data)
logger.info(f"Raw img size: {sys.getsizeof(image_data)}")
# decode base64 image to PIL
image = Image.open(BytesIO(base64.b64decode(image_data)))
logger.info(f"PIL Image img size: {sys.getsizeof(image)}")
parameters = data.pop("parameters", data)
prompt = parameters['prompt']
images = self.pipe(prompt, image=image, num_inference_steps=10, image_guidance_scale=1).images
return images[0]