import torchaudio import torch from transformers import ( AutoTokenizer, AutoModelForCausalLM, ) from speechtokenizer import SpeechTokenizer from audiotools import AudioSignal def decode_tts(tokens, quantizer, n_codebooks, n_original_tokens, start_audio_token_id, end_audio_token_id): # find start and end indices of audio tokens start = torch.nonzero(tokens == start_audio_token_id) end = torch.nonzero(tokens == end_audio_token_id) start = start[0, -1] + 1 if len(start) else 0 end = end[0, -1] if len(end) else tokens.shape[-1] # subtract length of original vocabulary -> tokens in range [0, 1024) audio_tokens = tokens[start:end] % n_original_tokens reminder = audio_tokens.shape[-1] % n_codebooks if reminder: # pad if last frame is incomplete pad_tokens = torch.zeros(n_codebooks - reminder, device="cuda") audio_tokens = torch.cat([audio_tokens, pad_tokens[reminder:n_codebooks]], dim=0) transposed = audio_tokens.view(-1, n_codebooks).t() codes = transposed.view(n_codebooks, 1, -1).to(device) audio = quantizer.decode(codes).squeeze(0) del tokens del audio_tokens torch.cuda.empty_cache() return AudioSignal(audio.detach().cpu().numpy(), quantizer.sample_rate) def infer_text_to_audio(text, model, tokenizer, quantizer, max_seq_length=1024, top_k=20): text_tokenized = tokenizer(text, return_tensors="pt") text_input_tokens = text_tokenized["input_ids"].to(device) soa = tokenizer(start_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device) eoa = tokenizer(end_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device) text_tokens = torch.cat([text_input_tokens, soa], dim=1) attention_mask = torch.ones(text_tokens.size(), device=device) output_audio_tokens = model.generate( text_tokens, attention_mask=attention_mask, max_new_tokens=max_seq_length, top_k=top_k, do_sample=True, temperature=0.8, no_repeat_ngram_size=3, ) audio_signal = decode_tts(output_audio_tokens[0], quantizer, 3, len(tokenizer) - codebook_size, soa, eoa) return audio_signal def infer_audio_to_text(audio_path, model, tokenizer, quantizer, max_seq_length=1024, top_k=20): audio_data, sample_rate = torchaudio.load(audio_path) audio = audio_data.view(1, 1, -1).float().to(device) # bandwidth_id = torch.tensor([0]) codes = quantizer.encode(audio) raw_audio_tokens = codes[:, :n_codebooks_asr] + len(tokenizer) - codebook_size soa = tokenizer(start_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device) eoa = tokenizer(end_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device) audio_tokens = torch.cat([soa, raw_audio_tokens.view(1, -1), eoa], dim=1) tokens = torch.cat([audio_tokens], dim=1) attention_mask = torch.ones(tokens.size(), device=device) output_text_tokens = model.generate( tokens, attention_mask=attention_mask, max_new_tokens=max_seq_length, temperature=0.6, top_p=0.9, top_k=top_k, no_repeat_ngram_size=4, length_penalty=2.0, repetition_penalty=1.5 ) output_text_tokens = output_text_tokens.cpu()[0] output_text_tokens = output_text_tokens[output_text_tokens < tokenizer(start_audio_token)["input_ids"][-1]] decoded_text = tokenizer.decode(output_text_tokens, skip_special_tokens=True) return decoded_text device = "cuda" n_special_tokens = 3 n_codebooks_tts = 3 n_codebooks_asr = 1 start_audio_token = "" end_audio_token = "" end_sequence_token = "" base_model = "Vikhrmodels/salt-116k" if __name__ == "__main__": tokenizer = AutoTokenizer.from_pretrained(base_model, cache_dir=".") model = AutoModelForCausalLM.from_pretrained( base_model, cache_dir=".", torch_dtype=torch.bfloat16, attn_implementation="sdpa", device_map={"": 0} ) quantizer_speech = SpeechTokenizer.load_from_checkpoint("speechtokenizer/config.json", "speechtokenizer/SpeechTokenizer.pt") quantizer_speech = quantizer_speech.eval().to(device) codebook_size = quantizer_speech.quantizer.bins text = ("Say 'COUNT NUMBERS FROM ONE TO TEN' with a male speaker delivers a very monotone and " "low-pitched speech with a moderate speed in a setting with almost no noise, " "creating a clear and quiet recording.") audio_signal = infer_text_to_audio(text, model, tokenizer, quantizer_speech, top_k=50) audio_signal.write("output.wav") audio_path = "./input.wav" generated_text = infer_audio_to_text(audio_path, model, tokenizer, quantizer_speech) print(generated_text)