akhaliq HF staff commited on
Commit
7d40f91
1 Parent(s): 86b1386

add image prompt option

Browse files
Files changed (2) hide show
  1. app.py +7 -3
  2. model.py +39 -33
app.py CHANGED
@@ -40,6 +40,9 @@ def main():
40
  label='Only First Stage',
41
  value=only_first_stage,
42
  visible=not only_first_stage)
 
 
 
43
  run_button = gr.Button('Run')
44
 
45
  with gr.Column():
@@ -50,10 +53,10 @@ def main():
50
  result_video = gr.Video(show_label=False)
51
 
52
  examples = gr.Examples(
53
- examples=[['骑滑板的皮卡丘', False, 1234, True],
54
- ['a cat playing chess', True, 1253, True]],
55
  fn=model.run_with_translation,
56
- inputs=[text, translate, seed, only_first_stage],
57
  outputs=[translated_text, result_video],
58
  cache_examples=True)
59
 
@@ -66,6 +69,7 @@ def main():
66
  translate,
67
  seed,
68
  only_first_stage,
 
69
  ],
70
  outputs=[translated_text, result_video])
71
 
 
40
  label='Only First Stage',
41
  value=only_first_stage,
42
  visible=not only_first_stage)
43
+ image_prompt = gr.Image(type="filepath"
44
+ label="Image Prompt",
45
+ value=None)
46
  run_button = gr.Button('Run')
47
 
48
  with gr.Column():
 
53
  result_video = gr.Video(show_label=False)
54
 
55
  examples = gr.Examples(
56
+ examples=[['骑滑板的皮卡丘', False, 1234, True,None],
57
+ ['a cat playing chess', True, 1253, True,None]],
58
  fn=model.run_with_translation,
59
+ inputs=[text, translate, seed, only_first_stage,image_prompt],
60
  outputs=[translated_text, result_video],
61
  cache_examples=True)
62
 
 
69
  translate,
70
  seed,
71
  only_first_stage,
72
+ image_prompt
73
  ],
74
  outputs=[translated_text, result_video])
75
 
model.py CHANGED
@@ -796,7 +796,8 @@ class Model:
796
  video_raw_text=None,
797
  video_guidance_text='视频',
798
  image_text_suffix='',
799
- batch_size=1):
 
800
  process_start_time = time.perf_counter()
801
 
802
  generate_frame_num = self.args.generate_frame_num
@@ -828,33 +829,36 @@ class Model:
828
 
829
  seq_1st = torch.tensor(seq_1st, dtype=torch.long,
830
  device=self.device).unsqueeze(0)
831
- output_list_1st = []
832
- for tim in range(max(batch_size // mbz, 1)):
833
- start_time = time.perf_counter()
834
- output_list_1st.append(
835
- my_filling_sequence(
836
- model,
837
- tokenizer,
838
- self.args,
839
- seq_1st.clone(),
840
- batch_size=min(batch_size, mbz),
841
- get_masks_and_position_ids=
842
- get_masks_and_position_ids_stage1,
843
- text_len=text_len_1st,
844
- frame_len=frame_len,
845
- strategy=self.strategy_cogview2,
846
- strategy2=self.strategy_cogvideo,
847
- log_text_attention_weights=1.4,
848
- enforce_no_swin=True,
849
- mode_stage1=True,
850
- )[0])
851
- elapsed = time.perf_counter() - start_time
852
- logger.info(f'[First Frame] Elapsed: {elapsed:.2f}')
853
- output_tokens_1st = torch.cat(output_list_1st, dim=0)
854
- given_tokens = output_tokens_1st[:, text_len_1st + 1:text_len_1st +
855
- 401].unsqueeze(
856
- 1
857
- ) # given_tokens.shape: [bs, frame_num, 400]
 
 
 
858
 
859
  # generate subsequent frames:
860
  total_frames = generate_frame_num
@@ -1167,7 +1171,7 @@ class Model:
1167
  1, 2, 0).to(torch.uint8).numpy()
1168
 
1169
  def run(self, text: str, seed: int,
1170
- only_first_stage: bool) -> list[np.ndarray]:
1171
  logger.info('==================== run ====================')
1172
  start = time.perf_counter()
1173
 
@@ -1188,7 +1192,8 @@ class Model:
1188
  video_raw_text=text,
1189
  video_guidance_text='视频',
1190
  image_text_suffix=' 高清摄影',
1191
- batch_size=self.args.batch_size)
 
