File size: 2,683 Bytes
48e53fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c388933
48e53fd
c388933
 
48e53fd
 
 
 
 
 
 
 
 
 
 
868b5ab
48e53fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
from typing import Dict, List, Any
import sys
rootDir = os.path.abspath(os.path.dirname(__file__))
sys.path.append(rootDir)
from uaiDiffusers.common.imageRequest import ImageRequest
from diffusers import StableDiffusionXLPipeline, DPMSolverSinglestepScheduler        
import torch
from uaiDiffusers.uaiDiffusers import ImagesToBase64

import torch

class EndpointHandler:
    def __init__(self, path=""):
        # Preload all the elements you are going to need at inference.
        # pseudo:
        # self.model= load_model(path)
        self.pipe = None
        self.modelName = ""
        base = "sd-community/sdxl-flash"
        baseReq = ImageRequest()
        baseReq.model = base
        print(f"Loading model: {base}")
        self.LoadModel(baseReq)
        
    def LoadModel(self, request):
        base = "sd-community/sdxl-flash"
        if request.model == "default":
            request.model = base
        else:
            base = request.model
        if self.pipe is None:
            del self.pipe
            torch.cuda.empty_cache()
            self.pipe = StableDiffusionXLPipeline.from_pretrained(base).to("cuda")

            # Ensure sampler uses "trailing" timesteps.
            self.pipe.scheduler = DPMSolverSinglestepScheduler.from_config(self.pipe.scheduler.config, timestep_spacing="trailing")
        
    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
         data args:
              input (:obj: `str` | `PIL.Image` | `np.array`)
              seed (:obj: `int`)
              prompt (:obj: `str`)
              negative_prompt (:obj: `str`)
              num_images_per_prompt (:obj: `int`)
              steps (:obj: `int`)
              guidance_scale (:obj: `float`)
              width (:obj: `int`)
              height (:obj: `int`)
              
              kwargs
        Return:
              A :obj:`list` | `dict`: will be serialized and returned
        """
        # inputs = data.pop("parameters", data)
        request = ImageRequest.FromDict(data)
        response = self.__runProcess__(request)
        return response

    def __runProcess__(self, request: ImageRequest) -> List[Dict[str, Any]]:
        """
        Run SDXL Lightning pipeline
        """

        self.LoadModel(request)
        
        # Ensure using the same inference steps as the loaded model and CFG set to 0.
        images = self.pipe(request.prompt, negative_prompt = request.negative_prompt, num_inference_steps=request.steps, guidance_scale=request.guidance_scale, num_images_per_prompt=request.num_images_per_prompt ).images

        return {"media":[{"media":ImagesToBase64(img)} for img in images]}