Spaces:
Running
on
L40S
Running
on
L40S
import gradio as gr | |
import subprocess | |
import os | |
import gdown | |
import glob | |
# Ensure 'checkpoint' directory exists | |
os.makedirs("checkpoint", exist_ok=True) | |
# Function to download the model weights from a Google Drive folder | |
def download_weights_from_folder(google_drive_folder_link): | |
# Extract the folder ID from the Google Drive link | |
folder_id = google_drive_folder_link.split('/')[-1] | |
output_folder = "checkpoint" # Directory to save the downloaded files | |
# Download all files in the Google Drive folder | |
gdown_url = f"https://drive.google.com/drive/folders/{folder_id}" | |
try: | |
gdown.download_folder(gdown_url, quiet=False, output=output_folder) | |
# Check if the model file exists and rename if necessary | |
downloaded_model_path = os.path.join(output_folder, "model_state-415001.th") | |
if os.path.exists(downloaded_model_path): | |
return f"Downloaded model weights to {downloaded_model_path}" | |
else: | |
return "Model file 'model_state-415001.th' not found in the folder." | |
except Exception as e: | |
return f"Failed to download weights: {e}" | |
download_weights_from_folder("https://drive.google.com/drive/folders/1Bq0n-w1VT5l99CoaVg02hFpqE5eGLo9O") | |
# Define a function to run your script with selected inputs | |
def run_xportrait( | |
model_config, | |
output_dir, | |
resume_dir, | |
seed, | |
uc_scale, | |
source_image, | |
driving_video, | |
best_frame, | |
out_frames, | |
num_mix, | |
ddim_steps | |
): | |
# Construct the command | |
command = [ | |
"python3", "core/test_xportrait.py", | |
"--model_config", model_config, | |
"--output_dir", output_dir, | |
"--resume_dir", resume_dir, | |
"--seed", str(seed), | |
"--uc_scale", str(uc_scale), | |
"--source_image", source_image, | |
"--driving_video", driving_video, | |
"--best_frame", str(best_frame), | |
"--out_frames", str(out_frames), | |
"--num_mix", str(num_mix), | |
"--ddim_steps", str(ddim_steps) | |
] | |
# Run the command | |
try: | |
subprocess.run(command, check=True) | |
# Find the generated video file in the output directory | |
video_files = glob.glob(os.path.join(output_dir, "*.mp4")) + glob.glob(os.path.join(output_dir, "*.avi")) | |
print(video_files) | |
if video_files: | |
return f"Output video saved at: {video_files[0]}", video_files[0] | |
else: | |
return "No video file was found in the output directory.", None | |
except subprocess.CalledProcessError as e: | |
return f"An error occurred: {e}", None | |
# Set up Gradio interface | |
app = gr.Interface( | |
fn=run_xportrait, | |
inputs=[ | |
gr.Textbox(value="config/cldm_v15_appearance_pose_local_mm.yaml", label="Model Config Path"), | |
gr.Textbox(value="outputs", label="Output Directory"), | |
gr.Textbox(value="checkpoint/model_state-415001.th", label="Resume Directory"), | |
gr.Number(value=999, label="Seed"), | |
gr.Number(value=5, label="UC Scale"), | |
gr.Image(label="Source Image", type="filepath"), | |
gr.Video(label="Driving Video"), | |
gr.Number(value=36, label="Best Frame"), | |
gr.Number(value=-1, label="Out Frames"), | |
gr.Number(value=4, label="Number of Mix"), | |
gr.Number(value=30, label="DDIM Steps") | |
], | |
outputs=["text", "video"], | |
title="XPortrait Model Runner", | |
description="Run XPortrait with customizable parameters." | |
) | |
# Launch the Gradio app | |
app.launch() |