hperkins commited on
Commit
45f4c1a
2 Parent(s): cacb254 33ce564

Merge remote-tracking branch 'myrepo/main'

Browse files
Files changed (1) hide show
  1. handler.py +88 -0
handler.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import QwenForVisionLanguage, QwenTokenizer, QwenProcessor
3
+ from PIL import Image
4
+ import base64
5
+ import io
6
+ import json
7
+ import cv2
8
+ import numpy as np
9
+
10
+ class Qwen2VL7bHandler:
11
+ def __init__(self):
12
+ # Initialize the model and processor
13
+ self.model = None
14
+ self.tokenizer = None
15
+ self.processor = None
16
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+
18
+ def initialize(self, ctx):
19
+ # Load the model and processor within the inference environment
20
+ model_dir = ctx.system_properties.get("model_dir")
21
+ self.model = QwenForVisionLanguage.from_pretrained(model_dir)
22
+ self.tokenizer = QwenTokenizer.from_pretrained(model_dir)
23
+ self.processor = QwenProcessor.from_pretrained(model_dir)
24
+ self.model.to(self.device)
25
+ self.model.eval()
26
+
27
+ def preprocess(self, data):
28
+ # Process incoming requests and extract video data
29
+ video_data = data.get('video')
30
+ if not video_data:
31
+ raise ValueError("Video data is required")
32
+
33
+ # Decode the base64 video
34
+ frames = self.extract_frames_from_video(video_data)
35
+ inputs = self.processor(images=frames, return_tensors="pt").to(self.device)
36
+ return inputs
37
+
38
+ def extract_frames_from_video(self, video_data):
39
+ # Decode the base64 video data
40
+ video_bytes = base64.b64decode(video_data)
41
+ video_array = np.frombuffer(video_bytes, np.uint8)
42
+ video = cv2.imdecode(video_array, cv2.IMREAD_COLOR)
43
+
44
+ # Capture frames from the video
45
+ vidcap = cv2.VideoCapture(io.BytesIO(video_bytes))
46
+ frames = []
47
+ success, frame = vidcap.read()
48
+ while success:
49
+ # Convert the frame from BGR to RGB format
50
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
51
+ pil_image = Image.fromarray(frame_rgb)
52
+ frames.append(pil_image)
53
+ success, frame = vidcap.read()
54
+
55
+ return frames
56
+
57
+ def inference(self, inputs):
58
+ # Perform inference on the preprocessed data
59
+ with torch.no_grad():
60
+ outputs = self.model(**inputs)
61
+ return outputs
62
+
63
+ def postprocess(self, inference_output):
64
+ # Convert the model outputs into a format suitable for the response
65
+ predicted_text = self.tokenizer.decode(inference_output.logits.argmax(-1))
66
+ return {"result": predicted_text}
67
+
68
+ def handle(self, data, context):
69
+ try:
70
+ # Deserialize the request data
71
+ request_data = json.loads(data[0].get("body"))
72
+ # Preprocess the input data
73
+ inputs = self.preprocess(request_data)
74
+ # Perform inference
75
+ outputs = self.inference(inputs)
76
+ # Postprocess the output
77
+ result = self.postprocess(outputs)
78
+ return [json.dumps(result)]
79
+ except Exception as e:
80
+ return [json.dumps({"error": str(e)})]
81
+
82
+ # Instantiate the handler for use in deployment
83
+ _service = Qwen2VL7bHandler()
84
+
85
+ def handle(data, context):
86
+ if not _service.model:
87
+ _service.initialize(context)
88
+ return _service.handle(data, context)