|
import threading |
|
from pathlib import Path |
|
|
|
from nota_wav2lip.demo import Wav2LipModelComparisonDemo |
|
|
|
|
|
class Wav2LipModelComparisonGradio(Wav2LipModelComparisonDemo): |
|
def __init__( |
|
self, |
|
device='cpu', |
|
result_dir='./temp', |
|
video_label_dict=None, |
|
audio_label_list=None, |
|
default_video='v1', |
|
default_audio='a1' |
|
) -> None: |
|
if audio_label_list is None: |
|
audio_label_list = {} |
|
if video_label_dict is None: |
|
video_label_dict = {} |
|
super().__init__(device, result_dir) |
|
self._video_label_dict = {k: Path(v).with_suffix('.mp4') for k, v in video_label_dict.items()} |
|
self._audio_label_dict = audio_label_list |
|
self._default_video = default_video |
|
self._default_audio = default_audio |
|
|
|
self._lock = threading.Lock() |
|
|
|
def _is_valid_input(self, video_selection, audio_selection): |
|
assert video_selection in self._video_label_dict, \ |
|
f"Your input ({video_selection}) is not in {self._video_label_dict}!!!" |
|
assert audio_selection in self._audio_label_dict, \ |
|
f"Your input ({audio_selection}) is not in {self._audio_label_dict}!!!" |
|
|
|
def generate_original_model(self, video_selection, audio_selection): |
|
try: |
|
self._is_valid_input(video_selection, audio_selection) |
|
|
|
with self._lock: |
|
output_video_path, inference_time, inference_fps = \ |
|
self.save_as_video(audio_name=audio_selection, |
|
video_name=video_selection, |
|
model_type='wav2lip') |
|
|
|
return str(output_video_path), format(inference_time, ".2f"), format(inference_fps, ".1f") |
|
except KeyboardInterrupt: |
|
exit() |
|
except Exception as e: |
|
print(e) |
|
pass |
|
|
|
def generate_compressed_model(self, video_selection, audio_selection): |
|
try: |
|
self._is_valid_input(video_selection, audio_selection) |
|
|
|
with self._lock: |
|
output_video_path, inference_time, inference_fps = \ |
|
self.save_as_video(audio_name=audio_selection, |
|
video_name=video_selection, |
|
model_type='nota_wav2lip') |
|
|
|
return str(output_video_path), format(inference_time, ".2f"), format(inference_fps, ".1f") |
|
except KeyboardInterrupt: |
|
exit() |
|
except Exception as e: |
|
print(e) |
|
pass |
|
|
|
def switch_video_samples(self, video_selection): |
|
try: |
|
if video_selection not in self._video_label_dict: |
|
return self._video_label_dict[self._default_video] |
|
return self._video_label_dict[video_selection] |
|
|
|
except KeyboardInterrupt: |
|
exit() |
|
except Exception as e: |
|
print(e) |
|
pass |
|
|
|
def switch_audio_samples(self, audio_selection): |
|
try: |
|
if audio_selection not in self._audio_label_dict: |
|
return self._audio_label_dict[self._default_audio] |
|
return self._audio_label_dict[audio_selection] |
|
|
|
except KeyboardInterrupt: |
|
exit() |
|
except Exception as e: |
|
print(e) |
|
pass |
|
|