Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import os | |
from omegaconf import OmegaConf,ListConfig | |
import spaces | |
from train import main as train_main | |
from inference import inference as inference_main | |
import transformers | |
transformers.utils.move_cache() | |
def inference_app( | |
embedding_dir, | |
prompt, | |
video_round, | |
save_dir, | |
motion_type, | |
seed, | |
inference_steps): | |
print('inference info:') | |
print('ref video:',embedding_dir) | |
print('prompt:',prompt) | |
print('motion type:',motion_type) | |
print('infer steps:',inference_steps) | |
return inference_main( | |
embedding_dir=embedding_dir, | |
prompt=prompt, | |
video_round=video_round, | |
save_dir=save_dir, | |
motion_type=motion_type, | |
seed=seed, | |
inference_steps=inference_steps | |
) | |
def train_model(video, config): | |
output_dir = 'results' | |
os.makedirs(output_dir, exist_ok=True) | |
cur_save_dir = os.path.join(output_dir, 'custom') | |
config.dataset.single_video_path = video | |
config.train.output_dir = cur_save_dir | |
# copy video to cur_save_dir | |
video_name = 'source.mp4' | |
video_path = os.path.join(cur_save_dir, video_name) | |
os.system(f"cp {video} {video_path}") | |
train_main(config) | |
# cur_save_dir = 'results/06' | |
return cur_save_dir | |
def inference_model(text, checkpoint, inference_steps, video_type,seed): | |
checkpoint = os.path.join('results',checkpoint) | |
embedding_dir = '/'.join(checkpoint.split('/')[:-1]) | |
video_round = checkpoint.split('/')[-1] | |
video_path = inference_app( | |
embedding_dir=embedding_dir, | |
prompt=text, | |
video_round=video_round, | |
save_dir=os.path.join('outputs',embedding_dir.split('/')[-1]), | |
motion_type=video_type, | |
seed=seed, | |
inference_steps=inference_steps | |
) | |
return video_path | |
def get_checkpoints(checkpoint_dir): | |
checkpoints = [] | |
for root, dirs, files in os.walk(checkpoint_dir): | |
for file in files: | |
if file == 'motion_embed.pt': | |
checkpoints.append('/'.join(root.split('/')[-2:])) | |
return checkpoints | |
def extract_combinations(motion_embeddings_combinations): | |
assert len(motion_embeddings_combinations) > 0, "At least one motion embedding combination is required" | |
combinations = [] | |
for combination in motion_embeddings_combinations: | |
name, resolution = combination.split(" ") | |
combinations.append([name, int(resolution)]) | |
return combinations | |
def generate_config_train(motion_embeddings_combinations, unet, checkpointing_steps, max_train_steps): | |
default_config = OmegaConf.load('configs/config.yaml') | |
default_config.model.motion_embeddings.combinations = ListConfig(extract_combinations(motion_embeddings_combinations)) | |
default_config.model.unet = unet | |
default_config.train.checkpointing_steps = checkpointing_steps | |
default_config.train.max_train_steps = max_train_steps | |
return default_config | |
def generate_config_inference(motion_embeddings_combinations, unet, checkpointing_steps, max_train_steps): | |
default_config = OmegaConf.load('configs/config.yaml') | |
default_config.model.motion_embeddings.combinations = ListConfig(extract_combinations(motion_embeddings_combinations)) | |
default_config.model.unet = unet | |
default_config.train.checkpointing_steps = checkpointing_steps | |
default_config.train.max_train_steps = max_train_steps | |
return default_config | |
def update_preview_video(checkpoint_dir): | |
# get the parent dir of the checkpoint | |
parent_dir = '/'.join(checkpoint_dir.split('/')[:-1]) | |
return gr.update(value=f'results/{parent_dir}/source.mp4') | |
def update_generated_prompt(text): | |
return gr.update(value=text) | |
if __name__ == "__main__": | |
if os.path.exists('results/custom'): | |
os.system('rm -rf results/custom') | |
if os.path.exists('outputs'): | |
os.system('rm -rf outputs') | |
inject_motion_embeddings_combinations = ['down 1280','up 1280','down 640','up 640'] | |
default_motion_embeddings_combinations = ['down 1280','up 1280'] | |
examples_inference = [ | |
['results/pan_up/source.mp4', 'A flora garden.', 'camera', 'pan_up/checkpoint'], | |
['results/dolly_zoom/source.mp4','A firefighter standing in front of a burning forest captured with a dolly zoom.','camera','dolly_zoom/checkpoint'], | |
['results/orbit_shot/source.mp4','A micro graden with orbit shot','camera','orbit_shot/checkpoint'], | |
['results/walk/source.mp4', 'A elephant walking in desert', 'object', 'walk/checkpoint'], | |
['results/santa_dance/source.mp4','A skeleton in suit is dancing with his hands','object','santa_dance/checkpoint'], | |
['results/car_turn/source.mp4','A toy train chugs around a roundabout tree','object','car_turn/checkpoint'], | |
['results/train_ride/source.mp4','A motorbike driving in a forest','object','train_ride/checkpoint'], | |
] | |
gradio_theme = gr.themes.Default() | |
with gr.Blocks( | |
theme=gradio_theme, | |
title="Motion Inversion", | |
css=""" | |
#download { | |
height: 118px; | |
} | |
.slider .inner { | |
width: 5px; | |
background: #FFF; | |
} | |
.viewport { | |
aspect-ratio: 4/3; | |
} | |
.tabs button.selected { | |
font-size: 20px !important; | |
color: crimson !important; | |
} | |
h1 { | |
text-align: center; | |
display: block; | |
} | |
h2 { | |
text-align: center; | |
display: block; | |
} | |
h3 { | |
text-align: center; | |
display: block; | |
} | |
.md_feedback li { | |
margin-bottom: 0px !important; | |
} | |
""", | |
head=""" | |
<script async src="https://www.googletagmanager.com/gtag/js?id=G-1FWSVCGZTG"></script> | |
<script> | |
window.dataLayer = window.dataLayer || []; | |
function gtag() {dataLayer.push(arguments);} | |
gtag('js', new Date()); | |
gtag('config', 'G-1FWSVCGZTG'); | |
</script> | |
""", | |
) as demo: | |
gr.Markdown( | |
""" | |
# Motion Inversion for Video Customization | |
<p align="center"> | |
<a href="https://arxiv.org/abs/2403.20193"><img src='https://img.shields.io/badge/arXiv-2403.20193-b31b1b.svg'></a> | |
<a href=''><img src='https://img.shields.io/badge/Project_Page-MotionInversion(Coming soon)-blue'></a> | |
<a href='https://github.com/EnVision-Research/MotionInversion'><img src='https://img.shields.io/github/stars/EnVision-Research/MotionInversion?label=GitHub%20%E2%98%85&logo=github&color=C8C'></a> | |
<br> | |
<strong>Please consider starring <span style="color: orange">★</span> the <a href="https://github.com/EnVision-Research/MotionInversion" target="_blank" rel="noopener noreferrer">GitHub Repo</a> if you find this useful!</strong> | |
</p> | |
""" | |
) | |
with gr.Tabs(elem_classes=["tabs"]): | |
with gr.Row(): | |
with gr.Column(): | |
preview_video = gr.Video(label="Preview Video") | |
text_input = gr.Textbox(label="Input Text") | |
checkpoint_dropdown = gr.Dropdown(label="Select Checkpoint", choices=get_checkpoints('results')) | |
seed = gr.Number(label="Seed", value=0) | |
inference_button = gr.Button("Generate Video") | |
with gr.Column(): | |
output_video = gr.Video(label="Output Video") | |
generated_prompt = gr.Textbox(label="Generated Prompt") | |
with gr.Accordion('Encounter Errors', open=False): | |
gr.Markdown(''' | |
<strong>Generally, inference time for one video often takes 45~50s on ZeroGPU</strong>. | |
<br> | |
<strong>You have exceeded your GPU quota</strong>: A limitation set by HF. Retry in an hour. | |
<br> | |
<strong>GPU task aborted</strong>: Possibly caused by ZeroGPU being used by too many people, the inference time excceeds the time limit. You may try again later, or clone the repo and run it locally. | |
<br> | |
If any other issues occur, please feel free to contact us through the community or by email ([email protected]). We will try our best to help you :) | |
''') | |
with gr.Accordion("Advanced Settings", open=False): | |
with gr.Row(): | |
inference_steps = gr.Number(label="Inference Steps", value=30) | |
motion_type = gr.Dropdown(label="Motion Type", choices=["camera", "object"], value="object") | |
gr.Examples(examples=examples_inference,inputs=[preview_video,text_input,motion_type,checkpoint_dropdown]) | |
checkpoint_dropdown.change(fn=update_preview_video, inputs=checkpoint_dropdown, outputs=preview_video) | |
inference_button.click(inference_model, inputs=[text_input, checkpoint_dropdown,inference_steps,motion_type, seed], outputs=output_video) | |
output_video.change(fn=update_generated_prompt, inputs=[text_input], outputs=generated_prompt) | |
demo.queue( | |
api_open=False, | |
).launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
) |