File size: 3,635 Bytes
9f4b9c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import argparse
import platform
import subprocess
import time
from pathlib import Path
from typing import Dict, Iterator, List, Literal, Optional, Union

import cv2
import numpy as np

from config import hparams as hp
from nota_wav2lip.inference import Wav2LipInferenceImpl
from nota_wav2lip.util import FFMPEG_LOGGING_MODE
from nota_wav2lip.video import AudioSlicer, VideoSlicer


class Wav2LipModelComparisonDemo:
    def __init__(self, device='cpu', result_dir='./temp', model_list: Optional[Union[str, List[str]]]=None):
        if model_list is None:
            model_list: List[str] = ['wav2lip', 'nota_wav2lip']
        if isinstance(model_list, str) and len(model_list) != 0:
            model_list: List[str] = [model_list]
        super().__init__()
        self.video_dict: Dict[str, VideoSlicer] = {}
        self.audio_dict: Dict[str, AudioSlicer] = {}

        self.model_zoo: Dict[str, Wav2LipInferenceImpl] = {}
        for model_name in model_list:
            assert model_name in hp.inference.model, f"{model_name} not in hp.inference_model: {hp.inference.model}"
            self.model_zoo[model_name] = Wav2LipInferenceImpl(
                model_name, hp_inference_model=hp.inference.model[model_name], device=device
            )

        self._params_zoo: Dict[str, str] = {
            model_name: self.model_zoo[model_name].params for model_name in self.model_zoo
        }

        self.result_dir: Path = Path(result_dir)
        self.result_dir.mkdir(exist_ok=True)

    @property
    def params(self):
        return self._params_zoo

    def _infer(
        self,
        audio_name: str,
        video_name: str,
        model_type: Literal['wav2lip', 'nota_wav2lip']
    ) -> Iterator[np.ndarray]:
        audio_iterable: AudioSlicer = self.audio_dict[audio_name]
        video_iterable: VideoSlicer = self.video_dict[video_name]
        target_model = self.model_zoo[model_type]
        return target_model.inference_with_iterator(audio_iterable, video_iterable)

    def update_audio(self, audio_path, name=None):
        _name = name if name is not None else Path(audio_path).stem
        self.audio_dict.update(
            {_name: AudioSlicer(audio_path)}
        )

    def update_video(self, frame_dir_path, bbox_path, name=None):
        _name = name if name is not None else Path(frame_dir_path).stem
        self.video_dict.update(
            {_name: VideoSlicer(frame_dir_path, bbox_path)}
        )

    def save_as_video(self, audio_name, video_name, model_type):

        output_video_path = self.result_dir / 'generated_with_audio.mp4'
        frame_only_video_path = self.result_dir / 'generated.mp4'
        audio_path = self.audio_dict[audio_name].audio_path

        out = cv2.VideoWriter(str(frame_only_video_path),
                              cv2.VideoWriter_fourcc(*'mp4v'),
                              hp.face.video_fps,
                              (hp.inference.frame.w, hp.inference.frame.h))
        start = time.time()
        for frame in self._infer(audio_name=audio_name, video_name=video_name, model_type=model_type):
            out.write(frame)
        inference_time = time.time() - start
        out.release()

        command = f"ffmpeg {FFMPEG_LOGGING_MODE['ERROR']} -y -i {audio_path} -i {frame_only_video_path} -strict -2 -q:v 1 {output_video_path}"
        subprocess.call(command, shell=platform.system() != 'Windows')

        # The number of frames of generated video
        video_frames_num = len(self.audio_dict[audio_name])
        inference_fps = video_frames_num / inference_time

        return output_video_path, inference_time, inference_fps