|
from typing import Dict, List, Any |
|
import torch |
|
from transformers import LlavaNextVideoForConditionalGeneration, LlavaNextVideoProcessor |
|
from peft import PeftModel |
|
import base64 |
|
import numpy as np |
|
|
|
def base64_to_numpy(base64_str, shape): |
|
arr_bytes = base64.b64decode(base64_str) |
|
arr = np.frombuffer(arr_bytes, dtype=np.uint8) |
|
return arr.reshape(shape) |
|
|
|
class EndpointHandler: |
|
def __init__(self): |
|
self.base_model_name = "llava-hf/LLaVA-NeXT-Video-7B-hf" |
|
self.adapter_model_name = "EnariGmbH/surftown-1.0" |
|
|
|
|
|
self.model = LlavaNextVideoForConditionalGeneration.from_pretrained( |
|
self.base_model_name, |
|
torch_dtype=torch.float16, |
|
device_map="auto" |
|
) |
|
|
|
|
|
self.model = PeftModel.from_pretrained(self.model, self.adapter_model_name) |
|
|
|
|
|
self.model = self.model.merge_and_unload() |
|
|
|
|
|
|
|
|
|
|
|
self.processor = LlavaNextVideoProcessor.from_pretrained(self.adapter_model_name) |
|
|
|
|
|
self.model.eval() |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
""" |
|
Args: |
|
data (Dict): Contains the input data including "clip" |
|
|
|
Returns: |
|
List[Dict[str, Any]]: The generated text from the model. |
|
""" |
|
|
|
clip_base64 = data.get("clip") |
|
clip_shape = data.get("clip_shape") |
|
|
|
if clip_base64 is None or clip_shape is None: |
|
return [{"error": "Missing 'clip' or 'clip_shape' in input data"}] |
|
|
|
|
|
clip = base64_to_numpy(clip_base64, tuple(clip_shape)) |
|
|
|
prompt = """ |
|
You are a surfing coach specialized on perfecting surfer's pop-up move. Please analyze the surfer's pop-up move in detail from the video. |
|
In your detailed analysis you should always mention: Wave Position and paddling, Pushing Phase, Transition, Reaching Phase and finnaly Balance and Control. |
|
At the end of your answer you must provide suggestions on how the surfer can improve in the next pop-up. |
|
Never mention your name in the answer and keep the answers short and direct. |
|
Your answers should ALWAYS follow this structure: |
|
Description: \n |
|
Wave Position and paddling: .\n. |
|
Pushing Phase: \n. |
|
Transition: \n. |
|
Reaching Phase: \n |
|
Balance and Control: \n\n\n |
|
Summary: \n |
|
Suggestions for improvement:\n |
|
""" |
|
|
|
|
|
|
|
conversation = [ |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{"type": "text", "text": prompt}, |
|
{"type": "video"}, |
|
], |
|
}, |
|
] |
|
|
|
|
|
prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True) |
|
|
|
if clip is None or prompt is None: |
|
return [{"error": "Missing 'clip' or 'prompt' in input data"}] |
|
|
|
|
|
inputs_video = ml.processor(text=prompt, videos=clip, padding=True, return_tensors="pt").to(ml.model.device) |
|
|
|
|
|
print(f"Keys in inputs_video: {inputs_video.keys()}") |
|
|
|
|
|
if 'pixel_values_videos' in inputs_video: |
|
inputs_video['pixel_values'] = inputs_video.pop('pixel_values_videos') |
|
print(f"Renamed pixel_values_videos to pixel_values. New shape: {inputs_video['pixel_values'].shape}") |
|
else: |
|
print("pixel_values_videos not found in inputs_video") |
|
|
|
if 'input_ids' in inputs_video: |
|
print(f"input_ids shape: {inputs_video['input_ids'].shape}") |
|
else: |
|
print("input_ids not found in inputs_video") |
|
|
|
if 'attention_mask' in inputs_video: |
|
print(f"attention_mask shape: {inputs_video['attention_mask'].shape}") |
|
else: |
|
print("attention_mask not found in inputs_video") |
|
|
|
|
|
|
|
generate_kwargs = {"max_new_tokens": 512, "do_sample": True, "top_p": 0.9} |
|
output = self.model.generate(**inputs_video, **generate_kwargs) |
|
generated_text = self.processor.batch_decode(output, skip_special_tokens=True) |
|
|
|
|
|
assistant_answer_start = generated_text[0].find("ASSISTANT:") + len("ASSISTANT:") |
|
assistant_answer = generated_text[0][assistant_answer_start:].strip() |
|
|
|
return [{"generated_text": assistant_answer}] |
|
|