File size: 3,513 Bytes
26720fb
8bcf68c
 
7d19143
e167623
7741a53
7d19143
 
e167623
 
 
 
7d19143
e167623
7d19143
e167623
 
 
 
7d19143
8bcf68c
 
 
 
e167623
8bcf68c
e167623
 
 
8bcf68c
e167623
8bcf68c
 
 
 
 
 
 
 
 
 
 
 
36a91f0
e167623
893fc64
36a91f0
d50ab60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8bcf68c
7741a53
 
 
b163c9a
7741a53
 
 
36a91f0
7741a53
 
8bcf68c
7741a53
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import av
import torch
import numpy as np
from huggingface_hub import hf_hub_download
from transformers import BitsAndBytesConfig, LlavaNextVideoForConditionalGeneration, LlavaNextVideoProcessor
import gradio as gr


quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16
)

processor = LlavaNextVideoProcessor.from_pretrained("llava-hf/LLaVA-NeXT-Video-7B-hf")
model = LlavaNextVideoForConditionalGeneration.from_pretrained(
    "llava-hf/LLaVA-NeXT-Video-7B-hf",
    quantization_config=quantization_config,
    device_map='auto'
)


def read_video_pyav(container, indices):
    '''
    Decode the video with PyAV decoder.

    Args:
        container (av.container.input.InputContainer): PyAV container.
        indices (List[int]): List of frame indices to decode.

    Returns:
        np.ndarray: np array of decoded frames of shape (num_frames, height, width, 3).
    '''
    frames = []
    container.seek(0)
    start_index = indices[0]
    end_index = indices[-1]
    for i, frame in enumerate(container.decode(video=0)):
        if i > end_index:
            break
        if i >= start_index and i in indices:
            frames.append(frame)
    return np.stack([x.to_ndarray(format="rgb24") for x in frames])

def chat(video_number,token):
# Download video from the hub
#video_path_1 = hf_hub_download(repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset")
    video_path="./sample1-Scene-{0}.mp4".format(video_number)
    #video_path_2 = hf_hub_download(repo_id="raushan-testing-hf/videos-test", filename="karate.mp4", repo_type="dataset")
    
    container = av.open(video_path)
    
    # sample uniformly 8 frames from the video (we can sample more for longer videos)
    total_frames = container.streams.video[0].frames
    indices = np.arange(0, total_frames, total_frames / 8).astype(int)
    clip_baby = read_video_pyav(container, indices)
    
    
    #container = av.open(video_path_2)
    
    # sample uniformly 8 frames from the video (we can sample more for longer videos)
    #total_frames = container.streams.video[0].frames
    #indices = np.arange(0, total_frames, total_frames / 8).astype(int)
    #clip_karate = read_video_pyav(container, indices)
    
    # Each "content" is a list of dicts and you can add image/video/text modalities
    conversation = [
          {
              "role": "user",
              "content": [
                  {"type": "text", "text": "What happens in the video?"},
                  {"type": "video"},
                  ],
          },
    ]

    conversation_2 = [
          {
              "role": "user",
              "content": [
                  {"type": "text", "text": "What do you see in this video?"},
                  {"type": "video"},
                  ],
          },
    ]
    
    prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
    #prompt_2 = processor.apply_chat_template(conversation_2, add_generation_prompt=True)
    
    inputs = processor(prompt, videos=clip_baby, padding=True, return_tensors="pt").to(model.device)


    generate_kwargs = {"max_new_tokens": token, "do_sample": True, "top_p": 0.9}

    output = model.generate(**inputs, **generate_kwargs)
    generated_text = processor.batch_decode(output, skip_special_tokens=True)

    return generated_text[0][45:]

demo = gr.Interface(
    fn=chat,
    inputs=["text",gr.Slider(100,300)],
    outputs=["text"],
)

# 起動
demo.launch()