diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000000000000000000000000000000000000..9ba89b39ce4f630b8d10cffd2d594109ac9c3ac5 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,23 @@ +# Contributing to this repository + +## Install linter + +First of all, you need to install `ruff` package to verify that you passed all conditions for formatting. + +``` +pip install ruff==0.0.287 +``` + +### Apply linter before PR + +Please run the ruff check with the following command: + +``` +ruff check . +``` + +### Auto-fix with fixable errors + +``` +ruff check . --fix +``` \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..4fc7fddb2d6fd08342e4f129a3ca6a5ebe78b408 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,9 @@ +FROM nvcr.io/nvidia/pytorch:22.03-py3 + +ARG DEBIAN_FRONTEND=noninteractive +RUN apt-get update +RUN apt-get install ffmpeg libsm6 libxext6 tmux git -y + +WORKDIR /workspace +COPY requirements.txt . +RUN pip install --no-cache -r requirements.txt \ No newline at end of file diff --git a/README.md b/README.md index 9ac8c1de6932f338e179ae1f1cd196e6f92649cd..3952366c0ae0becb1f233a6bbb74d674f4f45025 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,84 @@ --- -title: LipSync -emoji: 🌍 +title: Compressed Wav2Lip +emoji: 🌟 colorFrom: indigo -colorTo: indigo +colorTo: pink sdk: gradio -sdk_version: 4.27.0 +sdk_version: 4.13.0 app_file: app.py -pinned: false -license: unknown +pinned: true +license: apache-2.0 --- -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +# 28× Compressed Wav2Lip by Nota AI + +Official codebase for [**Accelerating Speech-Driven Talking Face Generation with 28× Compressed Wav2Lip**](https://arxiv.org/abs/2304.00471). + +- Presented at [ICCV'23 Demo](https://iccv2023.thecvf.com/demos-111.php) Track; [On-Device Intelligence Workshop](https://sites.google.com/g.harvard.edu/on-device-workshop-23/home) @ MLSys'23; [NVIDIA GTC 2023](https://www.nvidia.com/en-us/on-demand/search/?facet.mimetype[]=event%20session&layout=list&page=1&q=52409&sort=relevance&sortDir=desc) Poster. + + +## Installation +#### Docker (recommended) +```bash +git clone https://github.com/Nota-NetsPresso/nota-wav2lip.git +cd nota-wav2lip +docker compose run --service-ports --name nota-compressed-wav2lip compressed-wav2lip bash +``` + +#### Conda +
+Click + +```bash +git clone https://github.com/Nota-NetsPresso/nota-wav2lip.git +cd nota-wav2lip +apt-get update +apt-get install ffmpeg libsm6 libxext6 tmux git -y +conda create -n nota-wav2lip python=3.9 +conda activate nota-wav2lip +pip install -r requirements.txt +``` +
+ +## Gradio Demo +Use the below script to run the [nota-ai/compressed-wav2lip demo](https://huggingface.co/spaces/nota-ai/compressed-wav2lip). The models and sample data will be downloaded automatically. + + ```bash + bash app.sh + ``` + +## Inference +(1) Download YouTube videos in the LRS3-TED label text file and preprocess them properly. + - Download `lrs3_v0.4_txt.zip` from [this link](https://mmai.io/datasets/lip_reading/). + - Unzip the file and make a folder structure: `./data/lrs3_v0.4_txt/lrs3_v0.4/test` + - Run `bash download.sh` + - Run `bash preprocess.sh` + +(2) Run the script to compare the original Wav2Lip with Nota's compressed version. + + ```bash + bash inference.sh + ``` + +## License +- All rights related to this repository and the compressed models are reserved by Nota Inc. +- The intended use is strictly limited to research and non-commercial projects. + +## Contact +- To obtain compression code and assistance, kindly contact Nota AI (contact@nota.ai). These are provided as part of our business solutions. +- For Q&A about this repo, use this board: [Nota-NetsPresso/discussions](https://github.com/orgs/Nota-NetsPresso/discussions) + +## Acknowledgment + - [NVIDIA Applied Research Accelerator Program](https://www.nvidia.com/en-us/industries/higher-education-research/applied-research-program/) for supporting this research. + - [Wav2Lip](https://github.com/Rudrabha/Wav2Lip) and [LRS3-TED](https://www.robots.ox.ac.uk/~vgg/data/lip_reading/) for facilitating the development of the original Wav2Lip. + +## Citation +```bibtex +@article{kim2023unified, + title={A Unified Compression Framework for Efficient Speech-Driven Talking-Face Generation}, + author={Kim, Bo-Kyeong and Kang, Jaemin and Seo, Daeun and Park, Hancheol and Choi, Shinkook and Song, Hyoung-Kyu and Kim, Hyungshin and Lim, Sungsu}, + journal={MLSys Workshop on On-Device Intelligence (ODIW)}, + year={2023}, + url={https://arxiv.org/abs/2304.00471} +} +``` \ No newline at end of file diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..517f54e8518df6e78f1a7162d853188222f0df2d --- /dev/null +++ b/app.py @@ -0,0 +1,105 @@ +import os +import subprocess +from pathlib import Path + +import gradio as gr + +from config import hparams as hp +from config import hparams_gradio as hp_gradio +from nota_wav2lip import Wav2LipModelComparisonGradio + +# device = 'cuda' if torch.cuda.is_available() else 'cpu' +device = hp_gradio.device +print(f'Using {device} for inference.') +video_label_dict = hp_gradio.sample.video +audio_label_dict = hp_gradio.sample.audio + +LRS_ORIGINAL_URL = os.getenv('LRS_ORIGINAL_URL', None) +LRS_COMPRESSED_URL = os.getenv('LRS_COMPRESSED_URL', None) +LRS_INFERENCE_SAMPLE = os.getenv('LRS_INFERENCE_SAMPLE', None) + +if not Path(hp.inference.model.wav2lip.checkpoint).exists() and LRS_ORIGINAL_URL is not None: + subprocess.call(f"wget --no-check-certificate -O {hp.inference.model.wav2lip.checkpoint} {LRS_ORIGINAL_URL}", shell=True) +if not Path(hp.inference.model.nota_wav2lip.checkpoint).exists() and LRS_COMPRESSED_URL is not None: + subprocess.call(f"wget --no-check-certificate -O {hp.inference.model.nota_wav2lip.checkpoint} {LRS_COMPRESSED_URL}", shell=True) + +path_inference_sample = "sample.tar.gz" +if not Path(path_inference_sample).exists() and LRS_INFERENCE_SAMPLE is not None: + subprocess.call(f"wget --no-check-certificate -O {path_inference_sample} {LRS_INFERENCE_SAMPLE}", shell=True) +subprocess.call(f"tar -zxvf {path_inference_sample}", shell=True) + + +if __name__ == "__main__": + + servicer = Wav2LipModelComparisonGradio( + device=device, + video_label_dict=video_label_dict, + audio_label_list=audio_label_dict, + default_video='v1', + default_audio='a1' + ) + + for video_name in sorted(video_label_dict): + video_stem = Path(video_label_dict[video_name]) + servicer.update_video(video_stem, video_stem.with_suffix('.json'), + name=video_name) + + for audio_name in sorted(audio_label_dict): + audio_path = Path(audio_label_dict[audio_name]) + servicer.update_audio(audio_path, name=audio_name) + + with gr.Blocks(theme='nota-ai/theme', css=Path('docs/main.css').read_text()) as demo: + gr.Markdown(Path('docs/header.md').read_text()) + gr.Markdown(Path('docs/description.md').read_text()) + with gr.Row(): + with gr.Column(variant='panel'): + + gr.Markdown('## Select input video and audio', sanitize_html=False) + # Define samples + sample_video = gr.Video(interactive=False, label="Input Video") + sample_audio = gr.Audio(interactive=False, label="Input Audio") + + # Define radio inputs + video_selection = gr.components.Radio(video_label_dict, + type='value', label="Select an input video:") + audio_selection = gr.components.Radio(audio_label_dict, + type='value', label="Select an input audio:") + # Define button inputs + with gr.Row(equal_height=True): + generate_original_button = gr.Button(value="Generate with Original Model", variant="primary") + generate_compressed_button = gr.Button(value="Generate with Compressed Model", variant="primary") + with gr.Column(variant='panel'): + # Define original model output components + gr.Markdown('## Original Wav2Lip') + original_model_output = gr.Video(label="Original Model", interactive=False) + with gr.Column(): + with gr.Row(equal_height=True): + original_model_inference_time = gr.Textbox(value="", label="Total inference time (sec)") + original_model_fps = gr.Textbox(value="", label="FPS") + original_model_params = gr.Textbox(value=servicer.params['wav2lip'], label="# Parameters") + with gr.Column(variant='panel'): + # Define compressed model output components + gr.Markdown('## Compressed Wav2Lip (Ours)') + compressed_model_output = gr.Video(label="Compressed Model", interactive=False) + with gr.Column(): + with gr.Row(equal_height=True): + compressed_model_inference_time = gr.Textbox(value="", label="Total inference time (sec)") + compressed_model_fps = gr.Textbox(value="", label="FPS") + compressed_model_params = gr.Textbox(value=servicer.params['nota_wav2lip'], label="# Parameters") + + # Switch video and audio samples when selecting the raido button + video_selection.change(fn=servicer.switch_video_samples, inputs=video_selection, outputs=sample_video) + audio_selection.change(fn=servicer.switch_audio_samples, inputs=audio_selection, outputs=sample_audio) + + # Click the generate button for original model + generate_original_button.click(servicer.generate_original_model, + inputs=[video_selection, audio_selection], + outputs=[original_model_output, original_model_inference_time, original_model_fps]) + # Click the generate button for compressed model + generate_compressed_button.click(servicer.generate_compressed_model, + inputs=[video_selection, audio_selection], + outputs=[compressed_model_output, compressed_model_inference_time, compressed_model_fps]) + + gr.Markdown(Path('docs/footer.md').read_text()) + + demo.queue().launch() diff --git a/app.sh b/app.sh new file mode 100644 index 0000000000000000000000000000000000000000..be817bbf18a2eca401312ed145de2342d458151a --- /dev/null +++ b/app.sh @@ -0,0 +1,4 @@ +export LRS_ORIGINAL_URL=https://netspresso-huggingface-demo-checkpoint.s3.us-east-2.amazonaws.com/compressed-wav2lip/lrs3-wav2lip.pth && \ +export LRS_COMPRESSED_URL=https://netspresso-huggingface-demo-checkpoint.s3.us-east-2.amazonaws.com/compressed-wav2lip/lrs3-nota-wav2lip.pth && \ +export LRS_INFERENCE_SAMPLE=https://netspresso-huggingface-demo-checkpoint.s3.us-east-2.amazonaws.com/data/compressed-wav2lip-inference/sample.tar.gz && \ +python app.py \ No newline at end of file diff --git a/checkpoints/.gitkeep b/checkpoints/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/config/__init__.py b/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..250149afd8458af76172dd0d4c95511bd3b9d332 --- /dev/null +++ b/config/__init__.py @@ -0,0 +1,5 @@ +from omegaconf import DictConfig, OmegaConf + +hparams: DictConfig = OmegaConf.load("config/nota_wav2lip.yaml") + +hparams_gradio: DictConfig = OmegaConf.load("config/gradio.yaml") diff --git a/config/gradio.yaml b/config/gradio.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a0e59eb0148abd172ab603ef29b96e384a048b03 --- /dev/null +++ b/config/gradio.yaml @@ -0,0 +1,14 @@ +device: cpu +sample: + video: + v1: "sample/2145_orig" + v2: "sample/2942_orig" + v3: "sample/4598_orig" + v4: "sample/4653_orig" + v5: "sample/13692_orig" + audio: + a1: "sample/1673_orig.wav" + a2: "sample/9948_orig.wav" + a3: "sample/11028_orig.wav" + a4: "sample/12640_orig.wav" + a5: "sample/5592_orig.wav" diff --git a/config/nota_wav2lip.yaml b/config/nota_wav2lip.yaml new file mode 100644 index 0000000000000000000000000000000000000000..00ef05f6a798b81d90baf58cb20db1b22f55a597 --- /dev/null +++ b/config/nota_wav2lip.yaml @@ -0,0 +1,44 @@ + +inference: + batch_size: 1 + frame: + h: 224 + w: 224 + model: + wav2lip: + checkpoint: "checkpoints/lrs3-wav2lip.pth" + nota_wav2lip: + checkpoint: "checkpoints/lrs3-nota-wav2lip.pth" + +audio: + num_mels: 80 + rescale: True + rescaling_max: 0.9 + + use_lws: False + + n_fft: 800 # Extra window size is filled with 0 paddings to match this parameter + hop_size: 200 # For 16000Hz, 200 : 12.5 ms (0.0125 * sample_rate) + win_size: 800 # For 16000Hz, 800 : 50 ms (If None, win_size : n_fft) (0.05 * sample_rate) + sample_rate: 16000 # 16000Hz (corresponding to librispeech) (sox --i ) + + frame_shift_ms: ~ + + signal_normalization: True + allow_clipping_in_normalization: True + symmetric_mels: True + max_abs_value: 4. + preemphasize: True + preemphasis: 0.97 + + # Limits + min_level_db: -100 + ref_level_db: 20 + fmin: 55 + fmax: 7600 + +face: + video_fps: 25 + img_size: 96 + mel_step_size: 16 + diff --git a/data/.gitkeep b/data/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000000000000000000000000000000000000..c0820a0452ef1b6e8daa5ed86363b90bcc8f2ef5 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,11 @@ +version: "3.9" +services: + compressed-wav2lip: + image: nota-compressed-wav2lip:dev + build: ./ + container_name: nota-compressed-wav2lip + ipc: host + ports: + - "7860:7860" + volumes: + - ./:/workspace \ No newline at end of file diff --git a/docs/assets/fig5.png b/docs/assets/fig5.png new file mode 100644 index 0000000000000000000000000000000000000000..cfdd6f3dd80373532858e5f1b155b9ab780ed6da Binary files /dev/null and b/docs/assets/fig5.png differ diff --git a/docs/description.md b/docs/description.md new file mode 100644 index 0000000000000000000000000000000000000000..6bf786bc67ad726e14b762d911a84a5fecb2143b --- /dev/null +++ b/docs/description.md @@ -0,0 +1,22 @@ +This demo showcases a lightweight model for speech-driven talking-face synthesis, a **28× Compressed Wav2Lip**. The key features of our approach are: + - compact generator built by removing the residual blocks and reducing the channel width from Wav2Lip. + - knowledge distillation to effectively train the small-capacity generator without adversarial learning. + - selective quantization to accelerate inference on edge GPUs without noticeable performance degradation. + + +The below figure shows a latency comparison at different precisions on NVIDIA Jetson edge GPUs, highlighting a 8× to 17× speedup at FP16 and a 19× speedup on Xavier NX at mixed precision. + +
+ compressed-wav2lip-performance +
+ +
+ +The generation speed may vary depending on network traffic. Nevertheless, our compresed Wav2Lip _consistently_ delivers a faster inference than the original model, while maintaining similar visual quality. Different from the paper, in this demo, we measure **total processing time** and **FPS** throughout loading the preprocessed video and audio, generating with the model, and merging lip-synced facial images with the original video. + +
+ + +### Notice + - This work was accepted to [Demo] [**ICCV 2023 Demo Track**](https://iccv2023.thecvf.com/demos-111.php); [[Paper](https://arxiv.org/abs/2304.00471)] [**On-Device Intelligence Workshop (ODIW) @ MLSys 2023**](https://sites.google.com/g.harvard.edu/on-device-workshop-23/home); [Poster] [**NVIDIA GPU Technology Conference (GTC) as Poster Spotlight**](https://www.nvidia.com/en-us/on-demand/search/?facet.mimetype[]=event%20session&layout=list&page=1&q=52409&sort=relevance&sortDir=desc). + - We thank [NVIDIA Applied Research Accelerator Program](https://www.nvidia.com/en-us/industries/higher-education-research/applied-research-program/) for supporting this research and [Wav2Lip's Authors](https://github.com/Rudrabha/Wav2Lip) for their pioneering research. \ No newline at end of file diff --git a/docs/footer.md b/docs/footer.md new file mode 100644 index 0000000000000000000000000000000000000000..8fc884f2346cfbba76d1a517e5b7137c445df2bb --- /dev/null +++ b/docs/footer.md @@ -0,0 +1,5 @@ +

