Spaces:
No application file
No application file
import argparse | |
import json | |
import os | |
from functools import partial | |
from typing import Union | |
import gradio as gr | |
import librosa | |
import numpy as np | |
import soundfile as sf | |
import torch | |
from fish_audio_preprocess.utils import loudness_norm, separate_audio | |
from loguru import logger | |
from mmengine import Config | |
from fish_diffusion.feature_extractors import FEATURE_EXTRACTORS, PITCH_EXTRACTORS | |
from fish_diffusion.utils.audio import get_mel_from_audio, slice_audio | |
from fish_diffusion.utils.inference import load_checkpoint | |
from fish_diffusion.utils.tensor import repeat_expand | |
def inference( | |
config, | |
checkpoint, | |
input_path, | |
output_path, | |
speaker_id=0, | |
pitch_adjust=0, | |
silence_threshold=30, | |
max_slice_duration=5, | |
extract_vocals=True, | |
merge_non_vocals=True, | |
vocals_loudness_gain=0.0, | |
sampler_interval=None, | |
sampler_progress=False, | |
device="cuda", | |
gradio_progress=None, | |
): | |
"""Inference | |
Args: | |
config: config | |
checkpoint: checkpoint path | |
input_path: input path | |
output_path: output path | |
speaker_id: speaker id | |
pitch_adjust: pitch adjust | |
silence_threshold: silence threshold of librosa.effects.split | |
max_slice_duration: maximum duration of each slice | |
extract_vocals: extract vocals | |
merge_non_vocals: merge non-vocals, only works when extract_vocals is True | |
vocals_loudness_gain: loudness gain of vocals (dB) | |
sampler_interval: sampler interval, lower value means higher quality | |
sampler_progress: show sampler progress | |
device: device | |
gradio_progress: gradio progress callback | |
""" | |
if sampler_interval is not None: | |
config.model.diffusion.sampler_interval = sampler_interval | |
if os.path.isdir(checkpoint): | |
# Find the latest checkpoint | |
checkpoints = sorted(os.listdir(checkpoint)) | |
logger.info(f"Found {len(checkpoints)} checkpoints, using {checkpoints[-1]}") | |
checkpoint = os.path.join(checkpoint, checkpoints[-1]) | |
audio, sr = librosa.load(input_path, sr=config.sampling_rate, mono=True) | |
# Extract vocals | |
if extract_vocals: | |
logger.info("Extracting vocals...") | |
if gradio_progress is not None: | |
gradio_progress(0, "Extracting vocals...") | |
model = separate_audio.init_model("htdemucs", device=device) | |
audio = librosa.resample(audio, orig_sr=sr, target_sr=model.samplerate)[None] | |
# To two channels | |
audio = np.concatenate([audio, audio], axis=0) | |
audio = torch.from_numpy(audio).to(device) | |
tracks = separate_audio.separate_audio( | |
model, audio, shifts=1, num_workers=0, progress=True | |
) | |
audio = separate_audio.merge_tracks(tracks, filter=["vocals"]).cpu().numpy() | |
non_vocals = ( | |
separate_audio.merge_tracks(tracks, filter=["drums", "bass", "other"]) | |
.cpu() | |
.numpy() | |
) | |
audio = librosa.resample(audio[0], orig_sr=model.samplerate, target_sr=sr) | |
non_vocals = librosa.resample( | |
non_vocals[0], orig_sr=model.samplerate, target_sr=sr | |
) | |
# Normalize loudness | |
non_vocals = loudness_norm.loudness_norm(non_vocals, sr) | |
# Normalize loudness | |
audio = loudness_norm.loudness_norm(audio, sr) | |
# Slice into segments | |
segments = list( | |
slice_audio( | |
audio, sr, max_duration=max_slice_duration, top_db=silence_threshold | |
) | |
) | |
logger.info(f"Sliced into {len(segments)} segments") | |
# Load models | |
text_features_extractor = FEATURE_EXTRACTORS.build( | |
config.preprocessing.text_features_extractor | |
).to(device) | |
text_features_extractor.eval() | |
model = load_checkpoint(config, checkpoint, device=device) | |
pitch_extractor = PITCH_EXTRACTORS.build(config.preprocessing.pitch_extractor) | |
assert pitch_extractor is not None, "Pitch extractor not found" | |
generated_audio = np.zeros_like(audio) | |
audio_torch = torch.from_numpy(audio).to(device)[None] | |
for idx, (start, end) in enumerate(segments): | |
if gradio_progress is not None: | |
gradio_progress(idx / len(segments), "Generating audio...") | |
segment = audio_torch[:, start:end] | |
logger.info( | |
f"Processing segment {idx + 1}/{len(segments)}, duration: {segment.shape[-1] / sr:.2f}s" | |
) | |
# Extract mel | |
mel = get_mel_from_audio(segment, sr) | |
# Extract pitch (f0) | |
pitch = pitch_extractor(segment, sr, pad_to=mel.shape[-1]).float() | |
pitch *= 2 ** (pitch_adjust / 12) | |
# Extract text features | |
text_features = text_features_extractor(segment, sr)[0] | |
text_features = repeat_expand(text_features, mel.shape[-1]).T | |
# Predict | |
src_lens = torch.tensor([mel.shape[-1]]).to(device) | |
features = model.model.forward_features( | |
speakers=torch.tensor([speaker_id]).long().to(device), | |
contents=text_features[None].to(device), | |
src_lens=src_lens, | |
max_src_len=max(src_lens), | |
mel_lens=src_lens, | |
max_mel_len=max(src_lens), | |
pitches=pitch[None].to(device), | |
) | |
result = model.model.diffusion(features["features"], progress=sampler_progress) | |
wav = model.vocoder.spec2wav(result[0].T, f0=pitch).cpu().numpy() | |
max_wav_len = generated_audio.shape[-1] - start | |
generated_audio[start : start + wav.shape[-1]] = wav[:max_wav_len] | |
# Loudness normalization | |
generated_audio = loudness_norm.loudness_norm(generated_audio, sr) | |
# Loudness gain | |
loudness_float = 10 ** (vocals_loudness_gain / 20) | |
generated_audio = generated_audio * loudness_float | |
# Merge non-vocals | |
if extract_vocals and merge_non_vocals: | |
generated_audio = (generated_audio + non_vocals) / 2 | |
logger.info("Done") | |
if output_path is not None: | |
sf.write(output_path, generated_audio, sr) | |
return generated_audio, sr | |
def parse_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--config", | |
type=str, | |
required=True, | |
help="Path to the config file", | |
) | |
parser.add_argument( | |
"--checkpoint", | |
type=str, | |
required=True, | |
help="Path to the checkpoint file", | |
) | |
parser.add_argument( | |
"--gradio", | |
action="store_true", | |
help="Run in gradio mode", | |
) | |
parser.add_argument( | |
"--gradio_share", | |
action="store_true", | |
help="Share gradio app", | |
) | |
parser.add_argument( | |
"--input", | |
type=str, | |
required=False, | |
help="Path to the input audio file", | |
) | |
parser.add_argument( | |
"--output", | |
type=str, | |
required=False, | |
help="Path to the output audio file", | |
) | |
parser.add_argument( | |
"--speaker_id", | |
type=int, | |
default=0, | |
help="Speaker id", | |
) | |
parser.add_argument( | |
"--speaker_mapping", | |
type=str, | |
default=None, | |
help="Speaker mapping file (gradio mode only)", | |
) | |
parser.add_argument( | |
"--pitch_adjust", | |
type=int, | |
default=0, | |
help="Pitch adjustment in semitones", | |
) | |
parser.add_argument( | |
"--extract_vocals", | |
action="store_true", | |
help="Extract vocals", | |
) | |
parser.add_argument( | |
"--merge_non_vocals", | |
action="store_true", | |
help="Merge non-vocals", | |
) | |
parser.add_argument( | |
"--vocals_loudness_gain", | |
type=float, | |
default=0, | |
help="Loudness gain for vocals", | |
) | |
parser.add_argument( | |
"--sampler_interval", | |
type=int, | |
default=None, | |
required=False, | |
help="Sampler interval, if not specified, will be taken from config", | |
) | |
parser.add_argument( | |
"--sampler_progress", | |
action="store_true", | |
help="Show sampler progress", | |
) | |
parser.add_argument( | |
"--device", | |
type=str, | |
default=None, | |
required=False, | |
help="Device to use", | |
) | |
return parser.parse_args() | |
def run_inference( | |
config_path: str, | |
model_path: str, | |
input_path: str, | |
speaker: Union[int, str], | |
pitch_adjust: int, | |
sampler_interval: int, | |
extract_vocals: bool, | |
device: str, | |
progress=gr.Progress(), | |
speaker_mapping: dict = None, | |
): | |
if speaker_mapping is not None and isinstance(speaker, str): | |
speaker = speaker_mapping[speaker] | |
audio, sr = inference( | |
Config.fromfile(config_path), | |
model_path, | |
input_path=input_path, | |
output_path=None, | |
speaker_id=speaker, | |
pitch_adjust=pitch_adjust, | |
sampler_interval=round(sampler_interval), | |
extract_vocals=extract_vocals, | |
merge_non_vocals=False, | |
device=device, | |
gradio_progress=progress, | |
) | |
return (sr, audio) | |
def launch_gradio(args): | |
with gr.Blocks(title="Fish Diffusion") as app: | |
gr.Markdown("# Fish Diffusion SVC Inference") | |
with gr.Row(): | |
with gr.Column(): | |
input_audio = gr.Audio( | |
label="Input Audio", | |
type="filepath", | |
value=args.input, | |
) | |
output_audio = gr.Audio(label="Output Audio") | |
with gr.Column(): | |
if args.speaker_mapping is not None: | |
speaker_mapping = json.load(open(args.speaker_mapping)) | |
speaker = gr.Dropdown( | |
label="Speaker Name (Used for Multi-Speaker Models)", | |
choices=list(speaker_mapping.keys()), | |
value=list(speaker_mapping.keys())[0], | |
) | |
else: | |
speaker_mapping = None | |
speaker = gr.Number( | |
label="Speaker ID (Used for Multi-Speaker Models)", | |
value=args.speaker_id, | |
) | |
pitch_adjust = gr.Number( | |
label="Pitch Adjust (Semitones)", value=args.pitch_adjust | |
) | |
sampler_interval = gr.Slider( | |
label="Sampler Interval (⬆️ Faster Generation, ⬇️ Better Quality)", | |
value=args.sampler_interval or 10, | |
minimum=1, | |
maximum=100, | |
) | |
extract_vocals = gr.Checkbox( | |
label="Extract Vocals (For low quality audio)", | |
value=args.extract_vocals, | |
) | |
device = gr.Radio( | |
label="Device", choices=["cuda", "cpu"], value=args.device or "cuda" | |
) | |
run_btn = gr.Button(label="Run") | |
run_btn.click( | |
partial( | |
run_inference, | |
args.config, | |
args.checkpoint, | |
speaker_mapping=speaker_mapping, | |
), | |
[ | |
input_audio, | |
speaker, | |
pitch_adjust, | |
sampler_interval, | |
extract_vocals, | |
device, | |
], | |
output_audio, | |
) | |
app.queue(concurrency_count=2).launch(share=args.gradio_share) | |
if __name__ == "__main__": | |
args = parse_args() | |
assert args.gradio or ( | |
args.input is not None and args.output is not None | |
), "Either --gradio or --input and --output should be specified" | |
if args.device is None: | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
else: | |
device = torch.device(args.device) | |
if args.gradio: | |
args.device = device | |
launch_gradio(args) | |
else: | |
inference( | |
Config.fromfile(args.config), | |
args.checkpoint, | |
args.input, | |
args.output, | |
speaker_id=args.speaker_id, | |
pitch_adjust=args.pitch_adjust, | |
extract_vocals=args.extract_vocals, | |
merge_non_vocals=args.merge_non_vocals, | |
vocals_loudness_gain=args.vocals_loudness_gain, | |
sampler_interval=args.sampler_interval, | |
sampler_progress=args.sampler_progress, | |
device=device, | |
) | |