File size: 5,495 Bytes
776aee4
 
a700cdc
 
623c6f2
 
 
 
 
 
 
776aee4
 
a700cdc
20c305b
 
3700df9
a700cdc
 
 
 
 
 
 
 
 
 
 
 
 
20c305b
 
 
a700cdc
 
776aee4
 
 
 
 
 
 
20c305b
776aee4
 
 
 
 
623c6f2
 
 
 
 
 
 
 
20c305b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
776aee4
 
20c305b
c87b098
 
bd73118
c87b098
 
 
 
bd73118
c87b098
 
c024900
c87b098
c024900
 
 
 
 
 
 
 
 
 
 
 
776aee4
 
 
 
20c305b
776aee4
 
 
20c305b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from typing import Dict, List, Any
import torch
from transformers import LlavaNextVideoForConditionalGeneration, LlavaNextVideoProcessor
from peft import PeftModel
import base64
import numpy as np

def base64_to_numpy(base64_str, shape):
    arr_bytes = base64.b64decode(base64_str)
    arr = np.frombuffer(arr_bytes, dtype=np.uint8)
    return arr.reshape(shape)

class EndpointHandler:
    def __init__(self):
        self.base_model_name = "llava-hf/LLaVA-NeXT-Video-7B-hf"  # Replace with the original base model ID
        self.adapter_model_name = "EnariGmbH/surftown-1.0"  # Your fine-tuned adapter model ID

        # Load the base model
        self.model = LlavaNextVideoForConditionalGeneration.from_pretrained(
            self.base_model_name,
            torch_dtype=torch.float16,
            device_map="auto"
        )

        # Load the fine-tuned adapter model into the base model
        self.model = PeftModel.from_pretrained(self.model, self.adapter_model_name)

        # Merge the adapter weights into the base model and unload the adapter
        self.model = self.model.merge_and_unload()

        # # Save the full model
        # model.save_pretrained("surftown_fine_tuned_prompt_0")

        # # Optionally, load and save the processor (if needed)
        self.processor = LlavaNextVideoProcessor.from_pretrained(self.adapter_model_name)

        # Ensure the model is in evaluation mode
        self.model.eval()

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        Args:
            data (Dict): Contains the input data including "clip"
        
        Returns:
            List[Dict[str, Any]]: The generated text from the model.
        """
        # Extract inputs from the data dictionary
        clip_base64 = data.get("clip")
        clip_shape = data.get("clip_shape")  # Expect the shape to be passed
        
        if clip_base64 is None or clip_shape is None:
            return [{"error": "Missing 'clip' or 'clip_shape' in input data"}]

        # Decode the base64 back to numpy array and reshape
        clip = base64_to_numpy(clip_base64, tuple(clip_shape))

        prompt = """
        You are a surfing coach specialized on perfecting surfer's pop-up move. Please analyze the surfer's pop-up move in detail from the video. 
                    In your detailed analysis you should always mention: Wave Position and paddling, Pushing Phase, Transition, Reaching Phase and finnaly Balance and Control.
                    At the end of your answer you must provide suggestions on how the surfer can improve in the next pop-up.
                    Never mention your name in the answer and keep the answers short and direct. 
                    Your answers should ALWAYS follow this structure:
                        Description: \n
                            Wave Position and paddling: .\n.
                            Pushing Phase: \n. 
                            Transition: \n. 
                            Reaching Phase: \n 
                            Balance and Control: \n\n\n
                        Summary: \n 
                            Suggestions for improvement:\n 
                            """


        # Define a conversation history for surfing pop-up move analysis
        conversation = [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": prompt},
                    {"type": "video"},
                ],
            },
        ]

        # Apply the chat template to create the prompt for the model
        prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)

        if clip is None or prompt is None:
            return [{"error": "Missing 'clip' or 'prompt' in input data"}]

        # Ensure clip_bytes is converted properly to the expected format by the model
        inputs_video = ml.processor(text=prompt, videos=clip, padding=True, return_tensors="pt").to(ml.model.device)
        
        # Debug: Print the entire inputs_video structure
        print(f"Keys in inputs_video: {inputs_video.keys()}")

        # Rename pixel_values_videos to pixel_values if it exists
        if 'pixel_values_videos' in inputs_video:
            inputs_video['pixel_values'] = inputs_video.pop('pixel_values_videos')
            print(f"Renamed pixel_values_videos to pixel_values. New shape: {inputs_video['pixel_values'].shape}")
        else:
            print("pixel_values_videos not found in inputs_video")

        if 'input_ids' in inputs_video:
            print(f"input_ids shape: {inputs_video['input_ids'].shape}")
        else:
            print("input_ids not found in inputs_video")

        if 'attention_mask' in inputs_video:
            print(f"attention_mask shape: {inputs_video['attention_mask'].shape}")
        else:
            print("attention_mask not found in inputs_video")


        # Generate output from the model
        generate_kwargs = {"max_new_tokens": 512, "do_sample": True, "top_p": 0.9}
        output = self.model.generate(**inputs_video, **generate_kwargs)
        generated_text = self.processor.batch_decode(output, skip_special_tokens=True)

        # Extract the relevant part of the assistant's answer
        assistant_answer_start = generated_text[0].find("ASSISTANT:") + len("ASSISTANT:")
        assistant_answer = generated_text[0][assistant_answer_start:].strip()

        return [{"generated_text": assistant_answer}]