hperkins commited on
Commit
0afa3fc
1 Parent(s): 2b1f1d8

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +27 -66
handler.py CHANGED
@@ -1,7 +1,4 @@
1
  from typing import Dict, Any
2
- from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
3
- from modelscope import snapshot_download
4
- from qwen_vl_utils import process_vision_info
5
  import torch
6
  import os
7
  import base64
@@ -9,36 +6,26 @@ import io
9
  from PIL import Image
10
  import logging
11
  import requests
12
- import subprocess
13
- from moviepy.editor import VideoFileClip
14
  import traceback # For formatting exception tracebacks
 
 
 
15
 
16
  class EndpointHandler():
17
  """
18
  Handler class for the Qwen2-VL-7B-Instruct model on Hugging Face Inference Endpoints.
19
-
20
  This handler processes text, image, and video inputs, leveraging the Qwen2-VL model
21
- for multimodal understanding and generation. It includes a runtime workaround to
22
- install FFmpeg if it's not available in the environment.
23
  """
24
 
25
  def __init__(self, path=""):
26
  """
27
- Initializes the handler, installs FFmpeg, and loads the Qwen2-VL model.
28
-
29
  Args:
30
  path (str, optional): The path to the Qwen2-VL model directory. Defaults to "".
31
  """
32
  self.model_dir = path
33
 
34
- # Install FFmpeg at runtime (this will run once during container initialization)
35
- try:
36
- subprocess.run(["apt-get", "update"], check=True)
37
- subprocess.run(["apt-get", "install", "-y", "ffmpeg"], check=True)
38
- logging.info("FFmpeg installed successfully.")
39
- except subprocess.CalledProcessError as e:
40
- logging.error(f"Error installing FFmpeg: {e}")
41
-
42
  # Load the Qwen2-VL model
43
  self.model = Qwen2VLForConditionalGeneration.from_pretrained(
44
  self.model_dir, torch_dtype="auto", device_map="auto"
@@ -48,12 +35,10 @@ class EndpointHandler():
48
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
49
  """
50
  Processes the input data and returns the Qwen2-VL model's output.
51
-
52
  Args:
53
  data (Dict[str, Any]): A dictionary containing the input data.
54
  - "inputs" (str): The input text, including image/video references.
55
  - "max_new_tokens" (int, optional): Max tokens to generate (default: 128).
56
-
57
  Returns:
58
  Dict[str, Any]: A dictionary containing the generated text.
59
  """
@@ -69,9 +54,6 @@ class EndpointHandler():
69
  )
70
  image_inputs, video_inputs = process_vision_info(messages)
71
 
