hysts commited on
Commit
55efca8
1 Parent(s): b5d66a8

Remove files

Browse files
app.py DELETED
@@ -1,18 +0,0 @@
1
- import gradio as gr
2
- import os
3
-
4
-
5
-
6
-
7
- os.environ['SAT_HOME'] = '/home/user/app/sharefs/cogview-new'
8
-
9
- def inference(text):
10
- os.system("""bash ./scripts/inference_cogvideo_pipeline.sh""")
11
- return "output/out.mp4"
12
-
13
- gr.Interface(inference,"text","video").launch()
14
-
15
-
16
-
17
-
18
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cluster_label2.npy DELETED
Binary file (160 kB)
 
coglm_strategy.py DELETED
@@ -1,101 +0,0 @@
1
- # -*- encoding: utf-8 -*-
2
- '''
3
- @File : coglm_strategy.py
4
- @Time : 2021/10/08 22:22:42
5
- @Author : Ming Ding
6
- @Contact : [email protected]
7
- '''
8
-
9
- # here put the import lib
10
- import os
11
- import sys
12
- import math
13
- import random
14
- import torch
15
- import numpy as np
16
- import torch.nn.functional as F
17
-
18
-
19
- def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-65504):
20
- # This function has been mostly taken from huggingface conversational ai code at
21
- # https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313
22
-
23
- if top_k > 0:
24
- # Remove all tokens with a probability less than the last token of the top-k
25
- indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
26
- logits[indices_to_remove] = filter_value
27
-
28
- if top_p > 0.0:
29
- # convert to 1D
30
- logits = logits.view(logits.size()[1]).contiguous()
31
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
32
- cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
33
-
34
- # Remove tokens with cumulative probability above the threshold
35
- sorted_indices_to_remove = cumulative_probs > top_p
36
- # Shift the indices to the right to keep also the first token above the threshold
37
- sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
38
- sorted_indices_to_remove[..., 0] = 0
39
- indices_to_remove = sorted_indices[sorted_indices_to_remove]
40
- logits[indices_to_remove] = filter_value
41
- # going back to 2D
42
- logits = logits.view(1, -1).contiguous()
43
-
44
- return logits
45
-
46
-
47
- class CoglmStrategy:
48
- def __init__(self, invalid_slices=[], temperature=1., top_k=200, eps=1e-4, top_p=0.0, end_tokens=None, temperature2=0.89):
49
- self.invalid_slices = invalid_slices
50
- self.temperature = temperature
51
- self.temperature2 = temperature2
52
- self.topk = top_k
53
- self.top_p = top_p
54
- self.eps = eps
55
- if end_tokens is None:
56
- end_tokens = []
57
- self.end_tokens = end_tokens
58
- self._is_done = False
59
- self.outlier_count_down = torch.zeros(16)
60
- self.vis_list = [[]for i in range(16)]
61
- self.cluster_labels = torch.tensor(np.load('cluster_label2.npy'), device='cuda', dtype=torch.long)
62
- self.start_pos = -1
63
- self.white_cluster = []
64
- # self.fout = open('tmp.txt', 'w')
65
-
66
- @property
67
- def is_done(self) -> bool:
68
- return self._is_done
69
-
70
- def forward(self, logits, tokens, mems, temperature=None, temperature2=None):
71
- if temperature is None:
72
- temperature = self.temperature
73
- if temperature2 is None:
74
- temperature2 = self.temperature2
75
- logits = logits / temperature
76
- for invalid_slice in self.invalid_slices:
77
- logits[..., invalid_slice] = -65504
78
-
79
- rprobs = F.softmax(logits.float(), dim=-1)
80
- c = self.cluster_labels.expand(*rprobs.shape)
81
- cprobs = torch.zeros(logits.shape[0], 500, device=logits.device).scatter_add_(1, c, rprobs)
82
- # self.fout.write(str(tokens.shape[-1])+ ' ' + str(cprobs.topk(10)) + '\n')
83
- # self.fout.flush()
84
- best_scores, best_clusters = cprobs.topk(self.topk)
85
- bz = logits.shape[0]
86
- for i in range(bz):
87
- selected_cluster = best_clusters[i][torch.multinomial(best_scores[i] / best_scores[i].sum(), num_samples=1)]
88
- logits[i, self.cluster_labels != selected_cluster] = -65504
89
-
90
- # logits = top_k_logits(logits, self.topk, self.top_p)
91
- probs = F.softmax(logits.float()/temperature2, dim=-1) # float is essetial, due to a bug in Pytorch
92
- pred = torch.multinomial(probs, num_samples=1)
93
-
94
- if pred.numel() == 1 and pred.item() in self.end_tokens:
95
- self._is_done = True
96
- tokens = torch.cat((tokens, pred.view(tokens.shape[0], 1)), dim=1)
97
- return tokens, mems
98
-
99
- def finalize(self, tokens, mems):
100
- self._is_done = False
101
- return tokens, mems
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cogvideo_pipeline.py DELETED
@@ -1,793 +0,0 @@
1
- # -*- encoding: utf-8 -*-
2
- '''
3
- @File : cogvideo_pipeline.py
4
- @Time : 2022/07/15 11:24:56
5
- @Author : Wenyi Hong
6
- @Version : 1.0
7
- @Contact : [email protected]
8
- '''
9
-
10
- # here put the import lib
11
-
12
- import os
13
- import sys
14
- import torch
15
- import argparse
16
- import time
17
- from torchvision.utils import save_image
18
- import stat
19
- from icetk import icetk as tokenizer
20
- import logging, sys
21
-
22
- import torch.distributed as dist
23
- tokenizer.add_special_tokens(['<start_of_image>', '<start_of_english>', '<start_of_chinese>'])
24
-
25
-
26
- from SwissArmyTransformer import get_args
27
- from SwissArmyTransformer.data_utils import BinaryDataset, make_loaders
28
- from SwissArmyTransformer.generation.sampling_strategies import BaseStrategy
29
- from SwissArmyTransformer.generation.utils import timed_name, save_multiple_images, generate_continually
30
- from SwissArmyTransformer.resources import auto_create
31
-
32
- from models.cogvideo_cache_model import CogVideoCacheModel
33
- from coglm_strategy import CoglmStrategy
34
-
35
-
36
- def get_masks_and_position_ids_stage1(data, textlen, framelen):
37
- # Extract batch size and sequence length.
38
- tokens = data
39
- seq_length = len(data[0])
40
- # Attention mask (lower triangular).
41
- attention_mask = torch.ones((1, textlen+framelen, textlen+framelen), device=data.device)
42
- attention_mask[:, :textlen, textlen:] = 0
43
- attention_mask[:, textlen:, textlen:].tril_()
44
- attention_mask.unsqueeze_(1)
45
- # Unaligned version
46
- position_ids = torch.zeros(seq_length, dtype=torch.long,
47
- device=data.device)
48
- torch.arange(textlen, out=position_ids[:textlen],
49
- dtype=torch.long, device=data.device)
50
- torch.arange(512, 512+seq_length-textlen, out=position_ids[textlen:],
51
- dtype=torch.long, device=data.device)
52
- position_ids = position_ids.unsqueeze(0)
53
-
54
- return tokens, attention_mask, position_ids
55
-
56
- def get_masks_and_position_ids_stage2(data, textlen, framelen):
57
- # Extract batch size and sequence length.
58
- tokens = data
59
- seq_length = len(data[0])
60
-
61
- # Attention mask (lower triangular).
62
- attention_mask = torch.ones((1, textlen+framelen, textlen+framelen), device=data.device)
63
- attention_mask[:, :textlen, textlen:] = 0
64
- attention_mask[:, textlen:, textlen:].tril_()
65
- attention_mask.unsqueeze_(1)
66
-
67
- # Unaligned version
68
- position_ids = torch.zeros(seq_length, dtype=torch.long,
69
- device=data.device)
70
- torch.arange(textlen, out=position_ids[:textlen],
71
- dtype=torch.long, device=data.device)
72
- frame_num = (seq_length-textlen)//framelen
73
- assert frame_num == 5
74
- torch.arange(512, 512+framelen, out=position_ids[textlen:textlen+framelen],
75
- dtype=torch.long, device=data.device)
76
- torch.arange(512+framelen*2, 512+framelen*3, out=position_ids[textlen+framelen:textlen+framelen*2],
77
- dtype=torch.long, device=data.device)
78
- torch.arange(512+framelen*(frame_num-1), 512+framelen*frame_num, out=position_ids[textlen+framelen*2:textlen+framelen*3],
79
- dtype=torch.long, device=data.device)
80
- torch.arange(512+framelen*1, 512+framelen*2, out=position_ids[textlen+framelen*3:textlen+framelen*4],
81
- dtype=torch.long, device=data.device)
82
- torch.arange(512+framelen*3, 512+framelen*4, out=position_ids[textlen+framelen*4:textlen+framelen*5],
83
- dtype=torch.long, device=data.device)
84
-
85
- position_ids = position_ids.unsqueeze(0)
86
-
87
- return tokens, attention_mask, position_ids
88
-
89
- def my_update_mems(hiddens, mems_buffers, mems_indexs, limited_spatial_channel_mem, text_len, frame_len):
90
- if hiddens is None:
91
- return None, mems_indexs
92
- mem_num = len(hiddens)
93
- ret_mem = []
94
- with torch.no_grad():
95
- for id in range(mem_num):
96
- if hiddens[id][0] is None:
97
- ret_mem.append(None)
98
- else:
99
- if id == 0 and limited_spatial_channel_mem and mems_indexs[id]+hiddens[0][0].shape[1] >= text_len+frame_len:
100
- if mems_indexs[id] == 0:
101
- for layer, hidden in enumerate(hiddens[id]):
102
- mems_buffers[id][layer, :, :text_len] = hidden.expand(mems_buffers[id].shape[1], -1, -1)[:, :text_len]
103
- new_mem_len_part2 = (mems_indexs[id]+hiddens[0][0].shape[1]-text_len)%frame_len
104
- if new_mem_len_part2 > 0:
105
- for layer, hidden in enumerate(hiddens[id]):
106
- mems_buffers[id][layer, :, text_len:text_len+new_mem_len_part2] = hidden.expand(mems_buffers[id].shape[1], -1, -1)[:, -new_mem_len_part2:]
107
- mems_indexs[id] = text_len+new_mem_len_part2
108
- else:
109
- for layer, hidden in enumerate(hiddens[id]):
110
- mems_buffers[id][layer, :, mems_indexs[id]:mems_indexs[id]+hidden.shape[1]] = hidden.expand(mems_buffers[id].shape[1], -1, -1)
111
- mems_indexs[id] += hidden.shape[1]
112
- ret_mem.append(mems_buffers[id][:, :, :mems_indexs[id]])
113
- return ret_mem, mems_indexs
114
-
115
-
116
- def my_save_multiple_images(imgs, path, subdir, debug=True):
117
- # imgs: list of tensor images
118
- if debug:
119
- imgs = torch.cat(imgs, dim=0)
120
- print("\nSave to: ", path, flush=True)
121
- save_image(imgs, path, normalize=True)
122
- else:
123
- print("\nSave to: ", path, flush=True)
124
- single_frame_path = os.path.join(path, subdir)
125
- os.makedirs(single_frame_path, exist_ok=True)
126
- for i in range(len(imgs)):
127
- save_image(imgs[i], os.path.join(single_frame_path, f'{str(i).rjust(4,"0")}.jpg'), normalize=True)
128
- os.chmod(os.path.join(single_frame_path,f'{str(i).rjust(4,"0")}.jpg'), stat.S_IRWXO+stat.S_IRWXG+stat.S_IRWXU)
129
- save_image(torch.cat(imgs, dim=0), os.path.join(single_frame_path,f'frame_concat.jpg'), normalize=True)
130
- os.chmod(os.path.join(single_frame_path,f'frame_concat.jpg'), stat.S_IRWXO+stat.S_IRWXG+stat.S_IRWXU)
131
-
132
- def calc_next_tokens_frame_begin_id(text_len, frame_len, total_len):
133
- # The fisrt token's position id of the frame that the next token belongs to;
134
- if total_len < text_len:
135
- return None
136
- return (total_len-text_len)//frame_len * frame_len + text_len
137
-
138
- def my_filling_sequence(
139
- model,
140
- args,
141
- seq,
142
- batch_size,
143
- get_masks_and_position_ids,
144
- text_len,
145
- frame_len,
146
- strategy=BaseStrategy(),
147
- strategy2=BaseStrategy(),
148
- mems=None,
149
- log_text_attention_weights=0, # default to 0: no artificial change
150
- mode_stage1=True,
151
- enforce_no_swin=False,
152
- guider_seq=None,
153
- guider_text_len=0,
154
- guidance_alpha=1,
155
- limited_spatial_channel_mem=False, # 空间通道的存储限制在本帧内
156
- **kw_args
157
- ):
158
- '''
159
- seq: [2, 3, 5, ..., -1(to be generated), -1, ...]
160
- mems: [num_layers, batch_size, len_mems(index), mem_hidden_size]
161
- cache, should be first mems.shape[1] parts of context_tokens.
162
- mems are the first-level citizens here, but we don't assume what is memorized.
163
- input mems are used when multi-phase generation.
164
- '''
165
- if guider_seq is not None:
166
- logging.debug("Using Guidance In Inference")
167
- if limited_spatial_channel_mem:
168
- logging.debug("Limit spatial-channel's mem to current frame")
169
- assert len(seq.shape) == 2
170
-
171
- # building the initial tokens, attention_mask, and position_ids
172
- actual_context_length = 0
173
-
174
- while seq[-1][actual_context_length] >= 0: # the last seq has least given tokens
175
- actual_context_length += 1 # [0, context_length-1] are given
176
- assert actual_context_length > 0
177
- current_frame_num = (actual_context_length-text_len) // frame_len
178
- assert current_frame_num >= 0
179
- context_length = text_len + current_frame_num * frame_len
180
-
181
- tokens, attention_mask, position_ids = get_masks_and_position_ids(seq, text_len, frame_len)
182
- tokens = tokens[..., :context_length]
183
- input_tokens = tokens.clone()
184
-
185
- if guider_seq is not None:
186
- guider_index_delta = text_len - guider_text_len
187
- guider_tokens, guider_attention_mask, guider_position_ids = get_masks_and_position_ids(guider_seq, guider_text_len, frame_len)
188
- guider_tokens = guider_tokens[..., :context_length-guider_index_delta]
189
- guider_input_tokens = guider_tokens.clone()
190
-
191
- for fid in range(current_frame_num):
192
- input_tokens[:, text_len+400*fid] = tokenizer['<start_of_image>']
193
- if guider_seq is not None:
194
- guider_input_tokens[:, guider_text_len+400*fid] = tokenizer['<start_of_image>']
195
-
196
- attention_mask = attention_mask.type_as(next(model.parameters())) # if fp16
197
- # initialize generation
198
- counter = context_length - 1 # Last fixed index is ``counter''
199
- index = 0 # Next forward starting index, also the length of cache.
200
- mems_buffers_on_GPU = False
201
- mems_indexs = [0, 0]
202
- mems_len = [(400+74) if limited_spatial_channel_mem else 5*400+74, 5*400+74]
203
- mems_buffers = [torch.zeros(args.num_layers, batch_size, mem_len, args.hidden_size*2, dtype=next(model.parameters()).dtype)
204
- for mem_len in mems_len]
205
-
206
-
207
- if guider_seq is not None:
208
- guider_attention_mask = guider_attention_mask.type_as(next(model.parameters())) # if fp16
209
- guider_mems_buffers = [torch.zeros(args.num_layers, batch_size, mem_len, args.hidden_size*2, dtype=next(model.parameters()).dtype)
210
- for mem_len in mems_len]
211
- guider_mems_indexs = [0, 0]
212
- guider_mems = None
213
-
214
- torch.cuda.empty_cache()
215
- # step-by-step generation
216
- while counter < len(seq[0]) - 1:
217
- # we have generated counter+1 tokens
218
- # Now, we want to generate seq[counter + 1],
219
- # token[:, index: counter+1] needs forwarding.
220
- if index == 0:
221
- group_size = 2 if (input_tokens.shape[0] == batch_size and not mode_stage1) else batch_size
222
-
223
- logits_all = None
224
- for batch_idx in range(0, input_tokens.shape[0], group_size):
225
- logits, *output_per_layers = model(
226
- input_tokens[batch_idx:batch_idx+group_size, index:],
227
- position_ids[..., index: counter+1],
228
- attention_mask, # TODO memlen
229
- mems=mems,
230
- text_len=text_len,
231
- frame_len=frame_len,
232
- counter=counter,
233
- log_text_attention_weights=log_text_attention_weights,
234
- enforce_no_swin=enforce_no_swin,
235
- **kw_args
236
- )
237
- logits_all = torch.cat((logits_all, logits), dim=0) if logits_all is not None else logits
238
- mem_kv01 = [[o['mem_kv'][0] for o in output_per_layers], [o['mem_kv'][1] for o in output_per_layers]]
239
- next_tokens_frame_begin_id = calc_next_tokens_frame_begin_id(text_len, frame_len, mem_kv01[0][0].shape[1])
240
- for id, mem_kv in enumerate(mem_kv01):
241
- for layer, mem_kv_perlayer in enumerate(mem_kv):
242
- if limited_spatial_channel_mem and id == 0:
243
- mems_buffers[id][layer, batch_idx:batch_idx+group_size, :text_len] = mem_kv_perlayer.expand(min(group_size, input_tokens.shape[0]-batch_idx), -1, -1)[:, :text_len]
244
- mems_buffers[id][layer, batch_idx:batch_idx+group_size, text_len:text_len+mem_kv_perlayer.shape[1]-next_tokens_frame_begin_id] =\
245
- mem_kv_perlayer.expand(min(group_size, input_tokens.shape[0]-batch_idx), -1, -1)[:, next_tokens_frame_begin_id:]
246
- else:
247
- mems_buffers[id][layer, batch_idx:batch_idx+group_size, :mem_kv_perlayer.shape[1]] = mem_kv_perlayer.expand(min(group_size, input_tokens.shape[0]-batch_idx), -1, -1)
248
- mems_indexs[0], mems_indexs[1] = mem_kv01[0][0].shape[1], mem_kv01[1][0].shape[1]
249
- if limited_spatial_channel_mem:
250
- mems_indexs[0] -= (next_tokens_frame_begin_id - text_len)
251
-
252
- mems = [mems_buffers[id][:, :, :mems_indexs[id]] for id in range(2)]
253
- logits = logits_all
254
-
255
- # Guider
256
- if guider_seq is not None:
257
- guider_logits_all = None
258
- for batch_idx in range(0, guider_input_tokens.shape[0], group_size):
259
- guider_logits, *guider_output_per_layers = model(
260
- guider_input_tokens[batch_idx:batch_idx+group_size, max(index-guider_index_delta, 0):],
261
- guider_position_ids[..., max(index-guider_index_delta, 0): counter+1-guider_index_delta],
262
- guider_attention_mask,
263
- mems=guider_mems,
264
- text_len=guider_text_len,
265
- frame_len=frame_len,
266
- counter=counter-guider_index_delta,
267
- log_text_attention_weights=log_text_attention_weights,
268
- enforce_no_swin=enforce_no_swin,
269
- **kw_args
270
- )
271
- guider_logits_all = torch.cat((guider_logits_all, guider_logits), dim=0) if guider_logits_all is not None else guider_logits
272
- guider_mem_kv01 = [[o['mem_kv'][0] for o in guider_output_per_layers], [o['mem_kv'][1] for o in guider_output_per_layers]]
273
- for id, guider_mem_kv in enumerate(guider_mem_kv01):
274
- for layer, guider_mem_kv_perlayer in enumerate(guider_mem_kv):
275
- if limited_spatial_channel_mem and id == 0:
276
- guider_mems_buffers[id][layer, batch_idx:batch_idx+group_size, :guider_text_len] = guider_mem_kv_perlayer.expand(min(group_size, input_tokens.shape[0]-batch_idx), -1, -1)[:, :guider_text_len]
277
- guider_next_tokens_frame_begin_id = calc_next_tokens_frame_begin_id(guider_text_len, frame_len, guider_mem_kv_perlayer.shape[1])
278
- guider_mems_buffers[id][layer, batch_idx:batch_idx+group_size, guider_text_len:guider_text_len+guider_mem_kv_perlayer.shape[1]-guider_next_tokens_frame_begin_id] =\
279
- guider_mem_kv_perlayer.expand(min(group_size, input_tokens.shape[0]-batch_idx), -1, -1)[:, guider_next_tokens_frame_begin_id:]
280
- else:
281
- guider_mems_buffers[id][layer, batch_idx:batch_idx+group_size, :guider_mem_kv_perlayer.shape[1]] = guider_mem_kv_perlayer.expand(min(group_size, input_tokens.shape[0]-batch_idx), -1, -1)
282
- guider_mems_indexs[0], guider_mems_indexs[1] = guider_mem_kv01[0][0].shape[1], guider_mem_kv01[1][0].shape[1]
283
- if limited_spatial_channel_mem:
284
- guider_mems_indexs[0] -= (guider_next_tokens_frame_begin_id-guider_text_len)
285
- guider_mems = [guider_mems_buffers[id][:, :, :guider_mems_indexs[id]] for id in range(2)]
286
- guider_logits = guider_logits_all
287
- else:
288
- if not mems_buffers_on_GPU:
289
- if not mode_stage1:
290
- torch.cuda.empty_cache()
291
- for idx, mem in enumerate(mems):
292
- mems[idx] = mem.to(next(model.parameters()).device)
293
- if guider_seq is not None:
294
- for idx, mem in enumerate(guider_mems):
295
- guider_mems[idx] = mem.to(next(model.parameters()).device)
296
- else:
297
- torch.cuda.empty_cache()
298
- for idx, mem_buffer in enumerate(mems_buffers):
299
- mems_buffers[idx] = mem_buffer.to(next(model.parameters()).device)
300
- mems = [mems_buffers[id][:, :, :mems_indexs[id]] for id in range(2)]
301
- if guider_seq is not None:
302
- for idx, guider_mem_buffer in enumerate(guider_mems_buffers):
303
- guider_mems_buffers[idx] = guider_mem_buffer.to(next(model.parameters()).device)
304
- guider_mems = [guider_mems_buffers[id][:, :, :guider_mems_indexs[id]] for id in range(2)]
305
- mems_buffers_on_GPU = True
306
-
307
- logits, *output_per_layers = model(
308
- input_tokens[:, index:],
309
- position_ids[..., index: counter+1],
310
- attention_mask, # TODO memlen
311
- mems=mems,
312
- text_len=text_len,
313
- frame_len=frame_len,
314
- counter=counter,
315
- log_text_attention_weights=log_text_attention_weights,
316
- enforce_no_swin=enforce_no_swin,
317
- limited_spatial_channel_mem=limited_spatial_channel_mem,
318
- **kw_args
319
- )
320
- mem_kv0, mem_kv1 = [o['mem_kv'][0] for o in output_per_layers], [o['mem_kv'][1] for o in output_per_layers]
321
-
322
- if guider_seq is not None:
323
- guider_logits, *guider_output_per_layers = model(
324
- guider_input_tokens[:, max(index-guider_index_delta, 0):],
325
- guider_position_ids[..., max(index-guider_index_delta, 0): counter+1-guider_index_delta],
326
- guider_attention_mask,
327
- mems=guider_mems,
328
- text_len=guider_text_len,
329
- frame_len=frame_len,
330
- counter=counter-guider_index_delta,
331
- log_text_attention_weights=0,
332
- enforce_no_swin=enforce_no_swin,
333
- limited_spatial_channel_mem=limited_spatial_channel_mem,
334
- **kw_args
335
- )
336
- guider_mem_kv0, guider_mem_kv1 = [o['mem_kv'][0] for o in guider_output_per_layers], [o['mem_kv'][1] for o in guider_output_per_layers]
337
-
338
- if not mems_buffers_on_GPU:
339
- torch.cuda.empty_cache()
340
- for idx, mem_buffer in enumerate(mems_buffers):
341
- mems_buffers[idx] = mem_buffer.to(next(model.parameters()).device)
342
- if guider_seq is not None:
343
- for idx, guider_mem_buffer in enumerate(guider_mems_buffers):
344
- guider_mems_buffers[idx] = guider_mem_buffer.to(next(model.parameters()).device)
345
- mems_buffers_on_GPU = True
346
-
347
- mems, mems_indexs = my_update_mems([mem_kv0, mem_kv1], mems_buffers, mems_indexs, limited_spatial_channel_mem, text_len, frame_len)
348
- if guider_seq is not None:
349
- guider_mems, guider_mems_indexs = my_update_mems([guider_mem_kv0, guider_mem_kv1], guider_mems_buffers, guider_mems_indexs, limited_spatial_channel_mem, guider_text_len, frame_len)
350
-
351
-
352
- counter += 1
353
- index = counter
354
-
355
- logits = logits[:, -1].expand(batch_size, -1) # [batch size, vocab size]
356
- tokens = tokens.expand(batch_size, -1)
357
- if guider_seq is not None:
358
- guider_logits = guider_logits[:, -1].expand(batch_size, -1)
359
- guider_tokens = guider_tokens.expand(batch_size, -1)
360
-
361
- if seq[-1][counter].item() < 0:
362
- # sampling
363
- guided_logits = guider_logits+(logits-guider_logits)*guidance_alpha if guider_seq is not None else logits
364
- if mode_stage1 and counter < text_len + 400:
365
- tokens, mems = strategy.forward(guided_logits, tokens, mems)
366
- else:
367
- tokens, mems = strategy2.forward(guided_logits, tokens, mems)
368
- if guider_seq is not None:
369
- guider_tokens = torch.cat((guider_tokens, tokens[:, -1:]), dim=1)
370
-
371
- if seq[0][counter].item() >= 0:
372
- for si in range(seq.shape[0]):
373
- if seq[si][counter].item() >= 0:
374
- tokens[si, -1] = seq[si, counter]
375
- if guider_seq is not None:
376
- guider_tokens[si, -1] = guider_seq[si, counter-guider_index_delta]
377
-
378
- else:
379
- tokens = torch.cat((tokens, seq[:, counter:counter+1].clone().expand(tokens.shape[0], 1).to(device=tokens.device, dtype=tokens.dtype)), dim=1)
380
- if guider_seq is not None:
381
- guider_tokens = torch.cat((guider_tokens,
382
- guider_seq[:, counter-guider_index_delta:counter+1-guider_index_delta]
383
- .clone().expand(guider_tokens.shape[0], 1).to(device=guider_tokens.device, dtype=guider_tokens.dtype)), dim=1)
384
-
385
- input_tokens = tokens.clone()
386
- if guider_seq is not None:
387
- guider_input_tokens = guider_tokens.clone()
388
- if (index-text_len-1)//400 < (input_tokens.shape[-1]-text_len-1)//400:
389
- boi_idx = ((index-text_len-1)//400 +1)*400+text_len
390
- while boi_idx < input_tokens.shape[-1]:
391
- input_tokens[:, boi_idx] = tokenizer['<start_of_image>']
392
- if guider_seq is not None:
393
- guider_input_tokens[:, boi_idx-guider_index_delta] = tokenizer['<start_of_image>']
394
- boi_idx += 400
395
-
396
- if strategy.is_done:
397
- break
398
- return strategy.finalize(tokens, mems)
399
-
400
- class InferenceModel_Sequential(CogVideoCacheModel):
401
- def __init__(self, args, transformer=None, parallel_output=True):
402
- super().__init__(args, transformer=transformer, parallel_output=parallel_output, window_size=-1, cogvideo_stage=1)
403
- # TODO: check it
404
-
405
- def final_forward(self, logits, **kwargs):
406
- logits_parallel = logits
407
- logits_parallel = torch.nn.functional.linear(logits_parallel.float(), self.transformer.word_embeddings.weight[:20000].float())
408
- return logits_parallel
409
-
410
- class InferenceModel_Interpolate(CogVideoCacheModel):
411
- def __init__(self, args, transformer=None, parallel_output=True):
412
- super().__init__(args, transformer=transformer, parallel_output=parallel_output, window_size=10, cogvideo_stage=2)
413
- # TODO: check it
414
-
415
- def final_forward(self, logits, **kwargs):
416
- logits_parallel = logits
417
- logits_parallel = torch.nn.functional.linear(logits_parallel.float(), self.transformer.word_embeddings.weight[:20000].float())
418
- return logits_parallel
419
-
420
- def main(args):
421
- assert int(args.stage_1) + int(args.stage_2) + int(args.both_stages) == 1
422
- rank_id = args.device % args.parallel_size
423
- generate_frame_num = args.generate_frame_num
424
-
425
- if args.stage_1 or args.both_stages:
426
- model_stage1, args = InferenceModel_Sequential.from_pretrained(args, 'cogvideo-stage1')
427
- model_stage1.eval()
428
- if args.both_stages:
429
- model_stage1 = model_stage1.cpu()
430
-
431
- if args.stage_2 or args.both_stages:
432
- model_stage2, args = InferenceModel_Interpolate.from_pretrained(args, 'cogvideo-stage2')
433
- model_stage2.eval()
434
- if args.both_stages:
435
- model_stage2 = model_stage2.cpu()
436
-
437
- invalid_slices = [slice(tokenizer.num_image_tokens, None)]
438
- strategy_cogview2 = CoglmStrategy(invalid_slices,
439
- temperature=1.0, top_k=16)
440
- strategy_cogvideo = CoglmStrategy(invalid_slices,
441
- temperature=args.temperature, top_k=args.top_k,
442
- temperature2=args.coglm_temperature2)
443
- if not args.stage_1:
444
- from sr_pipeline import DirectSuperResolution
445
- dsr_path = auto_create('cogview2-dsr', path=None) # path=os.getenv('SAT_HOME', '~/.sat_models')
446
- dsr = DirectSuperResolution(args, dsr_path,
447
- max_bz=12, onCUDA=False)
448
-
449
- def process_stage2(model, seq_text, duration, video_raw_text=None, video_guidance_text="视频", parent_given_tokens=None, conddir=None, outputdir=None, gpu_rank=0, gpu_parallel_size=1):
450
- stage2_starttime = time.time()
451
- use_guidance = args.use_guidance_stage2
452
- if args.both_stages:
453
- move_start_time = time.time()
454
- logging.debug("moving stage-2 model to cuda")
455
- model = model.cuda()
456
- logging.debug("moving in stage-2 model takes time: {:.2f}".format(time.time()-move_start_time))
457
-
458
- try:
459
- if parent_given_tokens is None:
460
- assert conddir is not None
461
- parent_given_tokens = torch.load(os.path.join(conddir, 'frame_tokens.pt'), map_location='cpu')
462
- sample_num_allgpu = parent_given_tokens.shape[0]
463
- sample_num = sample_num_allgpu // gpu_parallel_size
464
- assert sample_num * gpu_parallel_size == sample_num_allgpu
465
- parent_given_tokens = parent_given_tokens[gpu_rank*sample_num:(gpu_rank+1)*sample_num]
466
- except:
467
- logging.critical("No frame_tokens found in interpolation, skip")
468
- return False
469
-
470
- # CogVideo Stage2 Generation
471
- while duration >= 0.5: # TODO: You can change the boundary to change the frame rate
472
- parent_given_tokens_num = parent_given_tokens.shape[1]
473
- generate_batchsize_persample = (parent_given_tokens_num-1)//2
474
- generate_batchsize_total = generate_batchsize_persample * sample_num
475
- total_frames = generate_frame_num
476
- frame_len = 400
477
- enc_text = tokenizer.encode(seq_text)
478
- enc_duration = tokenizer.encode(str(float(duration))+"秒")
479
- seq = enc_duration + [tokenizer['<n>']] + enc_text + [tokenizer['<start_of_image>']] + [-1]*400*generate_frame_num
480
- text_len = len(seq) - frame_len*generate_frame_num - 1
481
-
482
- logging.info("[Stage2: Generating Frames, Frame Rate {:d}]\nraw text: {:s}".format(int(4/duration), tokenizer.decode(enc_text)))
483
-
484
- # generation
485
- seq = torch.cuda.LongTensor(seq, device=args.device).unsqueeze(0).repeat(generate_batchsize_total, 1)
486
- for sample_i in range(sample_num):
487
- for i in range(generate_batchsize_persample):
488
- seq[sample_i*generate_batchsize_persample+i][text_len+1:text_len+1+400] = parent_given_tokens[sample_i][2*i]
489
- seq[sample_i*generate_batchsize_persample+i][text_len+1+400:text_len+1+800] = parent_given_tokens[sample_i][2*i+1]
490
- seq[sample_i*generate_batchsize_persample+i][text_len+1+800:text_len+1+1200] = parent_given_tokens[sample_i][2*i+2]
491
-
492
- if use_guidance:
493
- guider_seq = enc_duration + [tokenizer['<n>']] + tokenizer.encode(video_guidance_text) + [tokenizer['<start_of_image>']] + [-1]*400*generate_frame_num
494
- guider_text_len = len(guider_seq) - frame_len*generate_frame_num - 1
495
- guider_seq = torch.cuda.LongTensor(guider_seq, device=args.device).unsqueeze(0).repeat(generate_batchsize_total, 1)
496
- for sample_i in range(sample_num):
497
- for i in range(generate_batchsize_persample):
498
- guider_seq[sample_i*generate_batchsize_persample+i][text_len+1:text_len+1+400] = parent_given_tokens[sample_i][2*i]
499
- guider_seq[sample_i*generate_batchsize_persample+i][text_len+1+400:text_len+1+800] = parent_given_tokens[sample_i][2*i+1]
500
- guider_seq[sample_i*generate_batchsize_persample+i][text_len+1+800:text_len+1+1200] = parent_given_tokens[sample_i][2*i+2]
501
- video_log_text_attention_weights = 0
502
- else:
503
- guider_seq=None
504
- guider_text_len=0
505
- video_log_text_attention_weights = 1.4
506
-
507
- mbz = args.max_inference_batch_size
508
-
509
- assert generate_batchsize_total < mbz or generate_batchsize_total % mbz == 0
510
- output_list = []
511
- start_time = time.time()
512
- for tim in range(max(generate_batchsize_total // mbz, 1)):
513
- input_seq = seq[:min(generate_batchsize_total, mbz)].clone() if tim == 0 else seq[mbz*tim:mbz*(tim+1)].clone()
514
- guider_seq2 = (guider_seq[:min(generate_batchsize_total, mbz)].clone() if tim == 0 else guider_seq[mbz*tim:mbz*(tim+1)].clone()) if guider_seq is not None else None
515
- output_list.append(
516
- my_filling_sequence(model, args, input_seq,
517
- batch_size=min(generate_batchsize_total, mbz),
518
- get_masks_and_position_ids=get_masks_and_position_ids_stage2,
519
- text_len=text_len, frame_len=frame_len,
520
- strategy=strategy_cogview2,
521
- strategy2=strategy_cogvideo,
522
- log_text_attention_weights=video_log_text_attention_weights,
523
- mode_stage1=False,
524
- guider_seq=guider_seq2,
525
- guider_text_len=guider_text_len,
526
- guidance_alpha=args.guidance_alpha,
527
- limited_spatial_channel_mem=True,
528
- )[0]
529
- )
530
- logging.info("Duration {:.2f}, Taken time {:.2f}\n".format(duration, time.time() - start_time))
531
-
532
- output_tokens = torch.cat(output_list, dim=0)
533
- output_tokens = output_tokens[:, text_len+1:text_len+1+(total_frames)*400].reshape(sample_num, -1, 400*total_frames)
534
- output_tokens_merge = torch.cat((output_tokens[:, :, :1*400],
535
- output_tokens[:, :, 400*3:4*400],
536
- output_tokens[:, :, 400*1:2*400],
537
- output_tokens[:, :, 400*4:(total_frames)*400]), dim=2).reshape(sample_num, -1, 400)
538
-
539
- output_tokens_merge = torch.cat((output_tokens_merge, output_tokens[:, -1:, 400*2:3*400]), dim=1)
540
- duration /= 2
541
- parent_given_tokens = output_tokens_merge
542
-
543
- if args.both_stages:
544
- move_start_time = time.time()
545
- logging.debug("moving stage 2 model to cpu")
546
- model = model.cpu()
547
- torch.cuda.empty_cache()
548
- logging.debug("moving out model2 takes time: {:.2f}".format(time.time()-move_start_time))
549
-
550
- logging.info("CogVideo Stage2 completed. Taken time {:.2f}\n".format(time.time() - stage2_starttime))
551
-
552
- # decoding
553
- # imgs = [torch.nn.functional.interpolate(tokenizer.decode(image_ids=seq.tolist()), size=(480, 480)) for seq in output_tokens_merge]
554
- # os.makedirs(output_dir_full_path, exist_ok=True)
555
- # my_save_multiple_images(imgs, output_dir_full_path,subdir="frames", debug=False)
556
- # torch.save(output_tokens_merge.cpu(), os.path.join(output_dir_full_path, 'frame_token.pt'))
557
- # os.system(f"gifmaker -i '{output_dir_full_path}'/frames/0*.jpg -o '{output_dir_full_path}/{str(float(duration))}_concat.gif' -d 0.2")
558
-
559
- # direct super-resolution by CogView2
560
- logging.info("[Direct super-resolution]")
561
- dsr_starttime = time.time()
562
- enc_text = tokenizer.encode(seq_text)
563
- frame_num_per_sample = parent_given_tokens.shape[1]
564
- parent_given_tokens_2d = parent_given_tokens.reshape(-1, 400)
565
- text_seq = torch.cuda.LongTensor(enc_text, device=args.device).unsqueeze(0).repeat(parent_given_tokens_2d.shape[0], 1)
566
- sred_tokens = dsr(text_seq, parent_given_tokens_2d)
567
- decoded_sr_videos = []
568
-
569
- for sample_i in range(sample_num):
570
- decoded_sr_imgs = []
571
- for frame_i in range(frame_num_per_sample):
572
- decoded_sr_img = tokenizer.decode(image_ids=sred_tokens[frame_i+sample_i*frame_num_per_sample][-3600:])
573
- decoded_sr_imgs.append(torch.nn.functional.interpolate(decoded_sr_img, size=(480, 480)))
574
- decoded_sr_videos.append(decoded_sr_imgs)
575
-
576
- for sample_i in range(sample_num):
577
- my_save_multiple_images(decoded_sr_videos[sample_i], outputdir,subdir=f"frames/{sample_i+sample_num*gpu_rank}", debug=False)
578
- os.system(f"gifmaker -i '{outputdir}'/frames/'{sample_i+sample_num*gpu_rank}'/0*.jpg -o '{outputdir}/{sample_i+sample_num*gpu_rank}.gif' -d 0.125")
579
-
580
- logging.info("Direct super-resolution completed. Taken time {:.2f}\n".format(time.time() - dsr_starttime))
581
-
582
- return True
583
-
584
-
585
- def process_stage1(model, seq_text, duration, video_raw_text=None, video_guidance_text="视频", image_text_suffix="", outputdir=None, batch_size=1):
586
- process_start_time = time.time()
587
- use_guide = args.use_guidance_stage1
588
- if args.both_stages:
589
- move_start_time = time.time()
590
- logging.debug("moving stage 1 model to cuda")
591
- model = model.cuda()
592
- logging.debug("moving in model1 takes time: {:.2f}".format(time.time()-move_start_time))
593
-
594
- if video_raw_text is None:
595
- video_raw_text = seq_text
596
- mbz = args.stage1_max_inference_batch_size if args.stage1_max_inference_batch_size > 0 else args.max_inference_batch_size
597
- assert batch_size < mbz or batch_size % mbz == 0
598
- frame_len = 400
599
-
600
- # generate the first frame:
601
- enc_text = tokenizer.encode(seq_text+image_text_suffix)
602
- seq_1st = enc_text + [tokenizer['<start_of_image>']] + [-1]*400 # IV!! # test local!!! # test randboi!!!
603
- logging.info("[Generating First Frame with CogView2]Raw text: {:s}".format(tokenizer.decode(enc_text)))
604
- text_len_1st = len(seq_1st) - frame_len*1 - 1
605
-
606
- seq_1st = torch.cuda.LongTensor(seq_1st, device=args.device).unsqueeze(0)
607
- output_list_1st = []
608
- for tim in range(max(batch_size // mbz, 1)):
609
- start_time = time.time()
610
- output_list_1st.append(
611
- my_filling_sequence(model, args,seq_1st.clone(),
612
- batch_size=min(batch_size, mbz),
613
- get_masks_and_position_ids=get_masks_and_position_ids_stage1,
614
- text_len=text_len_1st,
615
- frame_len=frame_len,
616
- strategy=strategy_cogview2,
617
- strategy2=strategy_cogvideo,
618
- log_text_attention_weights=1.4,
619
- enforce_no_swin=True,
620
- mode_stage1=True,
621
- )[0]
622
- )
623
- logging.info("[First Frame]Taken time {:.2f}\n".format(time.time() - start_time))
624
- output_tokens_1st = torch.cat(output_list_1st, dim=0)
625
- given_tokens = output_tokens_1st[:, text_len_1st+1:text_len_1st+401].unsqueeze(1) # given_tokens.shape: [bs, frame_num, 400]
626
-
627
- # generate subsequent frames:
628
- total_frames = generate_frame_num
629
- enc_duration = tokenizer.encode(str(float(duration))+"秒")
630
- if use_guide:
631
- video_raw_text = video_raw_text + " 视频"
632
- enc_text_video = tokenizer.encode(video_raw_text)
633
- seq = enc_duration + [tokenizer['<n>']] + enc_text_video + [tokenizer['<start_of_image>']] + [-1]*400*generate_frame_num
634
- guider_seq = enc_duration + [tokenizer['<n>']] + tokenizer.encode(video_guidance_text) + [tokenizer['<start_of_image>']] + [-1]*400*generate_frame_num
635
- logging.info("[Stage1: Generating Subsequent Frames, Frame Rate {:.1f}]\nraw text: {:s}".format(4/duration, tokenizer.decode(enc_text_video)))
636
-
637
- text_len = len(seq) - frame_len*generate_frame_num - 1
638
- guider_text_len = len(guider_seq) - frame_len*generate_frame_num - 1
639
- seq = torch.cuda.LongTensor(seq, device=args.device).unsqueeze(0).repeat(batch_size, 1)
640
- guider_seq = torch.cuda.LongTensor(guider_seq, device=args.device).unsqueeze(0).repeat(batch_size, 1)
641
-
642
- for given_frame_id in range(given_tokens.shape[1]):
643
- seq[:, text_len+1+given_frame_id*400: text_len+1+(given_frame_id+1)*400] = given_tokens[:, given_frame_id]
644
- guider_seq[:, guider_text_len+1+given_frame_id*400:guider_text_len+1+(given_frame_id+1)*400] = given_tokens[:, given_frame_id]
645
- output_list = []
646
-
647
- if use_guide:
648
- video_log_text_attention_weights = 0
649
- else:
650
- guider_seq = None
651
- video_log_text_attention_weights = 1.4
652
-
653
- for tim in range(max(batch_size // mbz, 1)):
654
- start_time = time.time()
655
- input_seq = seq[:min(batch_size, mbz)].clone() if tim == 0 else seq[mbz*tim:mbz*(tim+1)].clone()
656
- guider_seq2 = (guider_seq[:min(batch_size, mbz)].clone() if tim == 0 else guider_seq[mbz*tim:mbz*(tim+1)].clone()) if guider_seq is not None else None
657
- output_list.append(
658
- my_filling_sequence(model, args,input_seq,
659
- batch_size=min(batch_size, mbz),
660
- get_masks_and_position_ids=get_masks_and_position_ids_stage1,
661
- text_len=text_len, frame_len=frame_len,
662
- strategy=strategy_cogview2,
663
- strategy2=strategy_cogvideo,
664
- log_text_attention_weights=video_log_text_attention_weights,
665
- guider_seq=guider_seq2,
666
- guider_text_len=guider_text_len,
667
- guidance_alpha=args.guidance_alpha,
668
- limited_spatial_channel_mem=True,
669
- mode_stage1=True,
670
- )[0]
671
- )
672
-
673
- output_tokens = torch.cat(output_list, dim=0)[:, 1+text_len:]
674
-
675
- if args.both_stages:
676
- move_start_time = time.time()
677
- logging.debug("moving stage 1 model to cpu")
678
- model = model.cpu()
679
- torch.cuda.empty_cache()
680
- logging.debug("moving in model1 takes time: {:.2f}".format(time.time()-move_start_time))
681
-
682
- # decoding
683
- imgs, sred_imgs, txts = [], [], []
684
- for seq in output_tokens:
685
- decoded_imgs = [torch.nn.functional.interpolate(tokenizer.decode(image_ids=seq.tolist()[i*400: (i+1)*400]), size=(480, 480)) for i in range(total_frames)]
686
- imgs.append(decoded_imgs) # only the last image (target)
687
-
688
- assert len(imgs) == batch_size
689
- save_tokens = output_tokens[:, :+total_frames*400].reshape(-1, total_frames, 400).cpu()
690
- if outputdir is not None:
691
- for clip_i in range(len(imgs)):
692
- # os.makedirs(output_dir_full_paths[clip_i], exist_ok=True)
693
- my_save_multiple_images(imgs[clip_i], outputdir, subdir=f"frames/{clip_i}", debug=False)
694
- os.system(f"gifmaker -i '{outputdir}'/frames/'{clip_i}'/0*.jpg -o '{outputdir}/{clip_i}.gif' -d 0.25")
695
- torch.save(save_tokens, os.path.join(outputdir, 'frame_tokens.pt'))
696
-
697
- logging.info("CogVideo Stage1 completed. Taken time {:.2f}\n".format(time.time() - process_start_time))
698
-
699
- return save_tokens
700
-
701
- # ======================================================================================================
702
-
703
- if args.stage_1 or args.both_stages:
704
- if args.input_source != "interactive":
705
- with open(args.input_source, 'r') as fin:
706
- promptlist = fin.readlines()
707
- promptlist = [p.strip() for p in promptlist]
708
- else:
709
- promptlist = None
710
-
711
- now_qi = -1
712
- while True:
713
- now_qi += 1
714
-
715
- if promptlist is not None: # with input-source
716
- if args.multi_gpu:
717
- if now_qi % dist.get_world_size() != dist.get_rank():
718
- continue
719
- rk = dist.get_rank()
720
- else:
721
- rk = 0
722
- raw_text = promptlist[now_qi]
723
- raw_text = raw_text.strip()
724
- print(f'Working on Line No. {now_qi} on {rk}... [{raw_text}]')
725
- else: # interactive
726
- raw_text = input("\nPlease Input Query (stop to exit) >>> ")
727
- raw_text = raw_text.strip()
728
- if not raw_text:
729
- print('Query should not be empty!')
730
- continue
731
- if raw_text == "stop":
732
- return
733
-
734
- try:
735
- path = os.path.join(args.output_path, f"{now_qi}_{raw_text}")
736
- parent_given_tokens = process_stage1(model_stage1, raw_text, duration=4.0, video_raw_text=raw_text, video_guidance_text="视频",
737
- image_text_suffix=" 高清摄影",
738
- outputdir=path if args.stage_1 else None, batch_size=args.batch_size)
739
- if args.both_stages:
740
- process_stage2(model_stage2, raw_text, duration=2.0, video_raw_text=raw_text+" 视频",
741
- video_guidance_text="视频", parent_given_tokens=parent_given_tokens,
742
- outputdir=path,
743
- gpu_rank=0, gpu_parallel_size=1) # TODO: 修改
744
- except (ValueError, FileNotFoundError) as e:
745
- print(e)
746
- continue
747
-
748
- elif args.stage_2:
749
- sample_dirs = os.listdir(args.output_path)
750
- for sample in sample_dirs:
751
- raw_text = sample.split('_')[-1]
752
- path = os.path.join(args.output_path, sample, 'Interp')
753
- parent_given_tokens = torch.load(os.path.join(args.output_path, sample, "frame_tokens.pt"))
754
-
755
- process_stage2(raw_text, duration=2.0, video_raw_text=raw_text+" 视频",
756
- video_guidance_text="视频", parent_given_tokens=parent_given_tokens,
757
- outputdir=path,
758
- gpu_rank=0, gpu_parallel_size=1) # TODO: 修改
759
-
760
- else:
761
- assert False
762
-
763
-
764
- if __name__ == "__main__":
765
- logging.basicConfig(stream=sys.stderr, level=logging.DEBUG)
766
-
767
- py_parser = argparse.ArgumentParser(add_help=False)
768
- py_parser.add_argument('--generate-frame-num', type=int, default=5)
769
- py_parser.add_argument('--coglm-temperature2', type=float, default=0.89)
770
- # py_parser.add_argument("--interp-duration", type=float, default=-1) # -1是顺序生成,0是超分,0.5/1/2是插帧
771
- # py_parser.add_argument("--total-duration", type=float, default=4.0) # 整个的时间
772
- py_parser.add_argument('--use-guidance-stage1', action='store_true')
773
- py_parser.add_argument('--use-guidance-stage2', action='store_false')
774
- py_parser.add_argument('--guidance-alpha', type=float, default=3.0)
775
- py_parser.add_argument('--stage-1', action='store_true') # stage 1: sequential generation
776
- py_parser.add_argument('--stage-2', action='store_false') # stage 2: interp + dsr
777
- py_parser.add_argument('--both-stages', action='store_false') # stage 1&2: sequential generation; interp + dsr
778
- py_parser.add_argument('--parallel-size', type=int, default=1)
779
- py_parser.add_argument('--stage1-max-inference-batch-size', type=int, default=1) # -1: use max-inference-batch-size
780
- py_parser.add_argument('--multi-gpu', action='store_false')
781
-
782
- CogVideoCacheModel.add_model_specific_args(py_parser)
783
-
784
- known, args_list = py_parser.parse_known_args()
785
- args = get_args(args_list)
786
- args = argparse.Namespace(**vars(args), **vars(known))
787
- args.layout = [int(x) for x in args.layout.split(',')]
788
- args.do_train = False
789
-
790
- torch.cuda.set_device(args.device)
791
-
792
- with torch.no_grad():
793
- main(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/cogvideo_cache_model.py DELETED
@@ -1,695 +0,0 @@
1
- # -*- encoding: utf-8 -*-
2
- '''
3
- @File : cogvideo_cache_model.py
4
- @Time : 2022/07/15 11:22:19
5
- @Author : Wenyi Hong
6
- @Version : 1.0
7
- @Contact : [email protected]
8
- '''
9
-
10
- # here put the import lib
11
-
12
- from multiprocessing import context
13
- from tkinter import E
14
- import torch
15
- from SwissArmyTransformer.model.base_model import BaseModel, BaseMixin
16
-
17
- from SwissArmyTransformer.mpu.utils import split_tensor_along_last_dim
18
- from SwissArmyTransformer.model.transformer import unscaled_init_method
19
- from SwissArmyTransformer.mpu import ColumnParallelLinear, RowParallelLinear
20
- import torch.nn.functional as F
21
- from deepspeed.runtime.activation_checkpointing.checkpointing import get_cuda_rng_tracker
22
- import math
23
-
24
-
25
- class PositionEmbeddingMixin(BaseMixin):
26
- def __init__(self, additional_sequence_length, hidden_size,
27
- init_method_std=0.02, reinit_slice=slice(512, 912),
28
- ):
29
- super(PositionEmbeddingMixin, self).__init__()
30
- self.reinit_slice = reinit_slice
31
- self.position_embeddings = torch.nn.Embedding(additional_sequence_length, hidden_size)
32
- torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std)
33
-
34
- def reinit(self, parent_model=None):
35
- old_weights = self.transformer.position_embeddings.weight.data[self.reinit_slice]
36
- old_len, hidden_size = old_weights.shape
37
- assert hidden_size == self.position_embeddings.weight.shape[-1]
38
- self.position_embeddings.weight.data.view(-1, old_len, hidden_size).copy_(old_weights)
39
-
40
-
41
- def window_partition(x, window_size):
42
- """
43
- Args:
44
- x: (B, framenum, H, W, C)
45
- window_size (int): window size
46
- Returns:
47
- windows: (num_windows*B, frame_num, window_size, window_size, C)
48
- """
49
- B, framenum, H, W, C = x.shape
50
- x = x.view(B, framenum, H // window_size, window_size, W // window_size, window_size, C)
51
- windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous().view(-1, framenum, window_size, window_size, C)
52
- return windows
53
-
54
- def window_reverse(windows, window_size, H, W):
55
- """
56
- Args:
57
- windows: (num_windows*B, frame_num, window_size, window_size, C)
58
- window_size (int): Window size
59
- H (int): Height of image
60
- W (int): Width of image
61
- Returns:
62
- x: (B, frame_num, H, W, C)
63
- """
64
- B = int(windows.shape[0] / (H * W / window_size / window_size))
65
- framenum = windows.shape[1]
66
- x = windows.view(B, H // window_size, W // window_size, framenum, window_size, window_size, -1)
67
- x = x.permute(0, 3, 1, 4, 2, 5, 6).contiguous().view(B, framenum, H, W, -1)
68
- return x
69
-
70
- class WindowAttentionMixin(BaseMixin):
71
- def __init__(self, num_layers,
72
- hidden_size,
73
- frame_resolution,
74
- window_size,
75
- shift_size,
76
- n_head,
77
- frame_num,
78
- init_method=unscaled_init_method(0.02),
79
- output_layer_init_method=unscaled_init_method(0.02),
80
- time_dim_attend_length=0
81
- ):
82
- super(WindowAttentionMixin, self).__init__()
83
- self.num_layers = num_layers # replace attention in the LAST n layers
84
- self.query_key_value = torch.nn.ModuleList(
85
- [ColumnParallelLinear(hidden_size, 3*hidden_size,stride=3,
86
- gather_output=False,init_method=init_method)
87
- for layer_id in range(num_layers)
88
- ])
89
- self.dense = torch.nn.ModuleList(
90
- [RowParallelLinear(
91
- hidden_size,
92
- hidden_size,
93
- input_is_parallel=True,
94
- init_method=output_layer_init_method,
95
- bias=True,
96
- module=self,
97
- name="dense")
98
- for layer_id in range(num_layers)
99
- ])
100
-
101
- self.n_head = n_head
102
- self.window_size = window_size
103
- self.frame_resolution = frame_resolution
104
- self.frame_len = frame_resolution * frame_resolution
105
- self.time_dim_attend_length = time_dim_attend_length
106
- assert frame_resolution % window_size == 0
107
- assert 0 < shift_size < window_size
108
- nW = (self.frame_resolution // self.window_size) ** 2
109
- ws_squre = self.window_size * self.window_size
110
-
111
- # odd non-shift, even shift
112
- img_mask = torch.zeros((1, 1, frame_resolution, frame_resolution, 1))
113
- h_slices = (slice(0, -shift_size),
114
- slice(-shift_size, None))
115
- w_slices = (slice(0, -shift_size),
116
- slice(-shift_size, None))
117
- cnt = 0
118
- for h in h_slices:
119
- for w in w_slices:
120
- img_mask[:, :, h, w, :] = cnt
121
- cnt += 1
122
- mask_windows = window_partition(img_mask, self.window_size) # nW, 1, window_size, window_size, 1
123
- mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
124
- sub_attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) #[nW, self.window_size * self.window_size, self.window_size * self.window_size]
125
- sub_attn_mask = sub_attn_mask.masked_fill(sub_attn_mask != 0, float(0.0)).masked_fill(sub_attn_mask == 0, float(1.00))
126
- attn_mask = sub_attn_mask.repeat(1, frame_num, frame_num)
127
- attn_mask = attn_mask.tril()
128
-
129
- causal_mask = torch.ones(ws_squre*frame_num, ws_squre*frame_num)
130
- causal_mask = causal_mask.tril()
131
-
132
- self.shift_sizes = [0, shift_size]
133
- self.attn_mask = attn_mask
134
- self.causal_mask = causal_mask
135
- self.mask_initialized = False
136
-
137
- self.attn_distribution = torch.nn.ParameterList([
138
- torch.nn.Parameter(torch.zeros(hidden_size))
139
- for _ in range(num_layers)
140
- ])
141
-
142
- def reinit(self, *pre_mixins):
143
- start_layer = len(self.transformer.layers) - self.num_layers
144
- assert start_layer >= 0
145
- for layer_id in range(self.num_layers):
146
- old_attention = self.transformer.layers[start_layer + layer_id].attention
147
- self.query_key_value[layer_id].weight.data.copy_(old_attention.query_key_value.weight.data)
148
- self.query_key_value[layer_id].bias.data.copy_(old_attention.query_key_value.bias.data)
149
-
150
- def attention_extra_NAR_inference(self, frame_hidden_state, layer_id, attn_dropout=None, memkv_text=None, stage=1):
151
- # frame_hidden_state [batchsize, frame_num*frame_size, n_head*hiddensize_perhead]
152
- if not self.mask_initialized:
153
- self.attn_mask = self.attn_mask.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype)
154
- self.causal_mask = self.causal_mask.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype)
155
- self.mask_initialized = True
156
- b0, s1, h0 = frame_hidden_state.shape
157
- h = h0 // self.n_head
158
- frame_len = self.frame_resolution * self.frame_resolution
159
- frame_num = s1 // frame_len
160
- if stage == 2:
161
- assert frame_num == 3
162
- assert frame_num*frame_len == s1
163
- wind_square = self.window_size * self.window_size
164
- nW = frame_len // wind_square
165
- bswin = b0 * nW
166
-
167
- if memkv_text is not None:
168
- s0 = memkv_text.shape[-2]
169
- k_text = memkv_text[..., :h0].expand(b0, -1, -1).reshape(b0, s0, self.n_head, h).permute(0, 2, 1, 3)
170
- v_text = memkv_text[..., h0:].expand(b0, -1, -1).reshape(b0, s0, self.n_head, h).permute(0, 2, 1, 3)
171
-
172
- # shift
173
- frame_hidden_state = frame_hidden_state.reshape(b0, frame_num, self.frame_resolution, self.frame_resolution, h0)
174
- if self.shift_sizes[layer_id%2] > 0:
175
- frame_hidden_state = torch.roll(frame_hidden_state, shifts=(-self.shift_sizes[layer_id%2], -self.shift_sizes[layer_id%2]), dims=(2,3))
176
- # window partition
177
- frame_hidden_state = window_partition(frame_hidden_state, self.window_size).reshape(bswin, frame_num*wind_square, h0)
178
- qkv = self.query_key_value[layer_id](frame_hidden_state).reshape(bswin, frame_num*wind_square, 3, self.n_head, h)\
179
- .permute(2, 0, 3, 1, 4) #[3, bswin, n_head, frame_num*wind_size*wind_size, h]
180
- q, k, v = qkv[0], qkv[1], qkv[2]
181
- attn = torch.matmul(q / math.sqrt(h), k.transpose(-1, -2))
182
-
183
- if stage == 1:
184
- if self.shift_sizes[layer_id%2] > 0:
185
- attn = torch.mul(attn.view(bswin // nW, nW, self.n_head, frame_num*wind_square, frame_num*wind_square),
186
- self.attn_mask[:,:frame_num*wind_square, :frame_num*wind_square].unsqueeze(1).unsqueeze(0))\
187
- - 10000.0 * (1.0 - self.attn_mask[:,:frame_num*wind_square, :frame_num*wind_square].unsqueeze(1).unsqueeze(0))
188
- attn = attn.view(bswin, self.n_head, frame_num*wind_square, frame_num*wind_square)
189
- else:
190
- attn = torch.mul(attn, self.causal_mask[:frame_num*wind_square, :frame_num*wind_square].unsqueeze(0).unsqueeze(0))\
191
- - 10000.0 * (1.0 - self.causal_mask[:frame_num*wind_square, :frame_num*wind_square].unsqueeze(0).unsqueeze(0))
192
-
193
- if memkv_text is None:
194
- attn = F.softmax(attn, dim=-1)
195
- if attn_dropout is not None:
196
- with get_cuda_rng_tracker().fork():
197
- attn = attn_dropout(attn)
198
- context_swin = torch.matmul(attn, v).permute(0, 2, 1, 3).reshape(bswin, frame_num, self.window_size, self.window_size, h0)
199
- else:
200
- attn_frame2text = torch.matmul(q.reshape(b0, -1, self.n_head, frame_num*wind_square, h) / math.sqrt(h), k_text.unsqueeze(1).transpose(-1, -2))
201
- attn_frame2text = attn_frame2text.reshape(bswin, self.n_head, frame_num*wind_square, s0)
202
- attn = torch.cat((attn, attn_frame2text), dim=-1)
203
- attn = F.softmax(attn, dim=-1)
204
-
205
- if attn_dropout is not None:
206
- with get_cuda_rng_tracker().fork():
207
- attn = attn_dropout(attn)
208
-
209
- context_swin = (torch.matmul(attn[..., :-s0], v) +
210
- torch.matmul(attn[..., -s0:].reshape(b0, -1, self.n_head,frame_num*wind_square, s0), v_text.unsqueeze(1))\
211
- .reshape(bswin, self.n_head, frame_num*wind_square, h))\
212
- .permute(0, 2, 1, 3).reshape(bswin, frame_num, self.window_size, self.window_size, h0)
213
-
214
- context_swin = window_reverse(context_swin, self.window_size, self.frame_resolution, self.frame_resolution)
215
-
216
- # reverse cycle shift
217
- if self.shift_sizes[layer_id%2] > 0:
218
- context_swin = torch.roll(context_swin, shifts=(self.shift_sizes[layer_id%2], self.shift_sizes[layer_id%2]), dims=(2,3))
219
- ret_context = context_swin.reshape(b0, s1, h0)
220
-
221
- # for mem
222
- memk = k.permute(0, 2, 1, 3).reshape(bswin, frame_num, self.window_size, self.window_size, h0)
223
- memv = v.permute(0, 2, 1, 3).reshape(bswin, frame_num, self.window_size, self.window_size, h0)
224
- memk = window_reverse(memk, self.window_size, self.frame_resolution, self.frame_resolution)
225
- memv = window_reverse(memv, self.window_size, self.frame_resolution, self.frame_resolution)
226
- if self.shift_sizes[layer_id%2] > 0:
227
- memk = torch.roll(memk, shifts=(self.shift_sizes[layer_id%2], self.shift_sizes[layer_id%2]), dims=(2,3))
228
- memv = torch.roll(memv, shifts=(self.shift_sizes[layer_id%2], self.shift_sizes[layer_id%2]), dims=(2,3))
229
- memk, memv = memk.reshape(b0, s1, h0), memv.reshape(b0, s1, h0)
230
-
231
- ret_mem = torch.cat((memk, memv), dim=-1)
232
- return ret_context, ret_mem
233
-
234
- def attention_extra_AR_inference(self, frame_hidden_state, memkv, pos, layer_id, log_text_attention_weights=0, attn_dropout=None, memkv_text=None, stage=1):
235
- # frame_hidden_state [batchsize, 1, n_head*hiddensize_perhead]
236
- # memkv [batchsize, pos, hidden_size*2] (include frames only)
237
- # if memkv_text is not None: will attend to text
238
- # pos: token's pos
239
- b0, sin, h0 = frame_hidden_state.shape
240
- h = h0 // self.n_head
241
- assert sin == 1
242
- this_qkv = self.query_key_value[layer_id](frame_hidden_state)
243
- thisq, thisk, thisv = this_qkv[..., :h0], this_qkv[..., h0:2*h0], this_qkv[..., 2*h0:]
244
- s1 = memkv.shape[1] if memkv is not None else 0
245
- frame_len = self.frame_resolution * self.frame_resolution
246
- frame_num_before = s1 // frame_len
247
-
248
-
249
- if memkv is not None:
250
- pos_inframe = pos - frame_num_before * frame_len
251
-
252
- xpos = pos_inframe // self.frame_resolution # pos = xpos*self.frame_resolution + ypos
253
- ypos = pos_inframe % self.frame_resolution
254
- # [start, end)
255
- if self.shift_sizes[layer_id%2] > 0:
256
- xstart = ((xpos+self.shift_sizes[layer_id%2]) // self.window_size) * self.window_size - self.shift_sizes[layer_id%2]
257
- ystart = ((ypos+self.shift_sizes[layer_id%2]) // self.window_size) * self.window_size - self.shift_sizes[layer_id%2]
258
- xend = xstart + self.window_size
259
- yend = ystart + self.window_size
260
- xstart, ystart = max(0, xstart), max(0, ystart)
261
- xend, yend = min(xend, self.frame_resolution), min(yend, self.frame_resolution)
262
- else:
263
- xstart = (xpos // self.window_size) * self.window_size
264
- ystart = (ypos // self.window_size) * self.window_size
265
- xend, yend = xstart + self.window_size, ystart+self.window_size
266
-
267
- # select index
268
- selected_index = list()
269
- if frame_num_before > 0:
270
- # frames before
271
- frame_attended_start = max(0, frame_num_before-self.time_dim_attend_length+1) if self.time_dim_attend_length > 0 else 0
272
- for x in range(xstart, xend):
273
- for y in range(ystart, yend):
274
- selected_index.append(x*self.frame_resolution+y+frame_len*frame_attended_start)
275
- cnt_per_frame = len(selected_index)
276
- for _ in range((frame_num_before-frame_attended_start-1)*cnt_per_frame):
277
- selected_index.append(selected_index[-cnt_per_frame]+frame_len)
278
-
279
- # the last frame
280
- for x in range(xstart, xend):
281
- for y in range(ystart, yend):
282
- tmppos = x*self.frame_resolution+y + frame_num_before * frame_len
283
- if tmppos < pos:
284
- selected_index.append(tmppos)
285
- else:
286
- break
287
- cnt_all = len(selected_index)+1
288
- selected_index = torch.tensor(selected_index, device=memkv.device)
289
- used_memkv = torch.index_select(memkv, 1, selected_index)
290
- used_k, used_v = used_memkv[..., :h0], used_memkv[..., h0:]
291
- used_k = torch.cat((used_k.expand(thisk.shape[0], -1, -1), thisk), dim=-2)
292
- used_v = torch.cat((used_v.expand(thisv.shape[0], -1, -1), thisv), dim=-2)
293
- if memkv_text is not None:
294
- cnt_all += memkv_text.shape[-2]
295
- used_k = torch.cat((memkv_text[..., :h0].expand(thisk.shape[0], -1, -1), used_k), dim=-2)
296
- used_v = torch.cat((memkv_text[..., h0:].expand(thisv.shape[0], -1, -1), used_v), dim=-2)
297
- used_k = used_k.reshape(b0, cnt_all, self.n_head, h).permute(0, 2, 1, 3)
298
- used_v = used_v.reshape(b0, cnt_all, self.n_head, h).permute(0, 2, 1, 3)
299
- else:
300
- used_k = thisk
301
- used_v = thisv
302
-
303
- if memkv_text is not None:
304
- used_k = torch.cat((memkv_text[..., :h0].expand(thisk.shape[0], -1, -1), used_k), dim=-2)
305
- used_v = torch.cat((memkv_text[..., h0:].expand(thisv.shape[0], -1, -1), used_v), dim=-2)
306
- used_k = used_k.reshape(b0, 1+memkv_text.shape[-2], self.n_head, h).permute(0, 2, 1, 3)
307
- used_v = used_v.reshape(b0, 1+memkv_text.shape[-2], self.n_head, h).permute(0, 2, 1, 3)
308
- else:
309
- used_k = used_k.reshape(b0, 1, self.n_head, h).permute(0, 2, 1, 3)
310
- used_v = used_v.reshape(b0, 1, self.n_head, h).permute(0, 2, 1, 3)
311
-
312
- thisq = thisq.reshape(b0, 1, self.n_head, h).permute(0, 2, 1, 3) # [b0, n_head, 1, h]
313
- attn = torch.matmul(thisq / math.sqrt(h), used_k.transpose(-1, -2))
314
- if memkv_text is not None:
315
- attn[..., :memkv_text.shape[-2]] += log_text_attention_weights
316
- attn = F.softmax(attn, dim=-1)
317
- context_swin = torch.matmul(attn, used_v).permute(0, 2, 1, 3).reshape(b0, 1, h0)
318
-
319
- return context_swin, this_qkv[..., h0:]
320
-
321
- class FullAttentionMixin(BaseMixin):
322
- def __init__(self, num_layers,
323
- hidden_size,
324
- frame_resolution,
325
- n_head,
326
- frame_num,
327
- init_method=unscaled_init_method(0.02),
328
- output_layer_init_method=unscaled_init_method(0.02),
329
- **kwargs,
330
- ):
331
- super(FullAttentionMixin, self).__init__()
332
- self.num_layers = num_layers # replace attention in the LAST n layers
333
- self.query_key_value = torch.nn.ModuleList(
334
- [ColumnParallelLinear(hidden_size, 3*hidden_size,stride=3,
335
- gather_output=False,init_method=init_method)
336
- for layer_id in range(num_layers)
337
- ])
338
- self.dense = torch.nn.ModuleList(
339
- [RowParallelLinear(
340
- hidden_size,
341
- hidden_size,
342
- input_is_parallel=True,
343
- init_method=output_layer_init_method,
344
- bias=True,
345
- module=self,
346
- name="dense")
347
- for layer_id in range(num_layers)
348
- ])
349
-
350
- self.n_head = n_head
351
- self.frame_resolution = frame_resolution
352
- self.frame_len = frame_resolution * frame_resolution
353
-
354
- self.attn_distribution = torch.nn.ParameterList([
355
- torch.nn.Parameter(torch.zeros(hidden_size))
356
- for _ in range(num_layers)
357
- ])
358
-
359
- def reinit(self, *pre_mixins):
360
- start_layer = len(self.transformer.layers) - self.num_layers
361
- assert start_layer >= 0
362
- for layer_id in range(self.num_layers):
363
- old_attention = self.transformer.layers[start_layer + layer_id].attention
364
- self.query_key_value[layer_id].weight.data.copy_(old_attention.query_key_value.weight.data)
365
- self.query_key_value[layer_id].bias.data.copy_(old_attention.query_key_value.bias.data)
366
-
367
-
368
- def attention_extra_NAR_inference(self, frame_hidden_state, layer_id, attn_dropout=None, memkv_text=None, stage=1):
369
- # frame_hidden_state [batchsize, frame_num*frame_size, n_head*hiddensize_perhead]
370
- assert stage == 1
371
-
372
- b0, s1, h0 = frame_hidden_state.shape
373
- h = h0 // self.n_head
374
- frame_len = self.frame_resolution * self.frame_resolution
375
- frame_num = s1 // frame_len
376
- assert frame_num*frame_len == s1
377
-
378
- if memkv_text is not None:
379
- s0 = memkv_text.shape[-2]
380
- k_text = memkv_text[..., :h0].expand(b0, -1, -1).reshape(b0, s0, self.n_head, h).permute(0, 2, 1, 3)
381
- v_text = memkv_text[..., h0:].expand(b0, -1, -1).reshape(b0, s0, self.n_head, h).permute(0, 2, 1, 3)
382
- qkv = self.query_key_value[layer_id](frame_hidden_state).reshape(b0, s1, 3, self.n_head, h)\
383
- .permute(2, 0, 3, 1, 4) #[3, b0, n_head, s1, h]
384
- q, k, v = qkv[0], qkv[1], qkv[2]
385
- attn = torch.matmul(q / math.sqrt(h), k.transpose(-1, -2))
386
- attn = attn - 10000.0 * (1.0-torch.ones(b0, self.n_head, s1, s1, device=attn.device, dtype=attn.dtype).tril())
387
-
388
- if memkv_text is None:
389
- attn = F.softmax(attn, dim=-1)
390
- if attn_dropout is not None:
391
- with get_cuda_rng_tracker().fork():
392
- attn = attn_dropout(attn)
393
- context_swin = torch.matmul(attn, v).permute(0, 2, 1, 3).reshape(b0, s1, h0)
394
- else:
395
- attn_frame2text = torch.matmul(q / math.sqrt(h), k_text.transpose(-1, -2)) #[b0, s1, s0]
396
- attn = torch.cat((attn, attn_frame2text), dim=-1)
397
- attn = F.softmax(attn, dim=-1)
398
- if attn_dropout is not None:
399
- with get_cuda_rng_tracker().fork():
400
- attn = attn_dropout(attn)
401
- context_swin = (torch.matmul(attn[..., :-s0], v) + torch.matmul(attn[..., -s0:], v_text))\
402
- .permute(0, 2, 1, 3).reshape(b0, s1, h0)
403
-
404
- # for mem
405
- memk = k.permute(0, 2, 1, 3).reshape(b0, s1, h0)
406
- memv = v.permute(0, 2, 1, 3).reshape(b0, s1, h0)
407
- ret_mem = torch.cat((memk, memv), dim=-1)
408
-
409
- return context_swin, ret_mem
410
-
411
- def attention_extra_AR_inference(self, frame_hidden_state, memkv, pos, layer_id, log_text_attention_weights=0, attn_dropout=None, memkv_text=None, stage=1):
412
- # pos: current token's pos
413
- b0, sin, h0 = frame_hidden_state.shape
414
- h = h0 // self.n_head
415
- assert sin == 1
416
- assert stage == 1
417
-
418
- this_qkv = self.query_key_value[layer_id](frame_hidden_state)
419
- thisq, thisk, thisv = this_qkv[..., :h0], this_qkv[..., h0:2*h0], this_qkv[..., 2*h0:]
420
-
421
- if memkv is not None:
422
- used_k, used_v = memkv[..., :h0], memkv[..., h0:]
423
- used_k = torch.cat((used_k.expand(thisk.shape[0], -1, -1), thisk), dim=-2)
424
- used_v = torch.cat((used_v.expand(thisv.shape[0], -1, -1), thisv), dim=-2)
425
- else:
426
- used_k, used_v = thisk, thisv
427
-
428
- if memkv_text is not None:
429
- used_k = torch.cat((memkv_text[..., :h0].expand(thisk.shape[0], -1, -1), used_k), dim=-2)
430
- used_v = torch.cat((memkv_text[..., h0:].expand(thisv.shape[0], -1, -1), used_v), dim=-2)
431
-
432
- used_k = used_k.reshape(b0, -1, self.n_head, h).permute(0, 2, 1, 3)
433
- used_v = used_v.reshape(b0, -1, self.n_head, h).permute(0, 2, 1, 3)
434
- thisq = thisq.reshape(b0, 1, self.n_head, h).permute(0, 2, 1, 3) # [b0, n_head, 1, h]
435
- attn = torch.matmul(thisq / math.sqrt(h), used_k.transpose(-1, -2))
436
- if memkv_text is not None:
437
- attn[..., :memkv_text.shape[-2]] += log_text_attention_weights
438
- attn = F.softmax(attn, dim=-1)
439
-
440
- context_swin = torch.matmul(attn, used_v).permute(0, 2, 1, 3).reshape(b0, 1, h0)
441
-
442
- return context_swin, this_qkv[..., h0:]
443
-
444
-
445
- def attention_localframe_and_text_NAR(q0, k0, v0, attention_mask,
446
- n_head, text_len, frame_len, frame_num,
447
- attention_dropout=None, log_text_attention_weights=0, stage=1, **kwargs):
448
- b, s0, h0 = q0.shape
449
- s1 = s0 - text_len
450
- h = h0 // n_head
451
- assert q0.shape[1] == v0.shape[1] == k0.shape[1] == text_len+frame_len*frame_num
452
- # attention_mask.shape [4, b or 1, 1, text_len+frame_len, text_len+frame_len]
453
- if stage == 2:
454
- assert frame_num == 3
455
-
456
- q0 = q0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
457
- v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
458
- k0 = k0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
459
- k0T = k0.transpose(-1, -2)
460
-
461
- score_any2text = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T[..., :text_len])
462
- score_any2text += log_text_attention_weights
463
- score_any2text_part1 = torch.mul(score_any2text[..., :text_len, :], attention_mask[..., :text_len, :text_len]) \
464
- - 10000.0 * (1.0 - attention_mask[..., :text_len, :text_len])
465
- # context for text
466
- attention_probs_text = F.softmax(score_any2text_part1, dim=-1)
467
- if attention_dropout is not None:
468
- with get_cuda_rng_tracker().fork():
469
- attention_probs_text = attention_dropout(attention_probs_text)
470
- context_text2text = torch.matmul(attention_probs_text, v0[..., :text_len, :])
471
- context_text2text = context_text2text.transpose(1, 2).reshape(b, text_len, h0)
472
-
473
- if frame_num > 0:
474
- score_any2text_part2 = score_any2text[..., text_len:, :]
475
-
476
- # score: frame local
477
- q0_frame = q0[:, :, text_len:].reshape(b, n_head, frame_num, frame_len, h)
478
- v0_frame = v0[:, :, text_len:].reshape(b, n_head, frame_num, frame_len, h)
479
- k0T_frame = k0[:, :, text_len:].reshape(b, n_head, frame_num, frame_len, h).transpose(-1, -2)
480
- score_frame_local0 = torch.matmul(q0_frame / math.sqrt(q0_frame.shape[-1]), k0T_frame)
481
- if stage == 1:
482
- score_frame_local0 = torch.mul(score_frame_local0, attention_mask[..., text_len:, text_len:].unsqueeze(1)) \
483
- - 10000.0 * (1.0 - attention_mask[..., text_len:, text_len:].unsqueeze(1))
484
-
485
- # context for frame
486
- score_frame_all = torch.cat((score_any2text_part2,
487
- score_frame_local0.view(b, n_head, s1, frame_len)), dim=-1)
488
- attention_probs_frame = F.softmax(score_frame_all, dim=-1)
489
- if attention_dropout is not None:
490
- with get_cuda_rng_tracker().fork():
491
- attention_probs_frame = attention_dropout(attention_probs_frame)
492
- context_frame2text = torch.matmul(attention_probs_frame[..., :text_len], v0[..., :text_len, :]) # [b, n_head, s1, h]
493
- context_frame_local0 = torch.matmul(attention_probs_frame[..., text_len:text_len+frame_len].\
494
- view(b, n_head, frame_num, frame_len, frame_len), v0_frame).view(b, n_head, s1, h)
495
-
496
- context_frame = (context_frame2text + context_frame_local0).transpose(1, 2).reshape(b, s1, h0)
497
- else:
498
- context_frame = None
499
-
500
- return context_text2text, context_frame
501
-
502
- def attention_localframe_and_text_AR(q0, k0, v0, n_head, text_len, frame_len, frame_num,
503
- attention_dropout=None, log_text_attention_weights=0, layer_id=None, limited_spatial_channel_mem=False, stage=1, **kwargs):
504
- # limited_spatial_channel_mem=True means: mems in spatial channel is consisted of {mem_text, mem_current_frame}
505
- b, s0, h0 = k0.shape
506
- frame_num_before = (s0-text_len-1) // frame_len # frame_num == frame_num_before or frame_num == frame_num_before+1
507
- h = h0 // n_head
508
- assert q0.shape[1] == 1
509
- assert v0.shape[1] == k0.shape[1]
510
-
511
- q0 = q0.reshape(b, 1, n_head, h).permute(0, 2, 1, 3)
512
- v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
513
- k0T = k0.reshape(b, s0, n_head, h).permute(0, 2, 3, 1)
514
-
515
- if limited_spatial_channel_mem:
516
- assert frame_num_before == 0
517
- assert stage == 1 # not implemented for stage-2 yet
518
- score = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T)
519
- score[..., :text_len] += log_text_attention_weights
520
- attention_probs_frame = F.softmax(score, dim=-1)
521
- context_frame = torch.matmul(attention_probs_frame, v0).transpose(1, 2).reshape(b, 1, h0)
522
-
523
- else:
524
- score_token2text = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T[..., :text_len])
525
- score_token2text += log_text_attention_weights
526
- score_frame_local0 = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T[..., text_len+frame_num_before*frame_len:])
527
- score_frame_all = torch.cat((score_token2text,
528
- score_frame_local0), dim=-1)
529
- attention_probs_frame = F.softmax(score_frame_all, dim=-1)
530
-
531
- context_token2text = torch.matmul(attention_probs_frame[..., :text_len], v0[..., :text_len, :]) # [b, n_head, s1, h]
532
- context_frame_local0 = torch.matmul(attention_probs_frame[..., text_len:], \
533
- v0[:, :, text_len+frame_num_before*frame_len:, :])
534
- context_frame = (context_token2text + context_frame_local0).transpose(1, 2).reshape(b, 1, h0)
535
-
536
- return context_frame
537
-
538
-
539
- class CogVideoCacheModel(BaseModel):
540
- def __init__(self, args, transformer=None, parallel_output=True, window_size=None, cogvideo_stage=None):
541
- super().__init__(args, transformer=transformer, parallel_output=parallel_output)
542
- self.layout = args.layout # [64, 64+1024, 64+6*1024]
543
- self.stage = cogvideo_stage if cogvideo_stage is not None else args.cogvideo_stage # 1 or 2
544
- self.n_head = args.num_attention_heads
545
- self.window_size = window_size if window_size is not None else args.window_size
546
-
547
- frame_resolution = int(math.sqrt(self.layout[1]-self.layout[0]))
548
- self.add_mixin('extra_position_embedding', PositionEmbeddingMixin(
549
- args.additional_seqlen, args.hidden_size
550
- ))
551
-
552
- if self.stage == 1:
553
- self.add_mixin('attention_plus', FullAttentionMixin(
554
- num_layers=args.num_layers,
555
- hidden_size=args.hidden_size,
556
- frame_resolution=frame_resolution,
557
- n_head=args.num_attention_heads,
558
- frame_num=(args.layout[2]-args.layout[0])//(args.layout[1]-args.layout[0]),
559
- ))
560
- else:
561
- self.add_mixin('attention_plus', WindowAttentionMixin(
562
- num_layers=args.num_layers,
563
- hidden_size=args.hidden_size,
564
- frame_resolution=frame_resolution,
565
- window_size=self.window_size,
566
- shift_size=self.window_size//2,
567
- n_head=args.num_attention_heads,
568
- frame_num=(args.layout[2]-args.layout[0])//(args.layout[1]-args.layout[0]),
569
- ))
570
-
571
-
572
- @classmethod
573
- def add_model_specific_args(cls, parser):
574
- group = parser.add_argument_group('VideoSwinLocalModel', 'video swin local model configurations')
575
- group.add_argument("--layout", type=str, default='64, 464, 2064')
576
- group.add_argument("--window-size", type=int, default=10) # 优先级在直接参数赋值之后
577
- group.add_argument("--additional-seqlen", type=int, default=2000)
578
- group.add_argument("--cogvideo-stage", type=int, default=1, choices=[1,2]) # 优先级在直接参数赋值之后
579
- return parser
580
-
581
- def disable_untrainable_params(self):
582
- pass
583
-
584
- def position_embedding_forward(self, position_ids, **kw_args):
585
- if position_ids.shape[-1] > 1:
586
- if self.stage == 1:
587
- if position_ids[0,-1] >= (512+400):
588
- frame_num = position_ids.shape[-1] // 400
589
- position_embeddings = torch.cat(
590
- (
591
- self.transformer.position_embeddings(position_ids[..., :-400*(frame_num-1)]),
592
- self.get_mixin('extra_position_embedding').position_embeddings(position_ids[..., -400*(frame_num-1):]-(512+400))
593
- ),
594
- dim=-2
595
- )
596
- else:
597
- position_embeddings = self.transformer.position_embeddings(position_ids)
598
- else:
599
- # given 3, interpolate 2
600
- position_embeddings = torch.cat(
601
- (
602
- self.transformer.position_embeddings(position_ids[..., :-800]),
603
- self.get_mixin('extra_position_embedding').position_embeddings(position_ids[..., -800:]-(512+400))
604
- ),
605
- dim=-2
606
- )
607
- else:
608
- if position_ids[0, 0] >= (512+400):
609
- position_embeddings = self.get_mixin('extra_position_embedding').position_embeddings(position_ids-(512+400))
610
- else:
611
- position_embeddings = self.transformer.position_embeddings(position_ids)
612
- return position_embeddings
613
-
614
- def attention_forward(self, hidden_states, mask, layer_id, mems=None, log_text_attention_weights=0, text_len=0, frame_len=0, counter=0, enforce_no_swin=False, limited_spatial_channel_mem=False, **kw_args):
615
- attn_module = self.transformer.layers[layer_id].attention
616
- hidden_size = hidden_states.shape[-1]
617
-
618
- # base model qkv
619
- if mems is None:
620
- mixed_raw_layer = attn_module.query_key_value(hidden_states)
621
- q0, k0, v0 = split_tensor_along_last_dim(mixed_raw_layer, 3)
622
- assert (q0.shape[1]-text_len) % frame_len == 0
623
- memkv0 = torch.cat((k0, v0), dim=-1)
624
- context_text, context_frame_local_text = attention_localframe_and_text_NAR(
625
- q0, k0, v0,
626
- mask,
627
- n_head=attn_module.num_attention_heads_per_partition,
628
- text_len=text_len,
629
- frame_len=frame_len,
630
- frame_num=(q0.shape[1]-text_len)//frame_len,
631
- log_text_attention_weights=log_text_attention_weights,
632
- stage=self.stage
633
- )
634
-
635
- # change: self.swin_attend_to_text默认为True:
636
- memkv1_text = self.get_mixin('attention_plus').query_key_value[layer_id](hidden_states[..., :text_len, :])[..., hidden_size:]
637
- output_text = attn_module.dense(context_text)
638
-
639
- if (q0.shape[1]-text_len)//frame_len > 0:
640
- assert (q0.shape[1]-text_len) % frame_len == 0
641
- context_frame_swin, memkv1_frame = self.get_mixin('attention_plus').attention_extra_NAR_inference(
642
- hidden_states[:,text_len:], layer_id, memkv_text=memkv1_text, stage=self.stage)
643
- if not enforce_no_swin:
644
- attn_distrib = torch.sigmoid(self.get_mixin('attention_plus').attn_distribution[layer_id])
645
- attn_distrib = attn_distrib.unsqueeze(0).unsqueeze(0)
646
- output_frame = torch.mul(attn_module.dense(context_frame_local_text), attn_distrib)\
647
- +torch.mul(self.get_mixin('attention_plus').dense[layer_id](context_frame_swin), 1-attn_distrib)
648
- else:
649
- output_frame = attn_module.dense(context_frame_local_text[..., :frame_len, :])
650
- output = torch.cat((output_text, output_frame), dim=-2)
651
- memkv1 = torch.cat((memkv1_text, memkv1_frame), dim=-2) if memkv1_text is not None else memkv1_frame
652
- else:
653
- output = output_text
654
- memkv1 = memkv1_text
655
- kw_args['output_this_layer']['mem_kv'] = (memkv0, memkv1)
656
-
657
-
658
- else:
659
- mixed_raw_layer = attn_module.query_key_value(hidden_states)
660
- q0, k0, v0 = split_tensor_along_last_dim(mixed_raw_layer, 3)
661
- new_memkv0 = torch.cat((k0, v0), dim=-1)
662
- old_k0, old_v0 = mems[0][layer_id][..., :hidden_size], mems[0][layer_id][..., hidden_size:]
663
-
664
- context_frame_local_text = attention_localframe_and_text_AR(
665
- q0,
666
- torch.cat((old_k0.expand(k0.shape[0], -1, -1), k0), dim=-2),
667
- torch.cat((old_v0.expand(v0.shape[0], -1, -1), v0), dim=-2),
668
- n_head=attn_module.num_attention_heads_per_partition,
669
- text_len=text_len,
670
- frame_len=frame_len,
671
- frame_num=None,
672
- log_text_attention_weights=log_text_attention_weights,
673
- layer_id=layer_id,
674
- limited_spatial_channel_mem=limited_spatial_channel_mem,
675
- )
676
-
677
- old_memkv1 = mems[1][layer_id] if mems[1] is not None else None
678
-
679
- context_frame_swin, new_memkv1 = self.get_mixin('attention_plus').attention_extra_AR_inference(hidden_states,
680
- old_memkv1[..., text_len:, :] if old_memkv1.shape[-2]>text_len else None,
681
- counter-text_len,
682
- layer_id,
683
- memkv_text=old_memkv1[..., :text_len, :],
684
- log_text_attention_weights=log_text_attention_weights)
685
- if not enforce_no_swin:
686
- attn_distrib = torch.sigmoid(self.get_mixin('attention_plus').attn_distribution[layer_id])
687
- attn_distrib = attn_distrib.unsqueeze(0).unsqueeze(0)
688
- output = torch.mul(attn_module.dense(context_frame_local_text), attn_distrib)\
689
- +torch.mul(self.get_mixin('attention_plus').dense[layer_id](context_frame_swin), 1-attn_distrib)
690
- else:
691
- output = attn_module.dense(context_frame_local_text)
692
-
693
- kw_args['output_this_layer']['mem_kv'] = (new_memkv0, new_memkv1)
694
-
695
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/cogvideo_model.py DELETED
@@ -1,543 +0,0 @@
1
- # -*- encoding: utf-8 -*-
2
- '''
3
- @File : cogvideo_model.py
4
- @Time : 2022/07/11 16:12:05
5
- @Author : Wenyi Hong
6
- @Version : 1.0
7
- @Contact : [email protected]
8
- '''
9
-
10
- # here put the import lib
11
-
12
- import torch
13
- from SwissArmyTransformer.model.base_model import BaseModel, BaseMixin
14
-
15
- from SwissArmyTransformer.mpu.utils import split_tensor_along_last_dim
16
- from SwissArmyTransformer.model.transformer import unscaled_init_method
17
- from SwissArmyTransformer.mpu import ColumnParallelLinear, RowParallelLinear
18
- import torch.nn.functional as F
19
- from deepspeed.runtime.activation_checkpointing.checkpointing import get_cuda_rng_tracker
20
- import math
21
-
22
- class PositionEmbeddingMixin(BaseMixin):
23
- def __init__(self, additional_sequence_length, hidden_size,
24
- init_method_std=0.02, reinit_slice=slice(512, 912),
25
- ):
26
- super(PositionEmbeddingMixin, self).__init__()
27
- self.reinit_slice = reinit_slice
28
- self.position_embeddings = torch.nn.Embedding(additional_sequence_length, hidden_size)
29
- torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std)
30
-
31
- def reinit(self, parent_model=None):
32
- old_weights = self.transformer.position_embeddings.weight.data[self.reinit_slice]
33
- old_len, hidden_size = old_weights.shape
34
- assert hidden_size == self.position_embeddings.weight.shape[-1]
35
- self.position_embeddings.weight.data.view(-1, old_len, hidden_size).copy_(old_weights)
36
-
37
- def window_partition(x, window_size):
38
- """
39
- Args:
40
- x: (B, framenum, H, W, C)
41
- window_size (int): window size
42
- Returns:
43
- windows: (num_windows*B, frame_num, window_size, window_size, C)
44
- """
45
- B, framenum, H, W, C = x.shape
46
- x = x.view(B, framenum, H // window_size, window_size, W // window_size, window_size, C)
47
- windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous().view(-1, framenum, window_size, window_size, C)
48
- return windows
49
-
50
- def window_reverse(windows, window_size, H, W):
51
- """
52
- Args:
53
- windows: (num_windows*B, frame_num, window_size, window_size, C)
54
- window_size (int): Window size
55
- H (int): Height of image
56
- W (int): Width of image
57
- Returns:
58
- x: (B, frame_num, H, W, C)
59
- """
60
- B = int(windows.shape[0] / (H * W / window_size / window_size))
61
- framenum = windows.shape[1]
62
- x = windows.view(B, H // window_size, W // window_size, framenum, window_size, window_size, -1)
63
- x = x.permute(0, 3, 1, 4, 2, 5, 6).contiguous().view(B, framenum, H, W, -1)
64
- return x
65
-
66
- class WindowAttentionMixin(BaseMixin):
67
- def __init__(self, num_layers,
68
- hidden_size,
69
- frame_resolution,
70
- window_size,
71
- shift_size,
72
- n_head,
73
- frame_num,
74
- init_method=unscaled_init_method(0.02),
75
- output_layer_init_method=unscaled_init_method(0.02),
76
- ):
77
- super(WindowAttentionMixin, self).__init__()
78
- self.num_layers = num_layers # replace attention in the LAST n layers
79
- self.query_key_value = torch.nn.ModuleList(
80
- [ColumnParallelLinear(hidden_size, 3*hidden_size,stride=3,
81
- gather_output=False,init_method=init_method)
82
- for layer_id in range(num_layers)
83
- ])
84
- self.dense = torch.nn.ModuleList(
85
- [RowParallelLinear(
86
- hidden_size,
87
- hidden_size,
88
- input_is_parallel=True,
89
- init_method=output_layer_init_method,
90
- bias=True,
91
- module=self,
92
- name="dense",
93
- )
94
- for layer_id in range(num_layers)
95
- ])
96
-
97
- self.n_head = n_head
98
- self.window_size = window_size
99
- self.frame_resolution = frame_resolution
100
- self.frame_len = frame_resolution * frame_resolution
101
- assert frame_resolution % window_size == 0
102
- assert 0 < shift_size < window_size
103
- nW = (self.frame_resolution // self.window_size) ** 2
104
- ws_squre = self.window_size * self.window_size
105
-
106
- # odd non-shift, even shift
107
- img_mask = torch.zeros((1, 1, frame_resolution, frame_resolution, 1))
108
- h_slices = (slice(0, -shift_size),
109
- slice(-shift_size, None))
110
- w_slices = (slice(0, -shift_size),
111
- slice(-shift_size, None))
112
- cnt = 0
113
- for h in h_slices:
114
- for w in w_slices:
115
- img_mask[:, :, h, w, :] = cnt
116
- cnt += 1
117
- mask_windows = window_partition(img_mask, self.window_size) # nW, 1, window_size, window_size, 1
118
- mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
119
- sub_attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) #[nW, self.window_size * self.window_size, self.window_size * self.window_size]
120
- sub_attn_mask = sub_attn_mask.masked_fill(sub_attn_mask != 0, float(0.0)).masked_fill(sub_attn_mask == 0, float(1.00))
121
- attn_mask = sub_attn_mask.repeat(1, frame_num, frame_num)
122
-
123
- self.attn_mask_sequential = attn_mask.clone().tril()
124
- self.causal_mask_sequential = torch.ones(1, ws_squre*frame_num, ws_squre*frame_num).tril()
125
-
126
- self.causal_mask_interp = torch.ones(1, ws_squre*frame_num, ws_squre*frame_num)
127
- self.attn_mask_interp = attn_mask.clone()
128
-
129
- # bi-dir
130
- for bi_idx in range(0, frame_num, 2):
131
- for uni_idx in range(1, frame_num, 2):
132
- self.attn_mask_interp[:, bi_idx*ws_squre:(bi_idx+1)*ws_squre, uni_idx*ws_squre:(uni_idx+1)*ws_squre] = 0
133
- self.causal_mask_interp[:, bi_idx*ws_squre:(bi_idx+1)*ws_squre, uni_idx*ws_squre:(uni_idx+1)*ws_squre] = 0
134
- # uni-dir
135
- for uni_idx in range(1, frame_num, 2):
136
- self.attn_mask_interp[:, ws_squre*uni_idx:ws_squre*(uni_idx+1), ws_squre*uni_idx:ws_squre*(uni_idx+1)].tril_()
137
- self.causal_mask_interp[:, ws_squre*uni_idx:ws_squre*(uni_idx+1), ws_squre*uni_idx:ws_squre*(uni_idx+1)].tril_()
138
- for uni_idx2 in range(uni_idx+2, frame_num, 2):
139
- self.attn_mask_interp[:, ws_squre*uni_idx:ws_squre*(uni_idx+1), ws_squre*uni_idx2:ws_squre*(uni_idx2+1)] = 0
140
- self.causal_mask_interp[:, ws_squre*uni_idx:ws_squre*(uni_idx+1), ws_squre*uni_idx2:ws_squre*(uni_idx2+1)] = 0
141
-
142
- # expand dim
143
- self.attn_mask_sequential = self.attn_mask_sequential[None, None, :, None]
144
- self.attn_mask_interp = self.attn_mask_interp[None, None, :, None]
145
- self.causal_mask_sequential = self.causal_mask_sequential[None, None, :, None]
146
- self.causal_mask_interp = self.causal_mask_interp[None, None, :, None]
147
-
148
- self.shift_sizes = [0, shift_size]
149
- # self.register_buffer("attn_mask", attn_mask)
150
- # self.register_buffer("causal_mask", causal_mask)
151
- self.mask_initialized = False
152
-
153
- self.attn_distribution = torch.nn.ParameterList([
154
- torch.nn.Parameter(torch.zeros(hidden_size))
155
- for _ in range(num_layers)
156
- ])
157
-
158
- def reinit(self, *pre_mixins):
159
- start_layer = len(self.transformer.layers) - self.num_layers
160
- assert start_layer >= 0
161
- for layer_id in range(self.num_layers):
162
- old_attention = self.transformer.layers[start_layer + layer_id].attention
163
- self.query_key_value[layer_id].weight.data.copy_(old_attention.query_key_value.weight.data)
164
- self.query_key_value[layer_id].bias.data.copy_(old_attention.query_key_value.bias.data)
165
-
166
- def attention_extra(self, frame_hidden_state, layer_id, attn_dropout, text_hidden_state=None,
167
- text_attn_mask=None, mode_sequential=True):
168
- # pb relax
169
- swin_pb_relax = True
170
- alpha = 16
171
-
172
- # frame_hidden_state [batchsize, frame_num*frame_size, n_head*hiddensize_perhead]
173
- if not self.mask_initialized:
174
- self.attn_mask_sequential = self.attn_mask_sequential.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype)
175
- self.causal_mask_sequential = self.causal_mask_sequential.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype)
176
- self.attn_mask_interp = self.attn_mask_interp.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype)
177
- self.causal_mask_interp = self.causal_mask_interp.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype)
178
- self.mask_initialized = True
179
- b0, s1, h0 = frame_hidden_state.shape
180
- h = h0 // self.n_head
181
- frame_len = self.frame_resolution * self.frame_resolution
182
- frame_num = s1 // frame_len
183
- assert frame_num*frame_len == s1
184
- wind_square = self.window_size * self.window_size
185
- nW = frame_len // wind_square
186
- bswin = b0 * nW
187
-
188
- causal_mask = self.causal_mask_sequential if mode_sequential else self.causal_mask_interp
189
- attn_mask = self.attn_mask_sequential if mode_sequential else self.attn_mask_interp
190
- if text_hidden_state is not None:
191
- s0 = text_hidden_state.shape[1]
192
- qkv_text = self.query_key_value[layer_id](text_hidden_state).reshape(b0, s0, 3, self.n_head, h).permute(2, 0, 3, 1, 4) #[3, b0, n_head, s0, h]
193
- q_text, k_text, v_text = qkv_text[0], qkv_text[1], qkv_text[2]
194
-
195
- # shift
196
- frame_hidden_state = frame_hidden_state.reshape(b0, frame_num, self.frame_resolution, self.frame_resolution, h0)
197
- if self.shift_sizes[layer_id%2] > 0:
198
- frame_hidden_state = torch.roll(frame_hidden_state, shifts=(-self.shift_sizes[layer_id%2], -self.shift_sizes[layer_id%2]), dims=(2,3))
199
- # window partition
200
- frame_hidden_state = window_partition(frame_hidden_state, self.window_size).reshape(bswin, frame_num*wind_square, h0)
201
- qkv = self.query_key_value[layer_id](frame_hidden_state).reshape(bswin, frame_num*wind_square, 3, self.n_head, h)\
202
- .permute(2, 0, 3, 1, 4) #[3, bswin, n_head, frame_num*wind_size*wind_size, h]
203
- q, k, v = qkv[0], qkv[1], qkv[2]
204
-
205
- # pb-relax
206
- if swin_pb_relax:
207
- attn = torch.matmul(q / (math.sqrt(h)*alpha), k.transpose(-1, -2))
208
- else:
209
- attn = torch.matmul(q / math.sqrt(h), k.transpose(-1, -2))
210
-
211
- if self.shift_sizes[layer_id%2] > 0:
212
- # attn = attn.view(bswin // nW, nW, self.n_head, frame_num*wind_square, frame_num*wind_square) + self.attn_mask.unsqueeze(1).unsqueeze(0)
213
- attn = torch.mul(attn.view(bswin // nW, nW, self.n_head, frame_num*wind_square, frame_num*wind_square), attn_mask)\
214
- - 10000.0 * (1.0 - attn_mask)
215
- attn = attn.view(bswin, self.n_head, frame_num*wind_square, frame_num*wind_square)
216
- else:
217
- attn = torch.mul(attn.view(bswin // nW, nW, self.n_head, frame_num*wind_square, frame_num*wind_square), causal_mask)\
218
- - 10000.0 * (1.0 - causal_mask)
219
- attn = attn.view(bswin, self.n_head, frame_num*wind_square, frame_num*wind_square)
220
- if swin_pb_relax:
221
- swin_pb_relax_const = torch.max(attn.reshape(bswin, self.n_head, -1), dim=-1, keepdim=True)[0].detach().unsqueeze(-1)
222
- attn = (attn - swin_pb_relax_const)*alpha
223
-
224
- if text_hidden_state is None:
225
- attn = F.softmax(attn, dim=-1)
226
- if attn_dropout is not None:
227
- with get_cuda_rng_tracker().fork():
228
- attn = attn_dropout(attn)
229
- context_swin = torch.matmul(attn, v).permute(0, 2, 1, 3).reshape(bswin, frame_num, self.window_size, self.window_size, h0)
230
- else:
231
- assert text_attn_mask is not None
232
- text_attn_mask = text_attn_mask.unsqueeze(2).unsqueeze(2)
233
- # pb-relax
234
- if swin_pb_relax:
235
- attn_frame2text = torch.matmul(q.reshape(b0, -1, self.n_head, frame_num*wind_square, h) / (math.sqrt(h)*alpha), k_text.unsqueeze(1).transpose(-1, -2))
236
- attn_frame2text = (attn_frame2text-swin_pb_relax_const.reshape(b0, -1, self.n_head, 1, 1))*alpha
237
- else:
238
- attn_frame2text = torch.matmul(q.reshape(b0, -1, self.n_head, frame_num*wind_square, h) / math.sqrt(h), k_text.unsqueeze(1).transpose(-1, -2))
239
-
240
- attn_frame2text = torch.mul(text_attn_mask, attn_frame2text) - 10000.0 * (1.0 - text_attn_mask)
241
- attn_frame2text = attn_frame2text.reshape(bswin, self.n_head, frame_num*wind_square, s0)
242
- attn = torch.cat((attn, attn_frame2text), dim=-1)
243
- attn = F.softmax(attn, dim=-1)
244
-
245
- if attn_dropout is not None:
246
- with get_cuda_rng_tracker().fork():
247
- attn = attn_dropout(attn)
248
-
249
- context_swin = (torch.matmul(attn[..., :-s0], v) +
250
- torch.matmul(attn[..., -s0:].reshape(b0, -1, self.n_head,frame_num*wind_square, s0), v_text.unsqueeze(1))\
251
- .reshape(bswin, self.n_head, frame_num*wind_square, h))\
252
- .permute(0, 2, 1, 3).reshape(bswin, frame_num, self.window_size, self.window_size, h0)
253
-
254
- context_swin = window_reverse(context_swin, self.window_size, self.frame_resolution, self.frame_resolution)
255
- # reverse cycle shift
256
- if self.shift_sizes[layer_id%2] > 0:
257
- context_swin = torch.roll(context_swin, shifts=(self.shift_sizes[layer_id%2], self.shift_sizes[layer_id%2]), dims=(2,3))
258
- context_swin = context_swin.reshape(b0, s1, h0)
259
-
260
- return context_swin
261
-
262
-
263
- class FullAttentionMixin(BaseMixin):
264
- def __init__(self, num_layers,
265
- hidden_size,
266
- frame_resolution,
267
- n_head,
268
- frame_num,
269
- init_method=unscaled_init_method(0.02),
270
- output_layer_init_method=unscaled_init_method(0.02),
271
- ):
272
- super(FullAttentionMixin, self).__init__()
273
- self.num_layers = num_layers # replace attention in the LAST n layers
274
- self.query_key_value = torch.nn.ModuleList(
275
- [ColumnParallelLinear(hidden_size, 3*hidden_size,stride=3,
276
- gather_output=False,init_method=init_method)
277
- for layer_id in range(num_layers)
278
- ])
279
- self.dense = torch.nn.ModuleList(
280
- [RowParallelLinear(
281
- hidden_size,
282
- hidden_size,
283
- input_is_parallel=True,
284
- init_method=output_layer_init_method,
285
- bias=True,
286
- module=self,
287
- name="dense",)
288
- for layer_id in range(num_layers)
289
- ])
290
-
291
- self.n_head = n_head
292
- self.frame_resolution = frame_resolution
293
- self.frame_len = frame_resolution * frame_resolution
294
- self.causal_mask = torch.ones(1, 1, self.frame_len*frame_num, self.frame_len*frame_num).tril()
295
-
296
- self.mask_initialized = False
297
-
298
- self.attn_distribution = torch.nn.ParameterList([
299
- torch.nn.Parameter(torch.zeros(hidden_size))
300
- for _ in range(num_layers)
301
- ])
302
-
303
- def reinit(self, *pre_mixins):
304
- start_layer = len(self.transformer.layers) - self.num_layers
305
- assert start_layer >= 0
306
- for layer_id in range(self.num_layers):
307
- base_attention = self.transformer.layers[start_layer + layer_id].attention
308
- self.query_key_value[layer_id].weight.data.copy_(base_attention.query_key_value.weight.data)
309
- self.query_key_value[layer_id].bias.data.copy_(base_attention.query_key_value.bias.data)
310
-
311
- def attention_extra(self, frame_hidden_state, layer_id, attn_dropout, text_hidden_state=None,
312
- text_attn_mask=None, mode_sequential=False):
313
- # pb relax
314
- # frame_hidden_state [batchsize, frame_num*frame_size, n_head*hiddensize_perhead]
315
- assert mode_sequential == True # only
316
- swin_pb_relax = True
317
- alpha = 16
318
-
319
- if not self.mask_initialized:
320
- self.causal_mask = self.causal_mask.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype)
321
- self.mask_initialized = True
322
- b0, s1, h0 = frame_hidden_state.shape
323
- h = h0 // self.n_head
324
- frame_len = self.frame_resolution * self.frame_resolution
325
- frame_num = s1 // frame_len
326
- assert frame_num*frame_len == s1
327
-
328
- qkv = self.query_key_value[layer_id](frame_hidden_state).reshape(b0, s1, 3, self.n_head, h)\
329
- .permute(2, 0, 3, 1, 4) #[3, b0, n_head, s1, h]
330
- q, k, v = qkv[0], qkv[1], qkv[2]
331
-
332
- # frames-to-frames
333
- if swin_pb_relax:
334
- attn = torch.matmul(q / (math.sqrt(h)*alpha), k.transpose(-1, -2))
335
- else:
336
- attn = torch.matmul(q / math.sqrt(h), k.transpose(-1, -2))
337
- attn = torch.mul(attn, self.causal_mask) - 10000.0 * (1.0 - self.causal_mask)
338
- if swin_pb_relax:
339
- swin_pb_relax_const = torch.max(attn.reshape(b0, self.n_head, -1), dim=-1, keepdim=True)[0].detach().unsqueeze(-1)
340
- attn = (attn - swin_pb_relax_const)*alpha
341
-
342
- if text_hidden_state is None:
343
- attn = F.softmax(attn, dim=-1)
344
- if attn_dropout is not None:
345
- with get_cuda_rng_tracker().fork():
346
- attn = attn_dropout(attn)
347
- context_swin = torch.matmul(attn, v).permute(0, 2, 1, 3).reshape(b0, s1, h0)
348
- else:
349
- # frame-to-text
350
- assert text_attn_mask is not None
351
- s0 = text_hidden_state.shape[1]
352
- qkv_text = self.query_key_value[layer_id](text_hidden_state).reshape(b0, s0, 3, self.n_head, h).permute(2, 0, 3, 1, 4) #[3, b0, n_head, s0, h]
353
- q_text, k_text, v_text = qkv_text[0], qkv_text[1], qkv_text[2]
354
- text_attn_mask = text_attn_mask.unsqueeze(2)
355
- if swin_pb_relax:
356
- attn_frame2text = torch.matmul(q.reshape(b0, self.n_head, s1, h) / (math.sqrt(h)*alpha), k_text.transpose(-1, -2))
357
- attn_frame2text = (attn_frame2text-swin_pb_relax_const.reshape(b0, self.n_head, 1, 1))*alpha
358
- else:
359
- attn_frame2text = torch.matmul(q.reshape(b0, self.n_head, s1, h) / math.sqrt(h), k_text.transpose(-1, -2))
360
- attn_frame2text = torch.mul(text_attn_mask, attn_frame2text) - 10000.0 * (1.0 - text_attn_mask)
361
- attn_frame2text = attn_frame2text.reshape(b0, self.n_head, s1, s0)
362
-
363
- attn = torch.cat((attn, attn_frame2text), dim=-1)
364
- attn = F.softmax(attn, dim=-1)
365
-
366
- if attn_dropout is not None:
367
- with get_cuda_rng_tracker().fork():
368
- attn = attn_dropout(attn)
369
-
370
- context_frame = (torch.matmul(attn[..., :-s0], v) +
371
- torch.matmul(attn[..., -s0:].reshape(b0, self.n_head,s1, s0), v_text))\
372
- .permute(0, 2, 1, 3).reshape(b0, s1, h0)
373
-
374
- return context_frame
375
-
376
-
377
- def attention_localframe_and_text(q0, k0, v0, attention_mask_totxt, attention_mask_local,
378
- n_head, text_len, frame_len, frame_num, attention_dropout=None, layer_id=0, **kwargs):
379
- b, s0, h0 = q0.shape
380
- s1 = s0 - text_len
381
- h = h0 // n_head
382
- assert q0.shape[1] == v0.shape[1] == k0.shape[1] == text_len+frame_len*frame_num
383
- # attention_mask_totxt [b, 1, 1, text_len]
384
- # attention_mask_local [1, 1, frame_num, frame_len, frame_len]
385
- # attention_mask: [1, 1, text_len+frame_len, text_len+frame_len]
386
-
387
- q0 = q0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
388
- v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
389
- k0 = k0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
390
- k0T = k0.transpose(-1, -2)
391
-
392
- # score: any2text
393
- score_any2text = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T[..., :text_len])
394
- score_any2text_part1 = torch.mul(score_any2text[..., :text_len, :], attention_mask_totxt) \
395
- - 10000.0 * (1.0 - attention_mask_totxt)
396
- score_any2text_part2 = torch.mul(score_any2text[..., text_len:, :], attention_mask_totxt) - \
397
- 10000.0 * (1.0 - attention_mask_totxt)
398
-
399
- # score: frame local
400
- q0_frame = q0[:, :, text_len:].reshape(b, n_head, frame_num, frame_len, h)
401
- v0_frame = v0[:, :, text_len:].reshape(b, n_head, frame_num, frame_len, h)
402
- k0T_frame = k0[:, :, text_len:].reshape(b, n_head, frame_num, frame_len, h).transpose(-1, -2)
403
- score_frame_local0 = torch.matmul(q0_frame / math.sqrt(q0_frame.shape[-1]), k0T_frame)
404
- score_frame_local0 = torch.mul(score_frame_local0, attention_mask_local) \
405
- - 10000.0 * (1.0 - attention_mask_local)
406
-
407
- # context for frame
408
- score_frame_all = torch.cat((score_any2text_part2,
409
- score_frame_local0.view(b, n_head, s1, frame_len)), dim=-1)
410
- attention_probs_frame = F.softmax(score_frame_all, dim=-1)
411
-
412
- if attention_dropout is not None:
413
- with get_cuda_rng_tracker().fork():
414
- attention_probs_frame = attention_dropout(attention_probs_frame)
415
-
416
- context_frame2text = torch.matmul(attention_probs_frame[..., :text_len], v0[..., :text_len, :]) # [b, n_head, s1, h]
417
- context_frame_local0 = torch.matmul(attention_probs_frame[..., text_len:text_len+frame_len].\
418
- view(b, n_head, frame_num, frame_len, frame_len), v0_frame).view(b, n_head, s1, h)
419
- context_frame = (context_frame2text + context_frame_local0).transpose(1, 2).reshape(b, s1, h0)
420
-
421
- # context for text
422
- attention_probs_text = F.softmax(score_any2text_part1, dim=-1)
423
- if attention_dropout is not None:
424
- with get_cuda_rng_tracker().fork():
425
- attention_probs_text = attention_dropout(attention_probs_text)
426
- context_text2text = torch.matmul(attention_probs_text, v0[..., :text_len, :])
427
- context_text2text = context_text2text.transpose(1, 2).reshape(b, text_len, h0)
428
-
429
- return context_text2text, context_frame
430
-
431
-
432
- class CogVideoModel(BaseModel):
433
- def __init__(self, args, transformer=None, parallel_output=True):
434
- super().__init__(args, transformer=transformer, parallel_output=parallel_output)
435
- self.stage = args.cogvideo_stage # 1 or 2
436
- self.mode_sequential = True if self.stage==1 else False
437
- self.layout = args.layout # [64, 64+400, 64+5*400]
438
- self.n_head = args.num_attention_heads
439
- frame_resolution = int(math.sqrt(self.layout[1]-self.layout[0]))
440
- frame_num = (args.layout[2]-args.layout[0])//(args.layout[1]-args.layout[0])
441
- frame_len = self.layout[1]-self.layout[0]
442
-
443
- self.add_mixin('extra_position_embedding', PositionEmbeddingMixin(
444
- args.additional_seqlen, args.hidden_size
445
- ))
446
-
447
- if args.window_size == -1:
448
- # full attention
449
- assert self.stage == 1
450
- self.add_mixin('attention_plus', FullAttentionMixin(
451
- num_layers=args.num_layers,
452
- hidden_size=args.hidden_size,
453
- frame_resolution=frame_resolution,
454
- n_head=args.num_attention_heads,
455
- frame_num=frame_num,
456
- ))
457
- else:
458
- self.add_mixin('attention_plus', WindowAttentionMixin(
459
- num_layers=args.num_layers,
460
- hidden_size=args.hidden_size,
461
- frame_resolution=frame_resolution,
462
- window_size=args.window_size,
463
- shift_size=args.window_size//2,
464
- n_head=args.num_attention_heads,
465
- frame_num=frame_num,
466
- ))
467
- # attention_mask_local
468
- self.attention_mask_local_sequential = torch.ones(1, 1, frame_num, frame_len, frame_len).tril().unsqueeze(0)
469
- self.attention_mask_local_interp = torch.ones(1, 1, frame_num, frame_len, frame_len)
470
-
471
- for idx in range(1, frame_num, 2):
472
- self.attention_mask_local_interp[:, :, idx:idx+1].tril_()
473
- self.attention_mask_local_interp = self.attention_mask_local_interp.unsqueeze(0)
474
- self.mask_initialized = False
475
-
476
- @classmethod
477
- def add_model_specific_args(cls, parser):
478
- group = parser.add_argument_group('CogVideoModel', 'CogVideo model configurations')
479
- group.add_argument("--layout", type=str, default='64, 464, 2064', help='text_len, textlen+frame_len, textlen+frame_len*frame_num')
480
- group.add_argument("--window-size", type=int, default=10, help="swin attention's window size in temperal channel, -1 represents full attention")
481
- group.add_argument("--additional-seqlen", type=int, default=2000)
482
- group.add_argument("--cogvideo-stage", type=int, default=1, choices=[1,2])
483
- return parser
484
-
485
- def disable_untrainable_params(self):
486
- self.transformer.requires_grad_(False)
487
-
488
- def position_embedding_forward(self, position_ids, **kw_args):
489
- position = position_ids[..., :(64+400)]
490
- position_plus = position_ids[..., (64+400):]
491
- position_embeddings = torch.cat(
492
- (
493
- self.transformer.position_embeddings(position),
494
- self.get_mixin('extra_position_embedding').position_embeddings(position_plus-(512+400))
495
- ),
496
- dim=-2
497
- )
498
- return position_embeddings
499
-
500
- def attention_forward(self, hidden_states, mask, layer_id, **kw_args):
501
- # mask.shape=[bs, 1, 1, 64]
502
- if not self.mask_initialized:
503
- self.attention_mask_local_sequential = self.attention_mask_local_sequential.to(device=hidden_states.device, dtype=hidden_states.dtype)
504
- self.attention_mask_local_interp = self.attention_mask_local_interp.to(device=hidden_states.device, dtype=hidden_states.dtype)
505
- self.mask_initialized = True
506
-
507
- attn_module = self.transformer.layers[layer_id].attention
508
- hidden_size = hidden_states.shape[-1]
509
- bs = hidden_states.shape[0]
510
-
511
- # base model qkv
512
- mixed_raw_layer = attn_module.query_key_value(hidden_states)
513
- q0, k0, v0 = split_tensor_along_last_dim(mixed_raw_layer, 3)
514
- dropout_fn = self.transformer.layers[layer_id].attention.attention_dropout if self.training else None
515
-
516
- attention_mask_local = self.attention_mask_local_sequential if self.mode_sequential else self.attention_mask_local_interp
517
- context_text, context_frame_local_text = attention_localframe_and_text(
518
- q0, k0, v0,
519
- attention_mask_totxt=mask,
520
- attention_mask_local=attention_mask_local,
521
- n_head=attn_module.num_attention_heads_per_partition,
522
- text_len=self.layout[0],
523
- frame_len=self.layout[1]-self.layout[0],
524
- frame_num=(self.layout[2]-self.layout[0])//(self.layout[1]-self.layout[0]),
525
- attention_dropout=dropout_fn,
526
- layer_id=layer_id,
527
- )
528
-
529
- context_frame_swin = self.get_mixin('attention_plus').attention_extra(
530
- hidden_states[:, self.layout[0]:], layer_id, dropout_fn,
531
- text_hidden_state=hidden_states[:, :self.layout[0]],
532
- text_attn_mask=mask[..., 0, :],
533
- mode_sequential=self.mode_sequential)
534
-
535
- attn_distrib = torch.sigmoid(self.get_mixin('attention_plus').attn_distribution[layer_id])
536
- attn_distrib = attn_distrib.unsqueeze(0).unsqueeze(0)
537
-
538
- output_text = attn_module.dense(context_text)
539
- output_frame = torch.mul(attn_module.dense(context_frame_local_text), attn_distrib)\
540
- +torch.mul(self.get_mixin('attention_plus').dense[layer_id](context_frame_swin), 1-attn_distrib)
541
- output = torch.cat((output_text, output_frame), dim=-2)
542
-
543
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pretrain_cogvideo.py DELETED
@@ -1,184 +0,0 @@
1
- # -*- encoding: utf-8 -*-
2
- '''
3
- @File : pretrain_cogvideo.py
4
- @Time : 2021/10/06 00:58:32
5
- @Author : Wenyi Hong
6
- @Contact : [email protected]
7
- '''
8
-
9
- # here put the import lib
10
- import os
11
- import sys
12
- import math
13
- import random
14
- import torch
15
- import argparse
16
- import numpy as np
17
- from icetk import icetk as tokenizer
18
- tokenizer.add_special_tokens(['<start_of_image>', '<start_of_english>', '<start_of_chinese>'])
19
-
20
- from models.cogvideo_model import CogVideoModel
21
- from SwissArmyTransformer import mpu, get_args
22
- from SwissArmyTransformer.training.deepspeed_training import training_main
23
- from SwissArmyTransformer.data_utils import BinaryDataset
24
-
25
- def get_masks_and_position_ids_video(data, attention_mask_totxt=None, args=None):
26
- # Extract batch size and sequence length.
27
- batch_size, seq_length = data.size()
28
- assert attention_mask_totxt is not None
29
- layout = args.layout
30
- assert seq_length == layout[-1]
31
- n_pads = layout[0] - attention_mask_totxt.sum(dim=-1).long()
32
- frame_len = layout[1]-layout[0]
33
- position_ids = torch.zeros(batch_size, layout[2], dtype=torch.long,
34
- device=data.device)
35
- for i in range(batch_size):
36
- torch.arange(layout[0] - n_pads[i], out=position_ids[i, n_pads[i]:layout[0]],
37
- dtype=torch.long, device=data.device)
38
- torch.arange(512, 512+layout[2]-layout[0],
39
- out=position_ids[i, layout[0]:], dtype=torch.long, device=data.device)
40
- return position_ids
41
-
42
-
43
- def get_batch(data_iterator, args, timers):
44
- # Items and their type.
45
- keys = ['text', 'loss_mask', 'attention_mask_totxt']
46
- datatype = torch.int64
47
-
48
- # Broadcast data.
49
- timers('data loader').start()
50
- if data_iterator is not None:
51
- data = next(data_iterator)
52
- else:
53
- data = None
54
- timers('data loader').stop()
55
-
56
- data_b = mpu.broadcast_data(keys, data, datatype)
57
- # Unpack.
58
- tokens_ = data_b['text'].long()
59
- loss_mask = data_b['loss_mask'].float()
60
- attention_mask_totxt = data_b['attention_mask_totxt'].float()
61
-
62
- labels = tokens_[:, 1:].clone().contiguous()
63
- loss_mask = loss_mask[:, 1:].contiguous()
64
- tokens = tokens_[:, :-1].clone().contiguous()
65
-
66
- for idx in range(args.layout[0], args.layout[2], 400):
67
- tokens[:, idx] = tokenizer['<start_of_image>']
68
- # Get the masks and postition ids.
69
- position_ids = get_masks_and_position_ids_video(
70
- tokens,
71
- attention_mask_totxt=attention_mask_totxt,
72
- args=args
73
- )
74
- attention_mask_totxt = attention_mask_totxt.unsqueeze(1).unsqueeze(1)
75
- # Convert
76
- if args.fp16:
77
- attention_mask_totxt = attention_mask_totxt.half()
78
- return tokens, labels, loss_mask, attention_mask_totxt, position_ids
79
-
80
-
81
- def forward_step(data_iterator, model, args, timers):
82
- """Forward step."""
83
-
84
- # Get the batch.
85
- timers('batch generator').start()
86
- tokens, labels, loss_mask, attention_mask_totxt, position_ids = get_batch(
87
- data_iterator, args, timers)
88
- timers('batch generator').stop()
89
-
90
- # Forward model.
91
- logits, *mems = model(tokens, position_ids, attention_mask_totxt)
92
- # ======= hyper params =======#
93
- perframe_len = 400
94
- text_len=64
95
- frame_num = 5
96
- logits_img_tokens = logits[:, text_len:, :tokenizer.num_image_tokens].float().contiguous()
97
- losses = mpu.vocab_parallel_cross_entropy(logits_img_tokens, labels[:, text_len:])
98
- # scaling loss mask
99
- loss_mask = loss_mask[:, text_len:].reshape(-1)
100
-
101
- losses_1d = losses.reshape(-1) * loss_mask
102
- loss = torch.sum(losses_1d) / loss_mask.sum()
103
- # ===================== Log partial losses ======================== #
104
- log_loss_dict = {}
105
- bs = losses.shape[0]
106
-
107
- if args.cogvideo_stage == 1:
108
- for i in range(frame_num):
109
- log_loss_dict[f'AR_f{i}_loss'] = losses[:, i*perframe_len:(i+1)*perframe_len].contiguous().reshape(-1).detach().sum() / max((perframe_len*bs), 1)
110
- else:
111
- for i in range(1, frame_num-1):
112
- log_loss_dict[f'ITP_f{i}_loss'] = losses[:, i*perframe_len:(i+1)*perframe_len].contiguous().reshape(-1).detach().sum() / max((perframe_len*bs), 1)
113
-
114
- # ===================== END OF BLOCK ======================= #
115
- return loss, log_loss_dict
116
-
117
-
118
- def create_dataset_function(path, args):
119
- dataset_layout = [64, 464, 2064]
120
- input_layout = [64, 464, 2064]
121
- # frame_num = 6
122
- # frame_interval = 2 # DEBUG!!!
123
- def process_fn(row):
124
- row = row.astype(np.int64)
125
- text = row[:dataset_layout[0]]
126
- frames = row[dataset_layout[0]:]
127
-
128
- if text[0] == tokenizer['<pad>']:
129
- text = text[1:] # due to our way of data processing
130
- if args.cogvideo_stage == 1:
131
- text, loss_mask, frames = make_text_video_generation(text, frames)
132
- else:
133
- text, loss_mask, frames = mask_video_frame_interpolation(text, frames)
134
-
135
- n_pad = input_layout[0] - len(text)
136
- parts = [
137
- np.array([tokenizer['<pad>']] * n_pad, dtype=np.int64),
138
- text,
139
- np.array([tokenizer['<start_of_image>']], dtype=np.int64),
140
- frames,
141
- ]
142
- ret = np.concatenate(parts, axis=0)
143
-
144
- attention_mask_totxt = np.array([0] * n_pad + [1] * (input_layout[0]-n_pad))
145
- return {'text': ret,
146
- 'loss_mask': loss_mask,
147
- 'attention_mask_totxt': attention_mask_totxt,
148
- }
149
- return BinaryDataset(path, process_fn, length_per_sample=dataset_layout[-1])
150
-
151
- def make_text_video_generation(text, frames):
152
- input_layout = [64, 464, 2064]
153
- text = text[text!= tokenizer['<pad>']][:input_layout[0]] # dataset format: 1.0秒<n>{text}<pad><pad> ...
154
- loss_mask = np.array([0] * (input_layout[1]+1) + [1] * (input_layout[2] - input_layout[1])) # 按照input的,之后loss_mask会左移一位
155
- return text, loss_mask, frames
156
-
157
- def mask_video_frame_interpolation(text, frames):
158
- input_layout = [64, 464, 2064]
159
- frame_len = input_layout[1]-input_layout[0]
160
- # text format: <pad> 1.0秒 <n> {text} <pad> <pad>
161
- text = text[text!= tokenizer['<pad>']][:input_layout[0]]
162
- loss_mask = np.array([0] * (input_layout[1]+1)
163
- + [1] * (input_layout[1]-input_layout[0])
164
- + [0] * (input_layout[1]-input_layout[0])
165
- + [1] * (input_layout[1]-input_layout[0])
166
- + [0] * (input_layout[1]-input_layout[0]) )# 按照input的,之后loss_mask会左移一位
167
-
168
- return text, loss_mask, frames
169
-
170
-
171
-
172
- if __name__ == '__main__':
173
- py_parser = argparse.ArgumentParser(add_help=False)
174
- py_parser.add_argument('--txt-loss-scale', type=float, default=1)
175
- CogVideoModel.add_model_specific_args(py_parser)
176
-
177
- known, args_list = py_parser.parse_known_args()
178
-
179
- args = get_args(args_list)
180
- args = argparse.Namespace(**vars(args), **vars(known))
181
-
182
- args.layout = [int(x) for x in args.layout.split(',')]
183
-
184
- training_main(args, model_cls=CogVideoModel, forward_step_function=forward_step, create_dataset_function=create_dataset_function)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt DELETED
@@ -1,4 +0,0 @@
1
- SwissArmyTransformer>=0.2.9
2
- icetk
3
- gifmaker
4
- torchvision
 
 
 
 
 
scripts/ds_brain_pretrain_cogvideo_stage1.sh DELETED
@@ -1,108 +0,0 @@
1
- #! /bin/bash
2
-
3
- # Change for multinode config
4
-
5
- NUM_WORKERS=1
6
- NUM_GPUS_PER_WORKER=8
7
- MP_SIZE=1
8
-
9
- script_path=$(realpath $0)
10
- script_dir=$(dirname $script_path)
11
- main_dir=$(dirname $script_dir)
12
-
13
- OPTIONS_NCCL="NCCL_DEBUG=warning NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2"
14
- HOST_FILE_PATH="hostfile"
15
- # HOST_FILE_PATH="hostfile_single"
16
-
17
- video_data_test="" # TODO
18
- CHECKPOINT_PATH="" # TODO: CogView2 ckpt
19
-
20
- config_json="$script_dir/ds_config_zero.json"
21
- gpt_options=" \
22
- --experiment-name pretrain-cogvideo-stage1 \
23
- --tokenizer-type fake \
24
- --vocab-size 150010 \
25
- --model-parallel-size ${MP_SIZE} \
26
- --mode finetune \
27
- --num-workers 0 \
28
- --num-layers 48 \
29
- --hidden-size 3072 \
30
- --num-attention-heads 48 \
31
- --layout 64,464,2064 \
32
- --window-size -1 \
33
- --cogvideo-stage 1 \
34
- --additional-seqlen 2000 \
35
- --train-iters 500000 \
36
- --resume-dataloader \
37
- --train-data ${video_data_test} \
38
- --train-data-weights 1 \
39
- --split 949,50,1 \
40
- --distributed-backend nccl \
41
- --lr-decay-style cosine \
42
- --warmup .001 \
43
- --checkpoint-activations \
44
- --max-sequence-length 1024 \
45
- --fp16 \
46
- --save-interval 2000 \
47
- --eval-interval 500 \
48
- --eval-iters 15 \
49
- --log-interval 50 \
50
- --save $main_dir/checkpoints \
51
- --sandwich-ln \
52
- --load $CHECKPOINT_PATH \
53
- "
54
- # --load $CHECKPOINT_PATH \
55
- # \ --sandwich-ln
56
-
57
-
58
- gpt_options="${gpt_options}
59
- --deepspeed \
60
- --deepspeed_config ${config_json} \
61
- "
62
-
63
- #!/bin/bash
64
-
65
- # Distribute Example
66
- #export NCCL_SOCKET_IFNAME=eth0
67
- export NCCL_IB_DISABLE=0
68
- export NCCL_NET_GDR_LEVEL=2
69
- #export NCCL_IB_CUDA_SUPPORT=1
70
- #export NCCL_IB_GID_INDEX=3
71
- #export NCCL_IB_HCA=$(pushd /sys/class/infiniband/ > /dev/null; for i in mlx5_*; do cat $i/ports/1/gid_attrs/types/* 2>/dev/null | grep v >/dev/null && echo $i ; done; popd > /dev/null)
72
- export NCCL_DEBUG=info
73
- export OMP_NUM_THREADS=4
74
-
75
- if [ $RLAUNCH_REPLICA == "0" ]; then
76
- ifconfig eth0 | grep inet | grep -v inet6 | awk '{print $2}' > master_ip
77
- fi
78
-
79
- function finish {
80
- rm -rf master_ip
81
- }
82
-
83
- trap finish EXIT INT TERM
84
-
85
- while [ ! -f master_ip ]; do
86
- echo "wait master_ip..."
87
- ls > /dev/null && sleep 1;
88
- done
89
-
90
- export MASTER_ADDR=$(cat master_ip)
91
- echo "master_ip: $MASTER_ADDR"
92
-
93
- MP_SIZE=1
94
- task_set=$2
95
- source $1
96
- DATESTR=$(date +"%m-%d-%H-%M")
97
-
98
- mkdir logs
99
- run_cmd="sudo /opt/conda/bin/python -m torch.distributed.launch --nproc_per_node=8 \
100
- --nnodes=$RLAUNCH_REPLICA_TOTAL --node_rank=$RLAUNCH_REPLICA \
101
- --master_addr=$MASTER_ADDR --master_port=12355 pretrain_cogvideo.py $@ ${gpt_options} 2>&1 | tee logs/log-${DATESTR}-${RLAUNCH_REPLICA}.txt"
102
-
103
-
104
- # run_cmd="${OPTIONS_NCCL} deepspeed --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER} --hostfile ${HOST_FILE_PATH} pretrain_video_swin_cond_glm_interp.py $@ ${gpt_options}"
105
- echo ${run_cmd}
106
- eval ${run_cmd}
107
-
108
- set +x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/ds_brain_pretrain_cogvideo_stage2.sh DELETED
@@ -1,108 +0,0 @@
1
- #! /bin/bash
2
-
3
- # Change for multinode config
4
-
5
- NUM_WORKERS=1
6
- NUM_GPUS_PER_WORKER=8
7
- MP_SIZE=1
8
-
9
- script_path=$(realpath $0)
10
- script_dir=$(dirname $script_path)
11
- main_dir=$(dirname $script_dir)
12
-
13
- OPTIONS_NCCL="NCCL_DEBUG=warning NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2"
14
- HOST_FILE_PATH="hostfile"
15
- # HOST_FILE_PATH="hostfile_single"
16
-
17
- video_data_test="" # TODO
18
- CHECKPOINT_PATH="" # TODO: CogView2 ckpt
19
-
20
- config_json="$script_dir/ds_config_zero.json"
21
- gpt_options=" \
22
- --experiment-name pretrain-cogvideo-stage2 \
23
- --tokenizer-type fake \
24
- --vocab-size 150010 \
25
- --model-parallel-size ${MP_SIZE} \
26
- --mode finetune \
27
- --num-workers 0 \
28
- --num-layers 48 \
29
- --hidden-size 3072 \
30
- --num-attention-heads 48 \
31
- --layout 64,464,2064 \
32
- --window-size 10 \
33
- --cogvideo-stage 2 \
34
- --additional-seqlen 2000 \
35
- --train-iters 500000 \
36
- --resume-dataloader \
37
- --train-data ${video_data_test} \
38
- --train-data-weights 1 \
39
- --split 949,50,1 \
40
- --distributed-backend nccl \
41
- --lr-decay-style cosine \
42
- --warmup .001 \
43
- --checkpoint-activations \
44
- --max-sequence-length 1024 \
45
- --fp16 \
46
- --save-interval 2000 \
47
- --eval-interval 500 \
48
- --eval-iters 15 \
49
- --log-interval 50 \
50
- --save $main_dir/checkpoints \
51
- --sandwich-ln \
52
- --load $CHECKPOINT_PATH \
53
- "
54
- # --load $CHECKPOINT_PATH \
55
- # \ --sandwich-ln
56
-
57
-
58
- gpt_options="${gpt_options}
59
- --deepspeed \
60
- --deepspeed_config ${config_json} \
61
- "
62
-
63
- #!/bin/bash
64
-
65
- # Distribute Example
66
- #export NCCL_SOCKET_IFNAME=eth0
67
- export NCCL_IB_DISABLE=0
68
- export NCCL_NET_GDR_LEVEL=2
69
- #export NCCL_IB_CUDA_SUPPORT=1
70
- #export NCCL_IB_GID_INDEX=3
71
- #export NCCL_IB_HCA=$(pushd /sys/class/infiniband/ > /dev/null; for i in mlx5_*; do cat $i/ports/1/gid_attrs/types/* 2>/dev/null | grep v >/dev/null && echo $i ; done; popd > /dev/null)
72
- export NCCL_DEBUG=info
73
- export OMP_NUM_THREADS=4
74
-
75
- if [ $RLAUNCH_REPLICA == "0" ]; then
76
- ifconfig eth0 | grep inet | grep -v inet6 | awk '{print $2}' > master_ip
77
- fi
78
-
79
- function finish {
80
- rm -rf master_ip
81
- }
82
-
83
- trap finish EXIT INT TERM
84
-
85
- while [ ! -f master_ip ]; do
86
- echo "wait master_ip..."
87
- ls > /dev/null && sleep 1;
88
- done
89
-
90
- export MASTER_ADDR=$(cat master_ip)
91
- echo "master_ip: $MASTER_ADDR"
92
-
93
- MP_SIZE=1
94
- task_set=$2
95
- source $1
96
- DATESTR=$(date +"%m-%d-%H-%M")
97
-
98
- mkdir logs
99
- run_cmd="sudo /opt/conda/bin/python -m torch.distributed.launch --nproc_per_node=8 \
100
- --nnodes=$RLAUNCH_REPLICA_TOTAL --node_rank=$RLAUNCH_REPLICA \
101
- --master_addr=$MASTER_ADDR --master_port=12355 pretrain_cogvideo.py $@ ${gpt_options} 2>&1 | tee logs/log-${DATESTR}-${RLAUNCH_REPLICA}.txt"
102
-
103
-
104
- # run_cmd="${OPTIONS_NCCL} deepspeed --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER} --hostfile ${HOST_FILE_PATH} pretrain_video_swin_cond_glm_interp.py $@ ${gpt_options}"
105
- echo ${run_cmd}
106
- eval ${run_cmd}
107
-
108
- set +x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/ds_config_zero.json DELETED
@@ -1,42 +0,0 @@
1
- {
2
- "train_micro_batch_size_per_gpu": 4,
3
- "gradient_accumulation_steps": 1,
4
- "steps_per_print": 1,
5
- "gradient_clipping": 0.1,
6
- "zero_optimization": {
7
- "stage": 2,
8
- "cpu_offload": true,
9
- "contiguous_gradients": false,
10
- "overlap_comm": true,
11
- "reduce_scatter": false,
12
- "reduce_bucket_size": 100000000,
13
- "allgather_bucket_size": 1000000000,
14
- "load_from_fp32_weights": false
15
- },
16
- "zero_allow_untested_optimizer": true,
17
- "fp16": {
18
- "enabled": true,
19
- "loss_scale": 0,
20
- "loss_scale_window": 400,
21
- "hysteresis": 2,
22
- "min_loss_scale": 1
23
- },
24
- "optimizer": {
25
- "type": "Adam",
26
- "params": {
27
- "lr": 0.0002,
28
- "betas": [
29
- 0.9,
30
- 0.95
31
- ],
32
- "eps": 1e-8,
33
- "weight_decay": 1e-4
34
- }
35
- },
36
- "activation_checkpointing": {
37
- "partition_activations": false,
38
- "contiguous_memory_optimization": false
39
- },
40
- "wall_clock_breakdown": false
41
- }
42
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/inference_cogvideo_pipeline.sh DELETED
@@ -1,38 +0,0 @@
1
- #!/bin/bash
2
-
3
- NLAYERS=48
4
- NHIDDEN=3072
5
- NATT=48
6
- MAXSEQLEN=1024
7
- MASTER_PORT=$(shuf -n 1 -i 10000-65535)
8
- MPSIZE=1
9
-
10
- #SAMPLING ARGS
11
- TEMP=1.05
12
- TOPK=12
13
-
14
- script_path=$(realpath $0)
15
- script_dir=$(dirname $script_path)
16
-
17
- MASTER_PORT=${MASTER_PORT} SAT_HOME=/sharefs/cogview-new python cogvideo_pipeline.py \
18
- --input-source /home/user/app/CogVideo/prompt.txt \
19
- --output-path ./output \
20
- --parallel-size 1 \
21
- --both-stages \
22
- --use-guidance-stage1 \
23
- --guidance-alpha 3.0 \
24
- --generate-frame-num 5 \
25
- --tokenizer-type fake \
26
- --mode inference \
27
- --distributed-backend nccl \
28
- --fp16 \
29
- --model-parallel-size $MPSIZE \
30
- --temperature $TEMP \
31
- --coglm-temperature2 0.89 \
32
- --top_k $TOPK \
33
- --sandwich-ln \
34
- --seed 1234 \
35
- --num-workers 0 \
36
- --batch-size 1 \
37
- --max-inference-batch-size 1 \
38
- $@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sr_pipeline/__init__.py DELETED
@@ -1,17 +0,0 @@
1
- # -*- encoding: utf-8 -*-
2
- '''
3
- @File : __init__.py
4
- @Time : 2022/03/02 13:57:09
5
- @Author : Ming Ding
6
- @Contact : [email protected]
7
- '''
8
-
9
- # here put the import lib
10
- import os
11
- import sys
12
- import math
13
- import random
14
-
15
- from .direct_sr import DirectSuperResolution
16
- from .iterative_sr import IterativeSuperResolution
17
- from .sr_group import SRGroup
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sr_pipeline/direct_sr.py DELETED
@@ -1,117 +0,0 @@
1
- # -*- encoding: utf-8 -*-
2
- '''
3
- @File : direct_sr.py
4
- @Time : 2022/03/02 13:58:11
5
- @Author : Ming Ding
6
- @Contact : [email protected]
7
- '''
8
-
9
- # here put the import lib
10
- import os
11
- import sys
12
- import math
13
- import random
14
- import torch
15
-
16
- # -*- encoding: utf-8 -*-
17
- '''
18
- @File : inference_cogview2.py
19
- @Time : 2021/10/10 16:31:34
20
- @Author : Ming Ding
21
- @Contact : [email protected]
22
- '''
23
-
24
- # here put the import lib
25
- import os
26
- import sys
27
- import math
28
- import random
29
- from PIL import ImageEnhance, Image
30
-
31
- import torch
32
- import argparse
33
- from torchvision import transforms
34
-
35
- from SwissArmyTransformer import get_args
36
- from SwissArmyTransformer.training.model_io import load_checkpoint
37
- from .dsr_sampling import filling_sequence_dsr, IterativeEntfilterStrategy
38
- from SwissArmyTransformer.generation.utils import timed_name, save_multiple_images, generate_continually
39
-
40
- from .dsr_model import DsrModel
41
-
42
- from icetk import icetk as tokenizer
43
-
44
- class DirectSuperResolution:
45
- def __init__(self, args, path, max_bz=4, topk=6, onCUDA=False):
46
- args.load = path
47
- args.kernel_size = 5
48
- args.kernel_size2 = 5
49
- args.new_sequence_length = 4624
50
- args.layout = [96,496,4096]
51
-
52
- model = DsrModel(args)
53
- if args.fp16:
54
- model = model.half()
55
-
56
- load_checkpoint(model, args) # on cpu
57
- model.eval()
58
- self.model = model
59
- self.onCUDA = onCUDA
60
- if onCUDA:
61
- self.model = self.model.cuda()
62
-
63
- invalid_slices = [slice(tokenizer.num_image_tokens, None)]
64
-
65
- self.strategy = IterativeEntfilterStrategy(invalid_slices,
66
- temperature=1.0, topk=topk) # temperature not used # Temperature Freezed Here!!
67
- self.max_bz = max_bz
68
-
69
- def __call__(self, text_tokens, image_tokens, enhance=False):
70
- if len(text_tokens.shape) == 1:
71
- text_tokens.unsqueeze_(0)
72
- if len(image_tokens.shape) == 1:
73
- image_tokens.unsqueeze_(0)
74
- # ===================== Debug ======================== #
75
- # new_image_tokens = []
76
- # for small_img in image_tokens:
77
- # decoded = tokenizer.decode(image_ids=small_img)
78
- # decoded = torch.nn.functional.interpolate(decoded, size=(480, 480)).squeeze(0)
79
- # ndarr = decoded.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
80
- # image_pil_raw = ImageEnhance.Sharpness(Image.fromarray(ndarr))
81
- # small_img2 = tokenizer.encode(image_pil=image_pil_raw.enhance(1.5), image_size=480).view(-1)
82
- # new_image_tokens.append(small_img2)
83
- # image_tokens = torch.stack(new_image_tokens)
84
- # return image_tokens
85
- # ===================== END OF BLOCK ======================= #
86
- if enhance:
87
- new_image_tokens = []
88
- for small_img in image_tokens:
89
- decoded = tokenizer.decode(image_ids=small_img).squeeze(0)
90
- ndarr = decoded.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
91
- image_pil_raw = ImageEnhance.Sharpness(Image.fromarray(ndarr))
92
- small_img2 = tokenizer.encode(image_pil=image_pil_raw.enhance(1.), image_size=160).view(-1)
93
- new_image_tokens.append(small_img2)
94
- image_tokens = torch.stack(new_image_tokens)
95
-
96
- seq = torch.cat((text_tokens,image_tokens), dim=1)
97
- seq1 = torch.tensor([tokenizer['<start_of_image>']]*3601, device=image_tokens.device).unsqueeze(0).expand(text_tokens.shape[0], -1)
98
- if not self.onCUDA:
99
- print('Converting Dsr model...')
100
- model = self.model.cuda()
101
- else:
102
- model = self.model
103
- print('Direct super-resolution...')
104
- output_list = []
105
- for tim in range(max((text_tokens.shape[0]+self.max_bz-1) // self.max_bz, 1)):
106
- output1 = filling_sequence_dsr(model,
107
- seq[tim*self.max_bz:(tim+1)*self.max_bz],
108
- seq1[tim*self.max_bz:(tim+1)*self.max_bz],
109
- warmup_steps=1, block_hw=(1, 0),
110
- strategy=self.strategy
111
- )
112
- output_list.extend(output1[1:])
113
- if not self.onCUDA:
114
- print('Moving back Dsr to cpu...')
115
- model = model.cpu()
116
- torch.cuda.empty_cache()
117
- return torch.cat(output_list, dim=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sr_pipeline/dsr_model.py DELETED
@@ -1,225 +0,0 @@
1
- # -*- encoding: utf-8 -*-
2
- '''
3
- @File : cuda2d_model.py
4
- @Time : 2021/10/02 01:36:32
5
- @Author : Ming Ding
6
- @Contact : [email protected]
7
- '''
8
-
9
- # here put the import lib
10
- import os
11
- import sys
12
- import math
13
- import random
14
- import torch
15
- import torch.nn.functional as F
16
-
17
-
18
- from SwissArmyTransformer.model.base_model import BaseModel, BaseMixin
19
-
20
- from SwissArmyTransformer.model.transformer import split_tensor_along_last_dim, unscaled_init_method
21
- from SwissArmyTransformer.mpu.utils import sqrt
22
- from deepspeed.runtime.activation_checkpointing.checkpointing import get_cuda_rng_tracker
23
- from SwissArmyTransformer.mpu import ColumnParallelLinear, RowParallelLinear
24
-
25
- class PositionEmbeddingMixin(BaseMixin):
26
- def __init__(self, additional_sequence_length, hidden_size,
27
- init_method_std=0.02, reinit_slice=slice(512, 512+400)
28
- ):
29
- super(PositionEmbeddingMixin, self).__init__()
30
- self.reinit_slice = reinit_slice
31
- self.position_embeddings = torch.nn.Embedding(additional_sequence_length, hidden_size)
32
- torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std)
33
-
34
- def reinit(self, parent_model=None):
35
- old_weights = self.transformer.position_embeddings.weight.data[self.reinit_slice]
36
- old_len, hidden_size = old_weights.shape
37
- assert hidden_size == self.position_embeddings.weight.shape[-1]
38
- old_edge, new_edge = sqrt(old_len), sqrt(self.position_embeddings.weight.shape[-2])
39
- assert new_edge % old_edge == 0
40
- self.position_embeddings.weight.data.view(new_edge // old_edge, old_edge, new_edge // old_edge, old_edge, hidden_size).copy_(old_weights.view(1, old_edge, 1, old_edge, hidden_size))
41
- # self.position_embeddings.weight.data.view(-1, old_len, hidden_size).copy_(old_weights)
42
-
43
-
44
- class AttentionMixin(BaseMixin):
45
- def __init__(self, num_layers,
46
- hidden_size,
47
- init_method=unscaled_init_method(0.02),
48
- output_layer_init_method=unscaled_init_method(0.02)
49
- ):
50
- super(AttentionMixin, self).__init__()
51
- self.num_layers = num_layers # replace attention in the LAST n layers
52
- self.query_key_value = torch.nn.ModuleList(
53
- [ColumnParallelLinear(hidden_size, 3 * hidden_size, stride=3,
54
- gather_output=False, init_method=init_method)
55
- for layer_id in range(num_layers)
56
- ])
57
- self.dense = torch.nn.ModuleList(
58
- [RowParallelLinear(hidden_size,
59
- hidden_size,
60
- input_is_parallel=True,
61
- init_method=output_layer_init_method)
62
- for layer_id in range(num_layers)
63
- ])
64
-
65
- def reinit(self, parent_model=None):
66
- start_layer = len(self.transformer.layers) - self.num_layers
67
- assert start_layer >= 0
68
- for layer_id in range(self.num_layers):
69
- old_attention = self.transformer.layers[start_layer + layer_id].attention
70
- self.query_key_value[layer_id].weight.data.copy_(old_attention.query_key_value.weight.data)
71
- self.query_key_value[layer_id].bias.data.copy_(old_attention.query_key_value.bias.data)
72
- self.dense[layer_id].weight.data.copy_(old_attention.dense.weight.data)
73
- self.dense[layer_id].bias.data.copy_(old_attention.dense.bias.data)
74
-
75
- class DsrModel(BaseModel):
76
- def __init__(self, args, transformer=None):
77
- super().__init__(args, transformer=transformer)
78
- self.original_sequence_length = args.max_sequence_length
79
- additional_seqlen = args.new_sequence_length - args.max_sequence_length
80
- self.add_mixin('extra_position_embedding', PositionEmbeddingMixin(
81
- additional_seqlen, args.hidden_size
82
- ))
83
- self.add_mixin('attention_plus', AttentionMixin(
84
- num_layers=args.num_layers,
85
- hidden_size=args.hidden_size
86
- ))
87
- self.layout = args.layout
88
- # [PAD]... [ROI1] text ... [BOI1] {layout[0]} 1024 {layout[1]} [EOI1] 4095 {layout[2]}
89
- self.kernel_size = args.kernel_size
90
- self.kernel_size2 = args.kernel_size2
91
- self.log_attention_weights = None
92
-
93
- def position_embedding_forward(self, position_ids, **kw_args):
94
- position = position_ids[..., :self.layout[1]]
95
- position_plus = position_ids[..., self.layout[1]:] - self.original_sequence_length
96
- position_embeddings = torch.cat(
97
- (
98
- self.transformer.position_embeddings(position),
99
- self.get_mixin('extra_position_embedding').position_embeddings(position_plus)
100
- ),
101
- dim=-2
102
- )
103
- return position_embeddings
104
-
105
- def attention_forward(self, hidden_states, mask,
106
- layer_id=None, log_attention_weights=None, **kw_args):
107
- attn_module = self.transformer.layers[layer_id].attention
108
- # attention_plus on all layers
109
- query_key_value_plus = self.get_mixin('attention_plus').query_key_value[layer_id]
110
- dense_plus = self.get_mixin('attention_plus').dense[layer_id]
111
- # split two parts
112
- hidden_states_plus = hidden_states[:, self.layout[1]:]
113
- hidden_states = hidden_states[:, :self.layout[1]]
114
- # base model qkv
115
- mixed_raw_layer = attn_module.query_key_value(hidden_states)
116
- q0, k0, v0 = split_tensor_along_last_dim(mixed_raw_layer, 3)
117
- # cuda2d model qkv
118
- mixed_raw_layer = query_key_value_plus(hidden_states_plus)
119
- q1, k1, v1 = split_tensor_along_last_dim(mixed_raw_layer, 3)
120
-
121
- dropout_fn = attn_module.attention_dropout if self.training else None
122
-
123
- # cuda2d attention
124
- context_layer0, context_layer1 = sparse_attention_2d_light(
125
- q0, k0, v0,
126
- q1, k1, v1,
127
- mask,
128
- n_head=attn_module.num_attention_heads_per_partition,
129
- text_len=self.layout[0],
130
- kernel_size=self.kernel_size,
131
- kernel_size2=self.kernel_size2,
132
- attention_dropout=dropout_fn,
133
- log_attention_weights=log_attention_weights,
134
- add_scalar=(kw_args['add_scalar'] if 'add_scalar' in kw_args else 0)
135
- )
136
-
137
- output_0 = attn_module.dense(context_layer0)
138
- output_1 = dense_plus(context_layer1)
139
- output = torch.cat((output_0, output_1), dim=1)
140
-
141
- return output
142
-
143
- def final_forward(self, logits, **kwargs):
144
- logits_parallel = logits
145
- logits_parallel = torch.nn.functional.linear(logits_parallel.float(), self.transformer.word_embeddings.weight[:20000].float())
146
- # logits_parallel = torch.nn.functional.linear(logits_parallel, self.transformer.word_embeddings.weight[:20000])
147
- return logits_parallel
148
-
149
- def disable_untrainable_params(self):
150
- self.transformer.requires_grad_(False)
151
-
152
- @classmethod
153
- def add_model_specific_args(cls, parser):
154
- group = parser.add_argument_group('Cuda2dModel', 'cuda2d model configurations')
155
- group.add_argument("--kernel-size", type=int, default=5)
156
- group.add_argument("--kernel-size2", type=int, default=5)
157
- group.add_argument("--layout", type=str, default='96,496,4096')
158
- group.add_argument("--new-sequence-length", type=int, default=4096)
159
- return parser
160
-
161
- def sparse_attention_2d_light(q0, k0, v0, q1, k1, v1, attention_mask, n_head, text_len, kernel_size=9, kernel_size2=7, attention_dropout=None, log_attention_weights = None, add_scalar=0, **kwargs):
162
- '''
163
- q0, k0, v0: [batch_size, 1088, hidden_size]
164
- q1, k1, v1: [batch_size, 4096, h2]
165
- n_head: int
166
- attention_mask: [batch_size, 1088, 1088]
167
- '''
168
- from SwissArmyTransformer.ops.local_attention_function import f_similar, f_weighting
169
-
170
- b, s0, h0 = q0.shape
171
- b, s1, h1 = q1.shape
172
- h, l0, l1 = h0 // n_head, sqrt(s0-text_len), sqrt(s1)
173
-
174
- q0 = q0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
175
- v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
176
- k0T = k0.reshape(b, s0, n_head, h).permute(0, 2, 3, 1)
177
-
178
- # standard attention for level 0
179
- attention_scores = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T)
180
-
181
- if log_attention_weights is not None:
182
- attention_scores += log_attention_weights
183
- attention_scores = torch.mul(attention_scores, attention_mask) - \
184
- 10000.0 * (1.0 - attention_mask)
185
-
186
- attention_probs0 = F.softmax(attention_scores, dim=-1)
187
-
188
- # local attention for level 1
189
- q1 = (q1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1) / math.sqrt(h1//n_head)).contiguous().view(b*n_head, h1//n_head, l1, l1)
190
- k1 = k1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1)
191
- v1 = v1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1)
192
- # scores_1_to_1 = f_similar(q1, k1, kernel_size*2-1, kernel_size, True)
193
- scores_1_to_1 = f_similar(q1, k1, kernel_size*2-1, kernel_size, False)
194
-
195
- # cross attention
196
- k0T = k0T[..., -l0**2:].reshape(b*n_head, h, l0, l0).contiguous()
197
- scores_1_to_0 = f_similar(q1, k0T, kernel_size2, kernel_size2, False) # [b*n_head, l1, l1, field]
198
- scores_1 = torch.cat(
199
- (
200
- scores_1_to_0.view(b*n_head, -1, scores_1_to_0.shape[3]) + add_scalar,
201
- scores_1_to_1.view(b*n_head, -1, scores_1_to_1.shape[3])
202
- ),
203
- dim=-1)
204
- attention_probs1 = F.softmax(scores_1, dim=-1)
205
-
206
- if attention_dropout is not None:
207
- # with get_cuda_rng_tracker().fork():
208
- attention_probs0 = attention_dropout(attention_probs0)
209
- attention_probs1 = attention_dropout(attention_probs1)
210
-
211
- # weighting for level 0
212
- context0 = torch.matmul(attention_probs0, v0) # [b, n_head, s0, h]
213
- # weighting for level 1
214
- probs_1_to_1 = attention_probs1[:, :, -scores_1_to_1.shape[3]:].view_as(scores_1_to_1)
215
- # context1_to_1 = f_weighting(v1, probs_1_to_1.contiguous(), kernel_size*2-1, kernel_size, True)
216
- context1_to_1 = f_weighting(v1, probs_1_to_1.contiguous(), kernel_size*2-1, kernel_size, False)
217
-
218
- context1 = context1_to_1.view(b, n_head * h, l1**2)
219
- # weighting for cross attention
220
- probs_1_to_0 = attention_probs1[:, :, :scores_1_to_0.shape[3]].view_as(scores_1_to_0)
221
- v0_part = v0[:, :, -l0**2:].transpose(-1, -2).contiguous().view(b*n_head, h, l0, l0)
222
- context1_to_0 = f_weighting(v0_part, probs_1_to_0.contiguous(), kernel_size2, kernel_size2, False)
223
- context1_to_0 = context1_to_0.view(b, n_head * h, l1**2)
224
- context1 = context1 + context1_to_0
225
- return context0.transpose(1, 2).reshape(b, s0, h0), context1.transpose(-1, -2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sr_pipeline/dsr_sampling.py DELETED
@@ -1,159 +0,0 @@
1
- # -*- encoding: utf-8 -*-
2
- '''
3
- @File : cuda2d_sampling.py
4
- @Time : 2021/10/09 00:46:04
5
- @Author : Ming Ding
6
- @Contact : [email protected]
7
- '''
8
-
9
- # here put the import lib
10
- import os
11
- import sys
12
- import math
13
- import random
14
- from cv2 import reduce
15
- import torch
16
-
17
- import torch
18
- import torch.nn.functional as F
19
- import numpy as np
20
-
21
- def top_k_logits_(logits, top_k=0, filter_value=-float('Inf')):
22
- indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
23
- logits[indices_to_remove] = filter_value
24
- return logits
25
-
26
- class IterativeEntfilterStrategy:
27
- def __init__(self, invalid_slices=[], temperature=1., topk=6):
28
- self.invalid_slices = invalid_slices
29
- self.temperature = temperature
30
- self.topk = topk
31
- self.cluster_labels = torch.tensor(np.load('cluster_label2.npy'), device='cuda', dtype=torch.long)
32
-
33
-
34
- def forward(self, logits_, tokens, temperature=None, entfilter=None, filter_topk=5, temperature2=None):
35
- # In interative strategy, logits are of shape [batch_size, seq_length, hidden_size]
36
- if temperature is None:
37
- temperature = self.temperature
38
-
39
- logits = logits_.float() / temperature
40
- for invalid_slice in self.invalid_slices:
41
- logits[..., invalid_slice] = -float('Inf')
42
- logits = logits.view(-1, logits.shape[-1])
43
-
44
- rprobs = F.softmax(logits.float(), dim=-1)
45
- c = self.cluster_labels.expand(*rprobs.shape)
46
- cprobs = torch.zeros(logits.shape[0], 500, device=logits.device).scatter_add_(1, c, rprobs)
47
-
48
- best_scores, best_clusters = cprobs.topk(self.topk)
49
- bz = logits.shape[0]
50
- best_scores = best_scores / best_scores.sum(dim=-1, keepdim=True)
51
- sampled_ids = torch.multinomial(best_scores, num_samples=1)
52
- selected_clusters = torch.gather(best_clusters, dim=1, index=sampled_ids)
53
- selected_mask = (self.cluster_labels.unsqueeze(0).expand(bz, -1) != selected_clusters) # cluster_labels [1, 20000] \in [0,500)
54
- logits[selected_mask] = -65504
55
- # for i in range(bz):
56
- # selected_cluster = best_clusters[i][torch.multinomial(best_scores[i] / best_scores[i].sum(), num_samples=1)]
57
- # logits[i, self.cluster_labels != selected_cluster] = -65504
58
-
59
- # logits = top_k_logits(logits, self.topk, self.top_p)
60
- probs = F.softmax(logits.float()/0.6, dim=-1) # float is essetial, due to a bug in Pytorch
61
- pred = torch.multinomial(probs, num_samples=1).view(*logits_.shape[:2])
62
-
63
- assert tokens.shape[1] == pred.shape[1] + 1
64
- tokens = torch.cat((tokens[:, :1], pred), dim=1)
65
- return tokens
66
-
67
- def filling_sequence_dsr(
68
- model,
69
- seq0,
70
- seq1,
71
- warmup_steps=3,
72
- block_hw=(4, 4),
73
- strategy=IterativeEntfilterStrategy(topk=10),
74
- ):
75
- '''
76
- seq: [PAD]... [ROI1] text ... [BOI1] {layout[0]} 1024 {layout[1]} [EOI1]
77
- 4095 {layout[2]} final_token.
78
- Attention:
79
- The sampling temperature are changing, temporally we hard code them here.
80
- The temperature in the strategy is not used.
81
- '''
82
- assert hasattr(model, 'layout')
83
- layout = model.layout
84
- assert len(seq0.shape) == 2 and len(seq1.shape) == 2 \
85
- and seq0.shape[0] == seq1.shape[0]
86
- assert len(layout) == 3
87
- assert seq1.shape[1] == layout[-1] - layout[-2] + 1
88
- assert (seq1 >= 0).all() and (seq0 >= 0).all()
89
- device = seq0.device
90
- # concat and pad sequences
91
- batch_size = seq0.shape[0]
92
- n_pad = layout[1] - seq0.shape[1]
93
- assert n_pad > 0, "You should truncate long input before filling."
94
- seq = torch.cat((
95
- torch.tensor([0]*n_pad, device=device, dtype=seq0.dtype)
96
- .unsqueeze(0).expand(batch_size, n_pad),
97
- seq0, seq1), dim=1) # [b, layout[-1]+1]
98
- assert seq.shape[1] == layout[-1] + 1
99
-
100
- # build initial tokens, attention_mask, and position_ids
101
- tokens = seq.clone()
102
- attention_mask = torch.ones(layout[1], layout[1]).to(device)
103
- attention_mask[:layout[0], layout[0]:] = 0
104
- attention_mask[n_pad:, :n_pad] = 0
105
- attention_mask = attention_mask.type_as(next(model.parameters())) # if fp16
106
- position_ids = torch.cat((
107
- torch.zeros(n_pad, dtype=torch.long),
108
- torch.arange(0, layout[0] - n_pad),
109
- torch.arange(513, 513 + layout[1] - layout[0]),
110
- torch.arange(1024, 1024+layout[2]-layout[1]))).to(device)
111
- log_attention_weights = torch.zeros(layout[1], layout[1],
112
- device=device).type_as(next(model.parameters()))
113
- log_attention_weights[layout[0]:, n_pad:layout[0]] = 0.
114
-
115
- # prepare for interation
116
- unfixed = (tokens < 0) # just init an all-False tensor
117
- unfixed[:, -layout[-1] + layout[-2]:] = True
118
-
119
- ll, rr = block_hw
120
- edge_len = int(math.sqrt(layout[-1] - layout[-2]) + 1e-4)
121
- num_steps = warmup_steps + ll - 1 + rr
122
- # interative refining
123
-
124
- # unfixed[..., -(layout[-1] - layout[-2]):].view(
125
- # batch_size, edge_len//ll, ll, edge_len//rr, rr)[:, :, :, :, -1] = False
126
-
127
-
128
- ret = []
129
- ret.append(tokens[:, layout[-2]+1:].clone())
130
- for step_cnt in range(1, num_steps+1):
131
- if step_cnt <= warmup_steps:
132
- logits, *_dump = model(tokens[:,:-1], position_ids, attention_mask, log_attention_weights=log_attention_weights)
133
- real_temp = 1.
134
- new_tokens = strategy.forward(logits, tokens, real_temp)
135
- tokens[unfixed] = new_tokens[unfixed]
136
- else:
137
- logits, *_dump = model(tokens[:,:-1], position_ids, attention_mask, log_attention_weights=log_attention_weights)
138
- real_temp = 1.
139
- new_tokens = strategy.forward(
140
- logits, tokens, real_temp,
141
- entfilter=1.3,
142
- filter_topk=5,
143
- temperature2=0.6
144
- )
145
- # tokens[unfixed] = new_tokens[unfixed]
146
- # fixed tokens (update unfixed)
147
- unfixed2 = (tokens > 10000000)
148
- for x in range(min(ll, step_cnt - warmup_steps)):
149
- y = step_cnt - warmup_steps - x - 1
150
- if y < rr:
151
- unfixed[..., -(layout[-1] - layout[-2]):].view(
152
- batch_size, edge_len//ll, ll, edge_len//rr, rr)[:, :, x, :, y] = False
153
- unfixed2[..., -(layout[-1] - layout[-2]):].view(
154
- batch_size, edge_len//ll, ll, edge_len//rr, rr)[:, :, x, :, y] = True
155
- tokens[unfixed2] = new_tokens[unfixed2]
156
-
157
- ret.append(tokens[:, layout[-2]+1:].clone())
158
-
159
- return ret
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sr_pipeline/iterative_sr.py DELETED
@@ -1,118 +0,0 @@
1
- # -*- encoding: utf-8 -*-
2
- '''
3
- @File : iterative_sr.py
4
- @Time : 2022/03/02 15:57:45
5
- @Author : Ming Ding
6
- @Contact : [email protected]
7
- '''
8
-
9
- # here put the import lib
10
- import os
11
- import sys
12
- import math
13
- import random
14
-
15
- # here put the import lib
16
- import os
17
- import sys
18
- import math
19
- import random
20
- from PIL import ImageEnhance, Image
21
-
22
- import torch
23
- import argparse
24
- from torchvision import transforms
25
-
26
- from SwissArmyTransformer.training.model_io import load_checkpoint
27
- from SwissArmyTransformer import get_args
28
- from .itersr_sampling import filling_sequence_itersr, IterativeEntfilterStrategy
29
- from SwissArmyTransformer.generation.utils import timed_name, save_multiple_images, generate_continually
30
-
31
- from .itersr_model import ItersrModel
32
-
33
- from icetk import icetk as tokenizer
34
-
35
- class IterativeSuperResolution:
36
- def __init__(self, args, path, max_bz=4, shared_transformer=None):
37
- args.load = path
38
- args.kernel_size = 5
39
- args.kernel_size2 = 5
40
- args.new_sequence_length = 4624
41
- args.layout = [16,3616]
42
-
43
- model = ItersrModel(args, transformer=shared_transformer)
44
- if args.fp16:
45
- model = model.half()
46
-
47
- load_checkpoint(model, args) # on cpu
48
- model.eval()
49
- self.model = model.cuda()
50
-
51
- # save cpu weights
52
- self.saved_weights = dict((k,v.cpu())
53
- for k, v in model.named_parameters()
54
- if 'transformer' in k
55
- )
56
-
57
- invalid_slices = [slice(tokenizer.num_image_tokens, None)]
58
-
59
- self.strategy = IterativeEntfilterStrategy(invalid_slices,
60
- temperature=args.temp_all_itersr, topk=args.topk_itersr)
61
- self.max_bz = max_bz
62
-
63
- def _restore_transformer_from_cpu(self, non_blocking=False):
64
- for k, v in self.model.named_parameters():
65
- if k in self.saved_weights:
66
- v.copy_(self.saved_weights[k])
67
-
68
- def __call__(self, text_tokens, image_tokens, enhance=False, input_mask=None):
69
- if len(text_tokens.shape) == 1:
70
- text_tokens.unsqueeze_(0)
71
- text_tokens = text_tokens.clone()[..., :16]
72
- if len(image_tokens.shape) == 1:
73
- image_tokens.unsqueeze_(0)
74
- if enhance:
75
- new_image_tokens = []
76
- for big_img in image_tokens:
77
- decoded = tokenizer.decode(image_ids=big_img).squeeze(0)
78
- ndarr = decoded.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
79
- image_pil_raw = ImageEnhance.Sharpness(Image.fromarray(ndarr))
80
- big_img2 = tokenizer.encode(image_pil=image_pil_raw.enhance(1.5), image_size=480).view(-1)
81
- new_image_tokens.append(big_img2)
82
- image_tokens = torch.stack(new_image_tokens)
83
- print('Converting Itersr model...')
84
- self._restore_transformer_from_cpu()
85
- model = self.model
86
- print('iterative super-resolution...')
87
- output_list = []
88
- for tim in range(max(text_tokens.shape[0] // self.max_bz, 1)):
89
- big_img = image_tokens[tim*self.max_bz:(tim+1)*self.max_bz]
90
- text_seq = text_tokens[tim*self.max_bz:(tim+1)*self.max_bz]
91
- mask_raw = torch.tensor(
92
- [
93
- -1, 0, 1, 2, 3, 4,
94
- 0, -1, 2, -1, -2, 5,
95
- 1, -2, 3, 4, 5, 6,
96
- 2, 3, 4, 5, -1, 1,
97
- 3, -1, -2, 0, -1, 2,
98
- 4, 5, 6, 1, 3, -2
99
- ]
100
- ).view(1, 6, 1, 6).expand(10, 6, 10, 6).reshape(-1).contiguous()
101
-
102
- topks = [60, 40, 40, 40, 20, 20, 10]
103
-
104
- for mask_ratio in range(1, 7):
105
- self.strategy.topk = topks[mask_ratio]
106
- mask = (mask_raw.to(big_img.device) >= mask_ratio)
107
- if input_mask is not None:
108
- mask = mask & input_mask
109
- big_img.masked_fill_(mask, tokenizer['<start_of_image>'])
110
- seq1 = big_img
111
- output1 = filling_sequence_itersr(model, text_seq, seq1,
112
- warmup_steps=1, block_hw=(1, 0),
113
- strategy=self.strategy
114
- )
115
- big_img = output1
116
- print(f'Iter {mask_ratio} times.')
117
- output_list.append(output1.clone())
118
- return torch.cat(output_list, dim=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sr_pipeline/itersr_model.py DELETED
@@ -1,232 +0,0 @@
1
- # -*- encoding: utf-8 -*-
2
- '''
3
- @File : itersr_model.py
4
- @Time : 2021/10/02 01:36:32
5
- @Author : Ming Ding
6
- @Contact : [email protected]
7
- '''
8
-
9
- # here put the import lib
10
- import os
11
- import sys
12
- import math
13
- import random
14
- import torch
15
- import torch.nn.functional as F
16
-
17
-
18
- from SwissArmyTransformer.model.base_model import BaseModel, BaseMixin
19
-
20
- from SwissArmyTransformer.mpu.utils import sqrt
21
- from deepspeed.runtime.activation_checkpointing.checkpointing import get_cuda_rng_tracker
22
- from SwissArmyTransformer.mpu import ColumnParallelLinear, RowParallelLinear
23
- from SwissArmyTransformer.model.transformer import unscaled_init_method, split_tensor_along_last_dim
24
-
25
- class PositionEmbeddingMixin(BaseMixin):
26
- def __init__(self, additional_sequence_length, hidden_size,
27
- init_method_std=0.02, reinit_slice=slice(512, 512+400)
28
- ):
29
- super(PositionEmbeddingMixin, self).__init__()
30
- self.reinit_slice = reinit_slice
31
- self.position_embeddings = torch.nn.Embedding(additional_sequence_length, hidden_size)
32
- torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std)
33
-
34
- def reinit(self, parent_model=None):
35
- old_weights = self.transformer.position_embeddings.weight.data[self.reinit_slice]
36
- old_len, hidden_size = old_weights.shape
37
- assert hidden_size == self.position_embeddings.weight.shape[-1]
38
- old_edge, new_edge = sqrt(old_len), sqrt(self.position_embeddings.weight.shape[-2])
39
- assert new_edge % old_edge == 0
40
- self.position_embeddings.weight.data.view(new_edge // old_edge, old_edge, new_edge // old_edge, old_edge, hidden_size).copy_(old_weights.view(1, old_edge, 1, old_edge, hidden_size))
41
-
42
- class ItersrModel(BaseModel):
43
- def __init__(self, args, transformer=None):
44
- super().__init__(args, transformer=transformer)
45
- self.original_sequence_length = args.max_sequence_length
46
- additional_seqlen = args.new_sequence_length - args.max_sequence_length
47
- self.add_mixin('extra_position_embedding', PositionEmbeddingMixin(
48
- additional_seqlen, args.hidden_size
49
- ))
50
- # self.add_mixin('attention_plus', AttentionMixin(
51
- # num_layers=args.num_layers,
52
- # hidden_size=args.hidden_size
53
- # ))
54
- self.layout = args.layout
55
- # [PAD]... [ROI1] text ... [BOI1] {layout[0]} 1024 {layout[1]} [EOI1] 4095 {layout[2]}
56
- self.kernel_size = args.kernel_size
57
- self.kernel_size2 = args.kernel_size2
58
- self.log_attention_weights = None
59
-
60
- def position_embedding_forward(self, position_ids, **kw_args):
61
- position = position_ids[..., :self.layout[0]]
62
- position_plus = position_ids[..., self.layout[0]:] - self.original_sequence_length
63
- position_embeddings = torch.cat(
64
- (
65
- self.transformer.position_embeddings(position),
66
- self.get_mixin('extra_position_embedding').position_embeddings(position_plus)
67
- ),
68
- dim=-2
69
- )
70
- return position_embeddings
71
-
72
- def attention_forward(self, hidden_states, mask,
73
- layer_id=None, log_attention_weights=None, **kw_args):
74
- attn_module = self.transformer.layers[layer_id].attention
75
- # base model qkv
76
- mixed_raw_layer = attn_module.query_key_value(hidden_states)
77
- q0, k0, v0 = split_tensor_along_last_dim(mixed_raw_layer[:, :self.layout[0]], 3)
78
- # cuda2d model qkv
79
- q1, k1, v1 = split_tensor_along_last_dim(mixed_raw_layer[:, self.layout[0]:], 3)
80
-
81
- dropout_fn = attn_module.attention_dropout if self.training else None
82
-
83
- # cuda2d attention
84
- context_layer = sparse_attention_2d_text(
85
- q0, k0, v0,
86
- q1, k1, v1,
87
- mask,
88
- n_head=attn_module.num_attention_heads_per_partition,
89
- text_len=self.layout[0],
90
- kernel_size=self.kernel_size,
91
- attention_dropout=dropout_fn,
92
- log_attention_weights=log_attention_weights,
93
- )
94
-
95
- output = attn_module.dense(context_layer)
96
-
97
- return output
98
-
99
- def final_forward(self, logits, **kwargs):
100
- logits_parallel = logits
101
- logits_parallel = torch.nn.functional.linear(logits_parallel, self.transformer.word_embeddings.weight[:20000]).float()
102
- # logits_parallel = torch.nn.functional.linear(logits_parallel, self.transformer.word_embeddings.weight[:20000])
103
- return logits_parallel
104
-
105
- # def disable_untrainable_params(self):
106
- # self.transformer.requires_grad_(False)
107
-
108
- @classmethod
109
- def add_model_specific_args(cls, parser):
110
- group = parser.add_argument_group('Cuda2dModel', 'cuda2d model configurations')
111
- group.add_argument("--kernel-size", type=int, default=5)
112
- group.add_argument("--kernel-size2", type=int, default=5)
113
- group.add_argument("--layout", type=str, default='16,3616')
114
- group.add_argument("--new-sequence-length", type=int, default=4096)
115
- return parser
116
-
117
- def sparse_attention_2d_text(q0, k0, v0, q1, k1, v1, attention_mask, n_head, text_len, kernel_size=9, attention_dropout=None, log_attention_weights = None, **kwargs):
118
- '''
119
- q0, k0, v0: [batch_size, 16, hidden_size]
120
- q1, k1, v1: [batch_size, 3600, hidden_size]
121
- n_head: int
122
- attention_mask: [batch_size, 16]
123
- '''
124
- from SwissArmyTransformer.ops.local_attention_function import f_similar, f_weighting
125
- b, s0, h0 = q0.shape
126
- b, s1, h1 = q1.shape
127
- h, l1 = h0 // n_head, sqrt(s1)
128
- assert attention_mask.shape[-1] == s0, f"Mask Shape: {attention_mask.shape}"
129
-
130
- q0 = q0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
131
- v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
132
- k0T = k0.reshape(b, s0, n_head, h).permute(0, 2, 3, 1)
133
-
134
- # standard attention for level 0
135
- attention_scores = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T)
136
-
137
- attention_scores = torch.mul(attention_scores, attention_mask) - \
138
- 10000.0 * (1.0 - attention_mask)
139
-
140
- attention_probs0 = F.softmax(attention_scores, dim=-1)
141
-
142
- # local attention for level 1
143
- q1 = (q1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1) / math.sqrt(h1//n_head)).contiguous().view(b*n_head, h1//n_head, l1, l1)
144
- k1 = k1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1)
145
- v1 = v1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1)
146
- scores_1_to_1 = f_similar(q1, k1, kernel_size*2-1, kernel_size, False)
147
-
148
- # cross attention
149
- scores_1_to_0 = torch.matmul(q1.view(b, n_head, h, s1).transpose(-1, -2), k0T)
150
- if log_attention_weights is not None:
151
- scores_1_to_0 += log_attention_weights
152
- scores_1_to_0 = torch.mul(scores_1_to_0, attention_mask) - \
153
- 10000.0 * (1.0 - attention_mask)
154
- scores_1 = torch.cat(
155
- (
156
- scores_1_to_0.view(b*n_head, s1, s0),
157
- scores_1_to_1.view(b*n_head, -1, scores_1_to_1.shape[3])
158
- ),
159
- dim=-1)
160
- attention_probs1 = F.softmax(scores_1, dim=-1)
161
-
162
- if attention_dropout is not None:
163
- with get_cuda_rng_tracker().fork():
164
- attention_probs1 = attention_dropout(attention_probs1)
165
-
166
- # weighting for level 0
167
- context0 = torch.matmul(attention_probs0, v0) # [b, n_head, s0, h]
168
- # weighting for level 1
169
- probs_1_to_1 = attention_probs1[:, :, -scores_1_to_1.shape[3]:].view_as(scores_1_to_1)
170
- context1_to_1 = f_weighting(v1, probs_1_to_1.contiguous(), kernel_size*2-1, kernel_size, False)
171
-
172
- context1 = context1_to_1.view(b, n_head, h, l1**2)
173
- # weighting for cross attention
174
- probs_1_to_0 = attention_probs1[:, :, :scores_1_to_0.shape[3]].view(b, n_head, -1, scores_1_to_0.shape[3])
175
-
176
- context1_to_0 = torch.matmul(probs_1_to_0, v0)
177
- context1 = context1.transpose(-1, -2) + context1_to_0
178
-
179
- output = torch.cat((context0, context1), dim=2).transpose(1, 2).reshape(b, s0+s1, h0)
180
-
181
- return output
182
-
183
- def sparse_attention_2d_notext(q0, k0, v0, q1, k1, v1, attention_mask, n_head, text_len, kernel_size=9, attention_dropout=None, log_attention_weights = None, **kwargs):
184
- '''
185
- q0, k0, v0: [batch_size, 16, hidden_size]
186
- q1, k1, v1: [batch_size, 3600, hidden_size]
187
- n_head: int
188
- attention_mask: [batch_size, 16]
189
- '''
190
- from SwissArmyTransformer.mpu.local_attention_function import f_similar, f_weighting
191
- b, s0, h0 = q0.shape
192
- b, s1, h1 = q1.shape
193
- h, l1 = h0 // n_head, sqrt(s1)
194
- assert len(attention_mask.shape) == 4 and attention_mask.shape[-1] == s0, f"Mask Shape: {attention_mask.shape}"
195
-
196
- q0 = q0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
197
- v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
198
- k0T = k0.reshape(b, s0, n_head, h).permute(0, 2, 3, 1)
199
-
200
- # standard attention for level 0
201
- attention_scores = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T)
202
-
203
- attention_scores = torch.mul(attention_scores, attention_mask) - \
204
- 10000.0 * (1.0 - attention_mask)
205
-
206
- attention_probs0 = F.softmax(attention_scores, dim=-1)
207
-
208
- # local attention for level 1
209
- q1 = (q1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1) / math.sqrt(h1//n_head)).contiguous().view(b*n_head, h1//n_head, l1, l1)
210
- k1 = k1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1)
211
- v1 = v1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1)
212
- scores_1_to_1 = f_similar(q1, k1, kernel_size*2-1, kernel_size, False)
213
-
214
- attention_probs1 = F.softmax(scores_1_to_1, dim=-1)
215
-
216
- if attention_dropout is not None:
217
- with get_cuda_rng_tracker().fork():
218
- attention_probs1 = attention_dropout(attention_probs1)
219
-
220
- # weighting for level 0
221
- context0 = torch.matmul(attention_probs0, v0) # [b, n_head, s0, h]
222
- # weighting for level 1
223
- probs_1_to_1 = attention_probs1
224
- context1_to_1 = f_weighting(v1, probs_1_to_1.contiguous(), kernel_size*2-1, kernel_size, False)
225
-
226
- context1 = context1_to_1.view(b, n_head, h, l1**2)
227
- # weighting for cross attention
228
- context1 = context1.transpose(-1, -2)
229
-
230
- output = torch.cat((context0, context1), dim=2).transpose(1, 2).reshape(b, s0+s1, h0)
231
-
232
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sr_pipeline/itersr_sampling.py DELETED
@@ -1,168 +0,0 @@
1
- # -*- encoding: utf-8 -*-
2
- '''
3
- @File : itersr_sampling.py
4
- @Time : 2022/03/03 14:24:28
5
- @Author : Ming Ding
6
- @Contact : [email protected]
7
- '''
8
-
9
- # here put the import lib
10
- import os
11
- import sys
12
- import math
13
- import random
14
- import numpy as np
15
-
16
- import torch
17
- import torch.nn.functional as F
18
- from icetk import icetk as tokenizer
19
-
20
- def top_k_logits_(logits, top_k=0, filter_value=-float('Inf')):
21
- indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
22
- logits[indices_to_remove] = filter_value
23
- return logits
24
-
25
- # class IterativeEntfilterStrategy:
26
- # def __init__(self, invalid_slices=[], temperature=1., topk=10):
27
- # self.invalid_slices = invalid_slices
28
- # self.temperature = temperature
29
- # self.topk = topk
30
- # self.cluster_labels = torch.tensor(np.load('cluster_label.npy'), device='cuda', dtype=torch.long)
31
-
32
-
33
- # def forward(self, logits_, tokens, temperature=None, entfilter=None, filter_topk=5, temperature2=None):
34
- # # In interative strategy, logits are of shape [batch_size, seq_length, hidden_size]
35
- # if temperature is None:
36
- # temperature = self.temperature
37
-
38
- # logits = logits_.float() / temperature
39
- # for invalid_slice in self.invalid_slices:
40
- # logits[..., invalid_slice] = -float('Inf')
41
- # logits = logits.view(-1, logits.shape[-1])
42
-
43
- # rprobs = F.softmax(logits.float(), dim=-1)
44
- # c = self.cluster_labels.expand(*rprobs.shape)
45
- # cprobs = torch.zeros(logits.shape[0], 500, device=logits.device).scatter_add_(1, c, rprobs)
46
-
47
- # best_scores, best_clusters = cprobs.topk(self.topk)
48
- # bz = logits.shape[0]
49
- # best_scores = best_scores / best_scores.sum(dim=-1, keepdim=True)
50
- # sampled_ids = torch.multinomial(best_scores, num_samples=1)
51
- # selected_clusters = torch.gather(best_clusters, dim=1, index=sampled_ids)
52
- # selected_mask = (self.cluster_labels.unsqueeze(0).expand(bz, -1) != selected_clusters) # cluster_labels [1, 20000] \in [0,500)
53
- # logits[selected_mask] = -65504
54
- # # for i in range(bz):
55
- # # selected_cluster = best_clusters[i][torch.multinomial(best_scores[i] / best_scores[i].sum(), num_samples=1)]
56
- # # logits[i, self.cluster_labels != selected_cluster] = -65504
57
-
58
- # # logits = top_k_logits(logits, self.topk, self.top_p)
59
- # probs = F.softmax(logits.float(), dim=-1) # float is essetial, due to a bug in Pytorch
60
- # pred = torch.multinomial(probs, num_samples=1).view(*logits_.shape[:2])
61
-
62
- # assert tokens.shape[1] == pred.shape[1]
63
- # tokens = pred
64
- # return tokens
65
-
66
- class IterativeEntfilterStrategy:
67
- def __init__(self, invalid_slices=[], temperature=1., topk=10):
68
- self.invalid_slices = invalid_slices
69
- self.temperature = temperature
70
- self.topk = topk
71
-
72
- def forward(self, logits, tokens, temperature=None, entfilter=None, filter_topk=5, temperature2=None):
73
- # In interative strategy, logits are of shape [batch_size, seq_length, hidden_size]
74
- if temperature is None:
75
- temperature = self.temperature
76
- # check entropy filter
77
- # if entfilter is not None:
78
- # assert temperature2 is not None
79
- # topraw = (torch.topk(logits, filter_topk, dim=-1)[0]).softmax(dim=-1)
80
- # ent = -(topraw * topraw.log()).sum(dim=-1) # [batch_size, seq_length]
81
- # temperature = torch.tensor([[[temperature - temperature2]]], device=logits.device).expand(*logits.shape[:2], 1) * (ent > entfilter).unsqueeze(-1) + temperature2
82
-
83
- logits = logits.float() / temperature
84
- for invalid_slice in self.invalid_slices:
85
- logits[..., invalid_slice] = -float('Inf')
86
-
87
- # debiased topk
88
- # probs = F.softmax(logits, dim=-1)
89
- # tk_value, tk_idx = torch.topk(probs, self.topk, dim=-1)
90
- # pred = torch.multinomial(probs.view(-1, logits.shape[-1]), num_samples=1).view(*logits.shape[:2], 1)
91
- # edge_idx = tk_idx[:, :, -1:]
92
- # edge_value = tk_value[:, :, -1:]
93
- # edge_mask = probs.gather(dim=-1, index=pred) < edge_value
94
- # pred[edge_mask] = edge_idx[edge_mask] # replace outliers as the "filter_topk"-th token
95
- # pred.squeeze_(-1) # [batch_size, seq_length]
96
-
97
- top_k_logits_(logits, self.topk)
98
- probs = F.softmax(logits, dim=-1)
99
- pred = torch.multinomial(probs.view(-1, logits.shape[-1]), num_samples=1).view(*logits.shape[:2], 1)
100
- pred.squeeze_(-1)
101
-
102
- assert tokens.shape[1] == pred.shape[1]
103
- tokens = pred
104
- return tokens
105
-
106
- def filling_sequence_itersr(
107
- model,
108
- seq0,
109
- seq1,
110
- warmup_steps=3,
111
- block_hw=(4, 4),
112
- strategy=IterativeEntfilterStrategy(topk=10),
113
- ):
114
- '''
115
- seq: [PAD]... [ROI1] text ... [BOI1] {layout[0]} 1024 {layout[1]} [EOI1]
116
- 4095 {layout[2]} final_token.
117
- Attention:
118
- The sampling temperature are changing, temporally we hard code them here.
119
- The temperature in the strategy is not used.
120
- '''
121
- assert hasattr(model, 'layout')
122
- layout = model.layout
123
-
124
- device = seq0.device
125
- # concat and pad sequences
126
- batch_size = seq0.shape[0]
127
- n_pad = layout[0] - seq0.shape[1]
128
- assert n_pad >= 0, "You should truncate long input before filling."
129
- seq = torch.cat((
130
- torch.tensor([0]*n_pad, device=device, dtype=seq0.dtype)
131
- .unsqueeze(0).expand(batch_size, n_pad),
132
- seq0, seq1), dim=1) # [b, layout[-1]+1]
133
- assert seq.shape[1] == layout[-1]
134
-
135
- # build initial tokens, attention_mask, and position_ids
136
- tokens = seq.clone()
137
- attention_mask = torch.ones(layout[0]).to(device)
138
- attention_mask[:n_pad] = 0
139
- attention_mask = attention_mask.unsqueeze(0).type_as(next(model.parameters())) # if fp16
140
- position_ids = torch.cat((
141
- torch.zeros(n_pad, dtype=torch.long),
142
- torch.arange(0, layout[0] - n_pad),
143
- torch.arange(1024, 1024+layout[1]-layout[0]))).to(device)
144
- log_attention_weights = torch.zeros(layout[0], device=device).type_as(next(model.parameters()))
145
- log_attention_weights[n_pad:layout[0]] = 0.
146
- log_attention_weights = log_attention_weights.unsqueeze(0)
147
-
148
- # prepare for interation
149
- unfixed = (tokens == tokenizer['<start_of_image>'])
150
- ll, rr = block_hw
151
- edge_len = int(math.sqrt(layout[-1] - layout[-2]) + 1e-4)
152
- num_steps = 1
153
- # interative refining
154
-
155
- # unfixed[..., -(layout[-1] - layout[-2]):].view(
156
- # batch_size, edge_len//ll, ll, edge_len//rr, rr)[:, :, :, :, -1] = False
157
-
158
-
159
- ret = []
160
- # ret.append(tokens[:, layout[-2]:-1].clone())
161
- for step_cnt in range(1, num_steps+1):
162
- logits, *_dump = model(tokens, position_ids, attention_mask, log_attention_weights=log_attention_weights)
163
- real_temp = 1.
164
- new_tokens = strategy.forward(logits, tokens, real_temp)
165
- tokens[unfixed] = new_tokens[unfixed]
166
-
167
- ret.append(tokens[:, layout[-2]:].clone())
168
- return torch.cat(ret, dim=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sr_pipeline/sr_group.py DELETED
@@ -1,49 +0,0 @@
1
- # -*- encoding: utf-8 -*-
2
- '''
3
- @File : sr_group.py
4
- @Time : 2022/04/02 01:17:21
5
- @Author : Ming Ding
6
- @Contact : [email protected]
7
- '''
8
-
9
- # here put the import lib
10
- import os
11
- import sys
12
- import math
13
- import random
14
-
15
- import numpy as np
16
- import torch
17
- import torch.nn.functional as F
18
- from SwissArmyTransformer.resources import auto_create
19
- from .direct_sr import DirectSuperResolution
20
- from .iterative_sr import IterativeSuperResolution
21
-
22
- class SRGroup:
23
- def __init__(self, args, home_path=None,):
24
- dsr_path = auto_create('cogview2-dsr', path=home_path)
25
- itersr_path = auto_create('cogview2-itersr', path=home_path)
26
- dsr = DirectSuperResolution(args, dsr_path)
27
- itersr = IterativeSuperResolution(args, itersr_path, shared_transformer=dsr.model.transformer)
28
- self.dsr = dsr
29
- self.itersr = itersr
30
-
31
- def sr_base(self, img_tokens, txt_tokens):
32
- assert img_tokens.shape[-1] == 400 and len(img_tokens.shape) == 2
33
- batch_size = img_tokens.shape[0]
34
- txt_len = txt_tokens.shape[-1]
35
- if len(txt_tokens.shape) == 1:
36
- txt_tokens = txt_tokens.unsqueeze(0).expand(batch_size, txt_len)
37
- sred_tokens = self.dsr(txt_tokens, img_tokens)
38
- iter_tokens = self.itersr(txt_tokens, sred_tokens[:, -3600:].clone())
39
- return iter_tokens[-batch_size:]
40
-
41
- # def sr_patch(self, img_tokens, txt_tokens):
42
- # assert img_tokens.shape[-1] == 3600 and len(img_tokens.shape) == 2
43
- # batch_size = img_tokens.shape[0] * 9
44
- # txt_len = txt_tokens.shape[-1]
45
- # if len(txt_tokens.shape) == 1:
46
- # txt_tokens = txt_tokens.unsqueeze(0).expand(batch_size, txt_len)
47
- # img_tokens = img_tokens.view(img_tokens.shape[0], 3, 20, 3, 20).permute(0, 1, 3, 2, 4).reshape(batch_size, 400)
48
- # iter_tokens = self.sr_base(img_tokens, txt_tokens)
49
- # return iter_tokens