File size: 3,489 Bytes
2945355
 
 
 
0aa3e03
2945355
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23fa30e
2945355
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0aa3e03
 
 
 
 
 
 
 
2945355
0aa3e03
2945355
 
 
 
 
 
 
 
 
 
3cc9c1d
47baa10
2945355
 
 
 
 
0aa3e03
2945355
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import gradio as gr
import subprocess
import os
import gdown
import glob

# Ensure 'checkpoint' directory exists
os.makedirs("checkpoint", exist_ok=True)

# Function to download the model weights from a Google Drive folder
def download_weights_from_folder(google_drive_folder_link):
    # Extract the folder ID from the Google Drive link
    folder_id = google_drive_folder_link.split('/')[-1]
    output_folder = "checkpoint"  # Directory to save the downloaded files

    # Download all files in the Google Drive folder
    gdown_url = f"https://drive.google.com/drive/folders/{folder_id}"
    try:
        gdown.download_folder(gdown_url, quiet=False, output=output_folder)
        
        # Check if the model file exists and rename if necessary
        downloaded_model_path = os.path.join(output_folder, "model_state-415001.th")
        if os.path.exists(downloaded_model_path):
            return f"Downloaded model weights to {downloaded_model_path}"
        else:
            return "Model file 'model_state-415001.th' not found in the folder."
    except Exception as e:
        return f"Failed to download weights: {e}"

download_weights_from_folder("https://drive.google.com/drive/folders/1Bq0n-w1VT5l99CoaVg02hFpqE5eGLo9O")

# Define a function to run your script with selected inputs
def run_xportrait(
    model_config, 
    output_dir, 
    resume_dir, 
    seed, 
    uc_scale, 
    source_image, 
    driving_video, 
    best_frame, 
    out_frames, 
    num_mix, 
    ddim_steps
):
    # Construct the command
    command = [
        "python3", "core/test_xportrait.py",
        "--model_config", model_config,
        "--output_dir", output_dir,
        "--resume_dir", resume_dir,
        "--seed", str(seed),
        "--uc_scale", str(uc_scale),
        "--source_image", source_image,
        "--driving_video", driving_video,
        "--best_frame", str(best_frame),
        "--out_frames", str(out_frames),
        "--num_mix", str(num_mix),
        "--ddim_steps", str(ddim_steps)
    ]
    
    # Run the command
    try:
        subprocess.run(command, check=True)
        
        # Find the generated video file in the output directory
        video_files = glob.glob(os.path.join(output_dir, "*.mp4")) + glob.glob(os.path.join(output_dir, "*.avi"))
        print(video_files)
        if video_files:
            return f"Output video saved at: {video_files[0]}", video_files[0]
        else:
            return "No video file was found in the output directory.", None
    except subprocess.CalledProcessError as e:
        return f"An error occurred: {e}", None

# Set up Gradio interface
app = gr.Interface(
    fn=run_xportrait,
    inputs=[
        gr.Textbox(value="config/cldm_v15_appearance_pose_local_mm.yaml", label="Model Config Path"),
        gr.Textbox(value="outputs", label="Output Directory"),
        gr.Textbox(value="checkpoint/model_state-415001.th", label="Resume Directory"),
        gr.Number(value=999, label="Seed"),
        gr.Number(value=5, label="UC Scale"),
        gr.Image(label="Source Image", type="filepath"),
        gr.Video(label="Driving Video"),
        gr.Number(value=36, label="Best Frame"),
        gr.Number(value=-1, label="Out Frames"),
        gr.Number(value=4, label="Number of Mix"),
        gr.Number(value=30, label="DDIM Steps")
    ],
    outputs=["text", "video"],
    title="XPortrait Model Runner",
    description="Run XPortrait with customizable parameters."
)

# Launch the Gradio app
app.launch()