chenjoya commited on
Commit
98f88b8
1 Parent(s): 4cc5cdc

Create inference.py

Browse files
Files changed (1) hide show
  1. inference.py +124 -0
inference.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, torchvision, transformers, collections
2
+ torchvision.set_video_backend('video_reader')
3
+ from dataclasses import asdict
4
+ from torchvision.io import read_video
5
+
6
+ from models import build_model_and_tokenizer, parse_args, fast_greedy_generate
7
+
8
+ logger = transformers.logging.get_logger('liveinfer')
9
+
10
+ # python -m demo.cli --resume_from_checkpoint ...
11
+
12
+ class LiveInfer:
13
+ def __init__(self, ) -> None:
14
+ args = parse_args()
15
+ self.model, self.tokenizer = build_model_and_tokenizer(is_training=False, set_vision_inside=True, **asdict(args))
16
+ self.model.to('cuda')
17
+
18
+ # visual
19
+ self.hidden_size = self.model.config.hidden_size
20
+ self.frame_fps = args.frame_fps
21
+ self.frame_interval = 1 / self.frame_fps
22
+ self.frame_resolution = self.model.config.frame_resolution
23
+ self.frame_num_tokens = self.model.config.frame_num_tokens
24
+ self.frame_v_placeholder = self.model.config.v_placeholder * self.frame_num_tokens
25
+ self.frame_token_interval_id = self.model.config.frame_token_interval_id
26
+ self.frame_placeholder_ids = torch.tensor(self.model.config.v_placeholder_id).repeat(self.model.config.frame_num_tokens).reshape(1,-1)
27
+
28
+ # generation
29
+ self.system_prompt = args.system_prompt
30
+ self.inplace_output_ids = torch.zeros(1, 100, device='cuda', dtype=torch.long)
31
+ self.frame_token_interval_threshold = 0.725
32
+ self.eos_token_id = self.model.config.eos_token_id
33
+ self._start_ids = self.tokenizer.apply_chat_template([{'role': 'system', 'content': self.system_prompt}], add_stream_prompt=True, return_tensors='pt').to('cuda')
34
+ self._added_stream_prompt_ids = self.tokenizer.apply_chat_template([{}], add_stream_prompt=True, return_tensors='pt').to('cuda')
35
+ self._added_stream_generation_ids = self.tokenizer.apply_chat_template([{}], add_stream_generation_prompt=True, return_tensors='pt').to('cuda')
36
+
37
+ # app
38
+ self.reset()
39
+
40
+ def _call_for_response(self, video_time, query):
41
+ if query is not None:
42
+ self.last_ids = self.tokenizer.apply_chat_template([{'role': 'user', 'content': query}], add_stream_query_prompt=True, add_generation_prompt=True, return_tensors='pt').to('cuda')
43
+ else:
44
+ assert self.last_ids == 933, f'{self.last_ids} != 933' # HACK, 933 = ]\n
45
+ self.last_ids = self._added_stream_generation_ids
46
+ inputs_embeds = self.model.get_input_embeddings()(self.last_ids)
47
+ output_ids, self.past_key_values = fast_greedy_generate(model=self.model, inputs_embeds=inputs_embeds, past_key_values=self.past_key_values, eos_token_id=self.eos_token_id, inplace_output_ids=self.inplace_output_ids)
48
+ self.last_ids = output_ids[:, -1:]
49
+ if query:
50
+ query = f'(Video Time = {video_time}s) User: {query}'
51
+ response = f'(Video Time = {video_time}s) Assistant:{self.tokenizer.decode(output_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)}'
52
+ return query, response
53
+
54
+ def _call_for_streaming(self, ):
55
+ while self.frame_embeds_queue:
56
+ # 1. if query is before next frame, response
57
+ if self.query_queue and self.frame_embeds_queue[0][0] > self.query_queue[0][0]:
58
+ video_time, query = self.query_queue.popleft()
59
+ return video_time, query
60
+ video_time, frame_embeds = self.frame_embeds_queue.popleft()
61
+ if not self.past_key_values:
62
+ self.last_ids = self._start_ids
63
+ elif self.last_ids == self.eos_token_id:
64
+ self.last_ids = torch.cat([self.last_ids, self._added_stream_prompt_ids], dim=1)
65
+ inputs_embeds = torch.cat([
66
+ self.model.get_input_embeddings()(self.last_ids).view(1, -1, self.hidden_size),
67
+ frame_embeds.view(1, -1, self.hidden_size),
68
+ ], dim=1)
69
+ outputs = self.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=self.past_key_values)
70
+ self.past_key_values = outputs.past_key_values
71
+ # 2. if the same time, response after frame at that time
72
+ if self.query_queue and video_time >= self.query_queue[0][0]:
73
+ video_time, query = self.query_queue.popleft()
74
+ return video_time, query
75
+ # 3. if the next is frame but next is not interval, then response
76
+ next_score = outputs.logits[:,-1:].softmax(dim=-1)
77
+ if next_score[:,:,self.frame_token_interval_id] < self.frame_token_interval_threshold:
78
+ next_score[:,:,self.frame_token_interval_id].zero_()
79
+ self.last_ids = next_score.argmax(dim=-1)
80
+ if self.last_ids != self.frame_token_interval_id:
81
+ return video_time, None
82
+ return None, None
83
+
84
+ def reset(self, ):
85
+ self.query_queue = collections.deque()
86
+ self.frame_embeds_queue = collections.deque()
87
+ self.video_time = 0
88
+ self.last_frame_idx = -1
89
+ self.video_tensor = None
90
+ self.last_ids = torch.tensor([[]], device='cuda', dtype=torch.long)
91
+ self.past_key_values = None
92
+
93
+ def input_query_stream(self, query, history=None, video_time=None):
94
+ if video_time is None:
95
+ self.query_queue.append((self.video_time, query))
96
+ else:
97
+ self.query_queue.append((video_time, query))
98
+ if not self.past_key_values:
99
+ return f'(NOTE: No video stream here. Please select or upload a video. Then the assistant will answer "{query} (at {self.video_time}s)" in the video stream)'
100
+ return f'(NOTE: Received "{query}" (at {self.video_time}s). Please wait until previous frames have been processed)'
101
+
102
+ def input_video_stream(self, video_time):
103
+ frame_idx = int(video_time * self.frame_fps)
104
+ if frame_idx > self.last_frame_idx:
105
+ ranger = range(self.last_frame_idx + 1, frame_idx + 1)
106
+ frames_embeds = self.model.visual_embed(self.video_tensor[ranger]).split(self.frame_num_tokens)
107
+ self.frame_embeds_queue.extend([(r / self.frame_fps, frame_embeds) for r, frame_embeds in zip(ranger, frames_embeds)])
108
+ self.last_frame_idx = frame_idx
109
+ self.video_time = video_time
110
+
111
+ def load_video(self, video_path):
112
+ self.video_tensor = read_video(video_path, pts_unit='sec', output_format='TCHW')[0].to('cuda')
113
+ self.num_video_frames = self.video_tensor.size(0)
114
+ self.video_duration = self.video_tensor.size(0) / self.frame_fps
115
+ logger.warning(f'{video_path} -> {self.video_tensor.shape}, {self.frame_fps} FPS')
116
+
117
+ def __call__(self, ):
118
+ while not self.frame_embeds_queue:
119
+ continue
120
+ video_time, query = self._call_for_streaming()
121
+ response = None
122
+ if video_time is not None:
123
+ query, response = self._call_for_response(video_time, query)
124
+ return query, response