anujg1508 commited on
Commit
5fb4029
1 Parent(s): 2802bb2

upload handler

Browse files
Files changed (1) hide show
  1. handler.py +24 -31
handler.py CHANGED
@@ -1,39 +1,32 @@
1
- from typing import Dict, List, Any
2
  import torch
3
  from torch import autocast
4
- from diffusers import StableDiffusionPipeline
5
  import base64
6
  from io import BytesIO
7
 
 
8
 
9
- # set device
10
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
-
12
- class EndpointHandler():
13
- def __init__(self, path=""):
14
- # load the optimized model
15
- self.pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float16)
16
- self.pipe = self.pipe.to(device)
17
 
18
 
19
- def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
20
- """
21
- Args:
22
- data (:obj:):
23
- includes the input data and the parameters for the inference.
24
- Return:
25
- A :obj:`dict`:. base64 encoded image
26
- """
27
- inputs = data.pop("inputs", data)
28
-
29
- # run inference pipeline
30
- with autocast(device.type):
31
- image = self.pipe(inputs, guidance_scale=7.5)["sample"][0]
32
-
33
- # encode image as base 64
34
- buffered = BytesIO()
35
- image.save(buffered, format="JPEG")
36
- img_str = base64.b64encode(buffered.getvalue())
37
-
38
- # postprocess the prediction
39
- return {"image": img_str}
 
1
+ from typing import Dict, List, Any
2
  import torch
3
  from torch import autocast
4
+ from diffusers import StableDiffusionAdapterPipeline
5
  import base64
6
  from io import BytesIO
7
 
8
+ device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
9
 
10
+ if device.type != "cuda":
11
+ raise ValueError('need to run on gpu')
 
 
 
 
 
 
12
 
13
 
14
+ class EndpointHandler():
15
+ def __init__(self, path="") :
16
+ self.pipe = StableDiffusionAdapterPipeline.from_pretrained(path, torch_dtype=torch.float16)
17
+ self.pipe = self.pipe.to(device)
18
+
19
+
20
+ def __call__(self, data:Any) -> List[List[Dict[str, float]]]:
21
+ inputs = data.pop("inputs", data)
22
+
23
+ with autocast(device.type):
24
+ image = self.pipe(inputs, guidance_scale=7.5)["samples"][0]
25
+
26
+ buffered = BytesIO()
27
+ image.save(buffered, format="JPEG")
28
+ img_str = base64.base64encode(buffered.getvalue())
29
+
30
+ return { "image" : img_str.decode()}
31
+
32
+