Spaces:
Running
Running
import os | |
import time | |
import random | |
from tqdm import tqdm | |
import argparse | |
import torch | |
import torchaudio | |
from accelerate import Accelerator | |
from einops import rearrange | |
from ema_pytorch import EMA | |
from vocos import Vocos | |
from model import CFM, UNetT, DiT | |
from model.utils import ( | |
get_tokenizer, | |
get_seedtts_testset_metainfo, | |
get_librispeech_test_clean_metainfo, | |
get_inference_prompt, | |
) | |
accelerator = Accelerator() | |
device = f"cuda:{accelerator.process_index}" | |
# --------------------- Dataset Settings -------------------- # | |
target_sample_rate = 24000 | |
n_mel_channels = 100 | |
hop_length = 256 | |
target_rms = 0.1 | |
tokenizer = "pinyin" | |
# ---------------------- infer setting ---------------------- # | |
parser = argparse.ArgumentParser(description="batch inference") | |
parser.add_argument('-s', '--seed', default=None, type=int) | |
parser.add_argument('-d', '--dataset', default="Emilia_ZH_EN") | |
parser.add_argument('-n', '--expname', required=True) | |
parser.add_argument('-c', '--ckptstep', default=1200000, type=int) | |
parser.add_argument('-nfe', '--nfestep', default=32, type=int) | |
parser.add_argument('-o', '--odemethod', default="euler") | |
parser.add_argument('-ss', '--swaysampling', default=-1, type=float) | |
parser.add_argument('-t', '--testset', required=True) | |
args = parser.parse_args() | |
seed = args.seed | |
dataset_name = args.dataset | |
exp_name = args.expname | |
ckpt_step = args.ckptstep | |
checkpoint = torch.load(f"ckpts/{exp_name}/model_{ckpt_step}.pt", map_location=device) | |
nfe_step = args.nfestep | |
ode_method = args.odemethod | |
sway_sampling_coef = args.swaysampling | |
testset = args.testset | |
infer_batch_size = 1 # max frames. 1 for ddp single inference (recommended) | |
cfg_strength = 2. | |
speed = 1. | |
use_truth_duration = False | |
no_ref_audio = False | |
if exp_name == "F5TTS_Base": | |
model_cls = DiT | |
model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4) | |
elif exp_name == "E2TTS_Base": | |
model_cls = UNetT | |
model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4) | |
if testset == "ls_pc_test_clean": | |
metalst = "data/librispeech_pc_test_clean_cross_sentence.lst" | |
librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path | |
metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path) | |
elif testset == "seedtts_test_zh": | |
metalst = "data/seedtts_testset/zh/meta.lst" | |
metainfo = get_seedtts_testset_metainfo(metalst) | |
elif testset == "seedtts_test_en": | |
metalst = "data/seedtts_testset/en/meta.lst" | |
metainfo = get_seedtts_testset_metainfo(metalst) | |
# path to save genereted wavs | |
if seed is None: seed = random.randint(-10000, 10000) | |
output_dir = f"results/{exp_name}_{ckpt_step}/{testset}/" \ | |
f"seed{seed}_{ode_method}_nfe{nfe_step}" \ | |
f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}" \ | |
f"_cfg{cfg_strength}_speed{speed}" \ | |
f"{'_gt-dur' if use_truth_duration else ''}" \ | |
f"{'_no-ref-audio' if no_ref_audio else ''}" | |
# -------------------------------------------------# | |
use_ema = True | |
prompts_all = get_inference_prompt( | |
metainfo, | |
speed = speed, | |
tokenizer = tokenizer, | |
target_sample_rate = target_sample_rate, | |
n_mel_channels = n_mel_channels, | |
hop_length = hop_length, | |
target_rms = target_rms, | |
use_truth_duration = use_truth_duration, | |
infer_batch_size = infer_batch_size, | |
) | |
# Vocoder model | |
local = False | |
if local: | |
vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz" | |
vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml") | |
state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device) | |
vocos.load_state_dict(state_dict) | |
vocos.eval() | |
else: | |
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz") | |
# Tokenizer | |
vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer) | |
# Model | |
model = CFM( | |
transformer = model_cls( | |
**model_cfg, | |
text_num_embeds = vocab_size, | |
mel_dim = n_mel_channels | |
), | |
mel_spec_kwargs = dict( | |
target_sample_rate = target_sample_rate, | |
n_mel_channels = n_mel_channels, | |
hop_length = hop_length, | |
), | |
odeint_kwargs = dict( | |
method = ode_method, | |
), | |
vocab_char_map = vocab_char_map, | |
).to(device) | |
if use_ema == True: | |
ema_model = EMA(model, include_online_model = False).to(device) | |
ema_model.load_state_dict(checkpoint['ema_model_state_dict']) | |
ema_model.copy_params_from_ema_to_model() | |
else: | |
model.load_state_dict(checkpoint['model_state_dict']) | |
if not os.path.exists(output_dir) and accelerator.is_main_process: | |
os.makedirs(output_dir) | |
# start batch inference | |
accelerator.wait_for_everyone() | |
start = time.time() | |
with accelerator.split_between_processes(prompts_all) as prompts: | |
for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process): | |
utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = prompt | |
ref_mels = ref_mels.to(device) | |
ref_mel_lens = torch.tensor(ref_mel_lens, dtype = torch.long).to(device) | |
total_mel_lens = torch.tensor(total_mel_lens, dtype = torch.long).to(device) | |
# Inference | |
with torch.inference_mode(): | |
generated, _ = model.sample( | |
cond = ref_mels, | |
text = final_text_list, | |
duration = total_mel_lens, | |
lens = ref_mel_lens, | |
steps = nfe_step, | |
cfg_strength = cfg_strength, | |
sway_sampling_coef = sway_sampling_coef, | |
no_ref_audio = no_ref_audio, | |
seed = seed, | |
) | |
# Final result | |
for i, gen in enumerate(generated): | |
gen = gen[ref_mel_lens[i]:total_mel_lens[i], :].unsqueeze(0) | |
gen_mel_spec = rearrange(gen, '1 n d -> 1 d n') | |
generated_wave = vocos.decode(gen_mel_spec.cpu()) | |
if ref_rms_list[i] < target_rms: | |
generated_wave = generated_wave * ref_rms_list[i] / target_rms | |
torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave, target_sample_rate) | |
accelerator.wait_for_everyone() | |
if accelerator.is_main_process: | |
timediff = time.time() - start | |
print(f"Done batch inference in {timediff / 60 :.2f} minutes.") | |