Ksenia Sycheva commited on
Commit
ee7a752
1 Parent(s): 2cd7bfa

Add inference code

Browse files
Files changed (1) hide show
  1. inference.py +139 -0
inference.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchaudio
2
+ import torch
3
+
4
+ from transformers import (
5
+ AutoTokenizer,
6
+ AutoModelForCausalLM,
7
+ )
8
+
9
+ from speechtokenizer import SpeechTokenizer
10
+ from audiotools import AudioSignal
11
+
12
+
13
+ def decode_tts(tokens, quantizer, n_codebooks, n_original_tokens, start_audio_token_id, end_audio_token_id):
14
+ # find start and end indices of audio tokens
15
+ start = torch.nonzero(tokens == start_audio_token_id)
16
+ end = torch.nonzero(tokens == end_audio_token_id)
17
+
18
+ start = start[0, -1] + 1 if len(start) else 0
19
+ end = end[0, -1] if len(end) else tokens.shape[-1]
20
+
21
+ # subtract length of original vocabulary -> tokens in range [0, 1024)
22
+ audio_tokens = tokens[start:end] % n_original_tokens
23
+ reminder = audio_tokens.shape[-1] % n_codebooks
24
+
25
+ if reminder:
26
+ # pad if last frame is incomplete
27
+ pad_tokens = torch.zeros(n_codebooks - reminder, device="cuda")
28
+ audio_tokens = torch.cat([audio_tokens, pad_tokens[reminder:n_codebooks]], dim=0)
29
+
30
+ transposed = audio_tokens.view(-1, n_codebooks).t()
31
+ codes = transposed.view(n_codebooks, 1, -1).to(device)
32
+
33
+ audio = quantizer.decode(codes).squeeze(0)
34
+
35
+ del tokens
36
+ del audio_tokens
37
+ torch.cuda.empty_cache()
38
+
39
+ return AudioSignal(audio.detach().cpu().numpy(), quantizer.sample_rate)
40
+
41
+
42
+ def infer_text_to_audio(text, model, tokenizer, quantizer, max_seq_length=1024, top_k=20):
43
+ text_tokenized = tokenizer(text, return_tensors="pt")
44
+ text_input_tokens = text_tokenized["input_ids"].to(device)
45
+
46
+ soa = tokenizer(start_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device)
47
+ eoa = tokenizer(end_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device)
48
+
49
+ text_tokens = torch.cat([text_input_tokens, soa], dim=1)
50
+ attention_mask = torch.ones(text_tokens.size(), device=device)
51
+
52
+ output_audio_tokens = model.generate(
53
+ text_tokens,
54
+ attention_mask=attention_mask,
55
+ max_new_tokens=max_seq_length,
56
+ top_k=top_k,
57
+ do_sample=True,
58
+ temperature=0.8,
59
+ no_repeat_ngram_size=3,
60
+ )
61
+
62
+ audio_signal = decode_tts(output_audio_tokens[0], quantizer, 3, len(tokenizer) - codebook_size, soa, eoa)
63
+
64
+ return audio_signal
65
+
66
+
67
+ def infer_audio_to_text(audio_path, model, tokenizer, quantizer, max_seq_length=1024, top_k=20):
68
+ audio_data, sample_rate = torchaudio.load(audio_path)
69
+
70
+ audio = audio_data.view(1, 1, -1).float().to(device)
71
+ # bandwidth_id = torch.tensor([0])
72
+
73
+ codes = quantizer.encode(audio)
74
+ raw_audio_tokens = codes[:, :n_codebooks_asr] + len(tokenizer) - codebook_size
75
+
76
+ soa = tokenizer(start_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device)
77
+ eoa = tokenizer(end_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device)
78
+ audio_tokens = torch.cat([soa, raw_audio_tokens.view(1, -1), eoa], dim=1)
79
+ tokens = torch.cat([audio_tokens], dim=1)
80
+
81
+ attention_mask = torch.ones(tokens.size(), device=device)
82
+
83
+ output_text_tokens = model.generate(
84
+ tokens,
85
+ attention_mask=attention_mask,
86
+ max_new_tokens=max_seq_length,
87
+ temperature=0.6,
88
+ top_p=0.9,
89
+ top_k=top_k,
90
+ no_repeat_ngram_size=4,
91
+ length_penalty=2.0,
92
+ repetition_penalty=1.5
93
+ )
94
+
95
+ output_text_tokens = output_text_tokens.cpu()[0]
96
+ output_text_tokens = output_text_tokens[output_text_tokens < tokenizer(start_audio_token)["input_ids"][-1]]
97
+ decoded_text = tokenizer.decode(output_text_tokens, skip_special_tokens=True)
98
+
99
+ return decoded_text
100
+
101
+
102
+ device = "cuda"
103
+
104
+ n_special_tokens = 3
105
+ n_codebooks_tts = 3
106
+ n_codebooks_asr = 1
107
+
108
+ start_audio_token = "<soa>"
109
+ end_audio_token = "<eoa>"
110
+ end_sequence_token = "<eos>"
111
+
112
+ base_model = "Vikhrmodels/salt-116k"
113
+
114
+
115
+ if __name__ == "__main__":
116
+ tokenizer = AutoTokenizer.from_pretrained(base_model, cache_dir=".")
117
+ model = AutoModelForCausalLM.from_pretrained(
118
+ base_model,
119
+ cache_dir=".",
120
+ torch_dtype=torch.bfloat16,
121
+ attn_implementation="sdpa",
122
+ device_map={"": 0}
123
+ )
124
+
125
+ quantizer_speech = SpeechTokenizer.load_from_checkpoint("speechtokenizer/config.json",
126
+ "speechtokenizer/SpeechTokenizer.pt")
127
+ quantizer_speech = quantizer_speech.eval().to(device)
128
+ codebook_size = quantizer_speech.quantizer.bins
129
+
130
+ text = ("Say 'COUNT NUMBERS FROM ONE TO TEN' with a male speaker delivers a very monotone and "
131
+ "low-pitched speech with a moderate speed in a setting with almost no noise, "
132
+ "creating a clear and quiet recording.")
133
+
134
+ audio_signal = infer_text_to_audio(text, model, tokenizer, quantizer_speech, top_k=50)
135
+ audio_signal.write("output.wav")
136
+
137
+ audio_path = "./input.wav"
138
+ generated_text = infer_audio_to_text(audio_path, model, tokenizer, quantizer_speech)
139
+ print(generated_text)