|
import torchaudio |
|
import torch |
|
|
|
from transformers import ( |
|
AutoTokenizer, |
|
AutoModelForCausalLM, |
|
) |
|
|
|
from speechtokenizer import SpeechTokenizer |
|
from audiotools import AudioSignal |
|
|
|
|
|
def decode_tts(tokens, quantizer, n_codebooks, n_original_tokens, start_audio_token_id, end_audio_token_id): |
|
|
|
start = torch.nonzero(tokens == start_audio_token_id) |
|
end = torch.nonzero(tokens == end_audio_token_id) |
|
|
|
start = start[0, -1] + 1 if len(start) else 0 |
|
end = end[0, -1] if len(end) else tokens.shape[-1] |
|
|
|
|
|
audio_tokens = tokens[start:end] % n_original_tokens |
|
reminder = audio_tokens.shape[-1] % n_codebooks |
|
|
|
if reminder: |
|
|
|
pad_tokens = torch.zeros(n_codebooks - reminder, device="cuda") |
|
audio_tokens = torch.cat([audio_tokens, pad_tokens[reminder:n_codebooks]], dim=0) |
|
|
|
transposed = audio_tokens.view(-1, n_codebooks).t() |
|
codes = transposed.view(n_codebooks, 1, -1).to(device) |
|
|
|
audio = quantizer.decode(codes).squeeze(0) |
|
|
|
del tokens |
|
del audio_tokens |
|
torch.cuda.empty_cache() |
|
|
|
return AudioSignal(audio.detach().cpu().numpy(), quantizer.sample_rate) |
|
|
|
|
|
def infer_text_to_audio(text, model, tokenizer, quantizer, max_seq_length=1024, top_k=20): |
|
text_tokenized = tokenizer(text, return_tensors="pt") |
|
text_input_tokens = text_tokenized["input_ids"].to(device) |
|
|
|
soa = tokenizer(start_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device) |
|
eoa = tokenizer(end_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device) |
|
|
|
text_tokens = torch.cat([text_input_tokens, soa], dim=1) |
|
attention_mask = torch.ones(text_tokens.size(), device=device) |
|
|
|
output_audio_tokens = model.generate( |
|
text_tokens, |
|
attention_mask=attention_mask, |
|
max_new_tokens=max_seq_length, |
|
top_k=top_k, |
|
do_sample=True, |
|
temperature=0.8, |
|
no_repeat_ngram_size=3, |
|
) |
|
|
|
audio_signal = decode_tts(output_audio_tokens[0], quantizer, 3, len(tokenizer) - codebook_size, soa, eoa) |
|
|
|
return audio_signal |
|
|
|
|
|
def infer_audio_to_text(audio_path, model, tokenizer, quantizer, max_seq_length=1024, top_k=20): |
|
audio_data, sample_rate = torchaudio.load(audio_path) |
|
|
|
audio = audio_data.view(1, 1, -1).float().to(device) |
|
|
|
|
|
codes = quantizer.encode(audio) |
|
raw_audio_tokens = codes[:, :n_codebooks_asr] + len(tokenizer) - codebook_size |
|
|
|
soa = tokenizer(start_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device) |
|
eoa = tokenizer(end_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device) |
|
audio_tokens = torch.cat([soa, raw_audio_tokens.view(1, -1), eoa], dim=1) |
|
tokens = torch.cat([audio_tokens], dim=1) |
|
|
|
attention_mask = torch.ones(tokens.size(), device=device) |
|
|
|
output_text_tokens = model.generate( |
|
tokens, |
|
attention_mask=attention_mask, |
|
max_new_tokens=max_seq_length, |
|
temperature=0.6, |
|
top_p=0.9, |
|
top_k=top_k, |
|
no_repeat_ngram_size=4, |
|
length_penalty=2.0, |
|
repetition_penalty=1.5 |
|
) |
|
|
|
output_text_tokens = output_text_tokens.cpu()[0] |
|
output_text_tokens = output_text_tokens[output_text_tokens < tokenizer(start_audio_token)["input_ids"][-1]] |
|
decoded_text = tokenizer.decode(output_text_tokens, skip_special_tokens=True) |
|
|
|
return decoded_text |
|
|
|
|
|
device = "cuda" |
|
|
|
n_special_tokens = 3 |
|
n_codebooks_tts = 3 |
|
n_codebooks_asr = 1 |
|
|
|
start_audio_token = "<soa>" |
|
end_audio_token = "<eoa>" |
|
end_sequence_token = "<eos>" |
|
|
|
base_model = "Vikhrmodels/salt-116k" |
|
|
|
|
|
if __name__ == "__main__": |
|
tokenizer = AutoTokenizer.from_pretrained(base_model, cache_dir=".") |
|
model = AutoModelForCausalLM.from_pretrained( |
|
base_model, |
|
cache_dir=".", |
|
torch_dtype=torch.bfloat16, |
|
attn_implementation="sdpa", |
|
device_map={"": 0} |
|
) |
|
|
|
quantizer_speech = SpeechTokenizer.load_from_checkpoint("speechtokenizer/config.json", |
|
"speechtokenizer/SpeechTokenizer.pt") |
|
quantizer_speech = quantizer_speech.eval().to(device) |
|
codebook_size = quantizer_speech.quantizer.bins |
|
|
|
text = ("Say 'COUNT NUMBERS FROM ONE TO TEN' with a male speaker delivers a very monotone and " |
|
"low-pitched speech with a moderate speed in a setting with almost no noise, " |
|
"creating a clear and quiet recording.") |
|
|
|
audio_signal = infer_text_to_audio(text, model, tokenizer, quantizer_speech, top_k=50) |
|
audio_signal.write("output.wav") |
|
|
|
audio_path = "./input.wav" |
|
generated_text = infer_audio_to_text(audio_path, model, tokenizer, quantizer_speech) |
|
print(generated_text) |
|
|