Spaces:
Runtime error
Runtime error
from typing import Iterable, Iterator, List, Tuple | |
import cv2 | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from omegaconf import DictConfig | |
from tqdm import tqdm | |
from config import hparams as hp | |
from nota_wav2lip.models.util import count_params, load_model | |
class Wav2LipInferenceImpl: | |
def __init__(self, model_name: str, hp_inference_model: DictConfig, device='cpu'): | |
self.model: nn.Module = load_model( | |
model_name, | |
device=device, | |
**hp_inference_model | |
) | |
self.device = device | |
self._params: str = self._format_param(count_params(self.model)) | |
def params(self): | |
return self._params | |
def _format_param(num_params: int) -> str: | |
params_in_million = num_params / 1e6 | |
return f"{params_in_million:.1f}M" | |
def _reset_batch() -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray], List[List[int]]]: | |
return [], [], [], [] | |
def get_data_iterator( | |
self, | |
audio_iterable: Iterable[np.ndarray], | |
video_iterable: List[Tuple[np.ndarray, List[int]]] | |
) -> Iterator[Tuple[np.ndarray, np.ndarray, np.ndarray, List[int]]]: | |
img_batch, mel_batch, frame_batch, coords_batch = self._reset_batch() | |
for i, m in enumerate(audio_iterable): | |
idx = i % len(video_iterable) | |
_frame_to_save, coords = video_iterable[idx] | |
frame_to_save = _frame_to_save.copy() | |
face = frame_to_save[coords[0]:coords[1], coords[2]:coords[3]].copy() | |
face: np.ndarray = cv2.resize(face, (hp.face.img_size, hp.face.img_size)) | |
img_batch.append(face) | |
mel_batch.append(m) | |
frame_batch.append(frame_to_save) | |
coords_batch.append(coords) | |
if len(img_batch) >= hp.inference.batch_size: | |
img_batch = np.asarray(img_batch) | |
mel_batch = np.asarray(mel_batch) | |
img_masked = img_batch.copy() | |
img_masked[:, hp.face.img_size // 2:] = 0 | |
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. | |
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) | |
yield img_batch, mel_batch, frame_batch, coords_batch | |
img_batch, mel_batch, frame_batch, coords_batch = self._reset_batch() | |
if len(img_batch) > 0: | |
img_batch = np.asarray(img_batch) | |
mel_batch = np.asarray(mel_batch) | |
img_masked = img_batch.copy() | |
img_masked[:, hp.face.img_size // 2:] = 0 | |
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. | |
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) | |
yield img_batch, mel_batch, frame_batch, coords_batch | |
def inference_with_iterator( | |
self, | |
audio_iterable: Iterable[np.ndarray], | |
video_iterable: List[Tuple[np.ndarray, List[int]]] | |
) -> Iterator[np.ndarray]: | |
data_iterator = self.get_data_iterator(audio_iterable, video_iterable) | |
for (img_batch, mel_batch, frames, coords) in \ | |
tqdm(data_iterator, total=int(np.ceil(float(len(audio_iterable)) / hp.inference.batch_size))): | |
img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(self.device) | |
mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(self.device) | |
preds: torch.Tensor = self.forward(mel_batch, img_batch) | |
preds = preds.cpu().numpy().transpose(0, 2, 3, 1) * 255. | |
for pred, frame, coord in zip(preds, frames, coords): | |
y1, y2, x1, x2 = coord | |
pred = cv2.resize(pred.astype(np.uint8), (x2 - x1, y2 - y1)) | |
frame[y1:y2, x1:x2] = pred | |
yield frame | |
def forward(self, audio_sequences: torch.Tensor, face_sequences: torch.Tensor) -> torch.Tensor: | |
return self.model(audio_sequences, face_sequences) | |
def __call__(self, *args, **kwargs): | |
return self.forward(*args, **kwargs) | |