File size: 4,549 Bytes
fbc8418 acc9b5d 171cc73 acc9b5d e4524b0 f40466f acc9b5d 057b8f0 acc9b5d b150b57 acc9b5d b150b57 e4524b0 acc9b5d d9da728 ed47265 f43b9bc ed47265 acc9b5d d9d7db9 e4524b0 acc9b5d 33ce564 acc9b5d 33ce564 acc9b5d e4524b0 acc9b5d e4524b0 acc9b5d 33ce564 acc9b5d fbc8418 33ce564 fbc8418 acc9b5d fbc8418 acc9b5d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
from typing import Dict, Any
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from modelscope import snapshot_download
from qwen_vl_utils import process_vision_info
import torch
import os
import base64
import io
from PIL import Image
import logging
import requests
from moviepy.editor import VideoFileClip
class EndpointHandler():
def __init__(self, path=""):
self.model_dir = path
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
self.model_dir, torch_dtype="auto", device_map="auto"
)
self.processor = AutoProcessor.from_pretrained(self.model_dir)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
data args:
inputs (str): The input text, including any image or video references.
max_new_tokens (int, optional): Maximum number of tokens to generate. Defaults to 128.
Return:
A dictionary containing the generated text.
"""
inputs = data.get("inputs")
max_new_tokens = data.get("max_new_tokens", 128)
# Construct the messages list from the input string
messages = [{"role": "user", "content": self._parse_input(inputs)}]
# Prepare for inference (using qwen_vl_utils)
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to("cuda" if torch.cuda.is_available() else "cpu")
# Inference
generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = self.processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0] # Return a single string
return {"generated_text": output_text}
def _parse_input(self, input_string):
"""Parses the input string to identify image/video references and text."""
content = []
parts = input_string.split("<image>")
for i, part in enumerate(parts):
if i % 2 == 0: # Text part
content.append({"type": "text", "text": part.strip()})
else: # Image/video part
if part.lower().startswith("video:"):
video_path = part.split("video:")[1].strip()
video_frames = self._extract_video_frames(video_path)
if video_frames:
content.append({"type": "video", "video": video_frames, "fps": 1})
else:
image = self._load_image(part.strip())
if image:
content.append({"type": "image", "image": image})
return content
def _load_image(self, image_data):
"""Loads an image from a URL or base64 encoded string."""
if image_data.startswith("http"):
try:
image = Image.open(requests.get(image_data, stream=True).raw)
except Exception as e:
logging.error(f"Error loading image from URL: {e}")
return None
elif image_data.startswith("data:image"):
try:
image_data = image_data.split(",")[1]
image_bytes = base64.b64decode(image_data)
image = Image.open(io.BytesIO(image_bytes))
except Exception as e:
logging.error(f"Error loading image from base64: {e}")
return None
else:
logging.error("Invalid image data format. Must be URL or base64 encoded.")
return None
return image
def _extract_video_frames(self, video_path, fps=1):
"""Extracts frames from a video at the specified FPS using MoviePy."""
try:
video = VideoFileClip(video_path)
frames = [
Image.fromarray(frame.astype('uint8'), 'RGB')
for frame in video.iter_frames(fps=fps)
]
video.close()
return frames
except Exception as e:
logging.error(f"Error extracting video frames: {e}")
return None |