freddyaboulton HF staff commited on
Commit
4a37dab
1 Parent(s): 521091c
Files changed (4) hide show
  1. .gitignore +1 -0
  2. app.py +52 -0
  3. inference.py +666 -0
  4. server.py +57 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ checkpoint/
app.py CHANGED
@@ -1,4 +1,56 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  def greet(name):
4
  return "Hello " + name + "!!"
 
1
  import gradio as gr
2
+ from huggingface_hub import snapshot_download
3
+ from threading import Thread
4
+ import os
5
+ import time
6
+ import gradio as gr
7
+ import base64
8
+ import numpy as np
9
+ import requests
10
+
11
+ from server import serve
12
+
13
+ repo_id = "gpt-omni/mini-omni"
14
+ snapshot_download(repo_id, local_dir="./checkpoint", revision="main")
15
+
16
+ IP='0.0.0.0'
17
+ PORT=60808
18
+
19
+ thread = Thread(target=serve, daemon=True)
20
+ thread.start()
21
+
22
+ API_URL = "http://0.0.0.0:60808/chat"
23
+
24
+ OUT_CHUNK = 4096
25
+ OUT_RATE = 24000
26
+ OUT_CHANNELS = 1
27
+
28
+ def process_audio(audio):
29
+ filepath = audio
30
+ print(f"filepath: {filepath}")
31
+ if filepath is None:
32
+ return
33
+
34
+ cnt = 0
35
+ with open(filepath, "rb") as f:
36
+ data = f.read()
37
+ base64_encoded = str(base64.b64encode(data), encoding="utf-8")
38
+ files = {"audio": base64_encoded}
39
+ tik = time.time()
40
+ with requests.post(API_URL, json=files, stream=True) as response:
41
+ try:
42
+ for chunk in response.iter_content(chunk_size=OUT_CHUNK):
43
+ if chunk:
44
+ # Convert chunk to numpy array
45
+ if cnt == 0:
46
+ print(f"first chunk time cost: {time.time() - tik:.3f}")
47
+ cnt += 1
48
+ audio_data = np.frombuffer(chunk, dtype=np.int16)
49
+ audio_data = audio_data.reshape(-1, OUT_CHANNELS)
50
+ yield OUT_RATE, audio_data.astype(np.int16)
51
+
52
+ except Exception as e:
53
+ print(f"error: {e}")
54
 
55
  def greet(name):
56
  return "Hello " + name + "!!"
