Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2024 Amphion. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import argparse | |
import json | |
import librosa | |
import numpy as np | |
import sys | |
import os | |
import tqdm | |
import warnings | |
import torch | |
from pydub import AudioSegment | |
from pyannote.audio import Pipeline | |
import pandas as pd | |
from utils.tool import ( | |
export_to_mp3, | |
load_cfg, | |
get_audio_files, | |
detect_gpu, | |
check_env, | |
calculate_audio_stats, | |
) | |
from utils.logger import Logger, time_logger | |
from models import separate_fast, dnsmos, whisper_asr, silero_vad | |
warnings.filterwarnings("ignore") | |
audio_count = 0 | |
def standardization(audio): | |
""" | |
Preprocess the audio file, including setting sample rate, bit depth, channels, and volume normalization. | |
Args: | |
audio (str or AudioSegment): Audio file path or AudioSegment object, the audio to be preprocessed. | |
Returns: | |
dict: A dictionary containing the preprocessed audio waveform, audio file name, and sample rate, formatted as: | |
{ | |
"waveform": np.ndarray, the preprocessed audio waveform, dtype is np.float32, shape is (num_samples,) | |
"name": str, the audio file name | |
"sample_rate": int, the audio sample rate | |
} | |
Raises: | |
ValueError: If the audio parameter is neither a str nor an AudioSegment. | |
""" | |
global audio_count | |
name = "audio" | |
if isinstance(audio, str): | |
name = os.path.basename(audio) | |
audio = AudioSegment.from_file(audio) | |
elif isinstance(audio, AudioSegment): | |
name = f"audio_{audio_count}" | |
audio_count += 1 | |
else: | |
raise ValueError("Invalid audio type") | |
logger.debug("Entering the preprocessing of audio") | |
# Convert the audio file to WAV format | |
audio = audio.set_frame_rate(cfg["entrypoint"]["SAMPLE_RATE"]) | |
audio = audio.set_sample_width(2) # Set bit depth to 16bit | |
audio = audio.set_channels(1) # Set to mono | |
logger.debug("Audio file converted to WAV format") | |
# Calculate the gain to be applied | |
target_dBFS = -20 | |
gain = target_dBFS - audio.dBFS | |
logger.info(f"Calculating the gain needed for the audio: {gain} dB") | |
# Normalize volume and limit gain range to between -3 and 3 | |
normalized_audio = audio.apply_gain(min(max(gain, -3), 3)) | |
waveform = np.array(normalized_audio.get_array_of_samples(), dtype=np.float32) | |
max_amplitude = np.max(np.abs(waveform)) | |
waveform /= max_amplitude # Normalize | |
logger.debug(f"waveform shape: {waveform.shape}") | |
logger.debug("waveform in np ndarray, dtype=" + str(waveform.dtype)) | |
return { | |
"waveform": waveform, | |
"name": name, | |
"sample_rate": cfg["entrypoint"]["SAMPLE_RATE"], | |
} | |
def source_separation(predictor, audio): | |
""" | |
Separate the audio into vocals and non-vocals using the given predictor. | |
Args: | |
predictor: The separation model predictor. | |
audio (str or dict): The audio file path or a dictionary containing audio waveform and sample rate. | |
Returns: | |
dict: A dictionary containing the separated vocals and updated audio waveform. | |
""" | |
mix, rate = None, None | |
if isinstance(audio, str): | |
mix, rate = librosa.load(audio, mono=False, sr=44100) | |
else: | |
# resample to 44100 | |
rate = audio["sample_rate"] | |
mix = librosa.resample(audio["waveform"], orig_sr=rate, target_sr=44100) | |
vocals, no_vocals = predictor.predict(mix) | |
# convert vocals back to previous sample rate | |
logger.debug(f"vocals shape before resample: {vocals.shape}") | |
vocals = librosa.resample(vocals.T, orig_sr=44100, target_sr=rate).T | |
logger.debug(f"vocals shape after resample: {vocals.shape}") | |
audio["waveform"] = vocals[:, 0] # vocals is stereo, only use one channel | |
return audio | |
# Step 2: Speaker Diarization | |
def speaker_diarization(audio): | |
""" | |
Perform speaker diarization on the given audio. | |
Args: | |
audio (dict): A dictionary containing the audio waveform and sample rate. | |
Returns: | |
pd.DataFrame: A dataframe containing segments with speaker labels. | |
""" | |
logger.debug(f"Start speaker diarization") | |
logger.debug(f"audio waveform shape: {audio['waveform'].shape}") | |
waveform = torch.tensor(audio["waveform"]).to(device) | |
waveform = torch.unsqueeze(waveform, 0) | |
segments = dia_pipeline( | |
{ | |
"waveform": waveform, | |
"sample_rate": audio["sample_rate"], | |
"channel": 0, | |
} | |
) | |
diarize_df = pd.DataFrame( | |
segments.itertracks(yield_label=True), | |
columns=["segment", "label", "speaker"], | |
) | |
diarize_df["start"] = diarize_df["segment"].apply(lambda x: x.start) | |
diarize_df["end"] = diarize_df["segment"].apply(lambda x: x.end) | |
logger.debug(f"diarize_df: {diarize_df}") | |
return diarize_df | |
def cut_by_speaker_label(vad_list): | |
""" | |
Merge and trim VAD segments by speaker labels, enforcing constraints on segment length and merge gaps. | |
Args: | |
vad_list (list): List of VAD segments with start, end, and speaker labels. | |
Returns: | |
list: A list of updated VAD segments after merging and trimming. | |
""" | |
MERGE_GAP = 2 # merge gap in seconds, if smaller than this, merge | |
MIN_SEGMENT_LENGTH = 3 # min segment length in seconds | |
MAX_SEGMENT_LENGTH = 30 # max segment length in seconds | |
updated_list = [] | |
for idx, vad in enumerate(vad_list): | |
last_start_time = updated_list[-1]["start"] if updated_list else None | |
last_end_time = updated_list[-1]["end"] if updated_list else None | |
last_speaker = updated_list[-1]["speaker"] if updated_list else None | |
if vad["end"] - vad["start"] >= MAX_SEGMENT_LENGTH: | |
current_start = vad["start"] | |
segment_end = vad["end"] | |
logger.warning( | |
f"cut_by_speaker_label > segment longer than 30s, force trimming to 30s smaller segments" | |
) | |
while segment_end - current_start >= MAX_SEGMENT_LENGTH: | |
vad["end"] = current_start + MAX_SEGMENT_LENGTH # update end time | |
updated_list.append(vad) | |
vad = vad.copy() | |
current_start += MAX_SEGMENT_LENGTH | |
vad["start"] = current_start # update start time | |
vad["end"] = segment_end | |
updated_list.append(vad) | |
continue | |
if ( | |
last_speaker is None | |
or last_speaker != vad["speaker"] | |
or vad["end"] - vad["start"] >= MIN_SEGMENT_LENGTH | |
): | |
updated_list.append(vad) | |
continue | |
if ( | |
vad["start"] - last_end_time >= MERGE_GAP | |
or vad["end"] - last_start_time >= MAX_SEGMENT_LENGTH | |
): | |
updated_list.append(vad) | |
else: | |
updated_list[-1]["end"] = vad["end"] # merge the time | |
logger.debug( | |
f"cut_by_speaker_label > merged {len(vad_list) - len(updated_list)} segments" | |
) | |
filter_list = [ | |
vad for vad in updated_list if vad["end"] - vad["start"] >= MIN_SEGMENT_LENGTH | |
] | |
logger.debug( | |
f"cut_by_speaker_label > removed: {len(updated_list) - len(filter_list)} segments by length" | |
) | |
return filter_list | |
def asr(vad_segments, audio): | |
""" | |
Perform Automatic Speech Recognition (ASR) on the VAD segments of the given audio. | |
Args: | |
vad_segments (list): List of VAD segments with start and end times. | |
audio (dict): A dictionary containing the audio waveform and sample rate. | |
Returns: | |
list: A list of ASR results with transcriptions and language details. | |
""" | |
if len(vad_segments) == 0: | |
return [] | |
temp_audio = audio["waveform"] | |
start_time = vad_segments[0]["start"] | |
end_time = vad_segments[-1]["end"] | |
start_frame = int(start_time * audio["sample_rate"]) | |
end_frame = int(end_time * audio["sample_rate"]) | |
temp_audio = temp_audio[start_frame:end_frame] # remove silent start and end | |
# update vad_segments start and end time (this is a little trick for batched asr:) | |
for idx, segment in enumerate(vad_segments): | |
vad_segments[idx]["start"] -= start_time | |
vad_segments[idx]["end"] -= start_time | |
# resample to 16k | |
temp_audio = librosa.resample( | |
temp_audio, orig_sr=audio["sample_rate"], target_sr=16000 | |
) | |
if multilingual_flag: | |
logger.debug("Multilingual flag is on") | |
valid_vad_segments, valid_vad_segments_language = [], [] | |
# get valid segments to be transcripted | |
for idx, segment in enumerate(vad_segments): | |
start_frame = int(segment["start"] * 16000) | |
end_frame = int(segment["end"] * 16000) | |
segment_audio = temp_audio[start_frame:end_frame] | |
language, prob = asr_model.detect_language(segment_audio) | |
# 1. if language is in supported list, 2. if prob > 0.8 | |
if language in supported_languages and prob > 0.8: | |
valid_vad_segments.append(vad_segments[idx]) | |
valid_vad_segments_language.append(language) | |
# if no valid segment, return empty | |
if len(valid_vad_segments) == 0: | |
return [] | |
all_transcribe_result = [] | |
logger.debug(f"valid_vad_segments_language: {valid_vad_segments_language}") | |
unique_languages = list(set(valid_vad_segments_language)) | |
logger.debug(f"unique_languages: {unique_languages}") | |
# process each language one by one | |
for language_token in unique_languages: | |
language = language_token | |
# filter out segments with different language | |
vad_segments = [ | |
valid_vad_segments[i] | |
for i, x in enumerate(valid_vad_segments_language) | |
if x == language | |
] | |
# bacthed trascription | |
transcribe_result_temp = asr_model.transcribe( | |
temp_audio, | |
vad_segments, | |
batch_size=batch_size, | |
language=language, | |
print_progress=True, | |
) | |
result = transcribe_result_temp["segments"] | |
# restore the segment annotation | |
for idx, segment in enumerate(result): | |
result[idx]["start"] += start_time | |
result[idx]["end"] += start_time | |
result[idx]["language"] = transcribe_result_temp["language"] | |
all_transcribe_result.extend(result) | |
# sort by start time | |
all_transcribe_result = sorted(all_transcribe_result, key=lambda x: x["start"]) | |
return all_transcribe_result | |
else: | |
logger.debug("Multilingual flag is off") | |
language, prob = asr_model.detect_language(temp_audio) | |
if language in supported_languages and prob > 0.8: | |
transcribe_result = asr_model.transcribe( | |
temp_audio, | |
vad_segments, | |
batch_size=batch_size, | |
language=language, | |
print_progress=True, | |
) | |
result = transcribe_result["segments"] | |
for idx, segment in enumerate(result): | |
result[idx]["start"] += start_time | |
result[idx]["end"] += start_time | |
result[idx]["language"] = transcribe_result["language"] | |
return result | |
else: | |
return [] | |
def mos_prediction(audio, vad_list): | |
""" | |
Predict the Mean Opinion Score (MOS) for the given audio and VAD segments. | |
Args: | |
audio (dict): A dictionary containing the audio waveform and sample rate. | |
vad_list (list): List of VAD segments with start and end times. | |
Returns: | |
tuple: A tuple containing the average MOS and the updated VAD segments with MOS scores. | |
""" | |
audio = audio["waveform"] | |
sample_rate = 16000 | |
audio = librosa.resample( | |
audio, orig_sr=cfg["entrypoint"]["SAMPLE_RATE"], target_sr=sample_rate | |
) | |
for index, vad in enumerate(tqdm.tqdm(vad_list, desc="DNSMOS")): | |
start, end = int(vad["start"] * sample_rate), int(vad["end"] * sample_rate) | |
segment = audio[start:end] | |
dnsmos = dnsmos_compute_score(segment, sample_rate, False)["OVRL"] | |
vad_list[index]["dnsmos"] = dnsmos | |
predict_dnsmos = np.mean([vad["dnsmos"] for vad in vad_list]) | |
logger.debug(f"avg predict_dnsmos for whole audio: {predict_dnsmos}") | |
return predict_dnsmos, vad_list | |
def filter(mos_list): | |
""" | |
Filter out the segments with MOS scores, wrong char duration, and total duration. | |
Args: | |
mos_list (list): List of VAD segments with MOS scores. | |
Returns: | |
list: A list of VAD segments with MOS scores above the average MOS. | |
""" | |
filtered_audio_stats, all_audio_stats = calculate_audio_stats(mos_list) | |
filtered_segment = len(filtered_audio_stats) | |
all_segment = len(all_audio_stats) | |
logger.debug( | |
f"> {all_segment - filtered_segment}/{all_segment} {(all_segment - filtered_segment) / all_segment:.2%} segments filtered." | |
) | |
filtered_list = [mos_list[idx] for idx, _ in filtered_audio_stats] | |
return filtered_list | |
def main_process(audio_path, save_path=None, audio_name=None): | |
""" | |
Process the audio file, including standardization, source separation, speaker segmentation, VAD, ASR, export to MP3, and MOS prediction. | |
Args: | |
audio_path (str): Audio file path. | |
save_path (str, optional): Save path, defaults to None, which means saving in the "_processed" folder in the audio file's directory. | |
audio_name (str, optional): Audio file name, defaults to None, which means using the file name from the audio file path. | |
Returns: | |
tuple: Contains the save path and the MOS list. | |
""" | |
if not audio_path.endswith((".mp3", ".wav", ".flac", ".m4a", ".aac")): | |
logger.warning(f"Unsupported file type: {audio_path}") | |
# for a single audio from path Ïaaa/bbb/ccc.wav ---> save to aaa/bbb_processed/ccc/ccc_0.wav | |
audio_name = audio_name or os.path.splitext(os.path.basename(audio_path))[0] | |
save_path = save_path or os.path.join( | |
os.path.dirname(audio_path) + "_processed", audio_name | |
) | |
os.makedirs(save_path, exist_ok=True) | |
logger.debug( | |
f"Processing audio: {audio_name}, from {audio_path}, save to: {save_path}" | |
) | |
logger.info( | |
"Step 0: Preprocess all audio files --> 24k sample rate + wave format + loudnorm + bit depth 16" | |
) | |
audio = standardization(audio_path) | |
logger.info("Step 1: Source Separation") | |
audio = source_separation(separate_predictor1, audio) | |
logger.info("Step 2: Speaker Diarization") | |
speakerdia = speaker_diarization(audio) | |
logger.info("Step 3: Fine-grained Segmentation by VAD") | |
vad_list = vad.vad(speakerdia, audio) | |
segment_list = cut_by_speaker_label(vad_list) # post process after vad | |
logger.info("Step 4: ASR") | |
asr_result = asr(segment_list, audio) | |
logger.info("Step 5: Filter") | |
logger.info("Step 5.1: calculate mos_prediction") | |
avg_mos, mos_list = mos_prediction(audio, asr_result) | |
logger.info(f"Step 5.1: done, average MOS: {avg_mos}") | |
logger.info("Step 5.2: Filter out files with less than average MOS") | |
filtered_list = filter(mos_list) | |
logger.info("Step 6: write result into MP3 and JSON file") | |
export_to_mp3(audio, filtered_list, save_path, audio_name) | |
final_path = os.path.join(save_path, audio_name + ".json") | |
with open(final_path, "w") as f: | |
json.dump(filtered_list, f, ensure_ascii=False) | |
logger.info(f"All done, Saved to: {final_path}") | |
return final_path, filtered_list | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--input_folder_path", | |
type=str, | |
default="", | |
help="input folder path, this will override config if set", | |
) | |
parser.add_argument( | |
"--config_path", type=str, default="config.json", help="config path" | |
) | |
parser.add_argument("--batch_size", type=int, default=16, help="batch size") | |
parser.add_argument( | |
"--compute_type", | |
type=str, | |
default="float16", | |
help="The compute type to use for the model", | |
) | |
parser.add_argument( | |
"--whisper_arch", | |
type=str, | |
default="medium", | |
help="The name of the Whisper model to load.", | |
) | |
parser.add_argument( | |
"--threads", | |
type=int, | |
default=4, | |
help="The number of CPU threads to use per worker, e.g. will be multiplied by num workers.", | |
) | |
parser.add_argument( | |
"--exit_pipeline", | |
type=bool, | |
default=False, | |
help="Exit pipeline when task done.", | |
) | |
args = parser.parse_args() | |
batch_size = args.batch_size | |
cfg = load_cfg(args.config_path) | |
logger = Logger.get_logger() | |
if args.input_folder_path: | |
logger.info(f"Using input folder path: {args.input_folder_path}") | |
cfg["entrypoint"]["input_folder_path"] = args.input_folder_path | |
logger.debug("Loading models...") | |
# Load models | |
if detect_gpu(): | |
logger.info("Using GPU") | |
device_name = "cuda" | |
device = torch.device(device_name) | |
else: | |
logger.info("Using CPU") | |
device_name = "cpu" | |
device = torch.device(device_name) | |
check_env(logger) | |
# Speaker Diarization | |
logger.debug(" * Loading Speaker Diarization Model") | |
if not cfg["huggingface_token"].startswith("hf"): | |
raise ValueError( | |
"huggingface_token must start with 'hf', check the config file. " | |
"You can get the token at https://huggingface.co/settings/tokens. " | |
"Remeber grant access following https://github.com/pyannote/pyannote-audio?tab=readme-ov-file#tldr" | |
) | |
dia_pipeline = Pipeline.from_pretrained( | |
"pyannote/speaker-diarization-3.1", | |
use_auth_token=cfg["huggingface_token"], | |
) | |
dia_pipeline.to(device) | |
# ASR | |
logger.debug(" * Loading ASR Model") | |
asr_model = whisper_asr.load_asr_model( | |
args.whisper_arch, | |
device_name, | |
compute_type=args.compute_type, | |
threads=args.threads, | |
asr_options={ | |
"initial_prompt": "Um, Uh, Ah. Like, you know. I mean, right. Actually. Basically, and right? okay. Alright. Emm. So. Oh. 生于忧患,死于安乐。岂不快哉?当然,嗯,呃,就,这样,那个,哪个,啊,呀,哎呀,哎哟,唉哇,啧,唷,哟,噫!微斯人,吾谁与归?ええと、あの、ま、そう、ええ。äh, hm, so, tja, halt, eigentlich. euh, quoi, bah, ben, tu vois, tu sais, t'sais, eh bien, du coup. genre, comme, style. 응,어,그,음." | |
}, | |
) | |
# VAD | |
logger.debug(" * Loading VAD Model") | |
vad = silero_vad.SileroVAD(device=device) | |
# Background Noise Separation | |
logger.debug(" * Loading Background Noise Model") | |
separate_predictor1 = separate_fast.Predictor( | |
args=cfg["separate"]["step1"], device=device_name | |
) | |
# DNSMOS Scoring | |
logger.debug(" * Loading DNSMOS Model") | |
primary_model_path = cfg["mos_model"]["primary_model_path"] | |
dnsmos_compute_score = dnsmos.ComputeScore(primary_model_path, device_name) | |
logger.debug("All models loaded") | |
supported_languages = cfg["language"]["supported"] | |
multilingual_flag = cfg["language"]["multilingual"] | |
logger.debug(f"supported languages multilingual {supported_languages}") | |
logger.debug(f"using multilingual asr {multilingual_flag}") | |
input_folder_path = cfg["entrypoint"]["input_folder_path"] | |
if not os.path.exists(input_folder_path): | |
raise FileNotFoundError(f"input_folder_path: {input_folder_path} not found") | |
audio_paths = get_audio_files(input_folder_path) # Get all audio files | |
logger.debug(f"Scanning {len(audio_paths)} audio files in {input_folder_path}") | |
for path in audio_paths: | |
main_process(path) | |