|
import os |
|
import subprocess |
|
from pathlib import Path |
|
|
|
import gradio as gr |
|
|
|
from config import hparams as hp |
|
from config import hparams_gradio as hp_gradio |
|
from nota_wav2lip import Wav2LipModelComparisonGradio |
|
|
|
|
|
device = hp_gradio.device |
|
print(f'Using {device} for inference.') |
|
video_label_dict = hp_gradio.sample.video |
|
audio_label_dict = hp_gradio.sample.audio |
|
|
|
LRS_ORIGINAL_URL = os.getenv('LRS_ORIGINAL_URL', None) |
|
LRS_COMPRESSED_URL = os.getenv('LRS_COMPRESSED_URL', None) |
|
LRS_INFERENCE_SAMPLE = os.getenv('LRS_INFERENCE_SAMPLE', None) |
|
|
|
if not Path(hp.inference.model.wav2lip.checkpoint).exists() and LRS_ORIGINAL_URL is not None: |
|
subprocess.call(f"wget --no-check-certificate -O {hp.inference.model.wav2lip.checkpoint} {LRS_ORIGINAL_URL}", shell=True) |
|
if not Path(hp.inference.model.nota_wav2lip.checkpoint).exists() and LRS_COMPRESSED_URL is not None: |
|
subprocess.call(f"wget --no-check-certificate -O {hp.inference.model.nota_wav2lip.checkpoint} {LRS_COMPRESSED_URL}", shell=True) |
|
|
|
path_inference_sample = "sample.tar.gz" |
|
if not Path(path_inference_sample).exists() and LRS_INFERENCE_SAMPLE is not None: |
|
subprocess.call(f"wget --no-check-certificate -O {path_inference_sample} {LRS_INFERENCE_SAMPLE}", shell=True) |
|
subprocess.call(f"tar -zxvf {path_inference_sample}", shell=True) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
servicer = Wav2LipModelComparisonGradio( |
|
device=device, |
|
video_label_dict=video_label_dict, |
|
audio_label_list=audio_label_dict, |
|
default_video='v1', |
|
default_audio='a1' |
|
) |
|
|
|
for video_name in sorted(video_label_dict): |
|
video_stem = Path(video_label_dict[video_name]) |
|
servicer.update_video(video_stem, video_stem.with_suffix('.json'), |
|
name=video_name) |
|
|
|
for audio_name in sorted(audio_label_dict): |
|
audio_path = Path(audio_label_dict[audio_name]) |
|
servicer.update_audio(audio_path, name=audio_name) |
|
|
|
with gr.Blocks(theme='nota-ai/theme', css=Path('docs/main.css').read_text()) as demo: |
|
gr.Markdown(Path('docs/header.md').read_text()) |
|
gr.Markdown(Path('docs/description.md').read_text()) |
|
with gr.Row(): |
|
with gr.Column(variant='panel'): |
|
|
|
gr.Markdown('## Select input video and audio', sanitize_html=False) |
|
|
|
sample_video = gr.Video(interactive=False, label="Input Video") |
|
sample_audio = gr.Audio(interactive=False, label="Input Audio") |
|
|
|
|
|
video_selection = gr.components.Radio(video_label_dict, |
|
type='value', label="Select an input video:") |
|
audio_selection = gr.components.Radio(audio_label_dict, |
|
type='value', label="Select an input audio:") |
|
|
|
with gr.Row(equal_height=True): |
|
generate_original_button = gr.Button(value="Generate with Original Model", variant="primary") |
|
generate_compressed_button = gr.Button(value="Generate with Compressed Model", variant="primary") |
|
with gr.Column(variant='panel'): |
|
|
|
gr.Markdown('## Original Wav2Lip') |
|
original_model_output = gr.Video(label="Original Model", interactive=False) |
|
with gr.Column(): |
|
with gr.Row(equal_height=True): |
|
original_model_inference_time = gr.Textbox(value="", label="Total inference time (sec)") |
|
original_model_fps = gr.Textbox(value="", label="FPS") |
|
original_model_params = gr.Textbox(value=servicer.params['wav2lip'], label="# Parameters") |
|
with gr.Column(variant='panel'): |
|
|
|
gr.Markdown('## Compressed Wav2Lip (Ours)') |
|
compressed_model_output = gr.Video(label="Compressed Model", interactive=False) |
|
with gr.Column(): |
|
with gr.Row(equal_height=True): |
|
compressed_model_inference_time = gr.Textbox(value="", label="Total inference time (sec)") |
|
compressed_model_fps = gr.Textbox(value="", label="FPS") |
|
compressed_model_params = gr.Textbox(value=servicer.params['nota_wav2lip'], label="# Parameters") |
|
|
|
|
|
video_selection.change(fn=servicer.switch_video_samples, inputs=video_selection, outputs=sample_video) |
|
audio_selection.change(fn=servicer.switch_audio_samples, inputs=audio_selection, outputs=sample_audio) |
|
|
|
|
|
generate_original_button.click(servicer.generate_original_model, |
|
inputs=[video_selection, audio_selection], |
|
outputs=[original_model_output, original_model_inference_time, original_model_fps]) |
|
|
|
generate_compressed_button.click(servicer.generate_compressed_model, |
|
inputs=[video_selection, audio_selection], |
|
outputs=[compressed_model_output, compressed_model_inference_time, compressed_model_fps]) |
|
|
|
gr.Markdown(Path('docs/footer.md').read_text()) |
|
|
|
demo.queue().launch() |
|
|