X-Portrait / app.py
fffiloni's picture
Update app.py
47baa10 verified
raw
history blame
3.49 kB
import gradio as gr
import subprocess
import os
import gdown
import glob
# Ensure 'checkpoint' directory exists
os.makedirs("checkpoint", exist_ok=True)
# Function to download the model weights from a Google Drive folder
def download_weights_from_folder(google_drive_folder_link):
# Extract the folder ID from the Google Drive link
folder_id = google_drive_folder_link.split('/')[-1]
output_folder = "checkpoint" # Directory to save the downloaded files
# Download all files in the Google Drive folder
gdown_url = f"https://drive.google.com/drive/folders/{folder_id}"
try:
gdown.download_folder(gdown_url, quiet=False, output=output_folder)
# Check if the model file exists and rename if necessary
downloaded_model_path = os.path.join(output_folder, "model_state-415001.th")
if os.path.exists(downloaded_model_path):
return f"Downloaded model weights to {downloaded_model_path}"
else:
return "Model file 'model_state-415001.th' not found in the folder."
except Exception as e:
return f"Failed to download weights: {e}"
download_weights_from_folder("https://drive.google.com/drive/folders/1Bq0n-w1VT5l99CoaVg02hFpqE5eGLo9O")
# Define a function to run your script with selected inputs
def run_xportrait(
model_config,
output_dir,
resume_dir,
seed,
uc_scale,
source_image,
driving_video,
best_frame,
out_frames,
num_mix,
ddim_steps
):
# Construct the command
command = [
"python3", "core/test_xportrait.py",
"--model_config", model_config,
"--output_dir", output_dir,
"--resume_dir", resume_dir,
"--seed", str(seed),
"--uc_scale", str(uc_scale),
"--source_image", source_image,
"--driving_video", driving_video,
"--best_frame", str(best_frame),
"--out_frames", str(out_frames),
"--num_mix", str(num_mix),
"--ddim_steps", str(ddim_steps)
]
# Run the command
try:
subprocess.run(command, check=True)
# Find the generated video file in the output directory
video_files = glob.glob(os.path.join(output_dir, "*.mp4")) + glob.glob(os.path.join(output_dir, "*.avi"))
print(video_files)
if video_files:
return f"Output video saved at: {video_files[0]}", video_files[0]
else:
return "No video file was found in the output directory.", None
except subprocess.CalledProcessError as e:
return f"An error occurred: {e}", None
# Set up Gradio interface
app = gr.Interface(
fn=run_xportrait,
inputs=[
gr.Textbox(value="config/cldm_v15_appearance_pose_local_mm.yaml", label="Model Config Path"),
gr.Textbox(value="outputs", label="Output Directory"),
gr.Textbox(value="checkpoint/model_state-415001.th", label="Resume Directory"),
gr.Number(value=999, label="Seed"),
gr.Number(value=5, label="UC Scale"),
gr.Image(label="Source Image", type="filepath"),
gr.Video(label="Driving Video"),
gr.Number(value=36, label="Best Frame"),
gr.Number(value=-1, label="Out Frames"),
gr.Number(value=4, label="Number of Mix"),
gr.Number(value=30, label="DDIM Steps")
],
outputs=["text", "video"],
title="XPortrait Model Runner",
description="Run XPortrait with customizable parameters."
)
# Launch the Gradio app
app.launch()