import gradio as gr import os import torch import tempfile import random import string import json from omegaconf import OmegaConf,ListConfig from train import main as train_main from inference import inference as inference_main # 模拟训练函数 def train_model(video, config): output_dir = 'results' os.makedirs(output_dir, exist_ok=True) cur_save_dir = os.path.join(output_dir, 'custom') config.dataset.single_video_path = video config.train.output_dir = cur_save_dir # copy video to cur_save_dir video_name = 'source.mp4' video_path = os.path.join(cur_save_dir, video_name) os.system(f"cp {video} {video_path}") train_main(config) # cur_save_dir = 'results/06' return cur_save_dir # 模拟推理函数 def inference_model(text, checkpoint, inference_steps, video_type,seed): checkpoint = os.path.join('results',checkpoint) embedding_dir = '/'.join(checkpoint.split('/')[:-1]) video_round = checkpoint.split('/')[-1] video_path = inference_main( embedding_dir=embedding_dir, prompt=text, video_round=video_round, save_dir=os.path.join('outputs',embedding_dir.split('/')[-1]), motion_type=video_type, seed=seed, inference_steps=inference_steps ) return video_path # 获取checkpoint文件列表 def get_checkpoints(checkpoint_dir): checkpoints = [] for root, dirs, files in os.walk(checkpoint_dir): for file in files: if file == 'motion_embed.pt': checkpoints.append('/'.join(root.split('/')[-2:])) return checkpoints def extract_combinations(motion_embeddings_combinations): assert len(motion_embeddings_combinations) > 0, "At least one motion embedding combination is required" combinations = [] for combination in motion_embeddings_combinations: name, resolution = combination.split(" ") combinations.append([name, int(resolution)]) return combinations def generate_config_train(motion_embeddings_combinations, unet, checkpointing_steps, max_train_steps): default_config = OmegaConf.load('configs/config.yaml') default_config.model.motion_embeddings.combinations = ListConfig(extract_combinations(motion_embeddings_combinations)) default_config.model.unet = unet default_config.train.checkpointing_steps = checkpointing_steps default_config.train.max_train_steps = max_train_steps return default_config def generate_config_inference(motion_embeddings_combinations, unet, checkpointing_steps, max_train_steps): default_config = OmegaConf.load('configs/config.yaml') default_config.model.motion_embeddings.combinations = ListConfig(extract_combinations(motion_embeddings_combinations)) default_config.model.unet = unet default_config.train.checkpointing_steps = checkpointing_steps default_config.train.max_train_steps = max_train_steps return default_config def update_preview_video(checkpoint_dir): # get the parent dir of the checkpoint parent_dir = '/'.join(checkpoint_dir.split('/')[:-1]) return gr.update(value=f'results/{parent_dir}/source.mp4') def update_generated_prompt(text): return gr.update(value=text) if __name__ == "__main__": if os.path.exists('results/custom'): os.system('rm -rf results/custom') if os.path.exists('outputs'): os.system('rm -rf outputs') inject_motion_embeddings_combinations = ['down 1280','up 1280','down 640','up 640'] default_motion_embeddings_combinations = ['down 1280','up 1280'] examples_train = [ 'assets/train/car_turn.mp4', 'assets/train/pan_up.mp4', 'assets/train/run_up.mp4', 'assets/train/train_ride.mp4', 'assets/train/orbit_shot.mp4', 'assets/train/dolly_zoom_out.mp4', 'assets/train/santa_dance.mp4', ] examples_inference = [ ['results/pan_up/source.mp4', 'A flora garden.', 'camera', 'pan_up/checkpoint'], ['results/dolly_zoom/source.mp4','A firefighter standing in front of a burning forest captured with a dolly zoom.','camera','dolly_zoom/checkpoint'], ['results/orbit_shot/source.mp4','A micro graden with orbit shot','camera','orbit_shot/checkpoint'], ['results/walk/source.mp4', 'A elephant walking in desert', 'object', 'walk/checkpoint'], ['results/santa_dance/source.mp4','A skeleton in suit is dancing with his hands','object','santa_dance/checkpoint'], ['results/car_turn/source.mp4','A toy train chugs around a roundabout tree','object','car_turn/checkpoint'], ['results/train_ride/source.mp4','A motorbike driving in a forest','object','train_ride/checkpoint'], ] gradio_theme = gr.themes.Default() # 创建Gradio界面 with gr.Blocks( theme=gradio_theme, title="Motion Inversion", css=""" #download { height: 118px; } .slider .inner { width: 5px; background: #FFF; } .viewport { aspect-ratio: 4/3; } .tabs button.selected { font-size: 20px !important; color: crimson !important; } h1 { text-align: center; display: block; } h2 { text-align: center; display: block; } h3 { text-align: center; display: block; } .md_feedback li { margin-bottom: 0px !important; } """, head=""" """, ) as demo: gr.Markdown( """ # Motion Inversion for Video Customization
""" ) with gr.Tab("Train"): with gr.Row(): with gr.Column(): video_input = gr.Video(label="Upload Video") train_button = gr.Button("Train") with gr.Column(): checkpoint_output = gr.Textbox(label="Checkpoint Directory") with gr.Accordion("Advanced Settings", open=False): with gr.Row(): motion_embeddings_combinations = gr.Dropdown(label="Motion Embeddings Combinations", choices=inject_motion_embeddings_combinations, multiselect=True,value=default_motion_embeddings_combinations) unet_dropdown = gr.Dropdown(label="Unet", choices=["videoCrafter2", "zeroscope_v2_576w"], value="videoCrafter2") checkpointing_steps = gr.Dropdown(label="Checkpointing Steps",choices=[100,50],value=100) max_train_steps = gr.Slider(label="Max Train Steps", minimum=200,maximum=500,value=200,step=50) # examples gr.Examples(examples=examples_train,inputs=[video_input]) train_button.click( lambda video, mec, u, cs, mts: train_model(video, generate_config_train(mec, u, cs, mts)), inputs=[video_input, motion_embeddings_combinations, unet_dropdown, checkpointing_steps, max_train_steps], outputs=checkpoint_output ) with gr.Tab("Inference"): with gr.Row(): with gr.Column(): preview_video = gr.Video(label="Preview Video") text_input = gr.Textbox(label="Input Text") checkpoint_dropdown = gr.Dropdown(label="Select Checkpoint", choices=get_checkpoints('results')) seed = gr.Number(label="Seed", value=0) inference_button = gr.Button("Generate Video") with gr.Column(): output_video = gr.Video(label="Output Video") generated_prompt = gr.Textbox(label="Generated Prompt") with gr.Accordion("Advanced Settings", open=False): with gr.Row(): inference_steps = gr.Number(label="Inference Steps", value=30) motion_type = gr.Dropdown(label="Motion Type", choices=["camera", "object"], value="object") gr.Examples(examples=examples_inference,inputs=[preview_video,text_input,motion_type,checkpoint_dropdown]) def update_checkpoints(checkpoint_dir): return gr.update(choices=get_checkpoints('results')) checkpoint_dropdown.change(fn=update_preview_video, inputs=checkpoint_dropdown, outputs=preview_video) checkpoint_output.change(update_checkpoints, inputs=checkpoint_output, outputs=checkpoint_dropdown) inference_button.click(inference_model, inputs=[text_input, checkpoint_dropdown,inference_steps,motion_type, seed], outputs=output_video) output_video.change(fn=update_generated_prompt, inputs=[text_input], outputs=generated_prompt) # 启动Gradio界面 demo.launch()