File size: 8,042 Bytes
0a9bdfb
3c0f460
a874577
 
0a9bdfb
 
3c0f460
0a9bdfb
a874577
0a9bdfb
3c0f460
902b0d6
3c0f460
 
 
 
 
 
 
 
699c0d5
 
0a9bdfb
7ecc5a8
 
 
 
 
 
0a9bdfb
 
7405324
0a9bdfb
 
 
 
7ecc5a8
19b4f44
0a9bdfb
8eda133
7ecc5a8
0a9bdfb
 
 
 
 
 
 
 
7405324
0a9bdfb
3f07ee0
 
3cf17e5
3f07ee0
 
 
 
 
 
 
 
7405324
7ecc5a8
0a9bdfb
 
 
 
 
 
7ecc5a8
19b4f44
0a9bdfb
7ecc5a8
 
0a9bdfb
 
 
 
 
 
 
 
86760f1
0a9bdfb
 
 
 
 
 
 
 
 
 
 
 
 
 
7ecc5a8
7405324
 
 
0a9bdfb
7ecc5a8
 
 
 
0a9bdfb
 
7405324
 
 
 
19b4f44
 
 
 
 
 
 
 
7405324
 
 
 
0a9bdfb
 
902b0d6
 
 
1190e23
 
 
3c0f460
 
d0257e3
699c0d5
902b0d6
3c0f460
 
 
1190e23
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import gradio as gr
import argparse
import os

from musepose_inference import MusePoseInference
from pose_align import PoseAlignmentInference
from downloading_weights import download_models


