krtk00 commited on
Commit
2e28b6e
1 Parent(s): b3a7a95

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +45 -0
handler.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any
2
+ from diffusers import AutoPipelineForText2Image
3
+ import torch
4
+ from PIL import Image
5
+ import base64
6
+ from io import BytesIO
7
+
8
+ class EndpointHandler:
9
+ def __init__(self, path: str = ""):
10
+ """
11
+ Initialize the handler, loading the model and LoRA weights.
12
+ The path parameter is provided by Hugging Face Inference Endpoints to point to the model directory.
13
+ """
14
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
+ self.pipeline = AutoPipelineForText2Image.from_pretrained(
16
+ 'black-forest-labs/FLUX.1-dev',
17
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
18
+ ).to(self.device)
19
+
20
+ # Load LoRA weights
21
+ lora_weights_path = 'krtk00/pan_crd_lora_v2'
22
+ self.pipeline.load_lora_weights(lora_weights_path, weight_name='lora.safetensors')
23
+
24
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
25
+ """
26
+ This method will be called on every request. The input is expected to be a dictionary
27
+ with a key "inputs" containing the text prompt.
28
+ """
29
+ # Preprocess input
30
+ prompt = data.get("inputs", None)
31
+ if not prompt:
32
+ raise ValueError("No prompt provided in the input")
33
+
34
+ # Run inference
35
+ with torch.no_grad():
36
+ images = self.pipeline(prompt).images
37
+
38
+ # Postprocess output: Convert image to base64
39
+ pil_image = images[0] # Assuming one image is generated
40
+ buffered = BytesIO()
41
+ pil_image.save(buffered, format="PNG")
42
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
43
+
44
+ # Return result
45
+ return {"image": img_str}