72
- logging.debug(f"Image inputs: {image_inputs}")
73
- logging.debug(f"Video inputs: {video_inputs}")
74
-
75
  inputs = self.processor(
76
  text=[text],
77
  images=image_inputs,
@@ -95,10 +77,8 @@ class EndpointHandler():
95
  def _parse_input(self, input_string):
96
  """
97
  Parses the input string to identify image/video references and text.
98
-
99
  Args:
100
  input_string (str): The input string containing text, image, and video references.
101
-
102
  Returns:
103
  list: A list of dictionaries representing the parsed content.
104
  """
@@ -110,9 +90,7 @@ class EndpointHandler():
110
  else: # Image/video part
111
  if part.lower().startswith("video:"):
112
  video_path = part.split("video:")[1].strip()
113
- print(f"Video path: {video_path}")
114
  video_frames = self._extract_video_frames(video_path)
115
- print(f"Number of frames extracted: {len(video_frames) if video_frames else 0}")
116
  if video_frames:
117
  content.append({"type": "video", "video": video_frames, "fps": 1})
118
  else:
@@ -124,59 +102,42 @@ class EndpointHandler():
124
  def _load_image(self, image_data):
125
  """
126
  Loads an image from a URL or base64 encoded string.
127
-
128
  Args:
129
  image_data (str): The image data, either a URL or a base64 encoded string.
130
-
131
  Returns:
132
  PIL.Image.Image or None: The loaded image, or None if loading fails.
133
  """
134
- if image_data.startswith("http"):
135
- try:
136
- image = Image.open(requests.get(image_data, stream=True).raw)
137
- except Exception as e:
138
- logging.error(f"Error loading image from URL: {e}")
139
- return None
140
- elif image_data.startswith("data:image"):
141
- try:
142
- image_data = image_data.split(",")[1]
143
- image_bytes = base64.b64decode(image_data)
144
- image = Image.open(io.BytesIO(image_bytes))
145
- except Exception as e:
146
- logging.error(f"Error loading image from base64: {e}")
147
- return None
148
- else:
149
- logging.error("Invalid image data format. Must be URL or base64 encoded.")
150
- return None
151
- return image
152
 
153
  def _extract_video_frames(self, video_path, fps=1):
154
  """
155
  Extracts frames from a video at the specified FPS using MoviePy.
156
-
157
  Args:
158
  video_path (str): The path or URL of the video file.
159
  fps (int, optional): The desired frames per second. Defaults to 1.
160
-
161
  Returns:
162
- list or None: A list of PIL Images representing the extracted frames,
163
  or None if extraction fails.
164
  """
165
  try:
166
- print(f"Attempting to load video from: {video_path}")
167
- video = VideoFileClip(video_path)
168
- print(f"Video loaded: {video}")
169
-
170
- frames = [
171
- Image.fromarray(frame.astype('uint8'), 'RGB')
172
- for frame in video.iter_frames(fps=fps)
173
- ]
174
- print(f"Number of frames: {len(frames)}")
175
- print(f"Frame type: {type(frames[0]) if frames else None}")
176
- print(f"Frame size: {frames[0].size if frames else None}")
177
- video.close()
178
- return frames
179
  except Exception as e:
180
- error_message = f"Error extracting video frames: {e}\n{traceback.format_exc()}"
181
- logging.error(error_message) # Log the formatted error message
182
- return None
 
 
 
1
  from typing import Dict, Any
 
 
 
2
  import torch
3
  import os
4
  import base64
 
6
  from PIL import Image
7
  import logging
8
  import requests
 
 
9
  import traceback # For formatting exception tracebacks
10
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
11
+ from qwen_vl_utils import process_vision_info
12
+ from moviepy.editor import VideoFileClip
13
 
14
  class EndpointHandler():
15
  """
16
  Handler class for the Qwen2-VL-7B-Instruct model on Hugging Face Inference Endpoints.
 
17
  This handler processes text, image, and video inputs, leveraging the Qwen2-VL model
18
+ for multimodal understanding and generation.
 
19
  """
20
 
21
  def __init__(self, path=""):
22
  """
23
+ Initializes the handler and loads the Qwen2-VL model.
 
24
  Args:
25
  path (str, optional): The path to the Qwen2-VL model directory. Defaults to "".
26
  """
27
  self.model_dir = path
28
 
 
 
 
 
 
 
 
 
29
  # Load the Qwen2-VL model
30
  self.model = Qwen2VLForConditionalGeneration.from_pretrained(
31
  self.model_dir, torch_dtype="auto", device_map="auto"
 
35
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
36
  """
37
  Processes the input data and returns the Qwen2-VL model's output.
 
38
  Args:
39
  data (Dict[str, Any]): A dictionary containing the input data.
40
  - "inputs" (str): The input text, including image/video references.
41
  - "max_new_tokens" (int, optional): Max tokens to generate (default: 128).
 
42
  Returns:
43
  Dict[str, Any]: A dictionary containing the generated text.
44
  """
 
54
  )
55
  image_inputs, video_inputs = process_vision_info(messages)
56
 
 
 
 
57
  inputs = self.processor(
58
  text=[text],
59
  images=image_inputs,
 
77
  def _parse_input(self, input_string):
78
  """
79
  Parses the input string to identify image/video references and text.
 
80
  Args:
81
  input_string (str): The input string containing text, image, and video references.
 
82
  Returns:
83
  list: A list of dictionaries representing the parsed content.
84
  """
 
90
  else: # Image/video part
91
  if part.lower().startswith("video:"):
92
  video_path = part.split("video:")[1].strip()
 
93
  video_frames = self._extract_video_frames(video_path)
 
94
  if video_frames:
95
  content.append({"type": "video", "video": video_frames, "fps": 1})
96
  else:
 
102
  def _load_image(self, image_data):
103
  """
104
  Loads an image from a URL or base64 encoded string.
 
105
  Args:
106
  image_data (str): The image data, either a URL or a base64 encoded string.
 
107
  Returns:
108
  PIL.Image.Image or None: The loaded image, or None if loading fails.
109
  """
110
+ try:
111
+ if image_data.startswith("http"):
112
+ response = requests.get(image_data, stream=True)
113
+ response.raise_for_status() # Check for HTTP errors
114
+ return Image.open(response.raw)
115
+ elif image_data.startswith("data:image"):
116
+ base64_data = image_data.split(",")[1]
117
+ image_bytes = base64.b64decode(base64_data)
118
+ return Image.open(io.BytesIO(image_bytes))
119
+ except requests.exceptions.RequestException as e:
120
+ logging.error(f"HTTP error occurred while loading image: {e}")
121
+ except IOError as e:
122
+ logging.error(f"Error opening image: {e}")
123
+ return None
 
 
 
 
124
 
125
  def _extract_video_frames(self, video_path, fps=1):
126
  """
127
  Extracts frames from a video at the specified FPS using MoviePy.
 
128
  Args:
129
  video_path (str): The path or URL of the video file.
130
  fps (int, optional): The desired frames per second. Defaults to 1.
 
131
  Returns:
132
+ list or None: A list of PIL Images representing the extracted frames,
133
  or None if extraction fails.
134
  """
135
  try:
136
+ with VideoFileClip(video_path) as video:
137
+ return [Image.fromarray(frame.astype('uint8'), 'RGB') for frame in video.iter_frames(fps=fps)]
 
 
 
 
 
 
 
 
 
 
 
138
  except Exception as e:
139
+ logging.error(f"Error extracting video frames: {e}")
140
+ return None
141
+
142
+ # Additional configurations for logging
143
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')