|
import argparse |
|
import platform |
|
import subprocess |
|
import time |
|
from pathlib import Path |
|
from typing import Dict, Iterator, List, Literal, Optional, Union |
|
|
|
import cv2 |
|
import numpy as np |
|
|
|
from config import hparams as hp |
|
from nota_wav2lip.inference import Wav2LipInferenceImpl |
|
from nota_wav2lip.util import FFMPEG_LOGGING_MODE |
|
from nota_wav2lip.video import AudioSlicer, VideoSlicer |
|
|
|
|
|
class Wav2LipModelComparisonDemo: |
|
def __init__(self, device='cpu', result_dir='./temp', model_list: Optional[Union[str, List[str]]]=None): |
|
if model_list is None: |
|
model_list: List[str] = ['wav2lip', 'nota_wav2lip'] |
|
if isinstance(model_list, str) and len(model_list) != 0: |
|
model_list: List[str] = [model_list] |
|
super().__init__() |
|
self.video_dict: Dict[str, VideoSlicer] = {} |
|
self.audio_dict: Dict[str, AudioSlicer] = {} |
|
|
|
self.model_zoo: Dict[str, Wav2LipInferenceImpl] = {} |
|
for model_name in model_list: |
|
assert model_name in hp.inference.model, f"{model_name} not in hp.inference_model: {hp.inference.model}" |
|
self.model_zoo[model_name] = Wav2LipInferenceImpl( |
|
model_name, hp_inference_model=hp.inference.model[model_name], device=device |
|
) |
|
|
|
self._params_zoo: Dict[str, str] = { |
|
model_name: self.model_zoo[model_name].params for model_name in self.model_zoo |
|
} |
|
|
|
self.result_dir: Path = Path(result_dir) |
|
self.result_dir.mkdir(exist_ok=True) |
|
|
|
@property |
|
def params(self): |
|
return self._params_zoo |
|
|
|
def _infer( |
|
self, |
|
audio_name: str, |
|
video_name: str, |
|
model_type: Literal['wav2lip', 'nota_wav2lip'] |
|
) -> Iterator[np.ndarray]: |
|
audio_iterable: AudioSlicer = self.audio_dict[audio_name] |
|
video_iterable: VideoSlicer = self.video_dict[video_name] |
|
target_model = self.model_zoo[model_type] |
|
return target_model.inference_with_iterator(audio_iterable, video_iterable) |
|
|
|
def update_audio(self, audio_path, name=None): |
|
_name = name if name is not None else Path(audio_path).stem |
|
self.audio_dict.update( |
|
{_name: AudioSlicer(audio_path)} |
|
) |
|
|
|
def update_video(self, frame_dir_path, bbox_path, name=None): |
|
_name = name if name is not None else Path(frame_dir_path).stem |
|
self.video_dict.update( |
|
{_name: VideoSlicer(frame_dir_path, bbox_path)} |
|
) |
|
|
|
def save_as_video(self, audio_name, video_name, model_type): |
|
|
|
output_video_path = self.result_dir / 'generated_with_audio.mp4' |
|
frame_only_video_path = self.result_dir / 'generated.mp4' |
|
audio_path = self.audio_dict[audio_name].audio_path |
|
|
|
out = cv2.VideoWriter(str(frame_only_video_path), |
|
cv2.VideoWriter_fourcc(*'mp4v'), |
|
hp.face.video_fps, |
|
(hp.inference.frame.w, hp.inference.frame.h)) |
|
start = time.time() |
|
for frame in self._infer(audio_name=audio_name, video_name=video_name, model_type=model_type): |
|
out.write(frame) |
|
inference_time = time.time() - start |
|
out.release() |
|
|
|
command = f"ffmpeg {FFMPEG_LOGGING_MODE['ERROR']} -y -i {audio_path} -i {frame_only_video_path} -strict -2 -q:v 1 {output_video_path}" |
|
subprocess.call(command, shell=platform.system() != 'Windows') |
|
|
|
|
|
video_frames_num = len(self.audio_dict[audio_name]) |
|
inference_fps = video_frames_num / inference_time |
|
|
|
return output_video_path, inference_time, inference_fps |
|
|