Spaces:
Running
Running
add image prompt option
Browse files
app.py
CHANGED
@@ -40,6 +40,9 @@ def main():
|
|
40 |
label='Only First Stage',
|
41 |
value=only_first_stage,
|
42 |
visible=not only_first_stage)
|
|
|
|
|
|
|
43 |
run_button = gr.Button('Run')
|
44 |
|
45 |
with gr.Column():
|
@@ -50,10 +53,10 @@ def main():
|
|
50 |
result_video = gr.Video(show_label=False)
|
51 |
|
52 |
examples = gr.Examples(
|
53 |
-
examples=[['骑滑板的皮卡丘', False, 1234, True],
|
54 |
-
['a cat playing chess', True, 1253, True]],
|
55 |
fn=model.run_with_translation,
|
56 |
-
inputs=[text, translate, seed, only_first_stage],
|
57 |
outputs=[translated_text, result_video],
|
58 |
cache_examples=True)
|
59 |
|
@@ -66,6 +69,7 @@ def main():
|
|
66 |
translate,
|
67 |
seed,
|
68 |
only_first_stage,
|
|
|
69 |
],
|
70 |
outputs=[translated_text, result_video])
|
71 |
|
|
|
40 |
label='Only First Stage',
|
41 |
value=only_first_stage,
|
42 |
visible=not only_first_stage)
|
43 |
+
image_prompt = gr.Image(type="filepath"
|
44 |
+
label="Image Prompt",
|
45 |
+
value=None)
|
46 |
run_button = gr.Button('Run')
|
47 |
|
48 |
with gr.Column():
|
|
|
53 |
result_video = gr.Video(show_label=False)
|
54 |
|
55 |
examples = gr.Examples(
|
56 |
+
examples=[['骑滑板的皮卡丘', False, 1234, True,None],
|
57 |
+
['a cat playing chess', True, 1253, True,None]],
|
58 |
fn=model.run_with_translation,
|
59 |
+
inputs=[text, translate, seed, only_first_stage,image_prompt],
|
60 |
outputs=[translated_text, result_video],
|
61 |
cache_examples=True)
|
62 |
|
|
|
69 |
translate,
|
70 |
seed,
|
71 |
only_first_stage,
|
72 |
+
image_prompt
|
73 |
],
|
74 |
outputs=[translated_text, result_video])
|
75 |
|
model.py
CHANGED
@@ -796,7 +796,8 @@ class Model:
|
|
796 |
video_raw_text=None,
|
797 |
video_guidance_text='视频',
|
798 |
image_text_suffix='',
|
799 |
-
batch_size=1
|
|
|
800 |
process_start_time = time.perf_counter()
|
801 |
|
802 |
generate_frame_num = self.args.generate_frame_num
|
@@ -828,33 +829,36 @@ class Model:
|
|
828 |
|
829 |
seq_1st = torch.tensor(seq_1st, dtype=torch.long,
|
830 |
device=self.device).unsqueeze(0)
|
831 |
-
|
832 |
-
|
833 |
-
|
834 |
-
|
835 |
-
|
836 |
-
|
837 |
-
|
838 |
-
|
839 |
-
|
840 |
-
|
841 |
-
|
842 |
-
|
843 |
-
|
844 |
-
|
845 |
-
|
846 |
-
|
847 |
-
|
848 |
-
|
849 |
-
|
850 |
-
|
851 |
-
|
852 |
-
|
853 |
-
|
854 |
-
|
855 |
-
|
856 |
-
|
857 |
-
|
|
|
|
|
|
|
858 |
|
859 |
# generate subsequent frames:
|
860 |
total_frames = generate_frame_num
|
@@ -1167,7 +1171,7 @@ class Model:
|
|
1167 |
1, 2, 0).to(torch.uint8).numpy()
|
1168 |
|
1169 |
def run(self, text: str, seed: int,
|
1170 |
-
only_first_stage: bool) -> list[np.ndarray]:
|
1171 |
logger.info('==================== run ====================')
|
1172 |
start = time.perf_counter()
|
1173 |
|
@@ -1188,7 +1192,8 @@ class Model:
|
|
1188 |
video_raw_text=text,
|
1189 |
video_guidance_text='视频',
|
1190 |
image_text_suffix=' 高清摄影',
|
1191 |
-
batch_size=self.args.batch_size
|
|
|
1192 |
if not only_first_stage:
|
1193 |
_, res = self.process_stage2(
|
1194 |
self.model_stage2,
|
@@ -1226,12 +1231,13 @@ class AppModel(Model):
|
|
1226 |
|
1227 |
def run_with_translation(
|
1228 |
self, text: str, translate: bool, seed: int,
|
1229 |
-
only_first_stage: bool) -> tuple[str | None, str | None]
|
1230 |
-
|
|
|
1231 |
if translate:
|
1232 |
text = translated_text = self.translator(text)
|
1233 |
else:
|
1234 |
translated_text = None
|
1235 |
-
frames = self.run(text, seed, only_first_stage)
|
1236 |
video_path = self.to_video(frames)
|
1237 |
return translated_text, video_path
|
|
|
796 |
video_raw_text=None,
|
797 |
video_guidance_text='视频',
|
798 |
image_text_suffix='',
|
799 |
+
batch_size=1,
|
800 |
+
image_prompt):
|
801 |
process_start_time = time.perf_counter()
|
802 |
|
803 |
generate_frame_num = self.args.generate_frame_num
|
|
|
829 |
|
830 |
seq_1st = torch.tensor(seq_1st, dtype=torch.long,
|
831 |
device=self.device).unsqueeze(0)
|
832 |
+
if self.image_prompt is None:
|
833 |
+
output_list_1st = []
|
834 |
+
for tim in range(max(batch_size // mbz, 1)):
|
835 |
+
start_time = time.perf_counter()
|
836 |
+
output_list_1st.append(
|
837 |
+
my_filling_sequence(
|
838 |
+
model,
|
839 |
+
tokenizer,
|
840 |
+
self.args,
|
841 |
+
seq_1st.clone(),
|
842 |
+
batch_size=min(batch_size, mbz),
|
843 |
+
get_masks_and_position_ids=
|
844 |
+
get_masks_and_position_ids_stage1,
|
845 |
+
text_len=text_len_1st,
|
846 |
+
frame_len=frame_len,
|
847 |
+
strategy=self.strategy_cogview2,
|
848 |
+
strategy2=self.strategy_cogvideo,
|
849 |
+
log_text_attention_weights=1.4,
|
850 |
+
enforce_no_swin=True,
|
851 |
+
mode_stage1=True,
|
852 |
+
)[0])
|
853 |
+
elapsed = time.perf_counter() - start_time
|
854 |
+
logger.info(f'[First Frame] Elapsed: {elapsed:.2f}')
|
855 |
+
output_tokens_1st = torch.cat(output_list_1st, dim=0)
|
856 |
+
given_tokens = output_tokens_1st[:, text_len_1st + 1:text_len_1st +
|
857 |
+
401].unsqueeze(
|
858 |
+
1
|
859 |
+
) # given_tokens.shape: [bs, frame_num, 400]
|
860 |
+
else:
|
861 |
+
given_tokens = tokenizer.encode(image_path=self.image_prompt, image_size=160).repeat(batch_size, 1).unsqueeze(1)
|
862 |
|
863 |
# generate subsequent frames:
|
864 |
total_frames = generate_frame_num
|
|
|
1171 |
1, 2, 0).to(torch.uint8).numpy()
|
1172 |
|
1173 |
def run(self, text: str, seed: int,
|
1174 |
+
only_first_stage: bool,image_prompt: None) -> list[np.ndarray]:
|
1175 |
logger.info('==================== run ====================')
|
1176 |
start = time.perf_counter()
|
1177 |
|
|
|
1192 |
video_raw_text=text,
|
1193 |
video_guidance_text='视频',
|
1194 |
image_text_suffix=' 高清摄影',
|
1195 |
+
batch_size=self.args.batch_size
|
1196 |
+
image_prompt=image_prompt)
|
1197 |
if not only_first_stage:
|
1198 |
_, res = self.process_stage2(
|
1199 |
self.model_stage2,
|
|
|
1231 |
|
1232 |
def run_with_translation(
|
1233 |
self, text: str, translate: bool, seed: int,
|
1234 |
+
only_first_stage: bool,image_prompt: None) -> tuple[str | None, str | None],
|
1235 |
+
|
1236 |
+
logger.info(f'{text=}, {translate=}, {seed=}, {only_first_stage=},{image_prompt=}')
|
1237 |
if translate:
|
1238 |
text = translated_text = self.translator(text)
|
1239 |
else:
|
1240 |
translated_text = None
|
1241 |
+
frames = self.run(text, seed, only_first_stage,image_prompt)
|
1242 |
video_path = self.to_video(frames)
|
1243 |
return translated_text, video_path
|