File size: 4,541 Bytes
776aee4 a700cdc 623c6f2 776aee4 a700cdc 20c305b 3700df9 a700cdc 20c305b a700cdc 776aee4 20c305b 776aee4 623c6f2 20c305b 776aee4 20c305b 776aee4 20c305b 776aee4 20c305b 776aee4 20c305b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
from typing import Dict, List, Any
import torch
from transformers import LlavaNextVideoForConditionalGeneration, LlavaNextVideoProcessor
from peft import PeftModel
import base64
import numpy as np
def base64_to_numpy(base64_str, shape):
arr_bytes = base64.b64decode(base64_str)
arr = np.frombuffer(arr_bytes, dtype=np.uint8)
return arr.reshape(shape)
class EndpointHandler:
def __init__(self):
self.base_model_name = "llava-hf/LLaVA-NeXT-Video-7B-hf" # Replace with the original base model ID
self.adapter_model_name = "EnariGmbH/surftown-1.0" # Your fine-tuned adapter model ID
# Load the base model
self.model = LlavaNextVideoForConditionalGeneration.from_pretrained(
self.base_model_name,
torch_dtype=torch.float16,
device_map="auto"
)
# Load the fine-tuned adapter model into the base model
self.model = PeftModel.from_pretrained(self.model, self.adapter_model_name)
# Merge the adapter weights into the base model and unload the adapter
self.model = self.model.merge_and_unload()
# # Save the full model
# model.save_pretrained("surftown_fine_tuned_prompt_0")
# # Optionally, load and save the processor (if needed)
self.processor = LlavaNextVideoProcessor.from_pretrained(self.adapter_model_name)
# Ensure the model is in evaluation mode
self.model.eval()
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Args:
data (Dict): Contains the input data including "clip"
Returns:
List[Dict[str, Any]]: The generated text from the model.
"""
# Extract inputs from the data dictionary
clip_base64 = data.get("clip")
clip_shape = data.get("clip_shape") # Expect the shape to be passed
if clip_base64 is None or clip_shape is None:
return [{"error": "Missing 'clip' or 'clip_shape' in input data"}]
# Decode the base64 back to numpy array and reshape
clip = base64_to_numpy(clip_base64, tuple(clip_shape))
prompt = """
You are a surfing coach specialized on perfecting surfer's pop-up move. Please analyze the surfer's pop-up move in detail from the video.
In your detailed analysis you should always mention: Wave Position and paddling, Pushing Phase, Transition, Reaching Phase and finnaly Balance and Control.
At the end of your answer you must provide suggestions on how the surfer can improve in the next pop-up.
Never mention your name in the answer and keep the answers short and direct.
Your answers should ALWAYS follow this structure:
Description: \n
Wave Position and paddling: .\n.
Pushing Phase: \n.
Transition: \n.
Reaching Phase: \n
Balance and Control: \n\n\n
Summary: \n
Suggestions for improvement:\n
"""
# Define a conversation history for surfing pop-up move analysis
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "video"},
],
},
]
# Apply the chat template to create the prompt for the model
prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)
if clip is None or prompt is None:
return [{"error": "Missing 'clip' or 'prompt' in input data"}]
# Prepare the inputs for the model
inputs_video = self.processor(text=prompt, videos=clip, padding=True, return_tensors="pt").to(self.model.device)
# Generate output from the model
generate_kwargs = {"max_new_tokens": 512, "do_sample": True, "top_p": 0.9}
output = self.model.generate(**inputs_video, **generate_kwargs)
generated_text = self.processor.batch_decode(output, skip_special_tokens=True)
# Extract the relevant part of the assistant's answer
assistant_answer_start = generated_text[0].find("ASSISTANT:") + len("ASSISTANT:")
assistant_answer = generated_text[0][assistant_answer_start:].strip()
return [{"generated_text": assistant_answer}]
|