Spaces:
No application file
No application file
import argparse | |
import json | |
import os | |
from functools import partial | |
from typing import Union | |
import gradio as gr | |
import librosa | |
import numpy as np | |
import soundfile as sf | |
import torch | |
from fish_audio_preprocess.utils import loudness_norm, separate_audio | |
from loguru import logger | |
from mmengine import Config | |
from fish_diffusion.feature_extractors import FEATURE_EXTRACTORS, PITCH_EXTRACTORS | |
from fish_diffusion.utils.audio import get_mel_from_audio, slice_audio | |
from fish_diffusion.utils.inference import load_checkpoint | |
from fish_diffusion.utils.tensor import repeat_expand | |
def inference( | |
in_sample, | |
config_path, | |
checkpoint, | |
input_path, | |
output_path, | |
speaker_id=0, | |
pitch_adjust=0, | |
silence_threshold=60, | |
max_slice_duration=30.0, | |
extract_vocals=True, | |
merge_non_vocals=True, | |
vocals_loudness_gain=0.0, | |
sampler_interval=None, | |
sampler_progress=False, | |
device="cuda", | |
gradio_progress=None, | |
): | |
"""Inference | |
Args: | |
config: config | |
checkpoint: checkpoint path | |
input_path: input path | |
output_path: output path | |
speaker_id: speaker id | |
pitch_adjust: pitch adjust | |
silence_threshold: silence threshold of librosa.effects.split | |
max_slice_duration: maximum duration of each slice | |
extract_vocals: extract vocals | |
merge_non_vocals: merge non-vocals, only works when extract_vocals is True | |
vocals_loudness_gain: loudness gain of vocals (dB) | |
sampler_interval: sampler interval, lower value means higher quality | |
sampler_progress: show sampler progress | |
device: device | |
gradio_progress: gradio progress callback | |
""" | |
config = Config.fromfile(config_path) | |
if sampler_interval is not None: | |
config.model.diffusion.sampler_interval = sampler_interval | |
if os.path.isdir(checkpoint): | |
# Find the latest checkpoint | |
checkpoints = sorted(os.listdir(checkpoint)) | |
logger.info(f"Found {len(checkpoints)} checkpoints, using {checkpoints[-1]}") | |
checkpoint = os.path.join(checkpoint, checkpoints[-1]) | |
audio, sr = librosa.load(input_path, config.sampling_rate, mono=True) | |
#sr = in_sample | |
#audio = sf.read(input_path) | |
# Extract vocals | |
if extract_vocals: | |
logger.info("Extracting vocals...") | |
if gradio_progress is not None: | |
gradio_progress(0, "Extracting vocals...") | |
model = separate_audio.init_model("htdemucs", device=device) | |
audio = librosa.resample(audio, orig_sr=sr, target_sr=model.samplerate)[None] | |
# To two channels | |
audio = np.concatenate([audio, audio], axis=0) | |
audio = torch.from_numpy(audio).to(device) | |
tracks = separate_audio.separate_audio( | |
model, audio, shifts=1, num_workers=0, progress=True | |
) | |
audio = separate_audio.merge_tracks(tracks, filter=["vocals"]).cpu().numpy() | |
non_vocals = ( | |
separate_audio.merge_tracks(tracks, filter=["drums", "bass", "other"]) | |
.cpu() | |
.numpy() | |
) | |
audio = librosa.resample(audio[0], orig_sr=model.samplerate, target_sr=sr) | |
non_vocals = librosa.resample( | |
non_vocals[0], orig_sr=model.samplerate, target_sr=sr | |
) | |
# Normalize loudness | |
non_vocals = loudness_norm.loudness_norm(non_vocals, sr) | |
# Normalize loudness | |
audio = loudness_norm.loudness_norm(audio, sr) | |
# Slice into segments | |
segments = list( | |
slice_audio( | |
audio, sr, max_duration=max_slice_duration, top_db=silence_threshold | |
) | |
) | |
logger.info(f"Sliced into {len(segments)} segments") | |
# Load models | |
text_features_extractor = FEATURE_EXTRACTORS.build( | |
config.preprocessing.text_features_extractor | |
).to(device) | |
text_features_extractor.eval() | |
model = load_checkpoint(config, checkpoint, device=device) | |
pitch_extractor = PITCH_EXTRACTORS.build(config.preprocessing.pitch_extractor) | |
assert pitch_extractor is not None, "Pitch extractor not found" | |
generated_audio = np.zeros_like(audio) | |
audio_torch = torch.from_numpy(audio).to(device)[None] | |
for idx, (start, end) in enumerate(segments): | |
if gradio_progress is not None: | |
gradio_progress(idx / len(segments), "Generating audio...") | |
segment = audio_torch[:, start:end] | |
logger.info( | |
f"Processing segment {idx + 1}/{len(segments)}, duration: {segment.shape[-1] / sr:.2f}s" | |
) | |
# Extract mel | |
mel = get_mel_from_audio(segment, sr) | |
# Extract pitch (f0) | |
pitch = pitch_extractor(segment, sr, pad_to=mel.shape[-1]).float() | |
pitch *= 2 ** (pitch_adjust / 12) | |
# Extract text features | |
text_features = text_features_extractor(segment, sr)[0] | |
text_features = repeat_expand(text_features, mel.shape[-1]).T | |
# Predict | |
src_lens = torch.tensor([mel.shape[-1]]).to(device) | |
features = model.model.forward_features( | |
speakers=torch.tensor([speaker_id]).long().to(device), | |
contents=text_features[None].to(device), | |
src_lens=src_lens, | |
max_src_len=max(src_lens), | |
mel_lens=src_lens, | |
max_mel_len=max(src_lens), | |
pitches=pitch[None].to(device), | |
) | |
result = model.model.diffusion(features["features"], progress=sampler_progress) | |
wav = model.vocoder.spec2wav(result[0].T, f0=pitch).cpu().numpy() | |
max_wav_len = generated_audio.shape[-1] - start | |
generated_audio[start : start + wav.shape[-1]] = wav[:max_wav_len] | |
# Loudness normalization | |
generated_audio = loudness_norm.loudness_norm(generated_audio, sr) | |
# Loudness gain | |
loudness_float = 10 ** (vocals_loudness_gain / 20) | |
generated_audio = generated_audio * loudness_float | |
# Merge non-vocals | |
if extract_vocals and merge_non_vocals: | |
generated_audio = (generated_audio + non_vocals) / 2 | |
logger.info("Done") | |
if output_path is not None: | |
sf.write(output_path, generated_audio, sr) | |
return generated_audio, sr | |
class SvcFish: | |
def __init__(self, checkpoint_path, config_path, sampler_interval=None, extract_vocals=True, | |
merge_non_vocals=True,vocals_loudness_gain=0.0,silence_threshold=60, max_slice_duration=30.0): | |
self.config_path = config_path | |
self.checkpoint_path = checkpoint_path | |
self.sampler_interval = sampler_interval | |
self.silence_threshold = silence_threshold | |
self.max_slice_duration = max_slice_duration | |
self.extract_vocals = extract_vocals | |
self.merge_non_vocals = merge_non_vocals | |
self.vocals_loudness_gain = vocals_loudness_gain | |
def infer(self, input_path, pitch_adjust, speaker_id, in_sample): | |
return inference( | |
in_sample=in_sample, | |
config_path=self.config_path, | |
checkpoint=self.checkpoint_path, | |
input_path=input_path, | |
output_path=None, | |
speaker_id=speaker_id, | |
pitch_adjust=pitch_adjust, | |
silence_threshold=self.silence_threshold, | |
max_slice_duration=self.max_slice_duration, | |
extract_vocals=self.extract_vocals, | |
merge_non_vocals=self.merge_non_vocals, | |
vocals_loudness_gain=self.vocals_loudness_gain, | |
sampler_interval=self.sampler_interval, | |
sampler_progress=True, | |
device="cuda", | |
gradio_progress=None, | |
) | |