inference.py ADDED
@@ -0,0 +1,666 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import lightning as L
3
+ import torch
4
+ import time
5
+ from snac import SNAC
6
+ from litgpt import Tokenizer
7
+ from litgpt.utils import (
8
+ num_parameters,
9
+ )
10
+ from litgpt.generate.base import (
11
+ generate_AA,
12
+ generate_ASR,
13
+ generate_TA,
14
+ generate_TT,
15
+ generate_AT,
16
+ generate_TA_BATCH,
17
+ next_token_batch
18
+ )
19
+ import soundfile as sf
20
+ from litgpt.model import GPT, Config
21
+ from lightning.fabric.utilities.load import _lazy_load as lazy_load
22
+ from utils.snac_utils import layershift, reconscruct_snac, reconstruct_tensors, get_time_str
23
+ from utils.snac_utils import get_snac, generate_audio_data
24
+ import whisper
25
+ from tqdm import tqdm
26
+ from huggingface_hub import snapshot_download
27
+
28
+
29
+ torch.set_printoptions(sci_mode=False)
30
+
31
+
32
+ # TODO
33
+ text_vocabsize = 151936
34
+ text_specialtokens = 64
35
+ audio_vocabsize = 4096
36
+ audio_specialtokens = 64
37
+
38
+ padded_text_vocabsize = text_vocabsize + text_specialtokens
39
+ padded_audio_vocabsize = audio_vocabsize + audio_specialtokens
40
+
41
+ _eot = text_vocabsize
42
+ _pad_t = text_vocabsize + 1
43
+ _input_t = text_vocabsize + 2
44
+ _answer_t = text_vocabsize + 3
45
+ _asr = text_vocabsize + 4
46
+
47
+ _eoa = audio_vocabsize
48
+ _pad_a = audio_vocabsize + 1
49
+ _input_a = audio_vocabsize + 2
50
+ _answer_a = audio_vocabsize + 3
51
+ _split = audio_vocabsize + 4
52
+
53
+
54
+ def get_input_ids_TA(text, text_tokenizer):
55
+ input_ids_item = [[] for _ in range(8)]
56
+ text_tokens = text_tokenizer.encode(text)
57
+ for i in range(7):
58
+ input_ids_item[i] = [layershift(_pad_a, i)] * (len(text_tokens) + 2) + [
59
+ layershift(_answer_a, i)
60
+ ]
61
+ input_ids_item[i] = torch.tensor(input_ids_item[i]).unsqueeze(0)
62
+ input_ids_item[-1] = [_input_t] + text_tokens.tolist() + [_eot] + [_answer_t]
63
+ input_ids_item[-1] = torch.tensor(input_ids_item[-1]).unsqueeze(0)
64
+ return input_ids_item
65
+
66
+
67
+ def get_input_ids_TT(text, text_tokenizer):
68
+ input_ids_item = [[] for i in range(8)]
69
+ text_tokens = text_tokenizer.encode(text).tolist()
70
+
71
+ for i in range(7):
72
+ input_ids_item[i] = torch.tensor(
73
+ [layershift(_pad_a, i)] * (len(text_tokens) + 3)
74
+ ).unsqueeze(0)
75
+ input_ids_item[-1] = [_input_t] + text_tokens + [_eot] + [_answer_t]
76
+ input_ids_item[-1] = torch.tensor(input_ids_item[-1]).unsqueeze(0)
77
+
78
+ return input_ids_item
79
+
80
+
81
+ def get_input_ids_whisper(
82
+ mel, leng, whispermodel, device,
83
+ special_token_a=_answer_a, special_token_t=_answer_t,
84
+ ):
85
+
86
+ with torch.no_grad():
87
+ mel = mel.unsqueeze(0).to(device)
88
+ # audio_feature = whisper.decode(whispermodel,mel, options).audio_features
89
+ audio_feature = whispermodel.embed_audio(mel)[0][:leng]
90
+
91
+ T = audio_feature.size(0)
92
+ input_ids = []
93
+ for i in range(7):
94
+ input_ids_item = []
95
+ input_ids_item.append(layershift(_input_a, i))
96
+ input_ids_item += [layershift(_pad_a, i)] * T
97
+ input_ids_item += [(layershift(_eoa, i)), layershift(special_token_a, i)]
98
+ input_ids.append(torch.tensor(input_ids_item).unsqueeze(0))
99
+ input_id_T = torch.tensor([_input_t] + [_pad_t] * T + [_eot, special_token_t])
100
+ input_ids.append(input_id_T.unsqueeze(0))
101
+ return audio_feature.unsqueeze(0), input_ids
102
+
103
+
104
+ def get_input_ids_whisper_ATBatch(mel, leng, whispermodel, device):
105
+ with torch.no_grad():
106
+ mel = mel.unsqueeze(0).to(device)
107
+ # audio_feature = whisper.decode(whispermodel,mel, options).audio_features
108
+ audio_feature = whispermodel.embed_audio(mel)[0][:leng]
109
+ T = audio_feature.size(0)
110
+ input_ids_AA = []
111
+ for i in range(7):
112
+ input_ids_item = []
113
+ input_ids_item.append(layershift(_input_a, i))
114
+ input_ids_item += [layershift(_pad_a, i)] * T
115
+ input_ids_item += [(layershift(_eoa, i)), layershift(_answer_a, i)]
116
+ input_ids_AA.append(torch.tensor(input_ids_item))
117
+ input_id_T = torch.tensor([_input_t] + [_pad_t] * T + [_eot, _answer_t])
118
+ input_ids_AA.append(input_id_T)
119
+
120
+ input_ids_AT = []
121
+ for i in range(7):
122
+ input_ids_item = []
123
+ input_ids_item.append(layershift(_input_a, i))
124
+ input_ids_item += [layershift(_pad_a, i)] * T
125
+ input_ids_item += [(layershift(_eoa, i)), layershift(_pad_a, i)]
126
+ input_ids_AT.append(torch.tensor(input_ids_item))
127
+ input_id_T = torch.tensor([_input_t] + [_pad_t] * T + [_eot, _answer_t])
128
+ input_ids_AT.append(input_id_T)
129
+
130
+ input_ids = [input_ids_AA, input_ids_AT]
131
+ stacked_inputids = [[] for _ in range(8)]
132
+ for i in range(2):
133
+ for j in range(8):
134
+ stacked_inputids[j].append(input_ids[i][j])
135
+ stacked_inputids = [torch.stack(tensors) for tensors in stacked_inputids]
136
+ return torch.stack([audio_feature, audio_feature]), stacked_inputids
137
+
138
+
139
+ def load_audio(path):
140
+ audio = whisper.load_audio(path)
141
+ duration_ms = (len(audio) / 16000) * 1000
142
+ audio = whisper.pad_or_trim(audio)
143
+ mel = whisper.log_mel_spectrogram(audio)
144
+ return mel, int(duration_ms / 20) + 1
145
+
146
+
147
+ def A1_A2_batch(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
148
+ snacmodel, out_dir=None):
149
+ with fabric.init_tensor():
150
+ model.set_kv_cache(batch_size=2)
151
+ tokenlist = generate_TA_BATCH(
152
+ model,
153
+ audio_feature,
154
+ input_ids,
155
+ [leng, leng],
156
+ ["A1A2", "A1T2"],
157
+ max_returned_tokens=2048,
158
+ temperature=0.9,
159
+ top_k=1,
160
+ eos_id_a=_eoa,
161
+ eos_id_t=_eot,
162
+ pad_id_t=_pad_t,
163
+ shift=padded_text_vocabsize,
164
+ include_prompt=True,
165
+ generate_text=True,
166
+ )
167
+ text_tokenlist = tokenlist[-1]
168
+ if text_vocabsize in text_tokenlist:
169
+ text_tokenlist = text_tokenlist[: text_tokenlist.index(text_vocabsize)]
170
+ text = text_tokenizer.decode(torch.tensor(text_tokenlist)).strip()
171
+
172
+ audio_tokenlist = tokenlist[:-1]
173
+ audiolist = reconscruct_snac(audio_tokenlist)
174
+ audio = reconstruct_tensors(audiolist)
175
+ if out_dir is None:
176
+ out_dir = "./output/default/A1-A2-batch"
177
+ else:
178
+ out_dir = out_dir + "/A1-A2-batch"
179
+ if not os.path.exists(out_dir):
180
+ os.makedirs(out_dir)
181
+ with torch.inference_mode():
182
+ audio_hat = snacmodel.decode(audio)
183
+ sf.write(
184
+ f"{out_dir}/{step:02d}.wav",
185
+ audio_hat.squeeze().cpu().numpy(),
186
+ 24000,
187
+ )
188
+ model.clear_kv_cache()
189
+ return text
190
+
191
+
192
+ def A1_T2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
193
+ with fabric.init_tensor():
194
+ model.set_kv_cache(batch_size=1)
195
+ tokenlist = generate_AT(
196
+ model,
197
+ audio_feature,
198
+ input_ids,
199
+ [leng],
200
+ ["AT"],
201
+ max_returned_tokens=2048,
202
+ temperature=0.9,
203
+ top_k=1,
204
+ eos_id_a=_eoa,
205
+ eos_id_t=_eot,
206
+ pad_id_t=_pad_t,
207
+ shift=padded_text_vocabsize,
208
+ include_prompt=True,
209
+ generate_text=True,
210
+ )
211
+ return text_tokenizer.decode(torch.tensor(tokenlist)).strip()
212
+
213
+
214
+ def A1_A2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
215
+ snacmodel, out_dir=None):
216
+ with fabric.init_tensor():
217
+ model.set_kv_cache(batch_size=1)
218
+ tokenlist = generate_AA(
219
+ model,
220
+ audio_feature,
221
+ input_ids,
222
+ [leng],
223
+ ["A1T2"],
224
+ max_returned_tokens=2048,
225
+ temperature=0.9,
226
+ top_k=1,
227
+ eos_id_a=_eoa,
228
+ eos_id_t=_eot,
229
+ pad_id_t=_pad_t,
230
+ shift=padded_text_vocabsize,
231
+ include_prompt=True,
232
+ generate_text=True,
233
+ )
234
+ audiolist = reconscruct_snac(tokenlist)
235
+ tokenlist = tokenlist[-1]
236
+ if text_vocabsize in tokenlist:
237
+ tokenlist = tokenlist[: tokenlist.index(text_vocabsize)]
238
+ if out_dir is None:
239
+ out_dir = "./output/default/A1-A2"
240
+ else:
241
+ out_dir = out_dir + "/A1-A2"
242
+ if not os.path.exists(out_dir):
243
+ os.makedirs(out_dir)
244
+
245
+ audio = reconstruct_tensors(audiolist)
246
+ with torch.inference_mode():
247
+ audio_hat = snacmodel.decode(audio)
248
+ sf.write(
249
+ f"{out_dir}/{step:02d}.wav",
250
+ audio_hat.squeeze().cpu().numpy(),
251
+ 24000,
252
+ )
253
+ model.clear_kv_cache()
254
+ return text_tokenizer.decode(torch.tensor(tokenlist)).strip()
255
+
256
+
257
+ def A1_T1(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
258
+ with fabric.init_tensor():
259
+ model.set_kv_cache(batch_size=1)
260
+ tokenlist = generate_ASR(
261
+ model,
262
+ audio_feature,
263
+ input_ids,
264
+ [leng],
265
+ ["A1T1"],
266
+ max_returned_tokens=2048,
267
+ temperature=0.9,
268
+ top_k=1,
269
+ eos_id_a=_eoa,
270
+ eos_id_t=_eot,
271
+ pad_id_t=_pad_t,
272
+ shift=padded_text_vocabsize,
273
+ include_prompt=True,
274
+ generate_text=True,
275
+ )
276
+ model.clear_kv_cache()
277
+ return text_tokenizer.decode(torch.tensor(tokenlist)).strip()
278
+
279
+
280
+ def T1_A2(fabric, input_ids, model, text_tokenizer, step,
281
+ snacmodel, out_dir=None):
282
+ with fabric.init_tensor():
283
+ model.set_kv_cache(batch_size=1)
284
+ tokenlist = generate_TA(
285
+ model,
286
+ None,
287
+ input_ids,
288
+ None,
289
+ ["T1A2"],
290
+ max_returned_tokens=2048,
291
+ temperature=0.9,
292
+ top_k=1,
293
+ eos_id_a=_eoa,
294
+ eos_id_t=_eot,
295
+ pad_id_t=_pad_t,
296
+ shift=padded_text_vocabsize,
297
+ include_prompt=True,
298
+ generate_text=True,
299
+ )
300
+
301
+ audiolist = reconscruct_snac(tokenlist)
302
+ tokenlist = tokenlist[-1]
303
+
304
+ if text_vocabsize in tokenlist:
305
+ tokenlist = tokenlist[: tokenlist.index(text_vocabsize)]
306
+ audio = reconstruct_tensors(audiolist)
307
+ if out_dir is None:
308
+ out_dir = "./output/default/T1-A2"
309
+ else:
310
+ out_dir = out_dir + "/T1-A2"
311
+ if not os.path.exists(out_dir):
312
+ os.makedirs(out_dir)
313
+
314
+ with torch.inference_mode():
315
+ audio_hat = snacmodel.decode(audio)
316
+ sf.write(
317
+ f"{out_dir}/{step:02d}.wav",
318
+ audio_hat.squeeze().cpu().numpy(),
319
+ 24000,
320
+ )
321
+ model.clear_kv_cache()
322
+ return text_tokenizer.decode(torch.tensor(tokenlist)).strip()
323
+
324
+
325
+ def T1_T2(fabric, input_ids, model, text_tokenizer, step):
326
+
327
+ with fabric.init_tensor():
328
+ model.set_kv_cache(batch_size=1)
329
+ tokenlist = generate_TT(
330
+ model,
331
+ None,
332
+ input_ids,
333
+ None,
334
+ ["T1T2"],
335
+ max_returned_tokens=2048,
336
+ temperature=0.9,
337
+ top_k=1,
338
+ eos_id_a=_eoa,
339
+ eos_id_t=_eot,
340
+ pad_id_t=_pad_t,
341
+ shift=padded_text_vocabsize,
342
+ include_prompt=True,
343
+ generate_text=True,
344
+ )
345
+ model.clear_kv_cache()
346
+ return text_tokenizer.decode(torch.tensor(tokenlist)).strip()
347
+
348
+
349
+ def load_model(ckpt_dir, device):
350
+ snacmodel = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval().to(device)
351
+ whispermodel = whisper.load_model("small").to(device)
352
+ text_tokenizer = Tokenizer(ckpt_dir)
353
+ fabric = L.Fabric(devices=1, strategy="auto")
354
+ config = Config.from_file(ckpt_dir + "/model_config.yaml")
355
+ config.post_adapter = False
356
+
357
+ with fabric.init_module(empty_init=False):
358
+ model = GPT(config)
359
+
360
+ model = fabric.setup(model)
361
+ state_dict = lazy_load(ckpt_dir + "/lit_model.pth")
362
+ model.load_state_dict(state_dict, strict=True)
363
+ model.to(device).eval()
364
+
365
+ return fabric, model, text_tokenizer, snacmodel, whispermodel
366
+
367
+
368
+ def download_model(ckpt_dir):
369
+ repo_id = "gpt-omni/mini-omni"
370
+ snapshot_download(repo_id, local_dir=ckpt_dir, revision="main")
371
+
372
+
373
+ class OmniInference:
374
+
375
+ def __init__(self, ckpt_dir='./checkpoint', device='cuda:0'):
376
+ self.device = device
377
+ if not os.path.exists(ckpt_dir):
378
+ print(f"checkpoint directory {ckpt_dir} not found, downloading from huggingface")
379
+ download_model(ckpt_dir)
380
+ self.fabric, self.model, self.text_tokenizer, self.snacmodel, self.whispermodel = load_model(ckpt_dir, device)
381
+
382
+ def warm_up(self, sample='./data/samples/output1.wav'):
383
+ for _ in self.run_AT_batch_stream(sample):
384
+ pass
385
+
386
+ @torch.inference_mode()
387
+ def run_AT_batch_stream(self,
388
+ audio_path,
389
+ stream_stride=4,
390
+ max_returned_tokens=2048,
391
+ temperature=0.9,
392
+ top_k=1,
393
+ top_p=1.0,
394
+ eos_id_a=_eoa,
395
+ eos_id_t=_eot,
396
+ ):
397
+
398
+ assert os.path.exists(audio_path), f"audio file {audio_path} not found"
399
+ model = self.model
400
+
401
+ with self.fabric.init_tensor():
402
+ model.set_kv_cache(batch_size=2,device=self.device)
403
+
404
+ mel, leng = load_audio(audio_path)
405
+ audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, self.whispermodel, self.device)
406
+ T = input_ids[0].size(1)
407
+ device = input_ids[0].device
408
+
409
+ assert max_returned_tokens > T, f"max_returned_tokens {max_returned_tokens} should be greater than audio length {T}"
410
+
411
+ if model.max_seq_length < max_returned_tokens - 1:
412
+ raise NotImplementedError(
413
+ f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}"
414
+ )
415
+
416
+ input_pos = torch.tensor([T], device=device)
417
+ list_output = [[] for i in range(8)]
418
+ tokens_A, token_T = next_token_batch(
419
+ model,
420
+ audio_feature.to(torch.float32).to(model.device),
421
+ input_ids,
422
+ [T - 3, T - 3],
423
+ ["A1T2", "A1T2"],
424
+ input_pos=torch.arange(0, T, device=device),
425
+ temperature=temperature,
426
+ top_k=top_k,
427
+ top_p=top_p,
428
+ )
429
+
430
+ for i in range(7):
431
+ list_output[i].append(tokens_A[i].tolist()[0])
432
+ list_output[7].append(token_T.tolist()[0])
433
+
434
+ model_input_ids = [[] for i in range(8)]
435
+ for i in range(7):
436
+ tokens_A[i] = tokens_A[i].clone() + padded_text_vocabsize + i * padded_audio_vocabsize
437
+ model_input_ids[i].append(tokens_A[i].clone().to(device).to(torch.int32))
438
+ model_input_ids[i].append(torch.tensor([layershift(4097, i)], device=device))
439
+ model_input_ids[i] = torch.stack(model_input_ids[i])
440
+
441
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
442
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
443
+ model_input_ids[-1] = torch.stack(model_input_ids[-1])
444
+
445
+ text_end = False
446
+ index = 1
447
+ nums_generate = stream_stride
448
+ begin_generate = False
449
+ current_index = 0
450
+ for _ in tqdm(range(2, max_returned_tokens - T + 1)):
451
+ tokens_A, token_T = next_token_batch(
452
+ model,
453
+ None,
454
+ model_input_ids,
455
+ None,
456
+ None,
457
+ input_pos=input_pos,
458
+ temperature=temperature,
459
+ top_k=top_k,
460
+ top_p=top_p,
461
+ )
462
+
463
+ if text_end:
464
+ token_T = torch.tensor([_pad_t], device=device)
465
+
466
+ if tokens_A[-1] == eos_id_a:
467
+ break
468
+
469
+ if token_T == eos_id_t:
470
+ text_end = True
471
+
472
+ for i in range(7):
473
+ list_output[i].append(tokens_A[i].tolist()[0])
474
+ list_output[7].append(token_T.tolist()[0])
475
+
476
+ model_input_ids = [[] for i in range(8)]
477
+ for i in range(7):
478
+ tokens_A[i] = tokens_A[i].clone() +padded_text_vocabsize + i * padded_audio_vocabsize
479
+ model_input_ids[i].append(tokens_A[i].clone().to(device).to(torch.int32))
480
+ model_input_ids[i].append(
481
+ torch.tensor([layershift(4097, i)], device=device)
482
+ )
483
+ model_input_ids[i] = torch.stack(model_input_ids[i])
484
+
485
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
486
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
487
+ model_input_ids[-1] = torch.stack(model_input_ids[-1])
488
+
489
+ if index == 7:
490
+ begin_generate = True
491
+
492
+ if begin_generate:
493
+ current_index += 1
494
+ if current_index == nums_generate:
495
+ current_index = 0
496
+ snac = get_snac(list_output, index, nums_generate)
497
+ audio_stream = generate_audio_data(snac, self.snacmodel, self.device)
498
+ yield audio_stream
499
+
500
+ input_pos = input_pos.add_(1)
501
+ index += 1
502
+ text = self.text_tokenizer.decode(torch.tensor(list_output[-1]))
503
+ print(f"text output: {text}")
504
+ model.clear_kv_cache()
505
+ return list_output
506
+
507
+
508
+ def test_infer():
509
+ device = "cuda:0"
510
+ out_dir = f"./output/{get_time_str()}"
511
+ ckpt_dir = f"./checkpoint"
512
+ if not os.path.exists(ckpt_dir):
513
+ print(f"checkpoint directory {ckpt_dir} not found, downloading from huggingface")
514
+ download_model(ckpt_dir)
515
+
516
+ fabric, model, text_tokenizer, snacmodel, whispermodel = load_model(ckpt_dir, device)
517
+
518
+ task = ['A1A2', 'asr', "T1A2", "AA-BATCH", 'T1T2', 'AT']
519
+
520
+ # prepare test data
521
+ # TODO
522
+ test_audio_list = sorted(os.listdir('./data/samples'))
523
+ test_audio_list = [os.path.join('./data/samples', path) for path in test_audio_list]
524
+ test_audio_transcripts = [
525
+ "What is your name?",
526
+ "what are your hobbies?",
527
+ "Do you like beijing",
528
+ "How are you feeling today?",
529
+ "what is the weather like today?",
530
+ ]
531
+ test_text_list = [
532
+ "What is your name?",
533
+ "How are you feeling today?",
534
+ "Can you describe your surroundings?",
535
+ "What did you do yesterday?",
536
+ "What is your favorite book and why?",
537
+ "How do you make a cup of tea?",
538
+ "What is the weather like today?",
539
+ "Can you explain the concept of time?",
540
+ "Can you tell me a joke?",
541
+ ]
542
+
543
+ # LOAD MODEL
544
+ with torch.no_grad():
545
+ if "A1A2" in task:
546
+ print("===============================================================")
547
+ print(" testing A1A2")
548
+ print("===============================================================")
549
+ step = 0
550
+ for path in test_audio_list:
551
+ try:
552
+ mel, leng = load_audio(path)
553
+ audio_feature, input_ids = get_input_ids_whisper(mel, leng, whispermodel, device)
554
+ text = A1_A2(
555
+ fabric,
556
+ audio_feature,
557
+ input_ids,
558
+ leng,
559
+ model,
560
+ text_tokenizer,
561
+ step,
562
+ snacmodel,
563
+ out_dir=out_dir,
564
+ )
565
+ print(f"input: {test_audio_transcripts[step]}")
566
+ print(f"output: {text}")
567
+ step += 1
568
+ print(
569
+ "+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++"
570
+ )
571
+ except:
572
+ print(f"[error] failed to process {path}")
573
+ print("===============================================================")
574
+
575
+ if 'asr' in task:
576
+ print("===============================================================")
577
+ print(" testing asr")
578
+ print("===============================================================")
579
+
580
+ index = 0
581
+ step = 0
582
+ for path in test_audio_list:
583
+ mel, leng = load_audio(path)
584
+ audio_feature, input_ids = get_input_ids_whisper(mel, leng, whispermodel, device, special_token_a=_pad_a, special_token_t=_asr)
585
+ output = A1_T1(fabric, audio_feature, input_ids ,leng, model, text_tokenizer, index).lower().replace(',','').replace('.','').replace('?','')
586
+ print(f"audio_path: {path}")
587
+ print(f"audio transcript: {test_audio_transcripts[index]}")
588
+ print(f"asr output: {output}")
589
+ print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
590
+ index += 1
591
+
592
+ if "T1A2" in task:
593
+ step = 0
594
+ print("\n")
595
+ print("===============================================================")
596
+ print(" testing T1A2")
597
+ print("===============================================================")
598
+ for text in test_text_list:
599
+ input_ids = get_input_ids_TA(text, text_tokenizer)
600
+ text_output = T1_A2(fabric, input_ids, model, text_tokenizer, step,
601
+ snacmodel, out_dir=out_dir)
602
+ print(f"input: {text}")
603
+ print(f"output: {text_output}")
604
+ print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
605
+ step += 1
606
+ print("===============================================================")
607
+
608
+ if "T1T2" in task:
609
+ step = 0
610
+ print("\n")
611
+ print("===============================================================")
612
+ print(" testing T1T2")
613
+ print("===============================================================")
614
+
615
+ for text in test_text_list:
616
+ input_ids = get_input_ids_TT(text, text_tokenizer)
617
+ text_output = T1_T2(fabric, input_ids, model, text_tokenizer, step)
618
+ print(f" Input: {text}")
619
+ print(f"Output: {text_output}")
620
+ print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
621
+ print("===============================================================")
622
+
623
+ if "AT" in task:
624
+ print("===============================================================")
625
+ print(" testing A1T2")
626
+ print("===============================================================")
627
+ step = 0
628
+ for path in test_audio_list:
629
+ mel, leng = load_audio(path)
630
+ audio_feature, input_ids = get_input_ids_whisper(
631
+ mel, leng, whispermodel, device,
632
+ special_token_a=_pad_a, special_token_t=_answer_t
633
+ )
634
+ text = A1_T2(
635
+ fabric, audio_feature, input_ids, leng, model, text_tokenizer, step
636
+ )
637
+ print(f"input: {test_audio_transcripts[step]}")
638
+ print(f"output: {text}")
639
+ step += 1
640
+ print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
641
+ print("===============================================================")
642
+
643
+ if "AA-BATCH" in task:
644
+ print("===============================================================")
645
+ print(" testing A1A2-BATCH")
646
+ print("===============================================================")
647
+ step = 0
648
+ for path in test_audio_list:
649
+ mel, leng = load_audio(path)
650
+ audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, whispermodel, device)
651
+ text = A1_A2_batch(
652
+ fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
653
+ snacmodel, out_dir=out_dir
654
+ )
655
+ print(f"input: {test_audio_transcripts[step]}")
656
+ print(f"output: {text}")
657
+ step += 1
658
+ print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
659
+ print("===============================================================")
660
+
661
+ print("*********************** test end *****************************")
662
+
663
+
664
+
665
+ if __name__ == "__main__":
666
+ test_infer()
server.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import flask
2
+ import base64
3
+ import tempfile
4
+ import traceback
5
+ from flask import Flask, Response, stream_with_context
6
+ from inference import OmniInference
7
+
8
+
9
+ class OmniChatServer(object):
10
+ def __init__(self, ip='0.0.0.0', port=60808, run_app=True,
11
+ ckpt_dir='./checkpoint', device='cuda:0') -> None:
12
+ server = Flask(__name__)
13
+ # CORS(server, resources=r"/*")
14
+ # server.config["JSON_AS_ASCII"] = False
15
+
16
+ self.client = OmniInference(ckpt_dir, device)
17
+ self.client.warm_up()
18
+
19
+ server.route("/chat", methods=["POST"])(self.chat)
20
+
21
+ if run_app:
22
+ server.run(host=ip, port=port, threaded=False)
23
+ else:
24
+ self.server = server
25
+
26
+ def chat(self) -> Response:
27
+
28
+ req_data = flask.request.get_json()
29
+ try:
30
+ data_buf = req_data["audio"].encode("utf-8")
31
+ data_buf = base64.b64decode(data_buf)
32
+ stream_stride = req_data.get("stream_stride", 4)
33
+ max_tokens = req_data.get("max_tokens", 2048)
34
+
35
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
36
+ f.write(data_buf)
37
+ audio_generator = self.client.run_AT_batch_stream(f.name, stream_stride, max_tokens)
38
+ return Response(stream_with_context(audio_generator), mimetype="audio/wav")
39
+ except Exception as e:
40
+ print(traceback.format_exc())
41
+
42
+
43
+ # CUDA_VISIBLE_DEVICES=1 gunicorn -w 2 -b 0.0.0.0:60808 'server:create_app()'
44
+ def create_app():
45
+ server = OmniChatServer(run_app=False)
46
+ return server.server
47
+
48
+
49
+ def serve(ip='0.0.0.0', port=60808, device='cuda:0'):
50
+
51
+ OmniChatServer(ip, port=port,run_app=True, device=device)
52
+
53
+
54
+ if __name__ == "__main__":
55
+ import fire
56
+ fire.Fire(serve)
57
+