class App:
    def __init__(self, args):
        self.args = args
        self.pose_alignment_infer = PoseAlignmentInference(
            model_dir=args.model_dir,
            output_dir=args.output_dir
        )
        self.musepose_infer = MusePoseInference(
            model_dir=args.model_dir,
            output_dir=args.output_dir
        )
        if not args.disable_model_download_at_start:
            download_models(model_dir=args.model_dir)

    @staticmethod
    def on_step1_complete(input_img: str, input_pose_vid: str):

        return [gr.Image(label="Input Image", value=input_img, type="filepath", scale=5),
                gr.Video(label="Input Aligned Pose Video", value=input_pose_vid, scale=5)]

    def musepose_demo(self):
        with gr.Blocks() as demo:
            md_header = self.header()
            with gr.Tabs():
                with gr.TabItem('Step1: Pose Alignment'):
                    with gr.Row():
                        with gr.Column(scale=3):
                            img_pose_input = gr.Image(label="Input Image", type="filepath", scale=5)
                            vid_dance_input = gr.Video(label="Input Dance Video", max_length=4, scale=5)
                        with gr.Column(scale=3):
                            vid_dance_output = gr.Video(label="Aligned Pose Output", scale=5, interactive=False)
                            vid_dance_output_demo = gr.Video(label="Aligned Pose Output Demo", scale=5)
                        with gr.Column(scale=3):
                            with gr.Column():
                                nb_detect_resolution = gr.Number(label="Detect Resolution", value=512, precision=0)
                                nb_image_resolution = gr.Number(label="Image Resolution.", value=720, precision=0)
                                nb_align_frame = gr.Number(label="Align Frame", value=0, precision=0)
                                nb_max_frame = gr.Number(label="Max Frame", value=300, precision=0)

                            with gr.Row():
                                btn_align_pose = gr.Button("ALIGN POSE", variant="primary")

                    with gr.Column():
                        examples = [
                            [os.path.join("examples", "dance.mp4"), os.path.join("examples", "ref.png"),
                             512, 720, 0, 300]]
                        ex_step1 = gr.Examples(examples=examples,
                                               inputs=[vid_dance_input, img_pose_input, nb_detect_resolution,
                                                       nb_image_resolution, nb_align_frame, nb_max_frame],
                                               outputs=[vid_dance_output, vid_dance_output_demo],
                                               fn=self.pose_alignment_infer.align_pose,
                                               cache_examples="lazy")

                btn_align_pose.click(fn=self.pose_alignment_infer.align_pose,
                                     inputs=[vid_dance_input, img_pose_input, nb_detect_resolution, nb_image_resolution,
                                             nb_align_frame, nb_max_frame],
                                     outputs=[vid_dance_output, vid_dance_output_demo])

                with gr.TabItem('Step2: MusePose Inference'):
                    with gr.Row():
                        with gr.Column(scale=3):
                            img_musepose_input = gr.Image(label="Input Image", type="filepath", scale=5)
                            vid_pose_input = gr.Video(label="Input Aligned Pose Video", max_length=4, scale=5)
                        with gr.Column(scale=3):
                            vid_output = gr.Video(label="MusePose Output", scale=5)
                            vid_output_demo = gr.Video(label="MusePose Output Demo", scale=5)

                        with gr.Column(scale=3):
                            with gr.Column():
                                weight_dtype = gr.Dropdown(label="Compute Type", choices=["fp16", "fp32"],
                                                           value="fp16")
                                nb_width = gr.Number(label="Width.", value=512, precision=0)
                                nb_height = gr.Number(label="Height.", value=512, precision=0)
                                nb_video_frame_length = gr.Number(label="Video Frame Length", value=300, precision=0)
                                nb_video_slice_frame_length = gr.Number(label="Video Slice Frame Number ", value=48,
                                                                        precision=0)
                                nb_video_slice_overlap_frame_number = gr.Number(
                                    label="Video Slice Overlap Frame Number", value=4, precision=0)
                                nb_cfg = gr.Number(label="CFG (Classifier Free Guidance)", value=3.5, precision=0)
                                nb_seed = gr.Number(label="Seed", value=99, precision=0)
                                nb_steps = gr.Number(label="DDIM Sampling Steps", value=20, precision=0)
                                nb_fps = gr.Number(label="FPS (Frames Per Second) ", value=-1, precision=0,
                                                   info="Set to '-1' to use same FPS with pose's")
                                nb_skip = gr.Number(label="SKIP (Frame Sample Rate = SKIP+1)", value=1, precision=0)

                            with gr.Row():
                                btn_generate = gr.Button("GENERATE", variant="primary")

                btn_generate.click(fn=self.musepose_infer.infer_musepose,
                                   inputs=[img_musepose_input, vid_pose_input, weight_dtype, nb_width, nb_height,
                                           nb_video_frame_length, nb_video_slice_frame_length,
                                           nb_video_slice_overlap_frame_number, nb_cfg, nb_seed, nb_steps, nb_fps,
                                           nb_skip],
                                   outputs=[vid_output, vid_output_demo])
                vid_dance_output.change(fn=self.on_step1_complete,
                                        inputs=[img_pose_input, vid_dance_output],
                                        outputs=[img_musepose_input, vid_pose_input])

        return demo

    @staticmethod
    def header():
        header = gr.HTML(
            """
            <h1 style="font-size: 23px;">
                <a href="https://github.com/jhj0517/MusePose-WebUI" target="_blank">MusePose WebUI</a>
            </h1>

            <p style="font-size: 18px;">
                <strong>Note</strong>: This space only allows video input up to <strong>3 seconds</strong> because ZeroGPU limits the function runtime to 2 minutes. <br>
                If you want longer video inputs, you have to run it locally. Click the link above and follow the README to try it locally.
            </p>
            """
        )
        return header

    def launch(self):
        demo = self.musepose_demo()
        demo.queue().launch(
            share=self.args.share
        )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_dir', type=str, default=os.path.join("pretrained_weights"), help='Pretrained models directory for MusePose')
    parser.add_argument('--output_dir', type=str, default=os.path.join("outputs"), help='Output directory for the result')
    parser.add_argument('--disable_model_download_at_start', type=bool, default=False, nargs='?', const=True, help='Disable model download at start or not')
    parser.add_argument('--share', type=bool, default=False, nargs='?', const=True, help='Gradio makes sharable link if it is true')
    args = parser.parse_args()

    app = App(args=args)
    app.launch()