1192
  if not only_first_stage:
1193
  _, res = self.process_stage2(
1194
  self.model_stage2,
@@ -1226,12 +1231,13 @@ class AppModel(Model):
1226
 
1227
  def run_with_translation(
1228
  self, text: str, translate: bool, seed: int,
1229
- only_first_stage: bool) -> tuple[str | None, str | None]:
1230
- logger.info(f'{text=}, {translate=}, {seed=}, {only_first_stage=}')
 
1231
  if translate:
1232
  text = translated_text = self.translator(text)
1233
  else:
1234
  translated_text = None
1235
- frames = self.run(text, seed, only_first_stage)
1236
  video_path = self.to_video(frames)
1237
  return translated_text, video_path
 
796
  video_raw_text=None,
797
  video_guidance_text='视频',
798
  image_text_suffix='',
799
+ batch_size=1,
800
+ image_prompt):
801
  process_start_time = time.perf_counter()
802
 
803
  generate_frame_num = self.args.generate_frame_num
 
829
 
830
  seq_1st = torch.tensor(seq_1st, dtype=torch.long,
831
  device=self.device).unsqueeze(0)
832
+ if self.image_prompt is None:
833
+ output_list_1st = []
834
+ for tim in range(max(batch_size // mbz, 1)):
835
+ start_time = time.perf_counter()
836
+ output_list_1st.append(
837
+ my_filling_sequence(
838
+ model,
839
+ tokenizer,
840
+ self.args,
841
+ seq_1st.clone(),
842
+ batch_size=min(batch_size, mbz),
843
+ get_masks_and_position_ids=
844
+ get_masks_and_position_ids_stage1,
845
+ text_len=text_len_1st,
846
+ frame_len=frame_len,
847
+ strategy=self.strategy_cogview2,
848
+ strategy2=self.strategy_cogvideo,
849
+ log_text_attention_weights=1.4,
850
+ enforce_no_swin=True,
851
+ mode_stage1=True,
852
+ )[0])
853
+ elapsed = time.perf_counter() - start_time
854
+ logger.info(f'[First Frame] Elapsed: {elapsed:.2f}')
855
+ output_tokens_1st = torch.cat(output_list_1st, dim=0)
856
+ given_tokens = output_tokens_1st[:, text_len_1st + 1:text_len_1st +
857
+ 401].unsqueeze(
858
+ 1
859
+ ) # given_tokens.shape: [bs, frame_num, 400]
860
+ else:
861
+ given_tokens = tokenizer.encode(image_path=self.image_prompt, image_size=160).repeat(batch_size, 1).unsqueeze(1)
862
 
863
  # generate subsequent frames:
864
  total_frames = generate_frame_num
 
1171
  1, 2, 0).to(torch.uint8).numpy()
1172
 
1173
  def run(self, text: str, seed: int,
1174
+ only_first_stage: bool,image_prompt: None) -> list[np.ndarray]:
1175
  logger.info('==================== run ====================')
1176
  start = time.perf_counter()
1177
 
 
1192
  video_raw_text=text,
1193
  video_guidance_text='视频',
1194
  image_text_suffix=' 高清摄影',
1195
+ batch_size=self.args.batch_size
1196
+ image_prompt=image_prompt)
1197
  if not only_first_stage:
1198
  _, res = self.process_stage2(
1199
  self.model_stage2,
 
1231
 
1232
  def run_with_translation(
1233
  self, text: str, translate: bool, seed: int,
1234
+ only_first_stage: bool,image_prompt: None) -> tuple[str | None, str | None],
1235
+
1236
+ logger.info(f'{text=}, {translate=}, {seed=}, {only_first_stage=},{image_prompt=}')
1237
  if translate:
1238
  text = translated_text = self.translator(text)
1239
  else:
1240
  translated_text = None
1241
+ frames = self.run(text, seed, only_first_stage,image_prompt)
1242
  video_path = self.to_video(frames)
1243
  return translated_text, video_path