Merge remote-tracking branch 'myrepo/main'
Browse files- handler.py +88 -0
handler.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import QwenForVisionLanguage, QwenTokenizer, QwenProcessor
|
3 |
+
from PIL import Image
|
4 |
+
import base64
|
5 |
+
import io
|
6 |
+
import json
|
7 |
+
import cv2
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
class Qwen2VL7bHandler:
|
11 |
+
def __init__(self):
|
12 |
+
# Initialize the model and processor
|
13 |
+
self.model = None
|
14 |
+
self.tokenizer = None
|
15 |
+
self.processor = None
|
16 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
17 |
+
|
18 |
+
def initialize(self, ctx):
|
19 |
+
# Load the model and processor within the inference environment
|
20 |
+
model_dir = ctx.system_properties.get("model_dir")
|
21 |
+
self.model = QwenForVisionLanguage.from_pretrained(model_dir)
|
22 |
+
self.tokenizer = QwenTokenizer.from_pretrained(model_dir)
|
23 |
+
self.processor = QwenProcessor.from_pretrained(model_dir)
|
24 |
+
self.model.to(self.device)
|
25 |
+
self.model.eval()
|
26 |
+
|
27 |
+
def preprocess(self, data):
|
28 |
+
# Process incoming requests and extract video data
|
29 |
+
video_data = data.get('video')
|
30 |
+
if not video_data:
|
31 |
+
raise ValueError("Video data is required")
|
32 |
+
|
33 |
+
# Decode the base64 video
|
34 |
+
frames = self.extract_frames_from_video(video_data)
|
35 |
+
inputs = self.processor(images=frames, return_tensors="pt").to(self.device)
|
36 |
+
return inputs
|
37 |
+
|
38 |
+
def extract_frames_from_video(self, video_data):
|
39 |
+
# Decode the base64 video data
|
40 |
+
video_bytes = base64.b64decode(video_data)
|
41 |
+
video_array = np.frombuffer(video_bytes, np.uint8)
|
42 |
+
video = cv2.imdecode(video_array, cv2.IMREAD_COLOR)
|
43 |
+
|
44 |
+
# Capture frames from the video
|
45 |
+
vidcap = cv2.VideoCapture(io.BytesIO(video_bytes))
|
46 |
+
frames = []
|
47 |
+
success, frame = vidcap.read()
|
48 |
+
while success:
|
49 |
+
# Convert the frame from BGR to RGB format
|
50 |
+
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
51 |
+
pil_image = Image.fromarray(frame_rgb)
|
52 |
+
frames.append(pil_image)
|
53 |
+
success, frame = vidcap.read()
|
54 |
+
|
55 |
+
return frames
|
56 |
+
|
57 |
+
def inference(self, inputs):
|
58 |
+
# Perform inference on the preprocessed data
|
59 |
+
with torch.no_grad():
|
60 |
+
outputs = self.model(**inputs)
|
61 |
+
return outputs
|
62 |
+
|
63 |
+
def postprocess(self, inference_output):
|
64 |
+
# Convert the model outputs into a format suitable for the response
|
65 |
+
predicted_text = self.tokenizer.decode(inference_output.logits.argmax(-1))
|
66 |
+
return {"result": predicted_text}
|
67 |
+
|
68 |
+
def handle(self, data, context):
|
69 |
+
try:
|
70 |
+
# Deserialize the request data
|
71 |
+
request_data = json.loads(data[0].get("body"))
|
72 |
+
# Preprocess the input data
|
73 |
+
inputs = self.preprocess(request_data)
|
74 |
+
# Perform inference
|
75 |
+
outputs = self.inference(inputs)
|
76 |
+
# Postprocess the output
|
77 |
+
result = self.postprocess(outputs)
|
78 |
+
return [json.dumps(result)]
|
79 |
+
except Exception as e:
|
80 |
+
return [json.dumps({"error": str(e)})]
|
81 |
+
|
82 |
+
# Instantiate the handler for use in deployment
|
83 |
+
_service = Qwen2VL7bHandler()
|
84 |
+
|
85 |
+
def handle(data, context):
|
86 |
+
if not _service.model:
|
87 |
+
_service.initialize(context)
|
88 |
+
return _service.handle(data, context)
|