TH9817 commited on
Commit
8bcf68c
1 Parent(s): 610c3d1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -0
app.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import av
2
+ import torch
3
+ import numpy as np
4
+ from transformers import VideoLlavaForConditionalGeneration, VideoLlavaProcessor
5
+
6
+ def read_video_pyav(container, indices):
7
+ '''
8
+ Decode the video with PyAV decoder.
9
+ Args:
10
+ container (`av.container.input.InputContainer`): PyAV container.
11
+ indices (`List[int]`): List of frame indices to decode.
12
+ Returns:
13
+ result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
14
+ '''
15
+ frames = []
16
+ container.seek(0)
17
+ start_index = indices[0]
18
+ end_index = indices[-1]
19
+ for i, frame in enumerate(container.decode(video=0)):
20
+ if i > end_index:
21
+ break
22
+ if i >= start_index and i in indices:
23
+ frames.append(frame)
24
+ return np.stack([x.to_ndarray(format="rgb24") for x in frames])
25
+
26
+ # Load the model in half-precision
27
+ model = VideoLlavaForConditionalGeneration.from_pretrained("LanguageBind/Video-LLaVA-7B-hf", torch_dtype=torch.float16, device_map="auto")
28
+ processor = VideoLlavaProcessor.from_pretrained("LanguageBind/Video-LLaVA-7B-hf")
29
+
30
+ # Load the video as an np.arrau, sampling uniformly 8 frames
31
+ video_path = hf_hub_download(repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset")
32
+ container = av.open(video_path)
33
+ total_frames = container.streams.video[0].frames
34
+ indices = np.arange(0, total_frames, total_frames / 8).astype(int)
35
+ video = read_video_pyav(container, indices)
36
+
37
+ # For better results, we recommend to prompt the model in the following format
38
+ prompt = "USER: <video>\nWhy is this funny? ASSISTANT:"
39
+ inputs = processor(text=prompt, videos=video, return_tensors="pt")
40
+
41
+ out = model.generate(**inputs, max_new_tokens=60)
42
+ processor.batch_decode(out, skip_special_tokens=True, clean_up_tokenization_spaces=True)