arjunanand13 commited on
Commit
87cf582
1 Parent(s): 0f59469

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +109 -0
handler.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import sys
3
+ import torch
4
+ import base64
5
+ from io import BytesIO
6
+ from PIL import Image
7
+ import requests
8
+ from transformers import AutoModelForCausalLM, AutoProcessor
9
+ from tokenizers import Tokenizer, pre_tokenizers
10
+ import os
11
+
12
+ def install(package):
13
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "--no-warn-script-location", package])
14
+
15
+ class EndpointHandler:
16
+ def __init__(self, path=""):
17
+ # Install necessary packages
18
+ required_packages = ['timm', 'einops', 'flash-attn', 'Pillow']
19
+ for package in required_packages:
20
+ try:
21
+ install(package)
22
+ print(f"Successfully installed {package}")
23
+ except Exception as e:
24
+ print(f"Failed to install {package}: {str(e)}")
25
+
26
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
+ print(f"Using device: {self.device}")
28
+
29
+ # Load the model
30
+ self.model_name = "arjunanand13/florence-enphaseall2-25e"
31
+ self.model = AutoModelForCausalLM.from_pretrained(
32
+ self.model_name, trust_remote_code=True
33
+ ).to(self.device)
34
+
35
+ # Manually load the tokenizer with a whitespace pre-tokenizer
36
+ self.tokenizer = self.load_tokenizer()
37
+
38
+ # Initialize the processor
39
+ self.processor = AutoProcessor.from_pretrained(self.model_name, trust_remote_code=True)
40
+
41
+ if torch.cuda.is_available():
42
+ torch.cuda.empty_cache()
43
+
44
+ def load_tokenizer(self):
45
+ """Manually loads the tokenizer and adds a whitespace pre-tokenizer."""
46
+ try:
47
+ tokenizer = Tokenizer.from_pretrained(self.model_name)
48
+ tokenizer.pre_tokenizer = pre_tokenizers.Whitespace()
49
+ print("[INFO] Whitespace pre-tokenizer added.")
50
+ return tokenizer
51
+ except Exception as e:
52
+ print(f"[ERROR] Failed to load tokenizer: {str(e)}")
53
+ return None
54
+
55
+ def process_image(self, image_data):
56
+ """Processes image data from file path or base64-encoded string."""
57
+ print("[DEBUG] Attempting to process image")
58
+ try:
59
+ if isinstance(image_data, str) and len(image_data) < 256 and os.path.exists(image_data):
60
+ with open(image_data, 'rb') as image_file:
61
+ print("[DEBUG] File opened successfully")
62
+ image = Image.open(image_file)
63
+ else:
64
+ print("[DEBUG] Decoding base64 image data")
65
+ image_bytes = base64.b64decode(image_data)
66
+ image = Image.open(BytesIO(image_bytes))
67
+
68
+ print("[DEBUG] Image opened:", image.format, image.size, image.mode)
69
+ return image
70
+ except Exception as e:
71
+ print(f"[ERROR] Error processing image: {str(e)}")
72
+ return None
73
+
74
+ def __call__(self, data):
75
+ """Processes input and generates model output."""
76
+ try:
77
+ inputs = data.pop("inputs", data)
78
+
79
+ if isinstance(inputs, dict):
80
+ image_path = inputs.get("image", None)
81
+ text_input = inputs.get("text", "")
82
+ else:
83
+ image_path = inputs
84
+ text_input = "What is in this image?"
85
+
86
+ print("[INFO] Image path:", image_path, "| Text input:", text_input)
87
+
88
+ image = self.process_image(image_path) if image_path else None
89
+
90
+ model_inputs = self.processor(
91
+ images=image if image else None,
92
+ text=text_input,
93
+ return_tensors="pt"
94
+ )
95
+
96
+ model_inputs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
97
+ for k, v in model_inputs.items()}
98
+
99
+ with torch.no_grad():
100
+ outputs = self.model.generate(**model_inputs)
101
+
102
+ decoded_outputs = self.processor.batch_decode(outputs, skip_special_tokens=True)
103
+ print(f"[INFO] Generated text: {decoded_outputs[0]}")
104
+ return {"generated_text": decoded_outputs[0]}
105
+
106
+ except Exception as e:
107
+ print(f"[ERROR] {str(e)}")
108
+ return {"error": str(e)}
109
+