File size: 9,687 Bytes
d16004a
6dd3263
 
1ad8288
d16004a
 
6dd3263
 
1ad8288
8a72b5c
 
 
 
1ad8288
 
 
 
 
 
 
 
 
 
c9ddddb
 
 
 
 
 
dce64a0
 
 
 
 
 
 
 
 
1ad8288
 
6dd3263
 
 
113884e
d16004a
6dd3263
 
 
 
 
 
 
d16004a
6dd3263
 
 
d16004a
1ad8288
6dd3263
 
 
d16004a
6dd3263
 
d16004a
1ad8288
6dd3263
 
 
 
 
 
 
 
d16004a
6dd3263
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d16004a
 
184537d
 
 
 
d16004a
113884e
 
 
 
 
 
6dd3263
 
 
 
 
 
b97534f
 
6dd3263
fc3b37c
b97534f
6dd3263
b97534f
6dd3263
 
184537d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac1a12f
184537d
 
2c982f9
 
 
 
 
 
 
 
184537d
 
ffb7d36
 
 
 
 
 
 
 
6dd3263
ffb7d36
 
 
 
6dd3263
c9ddddb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ffb7d36
 
 
 
6dd3263
c9ddddb
6dd3263
f869cca
 
 
ac1a12f
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
import gradio as gr
import os
from omegaconf import OmegaConf,ListConfig
import spaces


from train import main as train_main
from inference import inference as inference_main

import transformers
transformers.utils.move_cache()


@spaces.GPU()
def inference_app(
        embedding_dir,
        prompt, 
        video_round,
        save_dir,
        motion_type,
        seed,
        inference_steps):
    
    print('inference info:')
    print('ref video:',embedding_dir)
    print('prompt:',prompt)
    print('motion type:',motion_type)
    print('infer steps:',inference_steps)

    return inference_main(
        embedding_dir=embedding_dir,
        prompt=prompt, 
        video_round=video_round,
        save_dir=save_dir,
        motion_type=motion_type,
        seed=seed,
        inference_steps=inference_steps
        )


def train_model(video, config):
    output_dir = 'results'
    os.makedirs(output_dir, exist_ok=True)
    cur_save_dir = os.path.join(output_dir, 'custom')

    config.dataset.single_video_path = video
    config.train.output_dir = cur_save_dir
    
    # copy video to cur_save_dir
    video_name = 'source.mp4'
    video_path = os.path.join(cur_save_dir, video_name)
    os.system(f"cp {video} {video_path}")

    train_main(config)
    # cur_save_dir = 'results/06'
    return cur_save_dir


def inference_model(text, checkpoint, inference_steps, video_type,seed):
    
    checkpoint = os.path.join('results',checkpoint)

    embedding_dir = '/'.join(checkpoint.split('/')[:-1])
    video_round = checkpoint.split('/')[-1]

    video_path = inference_app(
        embedding_dir=embedding_dir,
        prompt=text, 
        video_round=video_round,
        save_dir=os.path.join('outputs',embedding_dir.split('/')[-1]),
        motion_type=video_type,
        seed=seed,
        inference_steps=inference_steps
        )

    return video_path


def get_checkpoints(checkpoint_dir):
    
    checkpoints = []
    for root, dirs, files in os.walk(checkpoint_dir):
        for file in files:
            if file == 'motion_embed.pt':
                checkpoints.append('/'.join(root.split('/')[-2:]))
    return checkpoints


def extract_combinations(motion_embeddings_combinations):
    assert len(motion_embeddings_combinations) > 0, "At least one motion embedding combination is required"
    combinations = []
    for combination in motion_embeddings_combinations:
        name, resolution = combination.split(" ")
        combinations.append([name, int(resolution)])
    return combinations


def generate_config_train(motion_embeddings_combinations, unet, checkpointing_steps, max_train_steps):

    default_config = OmegaConf.load('configs/config.yaml')

    default_config.model.motion_embeddings.combinations = ListConfig(extract_combinations(motion_embeddings_combinations))
    default_config.model.unet = unet
    default_config.train.checkpointing_steps = checkpointing_steps
    default_config.train.max_train_steps = max_train_steps

    return default_config


def generate_config_inference(motion_embeddings_combinations, unet, checkpointing_steps, max_train_steps):

    default_config = OmegaConf.load('configs/config.yaml')

    default_config.model.motion_embeddings.combinations = ListConfig(extract_combinations(motion_embeddings_combinations))
    default_config.model.unet = unet
    default_config.train.checkpointing_steps = checkpointing_steps
    default_config.train.max_train_steps = max_train_steps

    return default_config


def update_preview_video(checkpoint_dir):
    # get the parent dir of the checkpoint
    parent_dir = '/'.join(checkpoint_dir.split('/')[:-1])
    return gr.update(value=f'results/{parent_dir}/source.mp4')


def update_generated_prompt(text):
    return gr.update(value=text)