+ +

+ +
\ No newline at end of file diff --git a/docs/header.md b/docs/header.md new file mode 100644 index 0000000000000000000000000000000000000000..8abf23e808e0fd3911f1330e93cde1df165b0ebd --- /dev/null +++ b/docs/header.md @@ -0,0 +1,10 @@ +#
Lightweight Speech-Driven Talking-Face Synthesis Demo
+ +
+ +

+ + +

+ +
\ No newline at end of file diff --git a/docs/main.css b/docs/main.css new file mode 100644 index 0000000000000000000000000000000000000000..906dd4b279f9e3333648487aa78c86481ad10f0e --- /dev/null +++ b/docs/main.css @@ -0,0 +1,4 @@ +h1, h2, h3 { + text-align: center; + display:block; +} \ No newline at end of file diff --git a/download.py b/download.py new file mode 100644 index 0000000000000000000000000000000000000000..536638dbaaa50c353b911977b43ef556c0ac6114 --- /dev/null +++ b/download.py @@ -0,0 +1,44 @@ +import argparse + +from nota_wav2lip.preprocess import get_cropped_face_from_lrs3_label + + +def parse_args(): + + parser = argparse.ArgumentParser(description="NotaWav2Lip: Get LRS3 video sample with the label text file") + + parser.add_argument( + '-i', + '--input-file', + type=str, + required=True, + help="Path of the label text file downloaded from https://mmai.io/datasets/lip_reading" + ) + + parser.add_argument( + '-o', + '--output-dir', + type=str, + default="sample_video_lrs3", + help="Output directory to save the result. Defaults: sample_video_lrs3" + ) + + parser.add_argument( + '--ignore-cache', + action='store_true', + help="Whether to force downloading and resampling video and overwrite pre-existing files" + ) + + args = parser.parse_args() + + return args + + +if __name__ == '__main__': + args = parse_args() + + get_cropped_face_from_lrs3_label( + args.input_file, + video_root_dir=args.output_dir, + ignore_cache = args.ignore_cache + ) diff --git a/download.sh b/download.sh new file mode 100644 index 0000000000000000000000000000000000000000..46713655386dbe183d7eb81306a1067c7ecd02c1 --- /dev/null +++ b/download.sh @@ -0,0 +1,7 @@ +# example for audio source +python download.py\ + -i data/lrs3_v0.4_txt/lrs3_v0.4/test/sxnlvwprfSc/00007.txt + +# example for video source +python download.py\ + -i data/lrs3_v0.4_txt/lrs3_v0.4/test/Li4S1yyrsTI/00010.txt \ No newline at end of file diff --git a/face_detection/README.md b/face_detection/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c073376e4eeda6d4b29cc31c50cb7e88ab42bb73 --- /dev/null +++ b/face_detection/README.md @@ -0,0 +1 @@ +The code for Face Detection in this folder has been taken from the wonderful [face_alignment](https://github.com/1adrianb/face-alignment) repository. This has been modified to take batches of faces at a time. \ No newline at end of file diff --git a/face_detection/__init__.py b/face_detection/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4bae29fd5f85b41e4669302bd2603bc6924eddc7 --- /dev/null +++ b/face_detection/__init__.py @@ -0,0 +1,7 @@ +# -*- coding: utf-8 -*- + +__author__ = """Adrian Bulat""" +__email__ = 'adrian.bulat@nottingham.ac.uk' +__version__ = '1.0.1' + +from .api import FaceAlignment, LandmarksType, NetworkSize diff --git a/face_detection/api.py b/face_detection/api.py new file mode 100644 index 0000000000000000000000000000000000000000..cb02d5252db5362b9985687a992e128a522e5b63 --- /dev/null +++ b/face_detection/api.py @@ -0,0 +1,79 @@ +from __future__ import print_function +import os +import torch +from torch.utils.model_zoo import load_url +from enum import Enum +import numpy as np +import cv2 +try: + import urllib.request as request_file +except BaseException: + import urllib as request_file + +from .models import FAN, ResNetDepth +from .utils import * + + +class LandmarksType(Enum): + """Enum class defining the type of landmarks to detect. + + ``_2D`` - the detected points ``(x,y)`` are detected in a 2D space and follow the visible contour of the face + ``_2halfD`` - this points represent the projection of the 3D points into 3D + ``_3D`` - detect the points ``(x,y,z)``` in a 3D space + + """ + _2D = 1 + _2halfD = 2 + _3D = 3 + + +class NetworkSize(Enum): + # TINY = 1 + # SMALL = 2 + # MEDIUM = 3 + LARGE = 4 + + def __new__(cls, value): + member = object.__new__(cls) + member._value_ = value + return member + + def __int__(self): + return self.value + +ROOT = os.path.dirname(os.path.abspath(__file__)) + +class FaceAlignment: + def __init__(self, landmarks_type, network_size=NetworkSize.LARGE, + device='cuda', flip_input=False, face_detector='sfd', verbose=False): + self.device = device + self.flip_input = flip_input + self.landmarks_type = landmarks_type + self.verbose = verbose + + network_size = int(network_size) + + if 'cuda' in device: + torch.backends.cudnn.benchmark = True + + # Get the face detector + face_detector_module = __import__('face_detection.detection.' + face_detector, + globals(), locals(), [face_detector], 0) + self.face_detector = face_detector_module.FaceDetector(device=device, verbose=verbose) + + def get_detections_for_batch(self, images): + images = images[..., ::-1] + detected_faces = self.face_detector.detect_from_batch(images.copy()) + results = [] + + for i, d in enumerate(detected_faces): + if len(d) == 0: + results.append(None) + continue + d = d[0] + d = np.clip(d, 0, None) + + x1, y1, x2, y2 = map(int, d[:-1]) + results.append((x1, y1, x2, y2)) + + return results \ No newline at end of file diff --git a/face_detection/detection/__init__.py b/face_detection/detection/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1a6b0402dae864a3cc5dc2a90a412fd842a0efc7 --- /dev/null +++ b/face_detection/detection/__init__.py @@ -0,0 +1 @@ +from .core import FaceDetector \ No newline at end of file diff --git a/face_detection/detection/core.py b/face_detection/detection/core.py new file mode 100644 index 0000000000000000000000000000000000000000..0f8275e8e53143f66298f75f0517c234a68778cd --- /dev/null +++ b/face_detection/detection/core.py @@ -0,0 +1,130 @@ +import logging +import glob +from tqdm import tqdm +import numpy as np +import torch +import cv2 + + +class FaceDetector(object): + """An abstract class representing a face detector. + + Any other face detection implementation must subclass it. All subclasses + must implement ``detect_from_image``, that return a list of detected + bounding boxes. Optionally, for speed considerations detect from path is + recommended. + """ + + def __init__(self, device, verbose): + self.device = device + self.verbose = verbose + + if verbose: + if 'cpu' in device: + logger = logging.getLogger(__name__) + logger.warning("Detection running on CPU, this may be potentially slow.") + + if 'cpu' not in device and 'cuda' not in device: + if verbose: + logger.error("Expected values for device are: {cpu, cuda} but got: %s", device) + raise ValueError + + def detect_from_image(self, tensor_or_path): + """Detects faces in a given image. + + This function detects the faces present in a provided BGR(usually) + image. The input can be either the image itself or the path to it. + + Arguments: + tensor_or_path {numpy.ndarray, torch.tensor or string} -- the path + to an image or the image itself. + + Example:: + + >>> path_to_image = 'data/image_01.jpg' + ... detected_faces = detect_from_image(path_to_image) + [A list of bounding boxes (x1, y1, x2, y2)] + >>> image = cv2.imread(path_to_image) + ... detected_faces = detect_from_image(image) + [A list of bounding boxes (x1, y1, x2, y2)] + + """ + raise NotImplementedError + + def detect_from_directory(self, path, extensions=['.jpg', '.png'], recursive=False, show_progress_bar=True): + """Detects faces from all the images present in a given directory. + + Arguments: + path {string} -- a string containing a path that points to the folder containing the images + + Keyword Arguments: + extensions {list} -- list of string containing the extensions to be + consider in the following format: ``.extension_name`` (default: + {['.jpg', '.png']}) recursive {bool} -- option wherever to scan the + folder recursively (default: {False}) show_progress_bar {bool} -- + display a progressbar (default: {True}) + + Example: + >>> directory = 'data' + ... detected_faces = detect_from_directory(directory) + {A dictionary of [lists containing bounding boxes(x1, y1, x2, y2)]} + + """ + if self.verbose: + logger = logging.getLogger(__name__) + + if len(extensions) == 0: + if self.verbose: + logger.error("Expected at list one extension, but none was received.") + raise ValueError + + if self.verbose: + logger.info("Constructing the list of images.") + additional_pattern = '/**/*' if recursive else '/*' + files = [] + for extension in extensions: + files.extend(glob.glob(path + additional_pattern + extension, recursive=recursive)) + + if self.verbose: + logger.info("Finished searching for images. %s images found", len(files)) + logger.info("Preparing to run the detection.") + + predictions = {} + for image_path in tqdm(files, disable=not show_progress_bar): + if self.verbose: + logger.info("Running the face detector on image: %s", image_path) + predictions[image_path] = self.detect_from_image(image_path) + + if self.verbose: + logger.info("The detector was successfully run on all %s images", len(files)) + + return predictions + + @property + def reference_scale(self): + raise NotImplementedError + + @property + def reference_x_shift(self): + raise NotImplementedError + + @property + def reference_y_shift(self): + raise NotImplementedError + + @staticmethod + def tensor_or_path_to_ndarray(tensor_or_path, rgb=True): + """Convert path (represented as a string) or torch.tensor to a numpy.ndarray + + Arguments: + tensor_or_path {numpy.ndarray, torch.tensor or string} -- path to the image, or the image itself + """ + if isinstance(tensor_or_path, str): + return cv2.imread(tensor_or_path) if not rgb else cv2.imread(tensor_or_path)[..., ::-1] + elif torch.is_tensor(tensor_or_path): + # Call cpu in case its coming from cuda + return tensor_or_path.cpu().numpy()[..., ::-1].copy() if not rgb else tensor_or_path.cpu().numpy() + elif isinstance(tensor_or_path, np.ndarray): + return tensor_or_path[..., ::-1].copy() if not rgb else tensor_or_path + else: + raise TypeError diff --git a/face_detection/detection/sfd/__init__.py b/face_detection/detection/sfd/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5a63ecd45658f22e66c171ada751fb33764d4559 --- /dev/null +++ b/face_detection/detection/sfd/__init__.py @@ -0,0 +1 @@ +from .sfd_detector import SFDDetector as FaceDetector \ No newline at end of file diff --git a/face_detection/detection/sfd/bbox.py b/face_detection/detection/sfd/bbox.py new file mode 100644 index 0000000000000000000000000000000000000000..4bd7222e5e5f78a51944cbeed3cccbacddc46bed --- /dev/null +++ b/face_detection/detection/sfd/bbox.py @@ -0,0 +1,129 @@ +from __future__ import print_function +import os +import sys +import cv2 +import random +import datetime +import time +import math +import argparse +import numpy as np +import torch + +try: + from iou import IOU +except BaseException: + # IOU cython speedup 10x + def IOU(ax1, ay1, ax2, ay2, bx1, by1, bx2, by2): + sa = abs((ax2 - ax1) * (ay2 - ay1)) + sb = abs((bx2 - bx1) * (by2 - by1)) + x1, y1 = max(ax1, bx1), max(ay1, by1) + x2, y2 = min(ax2, bx2), min(ay2, by2) + w = x2 - x1 + h = y2 - y1 + if w < 0 or h < 0: + return 0.0 + else: + return 1.0 * w * h / (sa + sb - w * h) + + +def bboxlog(x1, y1, x2, y2, axc, ayc, aww, ahh): + xc, yc, ww, hh = (x2 + x1) / 2, (y2 + y1) / 2, x2 - x1, y2 - y1 + dx, dy = (xc - axc) / aww, (yc - ayc) / ahh + dw, dh = math.log(ww / aww), math.log(hh / ahh) + return dx, dy, dw, dh + + +def bboxloginv(dx, dy, dw, dh, axc, ayc, aww, ahh): + xc, yc = dx * aww + axc, dy * ahh + ayc + ww, hh = math.exp(dw) * aww, math.exp(dh) * ahh + x1, x2, y1, y2 = xc - ww / 2, xc + ww / 2, yc - hh / 2, yc + hh / 2 + return x1, y1, x2, y2 + + +def nms(dets, thresh): + if 0 == len(dets): + return [] + x1, y1, x2, y2, scores = dets[:, 0], dets[:, 1], dets[:, 2], dets[:, 3], dets[:, 4] + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1, yy1 = np.maximum(x1[i], x1[order[1:]]), np.maximum(y1[i], y1[order[1:]]) + xx2, yy2 = np.minimum(x2[i], x2[order[1:]]), np.minimum(y2[i], y2[order[1:]]) + + w, h = np.maximum(0.0, xx2 - xx1 + 1), np.maximum(0.0, yy2 - yy1 + 1) + ovr = w * h / (areas[i] + areas[order[1:]] - w * h) + + inds = np.where(ovr <= thresh)[0] + order = order[inds + 1] + + return keep + + +def encode(matched, priors, variances): + """Encode the variances from the priorbox layers into the ground truth boxes + we have matched (based on jaccard overlap) with the prior boxes. + Args: + matched: (tensor) Coords of ground truth for each prior in point-form + Shape: [num_priors, 4]. + priors: (tensor) Prior boxes in center-offset form + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + encoded boxes (tensor), Shape: [num_priors, 4] + """ + + # dist b/t match center and prior's center + g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2] + # encode variance + g_cxcy /= (variances[0] * priors[:, 2:]) + # match wh / prior wh + g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:] + g_wh = torch.log(g_wh) / variances[1] + # return target for smooth_l1_loss + return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4] + + +def decode(loc, priors, variances): + """Decode locations from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + loc (tensor): location predictions for loc layers, + Shape: [num_priors,4] + priors (tensor): Prior boxes in center-offset form. + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded bounding box predictions + """ + + boxes = torch.cat(( + priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], + priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1) + boxes[:, :2] -= boxes[:, 2:] / 2 + boxes[:, 2:] += boxes[:, :2] + return boxes + +def batch_decode(loc, priors, variances): + """Decode locations from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + loc (tensor): location predictions for loc layers, + Shape: [num_priors,4] + priors (tensor): Prior boxes in center-offset form. + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded bounding box predictions + """ + + boxes = torch.cat(( + priors[:, :, :2] + loc[:, :, :2] * variances[0] * priors[:, :, 2:], + priors[:, :, 2:] * torch.exp(loc[:, :, 2:] * variances[1])), 2) + boxes[:, :, :2] -= boxes[:, :, 2:] / 2 + boxes[:, :, 2:] += boxes[:, :, :2] + return boxes diff --git a/face_detection/detection/sfd/detect.py b/face_detection/detection/sfd/detect.py new file mode 100644 index 0000000000000000000000000000000000000000..efef6273adf317bc17f3dd0f02423c0701ca218e --- /dev/null +++ b/face_detection/detection/sfd/detect.py @@ -0,0 +1,112 @@ +import torch +import torch.nn.functional as F + +import os +import sys +import cv2 +import random +import datetime +import math +import argparse +import numpy as np + +import scipy.io as sio +import zipfile +from .net_s3fd import s3fd +from .bbox import * + + +def detect(net, img, device): + img = img - np.array([104, 117, 123]) + img = img.transpose(2, 0, 1) + img = img.reshape((1,) + img.shape) + + if 'cuda' in device: + torch.backends.cudnn.benchmark = True + + img = torch.from_numpy(img).float().to(device) + BB, CC, HH, WW = img.size() + with torch.no_grad(): + olist = net(img) + + bboxlist = [] + for i in range(len(olist) // 2): + olist[i * 2] = F.softmax(olist[i * 2], dim=1) + olist = [oelem.data.cpu() for oelem in olist] + for i in range(len(olist) // 2): + ocls, oreg = olist[i * 2], olist[i * 2 + 1] + FB, FC, FH, FW = ocls.size() # feature map size + stride = 2**(i + 2) # 4,8,16,32,64,128 + anchor = stride * 4 + poss = zip(*np.where(ocls[:, 1, :, :] > 0.05)) + for Iindex, hindex, windex in poss: + axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride + score = ocls[0, 1, hindex, windex] + loc = oreg[0, :, hindex, windex].contiguous().view(1, 4) + priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]]) + variances = [0.1, 0.2] + box = decode(loc, priors, variances) + x1, y1, x2, y2 = box[0] * 1.0 + # cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1) + bboxlist.append([x1, y1, x2, y2, score]) + bboxlist = np.array(bboxlist) + if 0 == len(bboxlist): + bboxlist = np.zeros((1, 5)) + + return bboxlist + +def batch_detect(net, imgs, device): + imgs = imgs - np.array([104, 117, 123]) + imgs = imgs.transpose(0, 3, 1, 2) + + if 'cuda' in device: + torch.backends.cudnn.benchmark = True + + imgs = torch.from_numpy(imgs).float().to(device) + BB, CC, HH, WW = imgs.size() + with torch.no_grad(): + olist = net(imgs) + + bboxlist = [] + for i in range(len(olist) // 2): + olist[i * 2] = F.softmax(olist[i * 2], dim=1) + olist = [oelem.data.cpu() for oelem in olist] + for i in range(len(olist) // 2): + ocls, oreg = olist[i * 2], olist[i * 2 + 1] + FB, FC, FH, FW = ocls.size() # feature map size + stride = 2**(i + 2) # 4,8,16,32,64,128 + anchor = stride * 4 + poss = zip(*np.where(ocls[:, 1, :, :] > 0.05)) + for Iindex, hindex, windex in poss: + axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride + score = ocls[:, 1, hindex, windex] + loc = oreg[:, :, hindex, windex].contiguous().view(BB, 1, 4) + priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]]).view(1, 1, 4) + variances = [0.1, 0.2] + box = batch_decode(loc, priors, variances) + box = box[:, 0] * 1.0 + # cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1) + bboxlist.append(torch.cat([box, score.unsqueeze(1)], 1).cpu().numpy()) + bboxlist = np.array(bboxlist) + if 0 == len(bboxlist): + bboxlist = np.zeros((1, BB, 5)) + + return bboxlist + +def flip_detect(net, img, device): + img = cv2.flip(img, 1) + b = detect(net, img, device) + + bboxlist = np.zeros(b.shape) + bboxlist[:, 0] = img.shape[1] - b[:, 2] + bboxlist[:, 1] = b[:, 1] + bboxlist[:, 2] = img.shape[1] - b[:, 0] + bboxlist[:, 3] = b[:, 3] + bboxlist[:, 4] = b[:, 4] + return bboxlist + + +def pts_to_bb(pts): + min_x, min_y = np.min(pts, axis=0) + max_x, max_y = np.max(pts, axis=0) + return np.array([min_x, min_y, max_x, max_y]) diff --git a/face_detection/detection/sfd/net_s3fd.py b/face_detection/detection/sfd/net_s3fd.py new file mode 100644 index 0000000000000000000000000000000000000000..fc64313c277ab594d0257585c70f147606693452 --- /dev/null +++ b/face_detection/detection/sfd/net_s3fd.py @@ -0,0 +1,129 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class L2Norm(nn.Module): + def __init__(self, n_channels, scale=1.0): + super(L2Norm, self).__init__() + self.n_channels = n_channels + self.scale = scale + self.eps = 1e-10 + self.weight = nn.Parameter(torch.Tensor(self.n_channels)) + self.weight.data *= 0.0 + self.weight.data += self.scale + + def forward(self, x): + norm = x.pow(2).sum(dim=1, keepdim=True).sqrt() + self.eps + x = x / norm * self.weight.view(1, -1, 1, 1) + return x + + +class s3fd(nn.Module): + def __init__(self): + super(s3fd, self).__init__() + self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) + self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) + + self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) + self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) + + self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1) + self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) + self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) + + self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1) + self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + + self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + + self.fc6 = nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=3) + self.fc7 = nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0) + + self.conv6_1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0) + self.conv6_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1) + + self.conv7_1 = nn.Conv2d(512, 128, kernel_size=1, stride=1, padding=0) + self.conv7_2 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1) + + self.conv3_3_norm = L2Norm(256, scale=10) + self.conv4_3_norm = L2Norm(512, scale=8) + self.conv5_3_norm = L2Norm(512, scale=5) + + self.conv3_3_norm_mbox_conf = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1) + self.conv3_3_norm_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1) + self.conv4_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1) + self.conv4_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1) + self.conv5_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1) + self.conv5_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1) + + self.fc7_mbox_conf = nn.Conv2d(1024, 2, kernel_size=3, stride=1, padding=1) + self.fc7_mbox_loc = nn.Conv2d(1024, 4, kernel_size=3, stride=1, padding=1) + self.conv6_2_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1) + self.conv6_2_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1) + self.conv7_2_mbox_conf = nn.Conv2d(256, 2, kernel_size=3, stride=1, padding=1) + self.conv7_2_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + h = F.relu(self.conv1_1(x)) + h = F.relu(self.conv1_2(h)) + h = F.max_pool2d(h, 2, 2) + + h = F.relu(self.conv2_1(h)) + h = F.relu(self.conv2_2(h)) + h = F.max_pool2d(h, 2, 2) + + h = F.relu(self.conv3_1(h)) + h = F.relu(self.conv3_2(h)) + h = F.relu(self.conv3_3(h)) + f3_3 = h + h = F.max_pool2d(h, 2, 2) + + h = F.relu(self.conv4_1(h)) + h = F.relu(self.conv4_2(h)) + h = F.relu(self.conv4_3(h)) + f4_3 = h + h = F.max_pool2d(h, 2, 2) + + h = F.relu(self.conv5_1(h)) + h = F.relu(self.conv5_2(h)) + h = F.relu(self.conv5_3(h)) + f5_3 = h + h = F.max_pool2d(h, 2, 2) + + h = F.relu(self.fc6(h)) + h = F.relu(self.fc7(h)) + ffc7 = h + h = F.relu(self.conv6_1(h)) + h = F.relu(self.conv6_2(h)) + f6_2 = h + h = F.relu(self.conv7_1(h)) + h = F.relu(self.conv7_2(h)) + f7_2 = h + + f3_3 = self.conv3_3_norm(f3_3) + f4_3 = self.conv4_3_norm(f4_3) + f5_3 = self.conv5_3_norm(f5_3) + + cls1 = self.conv3_3_norm_mbox_conf(f3_3) + reg1 = self.conv3_3_norm_mbox_loc(f3_3) + cls2 = self.conv4_3_norm_mbox_conf(f4_3) + reg2 = self.conv4_3_norm_mbox_loc(f4_3) + cls3 = self.conv5_3_norm_mbox_conf(f5_3) + reg3 = self.conv5_3_norm_mbox_loc(f5_3) + cls4 = self.fc7_mbox_conf(ffc7) + reg4 = self.fc7_mbox_loc(ffc7) + cls5 = self.conv6_2_mbox_conf(f6_2) + reg5 = self.conv6_2_mbox_loc(f6_2) + cls6 = self.conv7_2_mbox_conf(f7_2) + reg6 = self.conv7_2_mbox_loc(f7_2) + + # max-out background label + chunk = torch.chunk(cls1, 4, 1) + bmax = torch.max(torch.max(chunk[0], chunk[1]), chunk[2]) + cls1 = torch.cat([bmax, chunk[3]], dim=1) + + return [cls1, reg1, cls2, reg2, cls3, reg3, cls4, reg4, cls5, reg5, cls6, reg6] diff --git a/face_detection/detection/sfd/sfd_detector.py b/face_detection/detection/sfd/sfd_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..8fbce15253251d403754ab4348f93ae85a6ba2fb --- /dev/null +++ b/face_detection/detection/sfd/sfd_detector.py @@ -0,0 +1,59 @@ +import os +import cv2 +from torch.utils.model_zoo import load_url + +from ..core import FaceDetector + +from .net_s3fd import s3fd +from .bbox import * +from .detect import * + +models_urls = { + 's3fd': 'https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth', +} + + +class SFDDetector(FaceDetector): + def __init__(self, device, path_to_detector=os.path.join(os.path.dirname(os.path.abspath(__file__)), 's3fd.pth'), verbose=False): + super(SFDDetector, self).__init__(device, verbose) + + # Initialise the face detector + if not os.path.isfile(path_to_detector): + model_weights = load_url(models_urls['s3fd']) + else: + model_weights = torch.load(path_to_detector) + + self.face_detector = s3fd() + self.face_detector.load_state_dict(model_weights) + self.face_detector.to(device) + self.face_detector.eval() + + def detect_from_image(self, tensor_or_path): + image = self.tensor_or_path_to_ndarray(tensor_or_path) + + bboxlist = detect(self.face_detector, image, device=self.device) + keep = nms(bboxlist, 0.3) + bboxlist = bboxlist[keep, :] + bboxlist = [x for x in bboxlist if x[-1] > 0.5] + + return bboxlist + + def detect_from_batch(self, images): + bboxlists = batch_detect(self.face_detector, images, device=self.device) + keeps = [nms(bboxlists[:, i, :], 0.3) for i in range(bboxlists.shape[1])] + bboxlists = [bboxlists[keep, i, :] for i, keep in enumerate(keeps)] + bboxlists = [[x for x in bboxlist if x[-1] > 0.5] for bboxlist in bboxlists] + + return bboxlists + + @property + def reference_scale(self): + return 195 + + @property + def reference_x_shift(self): + return 0 + + @property + def reference_y_shift(self): + return 0 diff --git a/face_detection/models.py b/face_detection/models.py new file mode 100644 index 0000000000000000000000000000000000000000..ee2dde32bdf72c25a4600e48efa73ffc0d4a3893 --- /dev/null +++ b/face_detection/models.py @@ -0,0 +1,261 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + + +def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False): + "3x3 convolution with padding" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, + stride=strd, padding=padding, bias=bias) + + +class ConvBlock(nn.Module): + def __init__(self, in_planes, out_planes): + super(ConvBlock, self).__init__() + self.bn1 = nn.BatchNorm2d(in_planes) + self.conv1 = conv3x3(in_planes, int(out_planes / 2)) + self.bn2 = nn.BatchNorm2d(int(out_planes / 2)) + self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4)) + self.bn3 = nn.BatchNorm2d(int(out_planes / 4)) + self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4)) + + if in_planes != out_planes: + self.downsample = nn.Sequential( + nn.BatchNorm2d(in_planes), + nn.ReLU(True), + nn.Conv2d(in_planes, out_planes, + kernel_size=1, stride=1, bias=False), + ) + else: + self.downsample = None + + def forward(self, x): + residual = x + + out1 = self.bn1(x) + out1 = F.relu(out1, True) + out1 = self.conv1(out1) + + out2 = self.bn2(out1) + out2 = F.relu(out2, True) + out2 = self.conv2(out2) + + out3 = self.bn3(out2) + out3 = F.relu(out3, True) + out3 = self.conv3(out3) + + out3 = torch.cat((out1, out2, out3), 1) + + if self.downsample is not None: + residual = self.downsample(residual) + + out3 += residual + + return out3 + + +class Bottleneck(nn.Module): + + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class HourGlass(nn.Module): + def __init__(self, num_modules, depth, num_features): + super(HourGlass, self).__init__() + self.num_modules = num_modules + self.depth = depth + self.features = num_features + + self._generate_network(self.depth) + + def _generate_network(self, level): + self.add_module('b1_' + str(level), ConvBlock(self.features, self.features)) + + self.add_module('b2_' + str(level), ConvBlock(self.features, self.features)) + + if level > 1: + self._generate_network(level - 1) + else: + self.add_module('b2_plus_' + str(level), ConvBlock(self.features, self.features)) + + self.add_module('b3_' + str(level), ConvBlock(self.features, self.features)) + + def _forward(self, level, inp): + # Upper branch + up1 = inp + up1 = self._modules['b1_' + str(level)](up1) + + # Lower branch + low1 = F.avg_pool2d(inp, 2, stride=2) + low1 = self._modules['b2_' + str(level)](low1) + + if level > 1: + low2 = self._forward(level - 1, low1) + else: + low2 = low1 + low2 = self._modules['b2_plus_' + str(level)](low2) + + low3 = low2 + low3 = self._modules['b3_' + str(level)](low3) + + up2 = F.interpolate(low3, scale_factor=2, mode='nearest') + + return up1 + up2 + + def forward(self, x): + return self._forward(self.depth, x) + + +class FAN(nn.Module): + + def __init__(self, num_modules=1): + super(FAN, self).__init__() + self.num_modules = num_modules + + # Base part + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) + self.bn1 = nn.BatchNorm2d(64) + self.conv2 = ConvBlock(64, 128) + self.conv3 = ConvBlock(128, 128) + self.conv4 = ConvBlock(128, 256) + + # Stacking part + for hg_module in range(self.num_modules): + self.add_module('m' + str(hg_module), HourGlass(1, 4, 256)) + self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256)) + self.add_module('conv_last' + str(hg_module), + nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)) + self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256)) + self.add_module('l' + str(hg_module), nn.Conv2d(256, + 68, kernel_size=1, stride=1, padding=0)) + + if hg_module < self.num_modules - 1: + self.add_module( + 'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)) + self.add_module('al' + str(hg_module), nn.Conv2d(68, + 256, kernel_size=1, stride=1, padding=0)) + + def forward(self, x): + x = F.relu(self.bn1(self.conv1(x)), True) + x = F.avg_pool2d(self.conv2(x), 2, stride=2) + x = self.conv3(x) + x = self.conv4(x) + + previous = x + + outputs = [] + for i in range(self.num_modules): + hg = self._modules['m' + str(i)](previous) + + ll = hg + ll = self._modules['top_m_' + str(i)](ll) + + ll = F.relu(self._modules['bn_end' + str(i)] + (self._modules['conv_last' + str(i)](ll)), True) + + # Predict heatmaps + tmp_out = self._modules['l' + str(i)](ll) + outputs.append(tmp_out) + + if i < self.num_modules - 1: + ll = self._modules['bl' + str(i)](ll) + tmp_out_ = self._modules['al' + str(i)](tmp_out) + previous = previous + ll + tmp_out_ + + return outputs + + +class ResNetDepth(nn.Module): + + def __init__(self, block=Bottleneck, layers=[3, 8, 36, 3], num_classes=68): + self.inplanes = 64 + super(ResNetDepth, self).__init__() + self.conv1 = nn.Conv2d(3 + 68, 64, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.avgpool = nn.AvgPool2d(7) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + + return x diff --git a/face_detection/utils.py b/face_detection/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3dc4cf3e328efaa227cbcfdd969e1056688adad5 --- /dev/null +++ b/face_detection/utils.py @@ -0,0 +1,313 @@ +from __future__ import print_function +import os +import sys +import time +import torch +import math +import numpy as np +import cv2 + + +def _gaussian( + size=3, sigma=0.25, amplitude=1, normalize=False, width=None, + height=None, sigma_horz=None, sigma_vert=None, mean_horz=0.5, + mean_vert=0.5): + # handle some defaults + if width is None: + width = size + if height is None: + height = size + if sigma_horz is None: + sigma_horz = sigma + if sigma_vert is None: + sigma_vert = sigma + center_x = mean_horz * width + 0.5 + center_y = mean_vert * height + 0.5 + gauss = np.empty((height, width), dtype=np.float32) + # generate kernel + for i in range(height): + for j in range(width): + gauss[i][j] = amplitude * math.exp(-(math.pow((j + 1 - center_x) / ( + sigma_horz * width), 2) / 2.0 + math.pow((i + 1 - center_y) / (sigma_vert * height), 2) / 2.0)) + if normalize: + gauss = gauss / np.sum(gauss) + return gauss + + +def draw_gaussian(image, point, sigma): + # Check if the gaussian is inside + ul = [math.floor(point[0] - 3 * sigma), math.floor(point[1] - 3 * sigma)] + br = [math.floor(point[0] + 3 * sigma), math.floor(point[1] + 3 * sigma)] + if (ul[0] > image.shape[1] or ul[1] > image.shape[0] or br[0] < 1 or br[1] < 1): + return image + size = 6 * sigma + 1 + g = _gaussian(size) + g_x = [int(max(1, -ul[0])), int(min(br[0], image.shape[1])) - int(max(1, ul[0])) + int(max(1, -ul[0]))] + g_y = [int(max(1, -ul[1])), int(min(br[1], image.shape[0])) - int(max(1, ul[1])) + int(max(1, -ul[1]))] + img_x = [int(max(1, ul[0])), int(min(br[0], image.shape[1]))] + img_y = [int(max(1, ul[1])), int(min(br[1], image.shape[0]))] + assert (g_x[0] > 0 and g_y[1] > 0) + image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1] + ] = image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]] + g[g_y[0] - 1:g_y[1], g_x[0] - 1:g_x[1]] + image[image > 1] = 1 + return image + + +def transform(point, center, scale, resolution, invert=False): + """Generate and affine transformation matrix. + + Given a set of points, a center, a scale and a targer resolution, the + function generates and affine transformation matrix. If invert is ``True`` + it will produce the inverse transformation. + + Arguments: + point {torch.tensor} -- the input 2D point + center {torch.tensor or numpy.array} -- the center around which to perform the transformations + scale {float} -- the scale of the face/object + resolution {float} -- the output resolution + + Keyword Arguments: + invert {bool} -- define wherever the function should produce the direct or the + inverse transformation matrix (default: {False}) + """ + _pt = torch.ones(3) + _pt[0] = point[0] + _pt[1] = point[1] + + h = 200.0 * scale + t = torch.eye(3) + t[0, 0] = resolution / h + t[1, 1] = resolution / h + t[0, 2] = resolution * (-center[0] / h + 0.5) + t[1, 2] = resolution * (-center[1] / h + 0.5) + + if invert: + t = torch.inverse(t) + + new_point = (torch.matmul(t, _pt))[0:2] + + return new_point.int() + + +def crop(image, center, scale, resolution=256.0): + """Center crops an image or set of heatmaps + + Arguments: + image {numpy.array} -- an rgb image + center {numpy.array} -- the center of the object, usually the same as of the bounding box + scale {float} -- scale of the face + + Keyword Arguments: + resolution {float} -- the size of the output cropped image (default: {256.0}) + + Returns: + [type] -- [description] + """ # Crop around the center point + """ Crops the image around the center. Input is expected to be an np.ndarray """ + ul = transform([1, 1], center, scale, resolution, True) + br = transform([resolution, resolution], center, scale, resolution, True) + # pad = math.ceil(torch.norm((ul - br).float()) / 2.0 - (br[0] - ul[0]) / 2.0) + if image.ndim > 2: + newDim = np.array([br[1] - ul[1], br[0] - ul[0], + image.shape[2]], dtype=np.int32) + newImg = np.zeros(newDim, dtype=np.uint8) + else: + newDim = np.array([br[1] - ul[1], br[0] - ul[0]], dtype=np.int) + newImg = np.zeros(newDim, dtype=np.uint8) + ht = image.shape[0] + wd = image.shape[1] + newX = np.array( + [max(1, -ul[0] + 1), min(br[0], wd) - ul[0]], dtype=np.int32) + newY = np.array( + [max(1, -ul[1] + 1), min(br[1], ht) - ul[1]], dtype=np.int32) + oldX = np.array([max(1, ul[0] + 1), min(br[0], wd)], dtype=np.int32) + oldY = np.array([max(1, ul[1] + 1), min(br[1], ht)], dtype=np.int32) + newImg[newY[0] - 1:newY[1], newX[0] - 1:newX[1] + ] = image[oldY[0] - 1:oldY[1], oldX[0] - 1:oldX[1], :] + newImg = cv2.resize(newImg, dsize=(int(resolution), int(resolution)), + interpolation=cv2.INTER_LINEAR) + return newImg + + +def get_preds_fromhm(hm, center=None, scale=None): + """Obtain (x,y) coordinates given a set of N heatmaps. If the center + and the scale is provided the function will return the points also in + the original coordinate frame. + + Arguments: + hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H] + + Keyword Arguments: + center {torch.tensor} -- the center of the bounding box (default: {None}) + scale {float} -- face scale (default: {None}) + """ + max, idx = torch.max( + hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2) + idx += 1 + preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float() + preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1) + preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1) + + for i in range(preds.size(0)): + for j in range(preds.size(1)): + hm_ = hm[i, j, :] + pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1 + if pX > 0 and pX < 63 and pY > 0 and pY < 63: + diff = torch.FloatTensor( + [hm_[pY, pX + 1] - hm_[pY, pX - 1], + hm_[pY + 1, pX] - hm_[pY - 1, pX]]) + preds[i, j].add_(diff.sign_().mul_(.25)) + + preds.add_(-.5) + + preds_orig = torch.zeros(preds.size()) + if center is not None and scale is not None: + for i in range(hm.size(0)): + for j in range(hm.size(1)): + preds_orig[i, j] = transform( + preds[i, j], center, scale, hm.size(2), True) + + return preds, preds_orig + +def get_preds_fromhm_batch(hm, centers=None, scales=None): + """Obtain (x,y) coordinates given a set of N heatmaps. If the centers + and the scales is provided the function will return the points also in + the original coordinate frame. + + Arguments: + hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H] + + Keyword Arguments: + centers {torch.tensor} -- the centers of the bounding box (default: {None}) + scales {float} -- face scales (default: {None}) + """ + max, idx = torch.max( + hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2) + idx += 1 + preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float() + preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1) + preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1) + + for i in range(preds.size(0)): + for j in range(preds.size(1)): + hm_ = hm[i, j, :] + pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1 + if pX > 0 and pX < 63 and pY > 0 and pY < 63: + diff = torch.FloatTensor( + [hm_[pY, pX + 1] - hm_[pY, pX - 1], + hm_[pY + 1, pX] - hm_[pY - 1, pX]]) + preds[i, j].add_(diff.sign_().mul_(.25)) + + preds.add_(-.5) + + preds_orig = torch.zeros(preds.size()) + if centers is not None and scales is not None: + for i in range(hm.size(0)): + for j in range(hm.size(1)): + preds_orig[i, j] = transform( + preds[i, j], centers[i], scales[i], hm.size(2), True) + + return preds, preds_orig + +def shuffle_lr(parts, pairs=None): + """Shuffle the points left-right according to the axis of symmetry + of the object. + + Arguments: + parts {torch.tensor} -- a 3D or 4D object containing the + heatmaps. + + Keyword Arguments: + pairs {list of integers} -- [order of the flipped points] (default: {None}) + """ + if pairs is None: + pairs = [16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, + 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 27, 28, 29, 30, 35, + 34, 33, 32, 31, 45, 44, 43, 42, 47, 46, 39, 38, 37, 36, 41, + 40, 54, 53, 52, 51, 50, 49, 48, 59, 58, 57, 56, 55, 64, 63, + 62, 61, 60, 67, 66, 65] + if parts.ndimension() == 3: + parts = parts[pairs, ...] + else: + parts = parts[:, pairs, ...] + + return parts + + +def flip(tensor, is_label=False): + """Flip an image or a set of heatmaps left-right + + Arguments: + tensor {numpy.array or torch.tensor} -- [the input image or heatmaps] + + Keyword Arguments: + is_label {bool} -- [denote wherever the input is an image or a set of heatmaps ] (default: {False}) + """ + if not torch.is_tensor(tensor): + tensor = torch.from_numpy(tensor) + + if is_label: + tensor = shuffle_lr(tensor).flip(tensor.ndimension() - 1) + else: + tensor = tensor.flip(tensor.ndimension() - 1) + + return tensor + +# From pyzolib/paths.py (https://bitbucket.org/pyzo/pyzolib/src/tip/paths.py) + + +def appdata_dir(appname=None, roaming=False): + """ appdata_dir(appname=None, roaming=False) + + Get the path to the application directory, where applications are allowed + to write user specific files (e.g. configurations). For non-user specific + data, consider using common_appdata_dir(). + If appname is given, a subdir is appended (and created if necessary). + If roaming is True, will prefer a roaming directory (Windows Vista/7). + """ + + # Define default user directory + userDir = os.getenv('FACEALIGNMENT_USERDIR', None) + if userDir is None: + userDir = os.path.expanduser('~') + if not os.path.isdir(userDir): # pragma: no cover + userDir = '/var/tmp' # issue #54 + + # Get system app data dir + path = None + if sys.platform.startswith('win'): + path1, path2 = os.getenv('LOCALAPPDATA'), os.getenv('APPDATA') + path = (path2 or path1) if roaming else (path1 or path2) + elif sys.platform.startswith('darwin'): + path = os.path.join(userDir, 'Library', 'Application Support') + # On Linux and as fallback + if not (path and os.path.isdir(path)): + path = userDir + + # Maybe we should store things local to the executable (in case of a + # portable distro or a frozen application that wants to be portable) + prefix = sys.prefix + if getattr(sys, 'frozen', None): + prefix = os.path.abspath(os.path.dirname(sys.executable)) + for reldir in ('settings', '../settings'): + localpath = os.path.abspath(os.path.join(prefix, reldir)) + if os.path.isdir(localpath): # pragma: no cover + try: + open(os.path.join(localpath, 'test.write'), 'wb').close() + os.remove(os.path.join(localpath, 'test.write')) + except IOError: + pass # We cannot write in this directory + else: + path = localpath + break + + # Get path specific for this app + if appname: + if path == userDir: + appname = '.' + appname.lstrip('.') # Make it a hidden directory + path = os.path.join(path, appname) + if not os.path.isdir(path): # pragma: no cover + os.mkdir(path) + + # Done + return path diff --git a/inference.py b/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..d20423199428f5c0743ccca65eafd41c567d7020 --- /dev/null +++ b/inference.py @@ -0,0 +1,82 @@ +import argparse +import os +import subprocess +from pathlib import Path + +from config import hparams as hp +from nota_wav2lip import Wav2LipModelComparisonDemo + +LRS_ORIGINAL_URL = os.getenv('LRS_ORIGINAL_URL', None) +LRS_COMPRESSED_URL = os.getenv('LRS_COMPRESSED_URL', None) + +if not Path(hp.inference.model.wav2lip.checkpoint).exists() and LRS_ORIGINAL_URL is not None: + subprocess.call(f"wget --no-check-certificate -O {hp.inference.model.wav2lip.checkpoint} {LRS_ORIGINAL_URL}", shell=True) +if not Path(hp.inference.model.nota_wav2lip.checkpoint).exists() and LRS_COMPRESSED_URL is not None: + subprocess.call(f"wget --no-check-certificate -O {hp.inference.model.nota_wav2lip.checkpoint} {LRS_COMPRESSED_URL}", shell=True) + +def parse_args(): + + parser = argparse.ArgumentParser(description="NotaWav2Lip: Inference snippet for your own video and audio pair") + + parser.add_argument( + '-a', + '--audio-input', + type=str, + required=True, + help="Path of the audio file" + ) + + parser.add_argument( + '-v', + '--video-frame-input', + type=str, + required=True, + help="Input directory with face image sequence. We recommend to extract the face image sequence with `preprocess.py`." + ) + + parser.add_argument( + '-b', + '--bbox-input', + type=str, + help="Path of the file with bbox coordinates. We recommend to extract the json file with `preprocess.py`." + "If None, it pretends that the json file is located at the same directory with face images: {VIDEO_FRAME_INPUT}.with_suffix('.json')." + ) + + parser.add_argument( + '-m', + '--model', + choices=['wav2lip', 'nota_wav2lip'], + default='nota_wav2ilp', + help="Model for generating talking video. Defaults: nota_wav2lip" + ) + + parser.add_argument( + '-o', + '--output-dir', + type=str, + default="result", + help="Output directory to save the result. Defaults: result" + ) + + parser.add_argument( + '-d', + '--device', + choices=['cpu', 'cuda'], + default='cpu', + help="Device setting for model inference. Defaults: cpu" + ) + + args = parser.parse_args() + + return args + +if __name__ == "__main__": + args = parse_args() + bbox_input = args.bbox_input if args.bbox_input is not None \ + else Path(args.video_frame_input).with_suffix('.json') + + servicer = Wav2LipModelComparisonDemo(device=args.device, result_dir=args.output_dir, model_list=args.model) + servicer.update_audio(args.audio_input, name='a0') + servicer.update_video(args.video_frame_input, bbox_input, name='v0') + + servicer.save_as_video('a0', 'v0', args.model) diff --git a/inference.sh b/inference.sh new file mode 100644 index 0000000000000000000000000000000000000000..64cf0450278caa2f58d6a934ba20d3a44267a182 --- /dev/null +++ b/inference.sh @@ -0,0 +1,15 @@ +# Original Wav2Lip +python inference.py\ + -a "sample_video_lrs3/sxnlvwprf_c-00007.wav"\ + -v "sample_video_lrs3/Li4-1yyrsTI-00010"\ + -m "wav2lip"\ + -o "result_original"\ + --device cpu + +# Nota's Wav2Lip (28× Compressed) +python inference.py\ + -a "sample_video_lrs3/sxnlvwprf_c-00007.wav"\ + -v "sample_video_lrs3/Li4-1yyrsTI-00010"\ + -m "nota_wav2lip"\ + -o "result_nota"\ + --device cpu \ No newline at end of file diff --git a/nota_wav2lip/__init__.py b/nota_wav2lip/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8c2ef88257236c98ccca46dcde9546b61f6e69f9 --- /dev/null +++ b/nota_wav2lip/__init__.py @@ -0,0 +1,2 @@ +from nota_wav2lip.demo import Wav2LipModelComparisonDemo +from nota_wav2lip.gradio import Wav2LipModelComparisonGradio diff --git a/nota_wav2lip/audio.py b/nota_wav2lip/audio.py new file mode 100644 index 0000000000000000000000000000000000000000..d00e7cba5306d6f75437025b336595337c0b603f --- /dev/null +++ b/nota_wav2lip/audio.py @@ -0,0 +1,135 @@ +import librosa +import librosa.filters +import numpy as np +from scipy import signal +from scipy.io import wavfile + +from config import hparams + +hp = hparams.audio + +def load_wav(path, sr): + return librosa.core.load(path, sr=sr)[0] + +def save_wav(wav, path, sr): + wav *= 32767 / max(0.01, np.max(np.abs(wav))) + #proposed by @dsmiller + wavfile.write(path, sr, wav.astype(np.int16)) + +def save_wavenet_wav(wav, path, sr): + librosa.output.write_wav(path, wav, sr=sr) + +def preemphasis(wav, k, preemphasize=True): + if preemphasize: + return signal.lfilter([1, -k], [1], wav) + return wav + +def inv_preemphasis(wav, k, inv_preemphasize=True): + if inv_preemphasize: + return signal.lfilter([1], [1, -k], wav) + return wav + +def get_hop_size(): + hop_size = hp.hop_size + if hop_size is None: + assert hp.frame_shift_ms is not None + hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate) + return hop_size + +def linearspectrogram(wav): + D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize)) + S = _amp_to_db(np.abs(D)) - hp.ref_level_db + + if hp.signal_normalization: + return _normalize(S) + return S + +def melspectrogram(wav): + D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize)) + S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db + + if hp.signal_normalization: + return _normalize(S) + return S + +def _lws_processor(): + import lws + return lws.lws(hp.n_fft, get_hop_size(), fftsize=hp.win_size, mode="speech") + +def _stft(y): + if hp.use_lws: + return _lws_processor(hp).stft(y).T + else: + return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size) + +########################################################## +#Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!) +def num_frames(length, fsize, fshift): + """Compute number of time frames of spectrogram + """ + pad = (fsize - fshift) + M = (length + pad * 2 - fsize) // fshift + 1 if length % fshift == 0 else (length + pad * 2 - fsize) // fshift + 2 + return M + + +def pad_lr(x, fsize, fshift): + """Compute left and right padding + """ + M = num_frames(len(x), fsize, fshift) + pad = (fsize - fshift) + T = len(x) + 2 * pad + r = (M - 1) * fshift + fsize - T + return pad, pad + r +########################################################## +#Librosa correct padding +def librosa_pad_lr(x, fsize, fshift): + return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0] + +# Conversions +_mel_basis = None + +def _linear_to_mel(spectogram): + global _mel_basis + if _mel_basis is None: + _mel_basis = _build_mel_basis() + return np.dot(_mel_basis, spectogram) + +def _build_mel_basis(): + assert hp.fmax <= hp.sample_rate // 2 + return librosa.filters.mel(hp.sample_rate, hp.n_fft, n_mels=hp.num_mels, + fmin=hp.fmin, fmax=hp.fmax) + +def _amp_to_db(x): + min_level = np.exp(hp.min_level_db / 20 * np.log(10)) + return 20 * np.log10(np.maximum(min_level, x)) + +def _db_to_amp(x): + return np.power(10.0, (x) * 0.05) + +def _normalize(S): + if hp.allow_clipping_in_normalization: + if hp.symmetric_mels: + return np.clip((2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value, + -hp.max_abs_value, hp.max_abs_value) + else: + return np.clip(hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)), 0, hp.max_abs_value) + + assert S.max() <= 0 and S.min() - hp.min_level_db >= 0 + if hp.symmetric_mels: + return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value + else: + return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)) + +def _denormalize(D): + if hp.allow_clipping_in_normalization: + if hp.symmetric_mels: + return (((np.clip(D, -hp.max_abs_value, + hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + + hp.min_level_db) + else: + return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db) + + if hp.symmetric_mels: + return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db) + else: + return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db) diff --git a/nota_wav2lip/demo.py b/nota_wav2lip/demo.py new file mode 100644 index 0000000000000000000000000000000000000000..8985781f3052208232b0e13facd2cb6173bc27a2 --- /dev/null +++ b/nota_wav2lip/demo.py @@ -0,0 +1,91 @@ +import argparse +import platform +import subprocess +import time +from pathlib import Path +from typing import Dict, Iterator, List, Literal, Optional, Union + +import cv2 +import numpy as np + +from config import hparams as hp +from nota_wav2lip.inference import Wav2LipInferenceImpl +from nota_wav2lip.util import FFMPEG_LOGGING_MODE +from nota_wav2lip.video import AudioSlicer, VideoSlicer + + +class Wav2LipModelComparisonDemo: + def __init__(self, device='cpu', result_dir='./temp', model_list: Optional[Union[str, List[str]]]=None): + if model_list is None: + model_list: List[str] = ['wav2lip', 'nota_wav2lip'] + if isinstance(model_list, str) and len(model_list) != 0: + model_list: List[str] = [model_list] + super().__init__() + self.video_dict: Dict[str, VideoSlicer] = {} + self.audio_dict: Dict[str, AudioSlicer] = {} + + self.model_zoo: Dict[str, Wav2LipInferenceImpl] = {} + for model_name in model_list: + assert model_name in hp.inference.model, f"{model_name} not in hp.inference_model: {hp.inference.model}" + self.model_zoo[model_name] = Wav2LipInferenceImpl( + model_name, hp_inference_model=hp.inference.model[model_name], device=device + ) + + self._params_zoo: Dict[str, str] = { + model_name: self.model_zoo[model_name].params for model_name in self.model_zoo + } + + self.result_dir: Path = Path(result_dir) + self.result_dir.mkdir(exist_ok=True) + + @property + def params(self): + return self._params_zoo + + def _infer( + self, + audio_name: str, + video_name: str, + model_type: Literal['wav2lip', 'nota_wav2lip'] + ) -> Iterator[np.ndarray]: + audio_iterable: AudioSlicer = self.audio_dict[audio_name] + video_iterable: VideoSlicer = self.video_dict[video_name] + target_model = self.model_zoo[model_type] + return target_model.inference_with_iterator(audio_iterable, video_iterable) + + def update_audio(self, audio_path, name=None): + _name = name if name is not None else Path(audio_path).stem + self.audio_dict.update( + {_name: AudioSlicer(audio_path)} + ) + + def update_video(self, frame_dir_path, bbox_path, name=None): + _name = name if name is not None else Path(frame_dir_path).stem + self.video_dict.update( + {_name: VideoSlicer(frame_dir_path, bbox_path)} + ) + + def save_as_video(self, audio_name, video_name, model_type): + + output_video_path = self.result_dir / 'generated_with_audio.mp4' + frame_only_video_path = self.result_dir / 'generated.mp4' + audio_path = self.audio_dict[audio_name].audio_path + + out = cv2.VideoWriter(str(frame_only_video_path), + cv2.VideoWriter_fourcc(*'mp4v'), + hp.face.video_fps, + (hp.inference.frame.w, hp.inference.frame.h)) + start = time.time() + for frame in self._infer(audio_name=audio_name, video_name=video_name, model_type=model_type): + out.write(frame) + inference_time = time.time() - start + out.release() + + command = f"ffmpeg {FFMPEG_LOGGING_MODE['ERROR']} -y -i {audio_path} -i {frame_only_video_path} -strict -2 -q:v 1 {output_video_path}" + subprocess.call(command, shell=platform.system() != 'Windows') + + # The number of frames of generated video + video_frames_num = len(self.audio_dict[audio_name]) + inference_fps = video_frames_num / inference_time + + return output_video_path, inference_time, inference_fps diff --git a/nota_wav2lip/gradio.py b/nota_wav2lip/gradio.py new file mode 100644 index 0000000000000000000000000000000000000000..d7781d44ef0a2e99c3c7b386a1ed724a1fa21bc9 --- /dev/null +++ b/nota_wav2lip/gradio.py @@ -0,0 +1,91 @@ +import threading +from pathlib import Path + +from nota_wav2lip.demo import Wav2LipModelComparisonDemo + + +class Wav2LipModelComparisonGradio(Wav2LipModelComparisonDemo): + def __init__( + self, + device='cpu', + result_dir='./temp', + video_label_dict=None, + audio_label_list=None, + default_video='v1', + default_audio='a1' + ) -> None: + if audio_label_list is None: + audio_label_list = {} + if video_label_dict is None: + video_label_dict = {} + super().__init__(device, result_dir) + self._video_label_dict = {k: Path(v).with_suffix('.mp4') for k, v in video_label_dict.items()} + self._audio_label_dict = audio_label_list + self._default_video = default_video + self._default_audio = default_audio + + self._lock = threading.Lock() # lock for asserting that concurrency_count == 1 + + def _is_valid_input(self, video_selection, audio_selection): + assert video_selection in self._video_label_dict, \ + f"Your input ({video_selection}) is not in {self._video_label_dict}!!!" + assert audio_selection in self._audio_label_dict, \ + f"Your input ({audio_selection}) is not in {self._audio_label_dict}!!!" + + def generate_original_model(self, video_selection, audio_selection): + try: + self._is_valid_input(video_selection, audio_selection) + + with self._lock: + output_video_path, inference_time, inference_fps = \ + self.save_as_video(audio_name=audio_selection, + video_name=video_selection, + model_type='wav2lip') + + return str(output_video_path), format(inference_time, ".2f"), format(inference_fps, ".1f") + except KeyboardInterrupt: + exit() + except Exception as e: + print(e) + pass + + def generate_compressed_model(self, video_selection, audio_selection): + try: + self._is_valid_input(video_selection, audio_selection) + + with self._lock: + output_video_path, inference_time, inference_fps = \ + self.save_as_video(audio_name=audio_selection, + video_name=video_selection, + model_type='nota_wav2lip') + + return str(output_video_path), format(inference_time, ".2f"), format(inference_fps, ".1f") + except KeyboardInterrupt: + exit() + except Exception as e: + print(e) + pass + + def switch_video_samples(self, video_selection): + try: + if video_selection not in self._video_label_dict: + return self._video_label_dict[self._default_video] + return self._video_label_dict[video_selection] + + except KeyboardInterrupt: + exit() + except Exception as e: + print(e) + pass + + def switch_audio_samples(self, audio_selection): + try: + if audio_selection not in self._audio_label_dict: + return self._audio_label_dict[self._default_audio] + return self._audio_label_dict[audio_selection] + + except KeyboardInterrupt: + exit() + except Exception as e: + print(e) + pass diff --git a/nota_wav2lip/inference.py b/nota_wav2lip/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..4aa124261adee7fe40330335800c2367b302e7ae --- /dev/null +++ b/nota_wav2lip/inference.py @@ -0,0 +1,111 @@ +from typing import Iterable, Iterator, List, Tuple + +import cv2 +import numpy as np +import torch +import torch.nn as nn +from omegaconf import DictConfig +from tqdm import tqdm + +from config import hparams as hp +from nota_wav2lip.models.util import count_params, load_model + + +class Wav2LipInferenceImpl: + def __init__(self, model_name: str, hp_inference_model: DictConfig, device='cpu'): + self.model: nn.Module = load_model( + model_name, + device=device, + **hp_inference_model + ) + self.device = device + self._params: str = self._format_param(count_params(self.model)) + + @property + def params(self): + return self._params + + @staticmethod + def _format_param(num_params: int) -> str: + params_in_million = num_params / 1e6 + return f"{params_in_million:.1f}M" + + @staticmethod + def _reset_batch() -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray], List[List[int]]]: + return [], [], [], [] + + def get_data_iterator( + self, + audio_iterable: Iterable[np.ndarray], + video_iterable: List[Tuple[np.ndarray, List[int]]] + ) -> Iterator[Tuple[np.ndarray, np.ndarray, np.ndarray, List[int]]]: + img_batch, mel_batch, frame_batch, coords_batch = self._reset_batch() + + for i, m in enumerate(audio_iterable): + idx = i % len(video_iterable) + _frame_to_save, coords = video_iterable[idx] + frame_to_save = _frame_to_save.copy() + face = frame_to_save[coords[0]:coords[1], coords[2]:coords[3]].copy() + + face: np.ndarray = cv2.resize(face, (hp.face.img_size, hp.face.img_size)) + + img_batch.append(face) + mel_batch.append(m) + frame_batch.append(frame_to_save) + coords_batch.append(coords) + + if len(img_batch) >= hp.inference.batch_size: + img_batch = np.asarray(img_batch) + mel_batch = np.asarray(mel_batch) + + img_masked = img_batch.copy() + img_masked[:, hp.face.img_size // 2:] = 0 + + img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. + mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) + + yield img_batch, mel_batch, frame_batch, coords_batch + img_batch, mel_batch, frame_batch, coords_batch = self._reset_batch() + + if len(img_batch) > 0: + img_batch = np.asarray(img_batch) + mel_batch = np.asarray(mel_batch) + + img_masked = img_batch.copy() + img_masked[:, hp.face.img_size // 2:] = 0 + + img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. + mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) + + yield img_batch, mel_batch, frame_batch, coords_batch + + @torch.no_grad() + def inference_with_iterator( + self, + audio_iterable: Iterable[np.ndarray], + video_iterable: List[Tuple[np.ndarray, List[int]]] + ) -> Iterator[np.ndarray]: + data_iterator = self.get_data_iterator(audio_iterable, video_iterable) + + for (img_batch, mel_batch, frames, coords) in \ + tqdm(data_iterator, total=int(np.ceil(float(len(audio_iterable)) / hp.inference.batch_size))): + + img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(self.device) + mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(self.device) + + preds: torch.Tensor = self.forward(mel_batch, img_batch) + + preds = preds.cpu().numpy().transpose(0, 2, 3, 1) * 255. + for pred, frame, coord in zip(preds, frames, coords): + y1, y2, x1, x2 = coord + pred = cv2.resize(pred.astype(np.uint8), (x2 - x1, y2 - y1)) + + frame[y1:y2, x1:x2] = pred + yield frame + + @torch.no_grad() + def forward(self, audio_sequences: torch.Tensor, face_sequences: torch.Tensor) -> torch.Tensor: + return self.model(audio_sequences, face_sequences) + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) diff --git a/nota_wav2lip/models/__init__.py b/nota_wav2lip/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ac703b9a34f23255eda3f2aa709e3775b1db902f --- /dev/null +++ b/nota_wav2lip/models/__init__.py @@ -0,0 +1,3 @@ +from .base import Wav2LipBase +from .wav2lip import Wav2Lip +from .wav2lip_compressed import NotaWav2Lip diff --git a/nota_wav2lip/models/base.py b/nota_wav2lip/models/base.py new file mode 100644 index 0000000000000000000000000000000000000000..704c46a205e6246270303148cd29bb6dc8a8ce5b --- /dev/null +++ b/nota_wav2lip/models/base.py @@ -0,0 +1,55 @@ +from typing import final + +import torch +from torch import nn + + +class Wav2LipBase(nn.Module): + def __init__(self) -> None: + super().__init__() + + self.audio_encoder = nn.Sequential() + self.face_encoder_blocks = nn.ModuleList([]) + self.face_decoder_blocks = nn.ModuleList([]) + self.output_block = nn.Sequential() + + @final + def forward(self, audio_sequences, face_sequences): + # audio_sequences = (B, T, 1, 80, 16) + B = audio_sequences.size(0) + + input_dim_size = len(face_sequences.size()) + if input_dim_size > 4: + audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0) + face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0) + + audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1 + + feats = [] + x = face_sequences + for f in self.face_encoder_blocks: + x = f(x) + feats.append(x) + + x = audio_embedding + for f in self.face_decoder_blocks: + x = f(x) + try: + x = torch.cat((x, feats[-1]), dim=1) + except Exception as e: + print(x.size()) + print(feats[-1].size()) + raise e + + feats.pop() + + x = self.output_block(x) + + if input_dim_size > 4: + x = torch.split(x, B, dim=0) # [(B, C, H, W)] + outputs = torch.stack(x, dim=2) # (B, C, T, H, W) + + else: + outputs = x + + return outputs diff --git a/nota_wav2lip/models/conv.py b/nota_wav2lip/models/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..9e1e086d7ffca95f4398857e150ee91fd41cd5b6 --- /dev/null +++ b/nota_wav2lip/models/conv.py @@ -0,0 +1,34 @@ +import torch +from torch import nn +from torch.nn import functional as F + + +class Conv2d(nn.Module): + def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs): + super().__init__(*args, **kwargs) + self.conv_block = nn.Sequential( + nn.Conv2d(cin, cout, kernel_size, stride, padding), + nn.BatchNorm2d(cout) + ) + self.act = nn.ReLU() + self.residual = residual + + def forward(self, x): + out = self.conv_block(x) + if self.residual: + out += x + return self.act(out) + + +class Conv2dTranspose(nn.Module): + def __init__(self, cin, cout, kernel_size, stride, padding, output_padding=0, *args, **kwargs): + super().__init__(*args, **kwargs) + self.conv_block = nn.Sequential( + nn.ConvTranspose2d(cin, cout, kernel_size, stride, padding, output_padding), + nn.BatchNorm2d(cout) + ) + self.act = nn.ReLU() + + def forward(self, x): + out = self.conv_block(x) + return self.act(out) diff --git a/nota_wav2lip/models/util.py b/nota_wav2lip/models/util.py new file mode 100644 index 0000000000000000000000000000000000000000..f4c438d99b9a45d6fe31c5f3df0a69e668de5d03 --- /dev/null +++ b/nota_wav2lip/models/util.py @@ -0,0 +1,32 @@ +from typing import Dict, Type + +import torch + +from nota_wav2lip.models import NotaWav2Lip, Wav2Lip, Wav2LipBase + +MODEL_REGISTRY: Dict[str, Type[Wav2LipBase]] = { + 'wav2lip': Wav2Lip, + 'nota_wav2lip': NotaWav2Lip +} + +def _load(checkpoint_path, device): + assert device in ['cpu', 'cuda'] + + print(f"Load checkpoint from: {checkpoint_path}") + if device == 'cuda': + return torch.load(checkpoint_path) + return torch.load(checkpoint_path, map_location=lambda storage, _: storage) + +def load_model(model_name: str, device, checkpoint, **kwargs) -> Wav2LipBase: + + cls = MODEL_REGISTRY[model_name.lower()] + assert issubclass(cls, Wav2LipBase) + + model = cls(**kwargs) + checkpoint = _load(checkpoint, device) + model.load_state_dict(checkpoint) + model = model.to(device) + return model.eval() + +def count_params(model): + return sum(p.numel() for p in model.parameters()) diff --git a/nota_wav2lip/models/wav2lip.py b/nota_wav2lip/models/wav2lip.py new file mode 100644 index 0000000000000000000000000000000000000000..ddb924af64b971e2c7397496c10edb7b51ac9d98 --- /dev/null +++ b/nota_wav2lip/models/wav2lip.py @@ -0,0 +1,85 @@ +import torch +from torch import nn + +from nota_wav2lip.models.base import Wav2LipBase +from nota_wav2lip.models.conv import Conv2d, Conv2dTranspose + + +class Wav2Lip(Wav2LipBase): + def __init__(self): + super().__init__() + + self.face_encoder_blocks = nn.ModuleList([ + nn.Sequential(Conv2d(6, 16, kernel_size=7, stride=1, padding=3)), # 96,96 + + nn.Sequential(Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # 48,48 + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True)), + + nn.Sequential(Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # 24,24 + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True)), + + nn.Sequential(Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # 12,12 + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True)), + + nn.Sequential(Conv2d(128, 256, kernel_size=3, stride=2, padding=1), # 6,6 + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True)), + + nn.Sequential(Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 3,3 + Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),), + + nn.Sequential(Conv2d(512, 512, kernel_size=3, stride=1, padding=0), # 1, 1 + Conv2d(512, 512, kernel_size=1, stride=1, padding=0)),]) + + self.audio_encoder = nn.Sequential( + Conv2d(1, 32, kernel_size=3, stride=1, padding=1), + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(64, 128, kernel_size=3, stride=3, padding=1), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(256, 512, kernel_size=3, stride=1, padding=0), + Conv2d(512, 512, kernel_size=1, stride=1, padding=0),) + + self.face_decoder_blocks = nn.ModuleList([ + nn.Sequential(Conv2d(512, 512, kernel_size=1, stride=1, padding=0),), + + nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=1, padding=0), # 3,3 + Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),), + + nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=2, padding=1, output_padding=1), + Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),), # 6, 6 + + nn.Sequential(Conv2dTranspose(768, 384, kernel_size=3, stride=2, padding=1, output_padding=1), + Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True),), # 12, 12 + + nn.Sequential(Conv2dTranspose(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),), # 24, 24 + + nn.Sequential(Conv2dTranspose(320, 128, kernel_size=3, stride=2, padding=1, output_padding=1), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),), # 48, 48 + + nn.Sequential(Conv2dTranspose(160, 64, kernel_size=3, stride=2, padding=1, output_padding=1), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),),]) # 96,96 + + self.output_block = nn.Sequential(Conv2d(80, 32, kernel_size=3, stride=1, padding=1), + nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0), + nn.Sigmoid()) diff --git a/nota_wav2lip/models/wav2lip_compressed.py b/nota_wav2lip/models/wav2lip_compressed.py new file mode 100644 index 0000000000000000000000000000000000000000..7e1393743dc2fdce65454a971fc4b128209978bf --- /dev/null +++ b/nota_wav2lip/models/wav2lip_compressed.py @@ -0,0 +1,72 @@ +import torch +from torch import nn + +from nota_wav2lip.models.base import Wav2LipBase +from nota_wav2lip.models.conv import Conv2d, Conv2dTranspose + + +class NotaWav2Lip(Wav2LipBase): + def __init__(self, nef=4, naf=8, ndf=8, x_size=96): + super().__init__() + + assert x_size in [96, 128] + self.ker_sz_last = x_size // 32 + + self.face_encoder_blocks = nn.ModuleList([ + nn.Sequential(Conv2d(6, nef, kernel_size=7, stride=1, padding=3)), # 96,96 + + nn.Sequential(Conv2d(nef, nef * 2, kernel_size=3, stride=2, padding=1),), # 48,48 + + nn.Sequential(Conv2d(nef * 2, nef * 4, kernel_size=3, stride=2, padding=1),), # 24,24 + + nn.Sequential(Conv2d(nef * 4, nef * 8, kernel_size=3, stride=2, padding=1),), # 12,12 + + nn.Sequential(Conv2d(nef * 8, nef * 16, kernel_size=3, stride=2, padding=1),), # 6,6 + + nn.Sequential(Conv2d(nef * 16, nef * 32, kernel_size=3, stride=2, padding=1),), # 3,3 + + nn.Sequential(Conv2d(nef * 32, nef * 32, kernel_size=self.ker_sz_last, stride=1, padding=0), # 1, 1 + Conv2d(nef * 32, nef * 32, kernel_size=1, stride=1, padding=0)), ]) + + self.audio_encoder = nn.Sequential( + Conv2d(1, naf, kernel_size=3, stride=1, padding=1), + + Conv2d(naf, naf * 2, kernel_size=3, stride=(3, 1), padding=1), + + Conv2d(naf * 2, naf * 4, kernel_size=3, stride=3, padding=1), + + Conv2d(naf * 4, naf * 8, kernel_size=3, stride=(3, 2), padding=1), + + Conv2d(naf * 8, naf * 16, kernel_size=3, stride=1, padding=0), + Conv2d(naf * 16, naf * 16, kernel_size=1, stride=1, padding=0), ) + + self.face_decoder_blocks = nn.ModuleList([ + nn.Sequential(Conv2d(naf * 16, naf * 16, kernel_size=1, stride=1, padding=0), ), + + nn.Sequential(Conv2dTranspose(nef * 32 + naf * 16, ndf * 16, kernel_size=self.ker_sz_last, stride=1, padding=0),), + # 3,3 # 512+512 = 1024 + + nn.Sequential( + Conv2dTranspose(nef * 32 + ndf * 16, ndf * 16, kernel_size=3, stride=2, padding=1, output_padding=1),), # 6, 6 + # 512+512 = 1024 + + nn.Sequential( + Conv2dTranspose(nef * 16 + ndf * 16, ndf * 12, kernel_size=3, stride=2, padding=1, output_padding=1),), # 12, 12 + # 256+512 = 768 + + nn.Sequential( + Conv2dTranspose(nef * 8 + ndf * 12, ndf * 8, kernel_size=3, stride=2, padding=1, output_padding=1),), # 24, 24 + # 128+384 = 512 + + nn.Sequential( + Conv2dTranspose(nef * 4 + ndf * 8, ndf * 4, kernel_size=3, stride=2, padding=1, output_padding=1),), # 48, 48 + # 64+256 = 320 + + nn.Sequential( + Conv2dTranspose(nef * 2 + ndf * 4, ndf * 2, kernel_size=3, stride=2, padding=1, output_padding=1),), # 96,96 + # 32+128 = 160 + ]) + + self.output_block = nn.Sequential(Conv2d(nef + ndf * 2, ndf, kernel_size=3, stride=1, padding=1), # 16+64 = 80 + nn.Conv2d(ndf, 3, kernel_size=1, stride=1, padding=0), + nn.Sigmoid()) diff --git a/nota_wav2lip/preprocess/__init__.py b/nota_wav2lip/preprocess/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..190e5b03eb627c9a9fd95ea0f0605d930cecbb53 --- /dev/null +++ b/nota_wav2lip/preprocess/__init__.py @@ -0,0 +1,2 @@ +from nota_wav2lip.preprocess.core import get_preprocessed_data +from nota_wav2lip.preprocess.lrs3_download import get_cropped_face_from_lrs3_label diff --git a/nota_wav2lip/preprocess/core.py b/nota_wav2lip/preprocess/core.py new file mode 100644 index 0000000000000000000000000000000000000000..9839d0007d76703c1578e61ac68ed07a961fda9d --- /dev/null +++ b/nota_wav2lip/preprocess/core.py @@ -0,0 +1,98 @@ +import json +import platform +import subprocess +from pathlib import Path + +import cv2 +import numpy as np +from loguru import logger +from tqdm import tqdm + +import face_detection +from nota_wav2lip.util import FFMPEG_LOGGING_MODE + +detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D, flip_input=False, device='cpu') +PADDING = [0, 10, 0, 0] + + +def get_smoothened_boxes(boxes, T): + for i in range(len(boxes)): + window = boxes[len(boxes) - T:] if i + T > len(boxes) else boxes[i:i + T] + boxes[i] = np.mean(window, axis=0) + return boxes + + +def face_detect(images, pads, no_smooth=False, batch_size=1): + + predictions = [] + images_array = [cv2.imread(str(image)) for image in images] + for i in tqdm(range(0, len(images_array), batch_size)): + predictions.extend(detector.get_detections_for_batch(np.array(images_array[i:i + batch_size]))) + + results = [] + pady1, pady2, padx1, padx2 = pads + for rect, image_array in zip(predictions, images_array): + if rect is None: + cv2.imwrite('temp/faulty_frame.jpg', image_array) # check this frame where the face was not detected. + raise ValueError('Face not detected! Ensure the video contains a face in all the frames.') + + y1 = max(0, rect[1] - pady1) + y2 = min(image_array.shape[0], rect[3] + pady2) + x1 = max(0, rect[0] - padx1) + x2 = min(image_array.shape[1], rect[2] + padx2) + results.append([x1, y1, x2, y2]) + + boxes = np.array(results) + bbox_format = "(y1, y2, x1, x2)" + if not no_smooth: + boxes = get_smoothened_boxes(boxes, T=5) + outputs = { + 'bbox': {str(image_path): tuple(map(int, (y1, y2, x1, x2))) for image_path, (x1, y1, x2, y2) in zip(images, boxes)}, + 'format': bbox_format + } + return outputs + + +def save_video_frame(video_path, output_dir=None): + video_path = Path(video_path) + output_dir = output_dir if output_dir is not None else video_path.with_suffix('') + output_dir.mkdir(exist_ok=True) + return subprocess.call( + f"ffmpeg {FFMPEG_LOGGING_MODE['ERROR']} -y -i {video_path} -r 25 -f image2 {output_dir}/%05d.jpg", + shell=platform.system() != 'Windows' + ) + + +def save_audio_file(video_path, output_path=None): + video_path = Path(video_path) + output_path = output_path if output_path is not None else video_path.with_suffix('.wav') + subprocess.call( + f"ffmpeg {FFMPEG_LOGGING_MODE['ERROR']} -y -i {video_path} -vn -acodec pcm_s16le -ar 16000 -ac 1 {output_path}", + shell=platform.system() != 'Windows' + ) + + +def save_bbox_file(video_path, bbox_dict, output_path=None): + video_path = Path(video_path) + output_path = output_path if output_path is not None else video_path.with_suffix('.json') + + with open(output_path, 'w') as f: + json.dump(bbox_dict, f, indent=4) + +def get_preprocessed_data(video_path: Path): + video_path = Path(video_path) + + image_sequence_dir = video_path.with_suffix('') + audio_path = video_path.with_suffix('.wav') + face_bbox_json_path = video_path.with_suffix('.json') + + logger.info(f"Save 25 FPS video frames as image files ... will be saved at {video_path}") + save_video_frame(video_path=video_path, output_dir=image_sequence_dir) + + logger.info(f"Save the audio as wav file ... will be saved at {audio_path}") + save_audio_file(video_path=video_path, output_path=audio_path) # bonus + + # Load images, extract bboxes and save the coords(to directly use as array indicies) + logger.info(f"Extract face boxes and save the coords with json format ... will be saved at {face_bbox_json_path}") + results = face_detect(sorted(image_sequence_dir.glob("*.jpg")), pads=PADDING) + save_bbox_file(video_path, results, output_path=face_bbox_json_path) diff --git a/nota_wav2lip/preprocess/ffmpeg.py b/nota_wav2lip/preprocess/ffmpeg.py new file mode 100644 index 0000000000000000000000000000000000000000..baabcc41cfcffef90a94e6cedc1958449bf1666e --- /dev/null +++ b/nota_wav2lip/preprocess/ffmpeg.py @@ -0,0 +1,5 @@ +FFMPEG_LOGGING_MODE = { + 'DEBUG': "", + 'INFO': "-v quiet -stats", + 'ERROR': "-hide_banner -loglevel error", +} \ No newline at end of file diff --git a/nota_wav2lip/preprocess/lrs3_download.py b/nota_wav2lip/preprocess/lrs3_download.py new file mode 100644 index 0000000000000000000000000000000000000000..a564813fe39ad2c1cb4504b9cdaaf79ab54627f4 --- /dev/null +++ b/nota_wav2lip/preprocess/lrs3_download.py @@ -0,0 +1,259 @@ +import platform +import subprocess +from pathlib import Path +from typing import Dict, List, Tuple, TypedDict, Union + +import cv2 +import numpy as np +import yt_dlp +from loguru import logger +from tqdm import tqdm + +from nota_wav2lip.util import FFMPEG_LOGGING_MODE + + +class LabelInfo(TypedDict): + text: str + conf: int + url: str + bbox_xywhn: Dict[int, Tuple[float, float, float, float]] + +def frame_to_time(frame_id: int, fps=25) -> str: + seconds = frame_id / fps + + hours = int(seconds // 3600) + seconds -= 3600 * hours + + minutes = int(seconds // 60) + seconds -= 60 * minutes + + seconds_int = int(seconds) + seconds_milli = int((seconds - int(seconds)) * 1e3) + + return f"{hours:02d}:{minutes:02d}:{seconds_int:02d}.{seconds_milli:03d}" # HH:MM:SS.mmm + +def save_audio_file(input_path, start_frame_id, to_frame_id, output_path=None): + input_path = Path(input_path) + output_path = output_path if output_path is not None else input_path.with_suffix('.wav') + + ss = frame_to_time(start_frame_id) + to = frame_to_time(to_frame_id) + subprocess.call( + f"ffmpeg {FFMPEG_LOGGING_MODE['ERROR']} -y -i {input_path} -vn -acodec pcm_s16le -ss {ss} -to {to} -ar 16000 -ac 1 {output_path}", + shell=platform.system() != 'Windows' + ) + +def merge_video_audio(video_path, audio_path, output_path): + subprocess.call( + f"ffmpeg {FFMPEG_LOGGING_MODE['ERROR']} -y -i {video_path} -i {audio_path} -strict experimental {output_path}", + shell=platform.system() != 'Windows' + ) + +def parse_lrs3_label(label_path) -> LabelInfo: + label_text = Path(label_path).read_text() + label_splitted = label_text.split('\n') + + # Label validation + assert label_splitted[0].startswith("Text:") + assert label_splitted[1].startswith("Conf:") + assert label_splitted[2].startswith("Ref:") + assert label_splitted[4].startswith("FRAME") + + label_info = LabelInfo(bbox_xywhn={}) + label_info['text'] = label_splitted[0][len("Text: "):].strip() + label_info['conf'] = int(label_splitted[1][len("Conf: "):]) + label_info['url'] = label_splitted[2][len("Ref: "):].strip() + + for label_line in label_splitted[5:]: + bbox_splitted = [x.strip() for x in label_line.split('\t')] + if len(bbox_splitted) != 5: + continue + frame_index = int(bbox_splitted[0]) + bbox_xywhn = tuple(map(float, bbox_splitted[1:])) + label_info['bbox_xywhn'][frame_index] = bbox_xywhn + + return label_info + +def _get_cropped_bbox(bbox_info_xywhn, original_width, original_height): + + bbox_info = bbox_info_xywhn + x = bbox_info[0] * original_width + y = bbox_info[1] * original_height + w = bbox_info[2] * original_width + h = bbox_info[3] * original_height + + x_min = max(0, int(x - 0.5 * w)) + y_min = max(0, int(y)) + x_max = min(original_width, int(x + 1.5 * w)) + y_max = min(original_height, int(y + 1.5 * h)) + + cropped_width = x_max - x_min + cropped_height = y_max - y_min + + if cropped_height > cropped_width: + offset = cropped_height - cropped_width + offset_low = min(x_min, offset // 2) + offset_high = min(offset - offset_low, original_width - x_max) + x_min -= offset_low + x_max += offset_high + else: + offset = cropped_width - cropped_height + offset_low = min(y_min, offset // 2) + offset_high = min(offset - offset_low, original_width - y_max) + y_min -= offset_low + y_max += offset_high + + return x_min, y_min, x_max, y_max + +def _get_smoothened_boxes(bbox_dict, bbox_smoothen_window): + boxes = [np.array(bbox_dict[frame_id]) for frame_id in sorted(bbox_dict)] + for i in range(len(boxes)): + window = boxes[len(boxes) - bbox_smoothen_window:] if i + bbox_smoothen_window > len(boxes) else boxes[i:i + bbox_smoothen_window] + boxes[i] = np.mean(window, axis=0) + + for idx, frame_id in enumerate(sorted(bbox_dict)): + bbox_dict[frame_id] = (np.rint(boxes[idx])).astype(int).tolist() + return bbox_dict + +def download_video_from_youtube(youtube_ref, output_path): + ydl_url = f"https://www.youtube.com/watch?v={youtube_ref}" + ydl_opts = { + 'format': 'bestvideo[ext=mp4][height<=720]+bestaudio[ext=m4a]/best[ext=mp4][height<=720]', + 'outtmpl': str(output_path), + } + + with yt_dlp.YoutubeDL(ydl_opts) as ydl: + ydl.download([ydl_url]) + +def resample_video(input_path, output_path): + subprocess.call( + f"ffmpeg {FFMPEG_LOGGING_MODE['INFO']} -y -i {input_path} -r 25 -preset veryfast {output_path}", + shell=platform.system() != 'Windows' + ) + +def _get_smoothen_xyxy_bbox( + label_bbox_xywhn: Dict[int, Tuple[float, float, float, float]], + original_width: int, + original_height: int, + bbox_smoothen_window: int = 5 +) -> Dict[int, Tuple[float, float, float, float]]: + + label_bbox_xyxy: Dict[int, Tuple[float, float, float, float]] = {} + for frame_id in sorted(label_bbox_xywhn): + frame_bbox_xywhn = label_bbox_xywhn[frame_id] + bbox_xyxy = _get_cropped_bbox(frame_bbox_xywhn, original_width, original_height) + label_bbox_xyxy[frame_id] = bbox_xyxy + + label_bbox_xyxy = _get_smoothened_boxes(label_bbox_xyxy, bbox_smoothen_window=bbox_smoothen_window) + return label_bbox_xyxy + +def get_start_end_frame_id( + label_bbox_xywhn: Dict[int, Tuple[float, float, float, float]], +) -> Tuple[int, int]: + frame_ids = list(label_bbox_xywhn.keys()) + start_frame_id = min(frame_ids) + to_frame_id = max(frame_ids) + return start_frame_id, to_frame_id + +def crop_video_with_bbox( + input_path, + label_bbox_xywhn: Dict[int, Tuple[float, float, float, float]], + start_frame_id, + to_frame_id, + output_path, + bbox_smoothen_window = 5, + frame_width = 224, + frame_height = 224, + fps = 25, + interpolation = cv2.INTER_CUBIC, +): + def frame_generator(cap): + if not cap.isOpened(): + raise IOError("Error: Could not open video.") + + while True: + ret, frame = cap.read() + if not ret: + break + yield frame + + cap.release() + + cap = cv2.VideoCapture(str(input_path)) + original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + label_bbox_xyxy = _get_smoothen_xyxy_bbox(label_bbox_xywhn, original_width, original_height, bbox_smoothen_window=bbox_smoothen_window) + + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + out = cv2.VideoWriter(str(output_path), fourcc, fps, (frame_width, frame_height)) + + for frame_id, frame in tqdm(enumerate(frame_generator(cap))): + if start_frame_id <= frame_id <= to_frame_id: + x_min, y_min, x_max, y_max = label_bbox_xyxy[frame_id] + + frame_cropped = frame[y_min:y_max, x_min:x_max] + frame_cropped = cv2.resize(frame_cropped, (frame_width, frame_height), interpolation=interpolation) + out.write(frame_cropped) + + out.release() + + +def get_cropped_face_from_lrs3_label( + label_text_path: Union[Path, str], + video_root_dir: Union[Path, str], + bbox_smoothen_window: int = 5, + frame_width: int = 224, + frame_height: int = 224, + fps: int = 25, + interpolation = cv2.INTER_CUBIC, + ignore_cache: bool = False, +): + label_text_path = Path(label_text_path) + label_info = parse_lrs3_label(label_text_path) + start_frame_id, to_frame_id = get_start_end_frame_id(label_info['bbox_xywhn']) + + video_root_dir = Path(video_root_dir) + video_cache_dir = video_root_dir / ".cache" + video_cache_dir.mkdir(parents=True, exist_ok=True) + + output_video: Path = video_cache_dir / f"{label_info['url']}.mp4" + output_resampled_video: Path = output_video.with_name(f"{output_video.stem}-25fps.mp4") + output_cropped_audio: Path = output_video.with_name(f"{output_video.stem}-{label_text_path.stem}-cropped.wav") + output_cropped_video: Path = output_video.with_name(f"{output_video.stem}-{label_text_path.stem}-cropped.mp4") + output_cropped_with_audio: Path = video_root_dir / output_video.with_name(f"{output_video.stem}-{label_text_path.stem}.mp4").name + + if not output_video.exists() or ignore_cache: + youtube_ref = label_info['url'] + logger.info(f"Download Youtube video(https://www.youtube.com/watch?v={youtube_ref}) ... will be saved at {output_video}") + download_video_from_youtube(youtube_ref, output_path=output_video) + + if not output_resampled_video.exists() or ignore_cache: + logger.info(f"Resampling video to 25 FPS ... will be saved at {output_resampled_video}") + resample_video(input_path=output_video, output_path=output_resampled_video) + + if not output_cropped_audio.exists() or ignore_cache: + logger.info(f"Cut audio file with the given timestamps ... will be saved at {output_cropped_audio}") + save_audio_file( + output_resampled_video, + start_frame_id=start_frame_id, + to_frame_id=to_frame_id, + output_path=output_cropped_audio + ) + + logger.info(f"Naive crop the face region with the given frame labels ... will be saved at {output_cropped_video}") + crop_video_with_bbox( + output_resampled_video, + label_info['bbox_xywhn'], + start_frame_id, + to_frame_id, + output_path=output_cropped_video, + bbox_smoothen_window=bbox_smoothen_window, + frame_width=frame_width, + frame_height=frame_height, + fps=fps, + interpolation=interpolation + ) + + if not output_cropped_with_audio.exists() or ignore_cache: + logger.info(f"Merge an audio track with the cropped face sequence ... will be saved at {output_cropped_with_audio}") + merge_video_audio(output_cropped_video, output_cropped_audio, output_cropped_with_audio) diff --git a/nota_wav2lip/util.py b/nota_wav2lip/util.py new file mode 100644 index 0000000000000000000000000000000000000000..21ad21f1c803f9baaec6af5a85410bd6cb7a9476 --- /dev/null +++ b/nota_wav2lip/util.py @@ -0,0 +1,5 @@ +FFMPEG_LOGGING_MODE = { + 'DEBUG': "", + 'INFO': "-v quiet -stats", + 'ERROR': "-hide_banner -loglevel error", +} diff --git a/nota_wav2lip/video.py b/nota_wav2lip/video.py new file mode 100644 index 0000000000000000000000000000000000000000..5e9c274d7a799c6039a0627f39ed2effea10009f --- /dev/null +++ b/nota_wav2lip/video.py @@ -0,0 +1,68 @@ +import json +from pathlib import Path +from typing import List, Tuple, Union + +import cv2 +import numpy as np + +import nota_wav2lip.audio as audio +from config import hparams as hp + + +class VideoSlicer: + def __init__(self, frame_dir: Union[Path, str], bbox_path: Union[Path, str]): + self.fps = hp.face.video_fps + self.frame_dir = frame_dir + self.frame_path_list = sorted(Path(self.frame_dir).glob("*.jpg")) + self.frame_array_list: List[np.ndarray] = [cv2.imread(str(image)) for image in self.frame_path_list] + + with open(bbox_path, 'r') as f: + metadata = json.load(f) + self.bbox: List[List[int]] = [metadata['bbox'][key] for key in sorted(metadata['bbox'].keys())] + self.bbox_format = metadata['format'] + assert len(self.bbox) == len(self.frame_array_list) + + def __len__(self): + return len(self.frame_array_list) + + def __getitem__(self, idx) -> Tuple[np.ndarray, List[int]]: + bbox = self.bbox[idx] + frame_original: np.ndarray = self.frame_array_list[idx] + # return frame_original[bbox[0]:bbox[1], bbox[2]:bbox[3], :] + return frame_original, bbox + + +class AudioSlicer: + def __init__(self, audio_path: Union[Path, str]): + self.fps = hp.face.video_fps + self.mel_chunks = self._audio_chunk_generator(audio_path) + self._audio_path = audio_path + + @property + def audio_path(self): + return self._audio_path + + def __len__(self): + return len(self.mel_chunks) + + def _audio_chunk_generator(self, audio_path): + wav: np.ndarray = audio.load_wav(audio_path, hp.audio.sample_rate) + mel: np.ndarray = audio.melspectrogram(wav) + + if np.isnan(mel.reshape(-1)).sum() > 0: + raise ValueError('Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again') + + mel_chunks: List[np.ndarray] = [] + mel_idx_multiplier = 80. / self.fps + + i = 0 + while True: + start_idx = int(i * mel_idx_multiplier) + if start_idx + hp.face.mel_step_size > len(mel[0]): + mel_chunks.append(mel[:, len(mel[0]) - hp.face.mel_step_size:]) + return mel_chunks + mel_chunks.append(mel[:, start_idx: start_idx + hp.face.mel_step_size]) + i += 1 + + def __getitem__(self, idx: int) -> np.ndarray: + return self.mel_chunks[idx] diff --git a/preprocess.py b/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..93000a37d858e37f35411536c7af6a60bf581a3c --- /dev/null +++ b/preprocess.py @@ -0,0 +1,28 @@ +import argparse + +from nota_wav2lip.preprocess import get_preprocessed_data + + +def parse_args(): + + parser = argparse.ArgumentParser(description="NotaWav2Lip: Preprocess the facial video with face detection") + + parser.add_argument( + '-i', + '--input-file', + type=str, + required=True, + help="Path of the facial video. We recommend that the video is one of LRS3 data samples, which is the result of `download.py`." + "The extracted features and facial image sequences are saved at the same location with the input file." + ) + + args = parser.parse_args() + + return args + +if __name__ == '__main__': + args = parse_args() + + get_preprocessed_data( + args.input_file, + ) diff --git a/preprocess.sh b/preprocess.sh new file mode 100644 index 0000000000000000000000000000000000000000..df4a52f332ceb15dfe6e102d559a22ea157bb94e --- /dev/null +++ b/preprocess.sh @@ -0,0 +1,7 @@ +# example for audio source +python preprocess.py\ + -i sample_video_lrs3/sxnlvwprf_c-00007.mp4 + +# example for video source +python preprocess.py\ + -i sample_video_lrs3/Li4-1yyrsTI-00010.mp4 \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..e369940be5664912740f5508ac9cb3e5866e8699 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,28 @@ +[tool.ruff] +target-version = "py38" +line-length = 120 + +extend-select = [ + "B", + "C", + "I", + "SIM", + "INP001", + "W" +] + +ignore = [ + "E501", + "F401", + "C901", +] + +extend-exclude = [ + "face_detection/*.py", +] + +[tool.ruff.per-file-ignores] + +"models/__init__.py" = [ + "F401", # "Imported but unused" +] \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..2c736196a7f10061089b3455d03fbc5150dfd81e --- /dev/null +++ b/requirements.txt @@ -0,0 +1,12 @@ +gradio==4.13.0 +iou==0.1.0 +librosa==0.9.1 +numpy==1.22.3 +opencv-python==4.5.5.64 +scipy==1.7.3 +torch==1.12.0 +tqdm==4.63.0 +lws==1.2.7 +omegaconf==2.3.0 +yt-dlp==2022.6.22 +loguru==0.7.2 \ No newline at end of file