File size: 1,891 Bytes
f076f26
 
 
 
 
 
63043dc
 
dcc0b95
 
41339c4
dcc0b95
 
 
f076f26
 
dcc0b95
 
9c2101e
 
 
 
dcc0b95
f076f26
 
 
 
 
 
dcc0b95
 
 
f076f26
 
 
c57eb63
c3ebea1
 
c57eb63
 
 
f076f26
 
 
5581cd6
dcc0b95
bfbe48c
5581cd6
dcc0b95
5581cd6
 
 
f076f26
 
8468f08
c3ebea1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from typing import Dict, List, Any
from PIL import Image
from io import BytesIO
import torch
import base64
from diffusers import StableDiffusionInstructPix2PixPipeline, EulerAncestralDiscreteScheduler
import diffusers
import transformers
import logging
import subprocess
import sys

logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
subprocess.run("nvidia-smi")

logger.info(f"torch version: {torch.__version__}")
logger.info(f"diffusers version: {diffusers.__version__}")
logger.info(f"transformers version: {transformers.__version__}")

logger.info(f"device: {device}")

class EndpointHandler():
    def __init__(self, path=""):
        model_id = "timbrooks/instruct-pix2pix"
        self.pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, safety_checker=None)
        self.pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipe.scheduler.config)
        self.pipe.to(device)
        
        logger.info(f"PIPE LOADED")

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        data dict:
            inputs: base64 encoded image,
            parameters: dict:
                prompt: str
        returns:
            base64 encoded image
        """


        image_data = data.pop("inputs", data)
        logger.info(f"Raw img size: {sys.getsizeof(image_data)}")
        # decode base64 image to PIL
        image = Image.open(BytesIO(base64.b64decode(image_data)))
        logger.info(f"PIL Image img size: {sys.getsizeof(image)}")

        parameters = data.pop("parameters", data)
        prompt = parameters['prompt']

        images = self.pipe(prompt, image=image, num_inference_steps=10, image_guidance_scale=1).images
        
        return images[0]