if __name__ == "__main__":

    if os.path.exists('results/custom'):
        os.system('rm -rf results/custom')
    if os.path.exists('outputs'):
        os.system('rm -rf outputs')

    inject_motion_embeddings_combinations = ['down 1280','up 1280','down 640','up 640']
    default_motion_embeddings_combinations = ['down 1280','up 1280']


    examples_inference = [
        ['results/pan_up/source.mp4', 'A flora garden.', 'camera', 'pan_up/checkpoint'],
        ['results/dolly_zoom/source.mp4','A firefighter standing in front of a burning forest captured with a dolly zoom.','camera','dolly_zoom/checkpoint'],
        ['results/orbit_shot/source.mp4','A micro graden with orbit shot','camera','orbit_shot/checkpoint'],

        ['results/walk/source.mp4', 'A elephant walking in desert', 'object', 'walk/checkpoint'],
        ['results/santa_dance/source.mp4','A skeleton in suit is dancing with his hands','object','santa_dance/checkpoint'],
        ['results/car_turn/source.mp4','A toy train chugs around a roundabout tree','object','car_turn/checkpoint'],
        ['results/train_ride/source.mp4','A motorbike driving in a forest','object','train_ride/checkpoint'], 
    ]

    gradio_theme = gr.themes.Default()
    with gr.Blocks(
        theme=gradio_theme,
        title="Motion Inversion",
        css="""
            #download {
                height: 118px;
            }
            .slider .inner {
                width: 5px;
                background: #FFF;
            }
            .viewport {
                aspect-ratio: 4/3;
            }
            .tabs button.selected {
                font-size: 20px !important;
                color: crimson !important;
            }
            h1 {
                text-align: center;
                display: block;
            }
            h2 {
                text-align: center;
                display: block;
            }
            h3 {
                text-align: center;
                display: block;
            }
            .md_feedback li {
                margin-bottom: 0px !important;
            }
        """,
        head="""
            <script async src="https://www.googletagmanager.com/gtag/js?id=G-1FWSVCGZTG"></script>
            <script>
                window.dataLayer = window.dataLayer || [];
                function gtag() {dataLayer.push(arguments);}
                gtag('js', new Date());
                gtag('config', 'G-1FWSVCGZTG');
            </script>
        """,
    ) as demo:
        
        gr.Markdown(
            """
# Motion Inversion for Video Customization
<p align="center">
<a href="https://arxiv.org/abs/2403.20193"><img src='https://img.shields.io/badge/arXiv-2403.20193-b31b1b.svg'></a>
<a href=''><img src='https://img.shields.io/badge/Project_Page-MotionInversion(Coming soon)-blue'></a>
<a href='https://github.com/EnVision-Research/MotionInversion'><img src='https://img.shields.io/github/stars/EnVision-Research/MotionInversion?label=GitHub%20%E2%98%85&logo=github&color=C8C'></a>
<br>
<strong>Please consider starring <span style="color: orange">&#9733;</span> the <a href="https://github.com/EnVision-Research/MotionInversion" target="_blank" rel="noopener noreferrer">GitHub Repo</a> if you find this useful!</strong>
</p>
        """
        )
        with gr.Tabs(elem_classes=["tabs"]):
            with gr.Row():
                with gr.Column():
                    preview_video = gr.Video(label="Preview Video")
                    text_input = gr.Textbox(label="Input Text")
                    checkpoint_dropdown = gr.Dropdown(label="Select Checkpoint", choices=get_checkpoints('results'))
                    seed = gr.Number(label="Seed", value=0)
                    inference_button = gr.Button("Generate Video")
                
                with gr.Column():
                    
                    output_video = gr.Video(label="Output Video")
                    generated_prompt = gr.Textbox(label="Generated Prompt")

                    with gr.Accordion('Encounter Errors', open=False):
                        gr.Markdown('''
                                    <strong>Generally, inference time for one video often takes 45~50s on ZeroGPU</strong>.

                                    <br>
                                    <strong>You have exceeded your GPU quota</strong>: A limitation set by HF. Retry in an hour.           
                                    <br>
                                    <strong>GPU task aborted</strong>: Possibly caused by ZeroGPU being used by too many people, the inference time excceeds the time limit. You may try again later, or clone the repo and run it locally. 
                                    <br>
                                    
                                    If any other issues occur, please feel free to contact us through the community or by email ([email protected]). We will try our best to help you :)

                                    ''')


            with gr.Accordion("Advanced Settings", open=False):
                with gr.Row():
                    inference_steps = gr.Number(label="Inference Steps", value=30)
                    motion_type = gr.Dropdown(label="Motion Type", choices=["camera", "object"], value="object")

        gr.Examples(examples=examples_inference,inputs=[preview_video,text_input,motion_type,checkpoint_dropdown])

        checkpoint_dropdown.change(fn=update_preview_video, inputs=checkpoint_dropdown, outputs=preview_video)
        inference_button.click(inference_model, inputs=[text_input, checkpoint_dropdown,inference_steps,motion_type, seed], outputs=output_video)
        output_video.change(fn=update_generated_prompt, inputs=[text_input], outputs=generated_prompt)
        
        demo.queue(
            api_open=False,
        ).launch(
            server_name="0.0.0.0",
            server_port=7860,
        )