Spaces:
Runtime error
Runtime error
# Copyright (c) 2023 Amphion. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import os | |
import numpy as np | |
import torch | |
import torchaudio | |
import argparse | |
from text.g2p_module import G2PModule | |
from utils.tokenizer import AudioTokenizer, tokenize_audio | |
from models.tts.valle.valle import VALLE | |
from models.tts.base.tts_inferece import TTSInference | |
from models.tts.valle.valle_dataset import VALLETestDataset, VALLETestCollator | |
from processors.phone_extractor import phoneExtractor | |
from text.text_token_collation import phoneIDCollation | |
class VALLEInference(TTSInference): | |
def __init__(self, args=None, cfg=None): | |
TTSInference.__init__(self, args, cfg) | |
self.g2p_module = G2PModule(backend=self.cfg.preprocess.phone_extractor) | |
text_token_path = os.path.join( | |
cfg.preprocess.processed_dir, cfg.dataset[0], cfg.preprocess.symbols_dict | |
) | |
self.audio_tokenizer = AudioTokenizer() | |
def _build_model(self): | |
model = VALLE(self.cfg.model) | |
return model | |
def _build_test_dataset(self): | |
return VALLETestDataset, VALLETestCollator | |
def inference_one_clip(self, text, text_prompt, audio_file, save_name="pred"): | |
# get phone symbol file | |
phone_symbol_file = None | |
if self.cfg.preprocess.phone_extractor != "lexicon": | |
phone_symbol_file = os.path.join( | |
self.exp_dir, self.cfg.preprocess.symbols_dict | |
) | |
assert os.path.exists(phone_symbol_file) | |
# convert text to phone sequence | |
phone_extractor = phoneExtractor(self.cfg) | |
# convert phone sequence to phone id sequence | |
phon_id_collator = phoneIDCollation( | |
self.cfg, symbols_dict_file=phone_symbol_file | |
) | |
text = f"{text_prompt} {text}".strip() | |
phone_seq = phone_extractor.extract_phone(text) # phone_seq: list | |
phone_id_seq = phon_id_collator.get_phone_id_sequence(self.cfg, phone_seq) | |
phone_id_seq_len = torch.IntTensor([len(phone_id_seq)]).to(self.device) | |
# convert phone sequence to phone id sequence | |
phone_id_seq = np.array([phone_id_seq]) | |
phone_id_seq = torch.from_numpy(phone_id_seq).to(self.device) | |
# extract acoustic token | |
encoded_frames = tokenize_audio(self.audio_tokenizer, audio_file) | |
audio_prompt_token = encoded_frames[0][0].transpose(2, 1).to(self.device) | |
# copysyn | |
if self.args.copysyn: | |
samples = self.audio_tokenizer.decode(encoded_frames) | |
audio_copysyn = samples[0].cpu().detach() | |
out_path = os.path.join( | |
self.args.output_dir, self.infer_type, f"{save_name}_copysyn.wav" | |
) | |
torchaudio.save(out_path, audio_copysyn, self.cfg.preprocess.sampling_rate) | |
if self.args.continual: | |
encoded_frames = self.model.continual( | |
phone_id_seq, | |
phone_id_seq_len, | |
audio_prompt_token, | |
) | |
else: | |
enroll_x_lens = None | |
if text_prompt: | |
# prompt_phone_seq = tokenize_text(self.g2p_module, text=f"{text_prompt}".strip()) | |
# _, enroll_x_lens = self.text_tokenizer.get_token_id_seq(prompt_phone_seq) | |
text = f"{text_prompt}".strip() | |
prompt_phone_seq = phone_extractor.extract_phone( | |
text | |
) # phone_seq: list | |
prompt_phone_id_seq = phon_id_collator.get_phone_id_sequence( | |
self.cfg, prompt_phone_seq | |
) | |
prompt_phone_id_seq_len = torch.IntTensor( | |
[len(prompt_phone_id_seq)] | |
).to(self.device) | |
encoded_frames = self.model.inference( | |
phone_id_seq, | |
phone_id_seq_len, | |
audio_prompt_token, | |
enroll_x_lens=prompt_phone_id_seq_len, | |
top_k=self.args.top_k, | |
temperature=self.args.temperature, | |
) | |
samples = self.audio_tokenizer.decode([(encoded_frames.transpose(2, 1), None)]) | |
audio = samples[0].squeeze(0).cpu().detach() | |
return audio | |
def inference_for_single_utterance(self): | |
text = self.args.text | |
text_prompt = self.args.text_prompt | |
audio_file = self.args.audio_prompt | |
if not self.args.continual: | |
assert text != "" | |
else: | |
text = "" | |
assert text_prompt != "" | |
assert audio_file != "" | |
audio = self.inference_one_clip(text, text_prompt, audio_file) | |
return audio | |
def inference_for_batches(self): | |
test_list_file = self.args.test_list_file | |
assert test_list_file is not None | |
pred_res = [] | |
with open(test_list_file, "r") as fin: | |
for idx, line in enumerate(fin.readlines()): | |
fields = line.strip().split("|") | |
if self.args.continual: | |
assert len(fields) == 2 | |
text_prompt, audio_prompt_path = fields | |
text = "" | |
else: | |
assert len(fields) == 3 | |
text_prompt, audio_prompt_path, text = fields | |
audio = self.inference_one_clip( | |
text, text_prompt, audio_prompt_path, str(idx) | |
) | |
pred_res.append(audio) | |
return pred_res | |
""" | |
TODO: batch inference | |
###### Construct test_batch ###### | |
n_batch = len(self.test_dataloader) | |
now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())) | |
print( | |
"Model eval time: {}, batch_size = {}, n_batch = {}".format( | |
now, self.test_batch_size, n_batch | |
) | |
) | |
###### Inference for each batch ###### | |
pred_res = [] | |
with torch.no_grad(): | |
for i, batch_data in enumerate( | |
self.test_dataloader if n_batch == 1 else tqdm(self.test_dataloader) | |
): | |
if self.args.continual: | |
encoded_frames = self.model.continual( | |
batch_data["phone_seq"], | |
batch_data["phone_len"], | |
batch_data["acoustic_token"], | |
) | |
else: | |
encoded_frames = self.model.inference( | |
batch_data["phone_seq"], | |
batch_data["phone_len"], | |
batch_data["acoustic_token"], | |
enroll_x_lens=batch_data["pmt_phone_len"], | |
top_k=self.args.top_k, | |
temperature=self.args.temperature | |
) | |
samples = self.audio_tokenizer.decode( | |
[(encoded_frames.transpose(2, 1), None)] | |
) | |
for idx in range(samples.size(0)): | |
audio = samples[idx].cpu() | |
pred_res.append(audio) | |
return pred_res | |
""" | |
def add_arguments(parser: argparse.ArgumentParser): | |
parser.add_argument( | |
"--text_prompt", | |
type=str, | |
default="", | |
help="Text prompt that should be aligned with --audio_prompt.", | |
) | |
parser.add_argument( | |
"--audio_prompt", | |
type=str, | |
default="", | |
help="Audio prompt that should be aligned with --text_prompt.", | |
) | |
parser.add_argument( | |
"--top-k", | |
type=int, | |
default=-100, | |
help="Whether AR Decoder do top_k(if > 0) sampling.", | |
) | |
parser.add_argument( | |
"--temperature", | |
type=float, | |
default=1.0, | |
help="The temperature of AR Decoder top_k sampling.", | |
) | |
parser.add_argument( | |
"--continual", | |
action="store_true", | |
help="Inference for continual task.", | |
) | |
parser.add_argument( | |
"--copysyn", | |
action="store_true", | |
help="Copysyn: generate audio by decoder of the original audio tokenizer.", | |
) | |