File size: 4,549 Bytes
fbc8418
acc9b5d
 
171cc73
acc9b5d
 
 
 
 
 
 
e4524b0
f40466f
acc9b5d
 
 
 
 
 
 
057b8f0
acc9b5d
 
 
 
 
 
 
 
 
 
b150b57
acc9b5d
 
b150b57
e4524b0
acc9b5d
 
d9da728
ed47265
f43b9bc
ed47265
 
 
 
 
acc9b5d
 
 
d9d7db9
e4524b0
acc9b5d
 
 
 
 
 
 
33ce564
acc9b5d
33ce564
acc9b5d
 
 
 
 
 
 
 
e4524b0
acc9b5d
 
 
e4524b0
acc9b5d
 
 
 
 
33ce564
acc9b5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fbc8418
33ce564
fbc8418
 
 
 
 
 
acc9b5d
fbc8418
acc9b5d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from typing import Dict, Any
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from modelscope import snapshot_download
from qwen_vl_utils import process_vision_info
import torch
import os
import base64
import io
from PIL import Image
import logging
import requests
from moviepy.editor import VideoFileClip

class EndpointHandler():
    def __init__(self, path=""):
        self.model_dir = path
        self.model = Qwen2VLForConditionalGeneration.from_pretrained(
            self.model_dir, torch_dtype="auto", device_map="auto"
        )
        self.processor = AutoProcessor.from_pretrained(self.model_dir)

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """
        data args:
            inputs (str): The input text, including any image or video references.
            max_new_tokens (int, optional): Maximum number of tokens to generate. Defaults to 128.
        Return:
            A dictionary containing the generated text.
        """
        inputs = data.get("inputs")
        max_new_tokens = data.get("max_new_tokens", 128)

        # Construct the messages list from the input string
        messages = [{"role": "user", "content": self._parse_input(inputs)}]

        # Prepare for inference (using qwen_vl_utils)
        text = self.processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        image_inputs, video_inputs = process_vision_info(messages)

        inputs = self.processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt",
        )
        inputs = inputs.to("cuda" if torch.cuda.is_available() else "cpu")

        # Inference
        generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
        generated_ids_trimmed = [
            out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]
        output_text = self.processor.batch_decode(
            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )[0]  # Return a single string

        return {"generated_text": output_text}

    def _parse_input(self, input_string):
        """Parses the input string to identify image/video references and text."""
        content = []
        parts = input_string.split("<image>")
        for i, part in enumerate(parts):
            if i % 2 == 0:  # Text part
                content.append({"type": "text", "text": part.strip()})
            else:  # Image/video part
                if part.lower().startswith("video:"):
                    video_path = part.split("video:")[1].strip()
                    video_frames = self._extract_video_frames(video_path)
                    if video_frames:
                        content.append({"type": "video", "video": video_frames, "fps": 1})
                else:
                    image = self._load_image(part.strip())
                    if image:
                        content.append({"type": "image", "image": image})
        return content

    def _load_image(self, image_data):
        """Loads an image from a URL or base64 encoded string."""
        if image_data.startswith("http"):
            try:
                image = Image.open(requests.get(image_data, stream=True).raw)
            except Exception as e:
                logging.error(f"Error loading image from URL: {e}")
                return None
        elif image_data.startswith("data:image"):
            try:
                image_data = image_data.split(",")[1]
                image_bytes = base64.b64decode(image_data)
                image = Image.open(io.BytesIO(image_bytes))
            except Exception as e:
                logging.error(f"Error loading image from base64: {e}")
                return None
        else:
            logging.error("Invalid image data format. Must be URL or base64 encoded.")
            return None
        return image

    def _extract_video_frames(self, video_path, fps=1):
        """Extracts frames from a video at the specified FPS using MoviePy."""
        try:
            video = VideoFileClip(video_path)
            frames = [
                Image.fromarray(frame.astype('uint8'), 'RGB')
                for frame in video.iter_frames(fps=fps)
            ]
            video.close()
            return frames
        except Exception as e:
            logging.error(f"Error extracting video frames: {e}")
            return None