maskgct / models /tts /valle_v2 /valle_inference.py
Hecheng0625's picture
Upload 409 files
c968fc3 verified
raw
history blame
6.32 kB
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torchaudio
class ValleInference(torch.nn.Module):
def __init__(
self,
use_vocos=False,
use_speechtokenizer=True,
ar_path=None,
nar_path=None,
speechtokenizer_path=None,
device="cuda",
):
super().__init__()
self.device = device
# prepare pretrained VALLE AR model
from .valle_ar import ValleAR
self.ar_model = ValleAR(
phone_vocab_size=300,
target_vocab_size=1024,
pad_token_id=1324,
bos_target_id=1325,
eos_target_id=1326,
bos_phone_id=1327,
eos_phone_id=1328,
bos_prompt_id=1329,
eos_prompt_id=1330,
num_hidden_layers=16,
)
# change the following path to your trained model path
assert ar_path is not None
self.ar_model.load_state_dict(torch.load(ar_path, map_location="cpu"))
self.ar_model.eval().to(self.device)
# prepare pretrained VALLE NAR model
from .valle_nar import ValleNAR
self.nar_model = ValleNAR(
phone_vocab_size=300,
target_vocab_size=1024,
pad_token_id=1324,
bos_target_id=1325,
eos_target_id=1326,
bos_phone_id=1327,
eos_phone_id=1328,
bos_prompt_id=1329,
eos_prompt_id=1330,
num_hidden_layers=16,
)
assert nar_path is not None
self.nar_model.load_state_dict(torch.load(nar_path, map_location="cpu"))
self.nar_model.eval().to(self.device)
# prepare codec encoder
assert not (
use_speechtokenizer and use_vocos
), "Only one of use_speechtokenizer and use_vocos can be True"
self.use_speechtokenizer = use_speechtokenizer
if use_speechtokenizer:
from models.codec.speechtokenizer.model import SpeechTokenizer
# download from https://huggingface.co/fnlp/SpeechTokenizer/tree/main/speechtokenizer_hubert_avg
config_path = speechtokenizer_path + "/config.json"
ckpt_path = speechtokenizer_path + "/SpeechTokenizer.pt"
self.codec_encoder = SpeechTokenizer.load_from_checkpoint(
config_path, ckpt_path
)
self.codec_encoder.eval()
self.codec_encoder.to(device)
print(f"Loaded SpeechTokenizer from {config_path} and {ckpt_path}")
else:
# use Encodec
from encodec import EncodecModel
self.codec_encoder = EncodecModel.encodec_model_24khz()
self.codec_encoder.set_target_bandwidth(6.0)
self.codec_encoder.to(self.device)
if use_vocos:
from vocos import Vocos
self.codec_decoder = Vocos.from_pretrained(
"charactr/vocos-encodec-24khz"
)
self.codec_decoder.to(self.device)
print("Loaded Vocos")
print("Loaded EncodecModel")
self.use_vocos = use_vocos
def decode(self, vq_ids):
"""vq_ids.shape: [8, B, T],
returns: [B, 1, T]"""
if self.use_speechtokenizer:
# infer speechtokenizer
return self.codec_encoder.decode(vq_ids) # [B, 1, T]
else:
if not self.use_vocos:
# vocos decoder
return self.codec_encoder.decode([(vq_ids.transpose(0, 1), None)])
else:
# encodec decoder
features = self.codec_decoder.codes_to_features(vq_ids.squeeze(1))
bandwidth_id = torch.tensor([2], device=vq_ids.device)
return self.codec_decoder.decode(
features, bandwidth_id=bandwidth_id
).unsqueeze(0)
def forward(self, batch, chunk_configs: list, return_prompt=False, prompt_len=None):
"""batch: dict(
speech: [B, T]
phone_ids: [B, T]
)
returns: [B, 1, T] audio
"""
if prompt_len is None:
prompt_len = 100000 # no prompt length limiting
for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch[k] = v.to(self.device)
with torch.no_grad():
if self.use_speechtokenizer:
vq_id = self.codec_encoder.encode(
batch["speech"].unsqueeze(1)
) # [B,1,T] -> (n_q, B, T)
else:
vq_id = self.codec_encoder.encode(batch["speech"].unsqueeze(1))
vq_id = torch.cat([encoded[0] for encoded in vq_id], dim=-1).transpose(
0, 1
)
# typically we only require one config in the chunk,
# but we can also use multiple configs to, for example, use different sampling temperature at different positions
for chunk in chunk_configs:
ar_vq_ids = self.ar_model.sample_hf(
batch["phone_ids"],
vq_id[0, :, :prompt_len],
top_p=chunk["top_p"],
top_k=chunk["top_k"],
temperature=chunk["temperature"],
num_beams=chunk["num_beams"],
repeat_penalty=chunk["repeat_penalty"],
max_length=chunk["max_length"],
)
# recovered_audio_ar = self.decode(ar_vq_ids.unsqueeze(0))
# torchaudio.save('recovered_audio_ar.wav', recovered_audio_ar[0].cpu(), 24000)
nar_vq_ids = self.nar_model.sample_hf(
phone_ids=batch["phone_ids"],
prompt_ids=vq_id[:, :, :prompt_len],
first_stage_ids=ar_vq_ids,
# first_stage_ids=vq_id[0, :, prompt_len:],
)
if return_prompt:
nar_vq_ids = torch.cat(
[vq_id[..., :prompt_len], nar_vq_ids], dim=-1
)
recovered_audio = self.decode(nar_vq_ids)
return recovered_audio # [B, 1, T]