import gradio as gr import os from omegaconf import OmegaConf,ListConfig import spaces from train import main as train_main from inference import inference as inference_main import transformers transformers.utils.move_cache() @spaces.GPU() def inference_app( embedding_dir, prompt, video_round, save_dir, motion_type, seed, inference_steps): print('inference info:') print('ref video:',embedding_dir) print('prompt:',prompt) print('motion type:',motion_type) print('infer steps:',inference_steps) return inference_main( embedding_dir=embedding_dir, prompt=prompt, video_round=video_round, save_dir=save_dir, motion_type=motion_type, seed=seed, inference_steps=inference_steps ) def train_model(video, config): output_dir = 'results' os.makedirs(output_dir, exist_ok=True) cur_save_dir = os.path.join(output_dir, 'custom') config.dataset.single_video_path = video config.train.output_dir = cur_save_dir # copy video to cur_save_dir video_name = 'source.mp4' video_path = os.path.join(cur_save_dir, video_name) os.system(f"cp {video} {video_path}") train_main(config) # cur_save_dir = 'results/06' return cur_save_dir def inference_model(text, checkpoint, inference_steps, video_type,seed): checkpoint = os.path.join('results',checkpoint) embedding_dir = '/'.join(checkpoint.split('/')[:-1]) video_round = checkpoint.split('/')[-1] video_path = inference_app( embedding_dir=embedding_dir, prompt=text, video_round=video_round, save_dir=os.path.join('outputs',embedding_dir.split('/')[-1]), motion_type=video_type, seed=seed, inference_steps=inference_steps ) return video_path def get_checkpoints(checkpoint_dir): checkpoints = [] for root, dirs, files in os.walk(checkpoint_dir): for file in files: if file == 'motion_embed.pt': checkpoints.append('/'.join(root.split('/')[-2:])) return checkpoints def extract_combinations(motion_embeddings_combinations): assert len(motion_embeddings_combinations) > 0, "At least one motion embedding combination is required" combinations = [] for combination in motion_embeddings_combinations: name, resolution = combination.split(" ") combinations.append([name, int(resolution)]) return combinations def generate_config_train(motion_embeddings_combinations, unet, checkpointing_steps, max_train_steps): default_config = OmegaConf.load('configs/config.yaml') default_config.model.motion_embeddings.combinations = ListConfig(extract_combinations(motion_embeddings_combinations)) default_config.model.unet = unet default_config.train.checkpointing_steps = checkpointing_steps default_config.train.max_train_steps = max_train_steps return default_config def generate_config_inference(motion_embeddings_combinations, unet, checkpointing_steps, max_train_steps): default_config = OmegaConf.load('configs/config.yaml') default_config.model.motion_embeddings.combinations = ListConfig(extract_combinations(motion_embeddings_combinations)) default_config.model.unet = unet default_config.train.checkpointing_steps = checkpointing_steps default_config.train.max_train_steps = max_train_steps return default_config def update_preview_video(checkpoint_dir): # get the parent dir of the checkpoint parent_dir = '/'.join(checkpoint_dir.split('/')[:-1]) return gr.update(value=f'results/{parent_dir}/source.mp4') def update_generated_prompt(text): return gr.update(value=text) if __name__ == "__main__": if os.path.exists('results/custom'): os.system('rm -rf results/custom') if os.path.exists('outputs'): os.system('rm -rf outputs') inject_motion_embeddings_combinations = ['down 1280','up 1280','down 640','up 640'] default_motion_embeddings_combinations = ['down 1280','up 1280'] examples_inference = [ ['results/pan_up/source.mp4', 'A flora garden.', 'camera', 'pan_up/checkpoint'], ['results/dolly_zoom/source.mp4','A firefighter standing in front of a burning forest captured with a dolly zoom.','camera','dolly_zoom/checkpoint'], ['results/orbit_shot/source.mp4','A micro graden with orbit shot','camera','orbit_shot/checkpoint'], ['results/walk/source.mp4', 'A elephant walking in desert', 'object', 'walk/checkpoint'], ['results/santa_dance/source.mp4','A skeleton in suit is dancing with his hands','object','santa_dance/checkpoint'], ['results/car_turn/source.mp4','A toy train chugs around a roundabout tree','object','car_turn/checkpoint'], ['results/train_ride/source.mp4','A motorbike driving in a forest','object','train_ride/checkpoint'], ] gradio_theme = gr.themes.Default() with gr.Blocks( theme=gradio_theme, title="Motion Inversion", css=""" #download { height: 118px; } .slider .inner { width: 5px; background: #FFF; } .viewport { aspect-ratio: 4/3; } .tabs button.selected { font-size: 20px !important; color: crimson !important; } h1 { text-align: center; display: block; } h2 { text-align: center; display: block; } h3 { text-align: center; display: block; } .md_feedback li { margin-bottom: 0px !important; } """, head=""" """, ) as demo: gr.Markdown( """ # Motion Inversion for Video Customization


Please consider starring the GitHub Repo if you find this useful!

""" ) with gr.Tabs(elem_classes=["tabs"]): with gr.Row(): with gr.Column(): preview_video = gr.Video(label="Preview Video") text_input = gr.Textbox(label="Input Text") checkpoint_dropdown = gr.Dropdown(label="Select Checkpoint", choices=get_checkpoints('results')) seed = gr.Number(label="Seed", value=0) inference_button = gr.Button("Generate Video") with gr.Column(): output_video = gr.Video(label="Output Video") generated_prompt = gr.Textbox(label="Generated Prompt") with gr.Accordion('Encounter Errors', open=False): gr.Markdown(''' Generally, inference time for one video often takes 45~50s on ZeroGPU.
You have exceeded your GPU quota: A limitation set by HF. Retry in an hour.
GPU task aborted: Possibly caused by ZeroGPU being used by too many people, the inference time excceeds the time limit. You may try again later, or clone the repo and run it locally.
If any other issues occur, please feel free to contact us through the community or by email (ziyangmai06@gmail.com). We will try our best to help you :) ''') with gr.Accordion("Advanced Settings", open=False): with gr.Row(): inference_steps = gr.Number(label="Inference Steps", value=30) motion_type = gr.Dropdown(label="Motion Type", choices=["camera", "object"], value="object") gr.Examples(examples=examples_inference,inputs=[preview_video,text_input,motion_type,checkpoint_dropdown]) checkpoint_dropdown.change(fn=update_preview_video, inputs=checkpoint_dropdown, outputs=preview_video) inference_button.click(inference_model, inputs=[text_input, checkpoint_dropdown,inference_steps,motion_type, seed], outputs=output_video) output_video.change(fn=update_generated_prompt, inputs=[text_input], outputs=generated_prompt) demo.queue( api_open=False, ).launch( server_name="0.0.0.0", server_port=7860, )