freddyaboulton HF staff commited on
Commit
54011d4
1 Parent(s): 8622ec6
__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ import sys
2
+ import os
3
+
4
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
app.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import base64
3
+ import io
4
+ import os
5
+ from threading import Thread
6
+
7
+ import gradio as gr
8
+ import numpy as np
9
+ import requests
10
+ from gradio_webrtc import ReplyOnPause, WebRTC, AdditionalOutputs
11
+ from pydub import AudioSegment
12
+ from twilio.rest import Client
13
+
14
+ from server import serve
15
+
16
+ logging.basicConfig(level=logging.WARNING)
17
+ file_handler = logging.FileHandler("gradio_webrtc.log")
18
+ file_handler.setLevel(logging.DEBUG)
19
+ formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
20
+ file_handler.setFormatter(formatter)
21
+ logger = logging.getLogger("gradio_webrtc")
22
+ logger.setLevel(logging.DEBUG)
23
+ logger.addHandler(file_handler)
24
+
25
+
26
+ IP = "0.0.0.0"
27
+ PORT = 60808
28
+
29
+ thread = Thread(target=serve, daemon=True)
30
+ thread.start()
31
+
32
+
33
+ API_URL = "http://0.0.0.0:60808/chat"
34
+
35
+ account_sid = os.environ.get("TWILIO_ACCOUNT_SID")
36
+ auth_token = os.environ.get("TWILIO_AUTH_TOKEN")
37
+
38
+ if account_sid and auth_token:
39
+ client = Client(account_sid, auth_token)
40
+
41
+ token = client.tokens.create()
42
+
43
+ rtc_configuration = {
44
+ "iceServers": token.ice_servers,
45
+ "iceTransportPolicy": "relay",
46
+ }
47
+ else:
48
+ rtc_configuration = None
49
+
50
+ OUT_CHANNELS = 1
51
+ OUT_RATE = 24000
52
+ OUT_SAMPLE_WIDTH = 2
53
+ OUT_CHUNK = 20 * 4096
54
+
55
+
56
+ def response(audio: tuple[int, np.ndarray], conversation: list[dict], img: str | None):
57
+ conversation.append({"role": "user", "content": gr.Audio(audio)})
58
+ yield AdditionalOutputs(conversation)
59
+
60
+ sampling_rate, audio_np = audio
61
+ audio_np = audio_np.squeeze()
62
+
63
+ audio_buffer = io.BytesIO()
64
+ segment = AudioSegment(
65
+ audio_np.tobytes(),
66
+ frame_rate=sampling_rate,
67
+ sample_width=audio_np.dtype.itemsize,
68
+ channels=1,
69
+ )
70
+
71
+ segment.export(audio_buffer, format="wav")
72
+ conversation.append({"role": "assistant", "content": ""})
73
+
74
+ base64_encoded = str(base64.b64encode(audio_buffer.getvalue()), encoding="utf-8")
75
+ if API_URL is not None:
76
+ output_audio_bytes = b""
77
+ files = {"audio": base64_encoded}
78
+ if img is not None:
79
+ files["image"] = str(base64.b64encode(open(img, "rb").read()), encoding="utf-8")
80
+ print("sending request to server")
81
+ resp_text = ""
82
+ with requests.post(API_URL, json=files, stream=True) as response:
83
+ try:
84
+ buffer = b''
85
+ for chunk in response.iter_content(chunk_size=2048):
86
+ buffer += chunk
87
+ while b'\r\n--frame\r\n' in buffer:
88
+ frame, buffer = buffer.split(b'\r\n--frame\r\n', 1)
89
+ if b'Content-Type: audio/wav' in frame:
90
+ audio_data = frame.split(b'\r\n\r\n', 1)[1]
91
+ # audio_data = base64.b64decode(audio_data)
92
+ output_audio_bytes += audio_data
93
+ audio_array = np.frombuffer(audio_data, dtype=np.int8).reshape(1, -1)
94
+ yield (OUT_RATE, audio_array, "mono")
95
+ elif b'Content-Type: text/plain' in frame:
96
+ text_data = frame.split(b'\r\n\r\n', 1)[1].decode()
97
+ resp_text += text_data
98
+ if len(text_data) > 0:
99
+ conversation[-1]["content"] = resp_text
100
+ yield AdditionalOutputs(conversation)
101
+ except Exception as e:
102
+ raise Exception(f"Error during audio streaming: {e}") from e
103
+
104
+
105
+ with gr.Blocks() as demo:
106
+ gr.HTML(
107
+ """
108
+ <h1 style='text-align: center'>
109
+ Mini-Omni-2 Chat (Powered by WebRTC ⚡️)
110
+ </h1>
111
+ """
112
+ )
113
+ with gr.Row():
114
+ with gr.Column():
115
+ with gr.Group():
116
+ audio = WebRTC(
117
+ label="Stream",
118
+ rtc_configuration=rtc_configuration,
119
+ mode="send-receive",
120
+ modality="audio",
121
+ )
122
+ img = gr.Image(label="Image", type="filepath")
123
+ with gr.Column():
124
+ conversation = gr.Chatbot(label="Conversation", type="messages")
125
+
126
+ audio.stream(
127
+ fn=ReplyOnPause(
128
+ response, output_sample_rate=OUT_RATE, output_frame_size=480
129
+ ),
130
+ inputs=[audio, conversation, img],
131
+ outputs=[audio],
132
+ time_limit=90,
133
+ )
134
+ audio.on_additional_outputs(lambda c: c, outputs=[conversation])
135
+
136
+
137
+ demo.launch()
inference.py ADDED
@@ -0,0 +1,704 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import lightning as L
3
+ import torch
4
+ import glob
5
+ import time
6
+ from snac import SNAC
7
+ from litgpt import Tokenizer
8
+ from litgpt.utils import (
9
+ num_parameters,
10
+ )
11
+ from litgpt.generate.base import (
12
+ generate_AA,
13
+ generate_ASR,
14
+ generate_TA,
15
+ generate_TT,
16
+ generate_AT,
17
+ generate_TA_BATCH,
18
+ next_token_image_batch
19
+ )
20
+ import soundfile as sf
21
+ from litgpt.model import GPT, Config
22
+ from lightning.fabric.utilities.load import _lazy_load as lazy_load
23
+ from utils.snac_utils import layershift, reconscruct_snac, reconstruct_tensors, get_time_str
24
+ from utils.snac_utils import get_snac, generate_audio_data
25
+ import whisper
26
+ from tqdm import tqdm
27
+ from huggingface_hub import snapshot_download
28
+
29
+
30
+ torch.set_printoptions(sci_mode=False)
31
+
32
+
33
+ # TODO
34
+ text_vocabsize = 151936
35
+ text_specialtokens = 64
36
+ audio_vocabsize = 4096
37
+ audio_specialtokens = 64
38
+
39
+ padded_text_vocabsize = text_vocabsize + text_specialtokens
40
+ padded_audio_vocabsize = audio_vocabsize + audio_specialtokens
41
+
42
+ _eot = text_vocabsize
43
+ _pad_t = text_vocabsize + 1
44
+ _input_t = text_vocabsize + 2
45
+ _answer_t = text_vocabsize + 3
46
+ _asr = text_vocabsize + 4
47
+
48
+ _eoa = audio_vocabsize
49
+ _pad_a = audio_vocabsize + 1
50
+ _input_a = audio_vocabsize + 2
51
+ _answer_a = audio_vocabsize + 3
52
+ _split = audio_vocabsize + 4
53
+ _image = audio_vocabsize + 5
54
+ _eoimage = audio_vocabsize + 6
55
+
56
+
57
+ def get_input_ids_TA(text, text_tokenizer):
58
+ input_ids_item = [[] for _ in range(8)]
59
+ text_tokens = text_tokenizer.encode(text)
60
+ for i in range(7):
61
+ input_ids_item[i] = [layershift(_pad_a, i)] * (len(text_tokens) + 2) + [
62
+ layershift(_answer_a, i)
63
+ ]
64
+ input_ids_item[i] = torch.tensor(input_ids_item[i]).unsqueeze(0)
65
+ input_ids_item[-1] = [_input_t] + text_tokens.tolist() + [_eot] + [_answer_t]
66
+ input_ids_item[-1] = torch.tensor(input_ids_item[-1]).unsqueeze(0)
67
+ return input_ids_item
68
+
69
+
70
+ def get_input_ids_TT(text, text_tokenizer):
71
+ input_ids_item = [[] for i in range(8)]
72
+ text_tokens = text_tokenizer.encode(text).tolist()
73
+
74
+ for i in range(7):
75
+ input_ids_item[i] = torch.tensor(
76
+ [layershift(_pad_a, i)] * (len(text_tokens) + 3)
77
+ ).unsqueeze(0)
78
+ input_ids_item[-1] = [_input_t] + text_tokens + [_eot] + [_answer_t]
79
+ input_ids_item[-1] = torch.tensor(input_ids_item[-1]).unsqueeze(0)
80
+
81
+ return input_ids_item
82
+
83
+
84
+ def get_input_ids_whisper(
85
+ mel, leng, whispermodel, device,
86
+ special_token_a=_answer_a, special_token_t=_answer_t,
87
+ ):
88
+
89
+ with torch.no_grad():
90
+ mel = mel.unsqueeze(0).to(device)
91
+ # audio_feature = whisper.decode(whispermodel,mel, options).audio_features
92
+ audio_feature = whispermodel.embed_audio(mel)[0][:leng]
93
+
94
+ T = audio_feature.size(0)
95
+ input_ids = []
96
+ for i in range(7):
97
+ input_ids_item = []
98
+ input_ids_item.append(layershift(_input_a, i))
99
+ input_ids_item += [layershift(_pad_a, i)] * T
100
+ input_ids_item += [(layershift(_eoa, i)), layershift(special_token_a, i)]
101
+ input_ids.append(torch.tensor(input_ids_item).unsqueeze(0))
102
+ input_id_T = torch.tensor([_input_t] + [_pad_t] * T + [_eot, special_token_t])
103
+ input_ids.append(input_id_T.unsqueeze(0))
104
+ return audio_feature.unsqueeze(0), input_ids
105
+
106
+
107
+ def get_input_ids_whisper_ATBatch(mel, leng, whispermodel, device):
108
+ with torch.no_grad():
109
+ mel = mel.unsqueeze(0).to(device)
110
+ # audio_feature = whisper.decode(whispermodel,mel, options).audio_features
111
+ audio_feature = whispermodel.embed_audio(mel)[0][:leng]
112
+ T = audio_feature.size(0)
113
+ input_ids_AA = []
114
+ for i in range(7):
115
+ input_ids_item = []
116
+ input_ids_item.append(layershift(_input_a, i))
117
+ input_ids_item += [layershift(_pad_a, i)] * T
118
+ input_ids_item += [(layershift(_eoa, i)), layershift(_answer_a, i)]
119
+ input_ids_AA.append(torch.tensor(input_ids_item))
120
+ input_id_T = torch.tensor([_input_t] + [_pad_t] * T + [_eot, _answer_t])
121
+ input_ids_AA.append(input_id_T)
122
+
123
+ input_ids_AT = []
124
+ for i in range(7):
125
+ input_ids_item = []
126
+ input_ids_item.append(layershift(_input_a, i))
127
+ input_ids_item += [layershift(_pad_a, i)] * T
128
+ input_ids_item += [(layershift(_eoa, i)), layershift(_pad_a, i)]
129
+ input_ids_AT.append(torch.tensor(input_ids_item))
130
+ input_id_T = torch.tensor([_input_t] + [_pad_t] * T + [_eot, _answer_t])
131
+ input_ids_AT.append(input_id_T)
132
+
133
+ input_ids = [input_ids_AA, input_ids_AT]
134
+ stacked_inputids = [[] for _ in range(8)]
135
+ for i in range(2):
136
+ for j in range(8):
137
+ stacked_inputids[j].append(input_ids[i][j])
138
+ stacked_inputids = [torch.stack(tensors) for tensors in stacked_inputids]
139
+ return torch.stack([audio_feature, audio_feature]), stacked_inputids
140
+
141
+
142
+ def load_audio(path):
143
+ audio = whisper.load_audio(path)
144
+ duration_ms = (len(audio) / 16000) * 1000
145
+ audio = whisper.pad_or_trim(audio)
146
+ mel = whisper.log_mel_spectrogram(audio)
147
+ return mel, int(duration_ms / 20) + 1
148
+
149
+
150
+ def A1_A2_batch(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
151
+ snacmodel, out_dir=None):
152
+ with fabric.init_tensor():
153
+ model.set_kv_cache(batch_size=2)
154
+ tokenlist = generate_TA_BATCH(
155
+ model,
156
+ audio_feature,
157
+ input_ids,
158
+ [leng, leng],
159
+ ["A1A2", "A1T2"],
160
+ max_returned_tokens=2048,
161
+ temperature=0.9,
162
+ top_k=1,
163
+ eos_id_a=_eoa,
164
+ eos_id_t=_eot,
165
+ pad_id_t=_pad_t,
166
+ shift=padded_text_vocabsize,
167
+ include_prompt=True,
168
+ generate_text=True,
169
+ )
170
+ text_tokenlist = tokenlist[-1]
171
+ if text_vocabsize in text_tokenlist:
172
+ text_tokenlist = text_tokenlist[: text_tokenlist.index(text_vocabsize)]
173
+ text = text_tokenizer.decode(torch.tensor(text_tokenlist)).strip()
174
+
175
+ audio_tokenlist = tokenlist[:-1]
176
+ audiolist = reconscruct_snac(audio_tokenlist)
177
+ audio = reconstruct_tensors(audiolist)
178
+ if out_dir is None:
179
+ out_dir = "./output/default/A1-A2-batch"
180
+ else:
181
+ out_dir = out_dir + "/A1-A2-batch"
182
+ if not os.path.exists(out_dir):
183
+ os.makedirs(out_dir)
184
+ with torch.inference_mode():
185
+ audio_hat = snacmodel.decode(audio)
186
+ sf.write(
187
+ f"{out_dir}/{step:02d}.wav",
188
+ audio_hat.squeeze().cpu().numpy(),
189
+ 24000,
190
+ )
191
+ model.clear_kv_cache()
192
+ return text
193
+
194
+
195
+ def A1_T2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
196
+ with fabric.init_tensor():
197
+ model.set_kv_cache(batch_size=1)
198
+ tokenlist = generate_AT(
199
+ model,
200
+ audio_feature,
201
+ input_ids,
202
+ [leng],
203
+ ["AT"],
204
+ max_returned_tokens=2048,
205
+ temperature=0.9,
206
+ top_k=1,
207
+ eos_id_a=_eoa,
208
+ eos_id_t=_eot,
209
+ pad_id_t=_pad_t,
210
+ shift=padded_text_vocabsize,
211
+ include_prompt=True,
212
+ generate_text=True,
213
+ )
214
+ return text_tokenizer.decode(torch.tensor(tokenlist)).strip()
215
+
216
+
217
+ def A1_A2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
218
+ snacmodel, out_dir=None):
219
+ with fabric.init_tensor():
220
+ model.set_kv_cache(batch_size=1)
221
+ tokenlist = generate_AA(
222
+ model,
223
+ audio_feature,
224
+ input_ids,
225
+ [leng],
226
+ ["A1T2"],
227
+ max_returned_tokens=2048,
228
+ temperature=0.9,
229
+ top_k=1,
230
+ eos_id_a=_eoa,
231
+ eos_id_t=_eot,
232
+ pad_id_t=_pad_t,
233
+ shift=padded_text_vocabsize,
234
+ include_prompt=True,
235
+ generate_text=True,
236
+ )
237
+ audiolist = reconscruct_snac(tokenlist)
238
+ tokenlist = tokenlist[-1]
239
+ if text_vocabsize in tokenlist:
240
+ tokenlist = tokenlist[: tokenlist.index(text_vocabsize)]
241
+ if out_dir is None:
242
+ out_dir = "./output/default/A1-A2"
243
+ else:
244
+ out_dir = out_dir + "/A1-A2"
245
+ if not os.path.exists(out_dir):
246
+ os.makedirs(out_dir)
247
+
248
+ audio = reconstruct_tensors(audiolist)
249
+ with torch.inference_mode():
250
+ audio_hat = snacmodel.decode(audio)
251
+ sf.write(
252
+ f"{out_dir}/{step:02d}.wav",
253
+ audio_hat.squeeze().cpu().numpy(),
254
+ 24000,
255
+ )
256
+ model.clear_kv_cache()
257
+ return text_tokenizer.decode(torch.tensor(tokenlist)).strip()
258
+
259
+
260
+ def A1_T1(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
261
+ with fabric.init_tensor():
262
+ model.set_kv_cache(batch_size=1)
263
+ tokenlist = generate_ASR(
264
+ model,
265
+ audio_feature,
266
+ input_ids,
267
+ [leng],
268
+ ["A1T1"],
269
+ max_returned_tokens=2048,
270
+ temperature=0.9,
271
+ top_k=1,
272
+ eos_id_a=_eoa,
273
+ eos_id_t=_eot,
274
+ pad_id_t=_pad_t,
275
+ shift=padded_text_vocabsize,
276
+ include_prompt=True,
277
+ generate_text=True,
278
+ )
279
+ model.clear_kv_cache()
280
+ return text_tokenizer.decode(torch.tensor(tokenlist)).strip()
281
+
282
+
283
+ def T1_A2(fabric, input_ids, model, text_tokenizer, step,
284
+ snacmodel, out_dir=None):
285
+ with fabric.init_tensor():
286
+ model.set_kv_cache(batch_size=1)
287
+ tokenlist = generate_TA(
288
+ model,
289
+ None,
290
+ input_ids,
291
+ None,
292
+ ["T1A2"],
293
+ max_returned_tokens=2048,
294
+ temperature=0.9,
295
+ top_k=1,
296
+ eos_id_a=_eoa,
297
+ eos_id_t=_eot,
298
+ pad_id_t=_pad_t,
299
+ shift=padded_text_vocabsize,
300
+ include_prompt=True,
301
+ generate_text=True,
302
+ )
303
+
304
+ audiolist = reconscruct_snac(tokenlist)
305
+ tokenlist = tokenlist[-1]
306
+
307
+ if text_vocabsize in tokenlist:
308
+ tokenlist = tokenlist[: tokenlist.index(text_vocabsize)]
309
+ audio = reconstruct_tensors(audiolist)
310
+ if out_dir is None:
311
+ out_dir = "./output/default/T1-A2"
312
+ else:
313
+ out_dir = out_dir + "/T1-A2"
314
+ if not os.path.exists(out_dir):
315
+ os.makedirs(out_dir)
316
+
317
+ with torch.inference_mode():
318
+ audio_hat = snacmodel.decode(audio)
319
+ sf.write(
320
+ f"{out_dir}/{step:02d}.wav",
321
+ audio_hat.squeeze().cpu().numpy(),
322
+ 24000,
323
+ )
324
+ model.clear_kv_cache()
325
+ return text_tokenizer.decode(torch.tensor(tokenlist)).strip()
326
+
327
+
328
+ def T1_T2(fabric, input_ids, model, text_tokenizer, step):
329
+
330
+ with fabric.init_tensor():
331
+ model.set_kv_cache(batch_size=1)
332
+ tokenlist = generate_TT(
333
+ model,
334
+ None,
335
+ input_ids,
336
+ None,
337
+ ["T1T2"],
338
+ max_returned_tokens=2048,
339
+ temperature=0.9,
340
+ top_k=1,
341
+ eos_id_a=_eoa,
342
+ eos_id_t=_eot,
343
+ pad_id_t=_pad_t,
344
+ shift=padded_text_vocabsize,
345
+ include_prompt=True,
346
+ generate_text=True,
347
+ )
348
+ model.clear_kv_cache()
349
+ return text_tokenizer.decode(torch.tensor(tokenlist)).strip()
350
+
351
+
352
+ def load_model(ckpt_dir, device):
353
+ snacmodel = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval().to(device)
354
+ whisper_model_path = ckpt_dir + "/small.pt"
355
+ if not os.path.exists(whisper_model_path):
356
+ whisper_model_path = "small"
357
+ whispermodel = whisper.load_model(whisper_model_path).to(device)
358
+ text_tokenizer = Tokenizer(ckpt_dir)
359
+ fabric = L.Fabric(devices=1, strategy="auto")
360
+ config = Config.from_file(ckpt_dir + "/model_config.yaml")
361
+ config.post_adapter = False
362
+
363
+ with fabric.init_module(empty_init=False):
364
+ model = GPT(config)
365
+
366
+ model = fabric.setup(model)
367
+ state_dict = lazy_load(ckpt_dir + "/lit_model.pth")
368
+ model.load_state_dict(state_dict, strict=True)
369
+ model.to(device).eval()
370
+
371
+ return fabric, model, text_tokenizer, snacmodel, whispermodel
372
+
373
+
374
+ def download_model(ckpt_dir):
375
+ repo_id = "gpt-omni/mini-omni2"
376
+ snapshot_download(repo_id, local_dir=ckpt_dir, revision="main")
377
+
378
+
379
+ def get_text_stream(list_output, index, text_tokenizer):
380
+ text_tokens = list_output[-1][index:]
381
+ index += len(text_tokens)
382
+ is_text_end = False
383
+ if text_vocabsize in text_tokens:
384
+ text_tokens = text_tokens[:text_tokens.index(text_vocabsize)]
385
+ is_text_end = True
386
+ if len(text_tokens) == 0:
387
+ return "", index, is_text_end
388
+ res_text = text_tokenizer.decode(torch.tensor(text_tokens))
389
+ return res_text, index, is_text_end
390
+
391
+
392
+ class OmniInference:
393
+
394
+ def __init__(self, ckpt_dir='./checkpoint', device='cuda:0'):
395
+ self.device = device
396
+ if not os.path.exists(ckpt_dir):
397
+ print(f"checkpoint directory {ckpt_dir} not found, downloading from huggingface")
398
+ download_model(ckpt_dir)
399
+ self.fabric, self.model, self.text_tokenizer, self.snacmodel, self.whispermodel = load_model(ckpt_dir, device)
400
+
401
+ def warm_up(self, sample='./data/samples/output1.wav'):
402
+ for _ in self.run_AT_batch_stream(sample):
403
+ pass
404
+
405
+ @torch.inference_mode()
406
+ def run_AT_batch_stream(self,
407
+ audio_path,
408
+ stream_stride=4,
409
+ max_returned_tokens=2048,
410
+ temperature=0.9,
411
+ top_k=1,
412
+ top_p=1.0,
413
+ eos_id_a=_eoa,
414
+ eos_id_t=_eot,
415
+ save_path=None
416
+ ):
417
+
418
+ assert os.path.exists(audio_path), f"audio file {audio_path} not found"
419
+ model = self.model
420
+
421
+ with self.fabric.init_tensor():
422
+ model.set_kv_cache(batch_size=2,device=self.device)
423
+
424
+ mel, leng = load_audio(audio_path)
425
+ audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, self.whispermodel, self.device)
426
+ T = input_ids[0].size(1)
427
+ device = input_ids[0].device
428
+
429
+ assert max_returned_tokens > T, f"max_returned_tokens {max_returned_tokens} should be greater than audio length {T}"
430
+
431
+ if model.max_seq_length < max_returned_tokens - 1:
432
+ raise NotImplementedError(
433
+ f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}"
434
+ )
435
+
436
+ input_pos = torch.tensor([T], device=device)
437
+ list_output = [[] for i in range(8)]
438
+ tokens_A, token_T = next_token_image_batch(
439
+ model,
440
+ audio_feature.to(torch.float32).to(model.device),
441
+ None,
442
+ input_ids,
443
+ [T - 3, T - 3],
444
+ ["A1T2", "A1T2"],
445
+ input_pos=torch.arange(0, T, device=device),
446
+ temperature=temperature,
447
+ top_k=top_k,
448
+ top_p=top_p,
449
+ )
450
+
451
+ for i in range(7):
452
+ list_output[i].append(tokens_A[i].tolist()[0])
453
+ list_output[7].append(token_T.tolist()[0])
454
+
455
+ model_input_ids = [[] for i in range(8)]
456
+ for i in range(7):
457
+ tokens_A[i] = tokens_A[i].clone() + padded_text_vocabsize + i * padded_audio_vocabsize
458
+ model_input_ids[i].append(tokens_A[i].clone().to(device).to(torch.int32))
459
+ model_input_ids[i].append(torch.tensor([layershift(4097, i)], device=device))
460
+ model_input_ids[i] = torch.stack(model_input_ids[i])
461
+
462
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
463
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
464
+ model_input_ids[-1] = torch.stack(model_input_ids[-1])
465
+
466
+ text_end = False
467
+ index = 1
468
+ nums_generate = stream_stride
469
+ begin_generate = False
470
+ current_index = 0
471
+
472
+ text_index = 0
473
+ is_text_end = False
474
+
475
+ for _ in tqdm(range(2, max_returned_tokens - T + 1)):
476
+ tokens_A, token_T = next_token_image_batch(
477
+ model,
478
+ None,
479
+ None,
480
+ model_input_ids,
481
+ None,
482
+ None,
483
+ input_pos=input_pos,
484
+ temperature=temperature,
485
+ top_k=top_k,
486
+ top_p=top_p,
487
+ )
488
+
489
+ if text_end:
490
+ token_T = torch.tensor([_pad_t], device=device)
491
+
492
+ if tokens_A[-1] == eos_id_a:
493
+ break
494
+
495
+ if token_T == eos_id_t:
496
+ text_end = True
497
+
498
+ for i in range(7):
499
+ list_output[i].append(tokens_A[i].tolist()[0])
500
+ list_output[7].append(token_T.tolist()[0])
501
+
502
+ model_input_ids = [[] for i in range(8)]
503
+ for i in range(7):
504
+ tokens_A[i] = tokens_A[i].clone() +padded_text_vocabsize + i * padded_audio_vocabsize
505
+ model_input_ids[i].append(tokens_A[i].clone().to(device).to(torch.int32))
506
+ model_input_ids[i].append(
507
+ torch.tensor([layershift(4097, i)], device=device)
508
+ )
509
+ model_input_ids[i] = torch.stack(model_input_ids[i])
510
+
511
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
512
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
513
+ model_input_ids[-1] = torch.stack(model_input_ids[-1])
514
+
515
+ if index == 7:
516
+ begin_generate = True
517
+
518
+ if begin_generate:
519
+ current_index += 1
520
+ if current_index == nums_generate:
521
+ current_index = 0
522
+ snac = get_snac(list_output, index, nums_generate)
523
+ audio_stream = generate_audio_data(snac, self.snacmodel, self.device)
524
+ if is_text_end:
525
+ text_stream = ""
526
+ else:
527
+ text_stream, text_index, is_text_end = get_text_stream(list_output, text_index, self.text_tokenizer)
528
+
529
+ yield (audio_stream, text_stream)
530
+
531
+ input_pos = input_pos.add_(1)
532
+ index += 1
533
+ text = self.text_tokenizer.decode(torch.tensor(list_output[-1]))
534
+ print(f"text output: {text}")
535
+
536
+ if save_path is not None:
537
+ audiolist = reconscruct_snac(list_output)
538
+ audio = reconstruct_tensors(audiolist)
539
+ with torch.inference_mode():
540
+ audio_hat = self.snacmodel.decode(audio)
541
+ sf.write(save_path, audio_hat.squeeze().cpu().numpy(), 24000)
542
+
543
+ model.clear_kv_cache()
544
+ return list_output
545
+
546
+
547
+ def test_infer():
548
+ device = "cuda:0"
549
+ out_dir = f"./output/{get_time_str()}"
550
+ ckpt_dir = f"./checkpoint"
551
+ if not os.path.exists(ckpt_dir):
552
+ print(f"checkpoint directory {ckpt_dir} not found, downloading from huggingface")
553
+ download_model(ckpt_dir)
554
+
555
+ fabric, model, text_tokenizer, snacmodel, whispermodel = load_model(ckpt_dir, device)
556
+
557
+ task = ['A1A2', 'asr', "T1A2", "AA-BATCH", 'T1T2', 'AT']
558
+
559
+ # prepare test data
560
+ # TODO
561
+ test_audio_list = sorted(glob.glob('./data/samples/output*.wav'))
562
+ test_audio_transcripts = [
563
+ "What is your name?",
564
+ "what are your hobbies?",
565
+ "Do you like beijing",
566
+ "How are you feeling today?",
567
+ "what is the weather like today?",
568
+ ]
569
+ test_text_list = [
570
+ "What is your name?",
571
+ "How are you feeling today?",
572
+ "Can you describe your surroundings?",
573
+ "What did you do yesterday?",
574
+ "What is your favorite book and why?",
575
+ "How do you make a cup of tea?",
576
+ "What is the weather like today?",
577
+ "Can you explain the concept of time?",
578
+ "Can you tell me a joke?",
579
+ ]
580
+
581
+ # LOAD MODEL
582
+ with torch.no_grad():
583
+ if "A1A2" in task:
584
+ print("===============================================================")
585
+ print(" testing A1A2")
586
+ print("===============================================================")
587
+ step = 0
588
+ for path in test_audio_list:
589
+ try:
590
+ mel, leng = load_audio(path)
591
+ audio_feature, input_ids = get_input_ids_whisper(mel, leng, whispermodel, device)
592
+ text = A1_A2(
593
+ fabric,
594
+ audio_feature,
595
+ input_ids,
596
+ leng,
597
+ model,
598
+ text_tokenizer,
599
+ step,
600
+ snacmodel,
601
+ out_dir=out_dir,
602
+ )
603
+ print(f"input: {test_audio_transcripts[step]}")
604
+ print(f"output: {text}")
605
+ step += 1
606
+ print(
607
+ "+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++"
608
+ )
609
+ except:
610
+ print(f"[error] failed to process {path}")
611
+ print("===============================================================")
612
+
613
+ if 'asr' in task:
614
+ print("===============================================================")
615
+ print(" testing asr")
616
+ print("===============================================================")
617
+
618
+ index = 0
619
+ step = 0
620
+ for path in test_audio_list:
621
+ mel, leng = load_audio(path)
622
+ audio_feature, input_ids = get_input_ids_whisper(mel, leng, whispermodel, device, special_token_a=_pad_a, special_token_t=_asr)
623
+ output = A1_T1(fabric, audio_feature, input_ids ,leng, model, text_tokenizer, index).lower().replace(',','').replace('.','').replace('?','')
624
+ print(f"audio_path: {path}")
625
+ print(f"audio transcript: {test_audio_transcripts[index]}")
626
+ print(f"asr output: {output}")
627
+ print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
628
+ index += 1
629
+
630
+ if "T1A2" in task:
631
+ step = 0
632
+ print("\n")
633
+ print("===============================================================")
634
+ print(" testing T1A2")
635
+ print("===============================================================")
636
+ for text in test_text_list:
637
+ input_ids = get_input_ids_TA(text, text_tokenizer)
638
+ text_output = T1_A2(fabric, input_ids, model, text_tokenizer, step,
639
+ snacmodel, out_dir=out_dir)
640
+ print(f"input: {text}")
641
+ print(f"output: {text_output}")
642
+ print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
643
+ step += 1
644
+ print("===============================================================")
645
+
646
+ if "T1T2" in task:
647
+ step = 0
648
+ print("\n")
649
+ print("===============================================================")
650
+ print(" testing T1T2")
651
+ print("===============================================================")
652
+
653
+ for text in test_text_list:
654
+ input_ids = get_input_ids_TT(text, text_tokenizer)
655
+ text_output = T1_T2(fabric, input_ids, model, text_tokenizer, step)
656
+ print(f" Input: {text}")
657
+ print(f"Output: {text_output}")
658
+ print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
659
+ print("===============================================================")
660
+
661
+ if "AT" in task:
662
+ print("===============================================================")
663
+ print(" testing A1T2")
664
+ print("===============================================================")
665
+ step = 0
666
+ for path in test_audio_list:
667
+ mel, leng = load_audio(path)
668
+ audio_feature, input_ids = get_input_ids_whisper(
669
+ mel, leng, whispermodel, device,
670
+ special_token_a=_pad_a, special_token_t=_answer_t
671
+ )
672
+ text = A1_T2(
673
+ fabric, audio_feature, input_ids, leng, model, text_tokenizer, step
674
+ )
675
+ print(f"input: {test_audio_transcripts[step]}")
676
+ print(f"output: {text}")
677
+ step += 1
678
+ print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
679
+ print("===============================================================")
680
+
681
+ if "AA-BATCH" in task:
682
+ print("===============================================================")
683
+ print(" testing A1A2-BATCH")
684
+ print("===============================================================")
685
+ step = 0
686
+ for path in test_audio_list:
687
+ mel, leng = load_audio(path)
688
+ audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, whispermodel, device)
689
+ text = A1_A2_batch(
690
+ fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
691
+ snacmodel, out_dir=out_dir
692
+ )
693
+ print(f"input: {test_audio_transcripts[step]}")
694
+ print(f"output: {text}")
695
+ step += 1
696
+ print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
697
+ print("===============================================================")
698
+
699
+ print("*********************** test end *****************************")
700
+
701
+
702
+
703
+ if __name__ == "__main__":
704
+ test_infer()
inference_vision.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from litgpt.generate.base import next_token_image_batch
4
+ import soundfile as sf
5
+ from utils.snac_utils import layershift, reconscruct_snac, reconstruct_tensors, get_time_str
6
+ from utils.snac_utils import get_snac, generate_audio_data
7
+ import clip
8
+ import inference
9
+ from tqdm import tqdm
10
+ from inference import OmniInference, load_model, load_audio, download_model
11
+ from inference import text_vocabsize, padded_text_vocabsize, get_text_stream
12
+ from PIL import Image
13
+
14
+
15
+ torch.set_printoptions(sci_mode=False)
16
+
17
+ _image = inference._image
18
+ _eoimage = inference._eoimage
19
+ _pad_t = inference._pad_t
20
+ _input_t = inference._input_t
21
+ _answer_t = inference._answer_t
22
+ _eot = inference._eot
23
+ _eoa = inference._eoa
24
+ _pad_a = inference._pad_a
25
+ _input_a = inference._input_a
26
+ _answer_a = inference._answer_a
27
+
28
+
29
+ def get_input_ids_ImageQA_ATBatch(mel, leng, whispermodel, device):
30
+
31
+ with torch.no_grad():
32
+ mel = mel.unsqueeze(0).to(device)
33
+ audio_feature = whispermodel.embed_audio(mel)[0][:leng]
34
+
35
+ audio_len = audio_feature.size(0)
36
+
37
+ input_ids = []
38
+ input_ids_item = [[] for i in range(8)]
39
+ for i in range(7):
40
+ input_ids_item[i] = [layershift(_image,i)] + [layershift(_pad_a,i)] * 50 + [layershift(_eoimage,i)]
41
+ input_ids_item[i] += [layershift(_input_a,i)]+[layershift(_pad_a,i)]*(audio_len)+[layershift(_eoa,i)]
42
+ input_ids_item[i] += [layershift(_answer_a,i)]
43
+
44
+ input_ids_item[-1] = [_pad_t]* (52 + 2 + audio_len) + [_answer_t]
45
+ input_ids_item = [torch.tensor(item) for item in input_ids_item]
46
+
47
+ input_ids.append(input_ids_item)
48
+
49
+ input_ids_item = [[] for i in range(8)]
50
+ for i in range(7):
51
+ input_ids_item[i] = [layershift(_image,i)] + [layershift(_pad_a,i)] * 50 + [layershift(_eoimage,i)]
52
+ input_ids_item[i] += [layershift(_input_a,i)]+[layershift(_pad_a,i)]*(audio_len)+[layershift(_eoa,i)] + [layershift(_pad_a,i)]
53
+
54
+ input_ids_item[-1] = [_pad_t]* (52 + 2 + audio_len) + [_answer_t]
55
+
56
+ input_ids_item = [torch.tensor(item) for item in input_ids_item]
57
+ input_ids.append(input_ids_item)
58
+
59
+ stacked_inputids = [[] for _ in range(8)]
60
+ for i in range(2):
61
+ for j in range(8):
62
+ stacked_inputids[j].append(input_ids[i][j])
63
+ stacked_inputids = [torch.stack(tensors) for tensors in stacked_inputids]
64
+
65
+ return torch.stack([audio_feature,audio_feature]), stacked_inputids
66
+
67
+
68
+ def load_clip_model(ckpt_dir, device):
69
+ clip_model_path = ckpt_dir + "/ViT-B-32.pt"
70
+ if not os.path.exists(clip_model_path):
71
+ clip_model_path = "ViT-B/32"
72
+ clipmodel, clippreprocess = clip.load(clip_model_path, device=device)
73
+ return clipmodel, clippreprocess
74
+
75
+
76
+ class OmniVisionInference(OmniInference):
77
+
78
+ def __init__(self, ckpt_dir='./checkpoint', device='cuda:0'):
79
+ self.device = device
80
+ if not os.path.exists(ckpt_dir):
81
+ print(f"checkpoint directory {ckpt_dir} not found, downloading from huggingface")
82
+ download_model(ckpt_dir)
83
+ self.fabric, self.model, self.text_tokenizer, self.snacmodel, self.whispermodel = load_model(ckpt_dir, device)
84
+ self.clipmodel, self.clippreprocess = load_clip_model(ckpt_dir, device)
85
+
86
+ def warm_up(self,
87
+ audio_sample='./data/samples/vision_qa_audio.wav',
88
+ image_sample='./data/samples/vision_qa_image.jpg'
89
+ ):
90
+ for _ in self.run_vision_AA_batch_stream(audio_sample, image_sample,
91
+ save_path="./data/samples/vision_qa_output.wav",
92
+ warm_up=True):
93
+ pass
94
+
95
+ @torch.inference_mode()
96
+ def run_vision_AA_batch_stream(self, audio_path, image_path,
97
+ stream_stride=4,
98
+ max_returned_tokens=2048,
99
+ temperature=0.9,
100
+ top_k=1,
101
+ top_p=1.0,
102
+ eos_id_a=_eoa,
103
+ eos_id_t=_eot,
104
+ pad_id=_pad_t,
105
+ save_path=None,
106
+ warm_up=False
107
+ ):
108
+ with self.fabric.init_tensor():
109
+ self.model.set_kv_cache(batch_size=2)
110
+
111
+ model = self.model
112
+
113
+ mel, leng = load_audio(audio_path)
114
+ img = Image.open(image_path)
115
+
116
+ audio_feature, input_ids = get_input_ids_ImageQA_ATBatch(mel, leng, self.whispermodel, self.device)
117
+ ima = self.clippreprocess(img).unsqueeze(0).to(self.device)
118
+ ima_feature = self.clipmodel.encode_image(ima).squeeze(0).to(self.device)
119
+
120
+ ima_feature = torch.stack([ima_feature.clone(),ima_feature.clone()]).to(self.device)
121
+ leng = [leng,leng]
122
+ task = ['ImageQA_A','ImageQA_AT']
123
+
124
+ T = input_ids[0].size(1)
125
+ assert max_returned_tokens > T, f"max_returned_tokens {max_returned_tokens} should be greater than audio length {T}"
126
+
127
+ if model.max_seq_length < max_returned_tokens - 1:
128
+ raise NotImplementedError(
129
+ f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}"
130
+ )
131
+
132
+ list_output = [[] for i in range(8)]
133
+
134
+ tokens_A , token_T = next_token_image_batch(
135
+ model,
136
+ audio_feature.to(torch.float32).to(self.device),
137
+ ima_feature.to(torch.float32).to(self.device) ,
138
+ input_ids ,
139
+ whisper_lens = leng ,
140
+ task = task,
141
+ input_pos = torch.arange(0, T, device=self.device),
142
+ temperature=temperature,
143
+ top_k=top_k,
144
+ top_p=top_p
145
+ )
146
+ for i in range(7): list_output[i].append(tokens_A[i].tolist()[0])
147
+ list_output[7].append(token_T.tolist()[0])
148
+
149
+ text_end = False
150
+ index = 1
151
+ nums_generate = stream_stride
152
+ begin_generate = False
153
+ current_index = 0
154
+ input_pos = torch.tensor([T], device=self.device)
155
+
156
+ model_input_ids = [[] for i in range(8)]
157
+ for i in range(7):
158
+ tokens_A[i] = tokens_A[i].clone() + padded_text_vocabsize+ i * 4160
159
+ model_input_ids[i].append(tokens_A[i].clone().to(self.device).to(torch.int32))
160
+ model_input_ids[i].append(torch.tensor([layershift(4097,i)],device=self.device))
161
+ model_input_ids[i] = torch.stack(model_input_ids[i])
162
+
163
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
164
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
165
+ model_input_ids[-1] = torch.stack(model_input_ids[-1])
166
+
167
+ text_index = 0
168
+ is_text_end = False
169
+
170
+ for _ in tqdm(range(2, max_returned_tokens - T + 1)):
171
+
172
+ tokens_A , token_T = next_token_image_batch(model, None , None ,
173
+ input_ids = model_input_ids,
174
+ whisper_lens= None,
175
+ task = None,
176
+ input_pos = input_pos,
177
+ temperature=temperature,
178
+ top_k=top_k,
179
+ top_p=top_p)
180
+
181
+ if text_end:
182
+ token_T = torch.tensor([_pad_t], device=self.device)
183
+
184
+ if tokens_A[-1] == eos_id_a:
185
+ break
186
+ if token_T == eos_id_t:
187
+ text_end = True
188
+
189
+ for i in range(7): list_output[i].append(tokens_A[i].tolist()[0])
190
+ list_output[7].append(token_T.tolist()[0])
191
+
192
+
193
+ if index == 7:
194
+ begin_generate = True
195
+
196
+ if begin_generate:
197
+ current_index += 1
198
+ if current_index == nums_generate:
199
+ current_index = 0
200
+ snac = get_snac(list_output,index,nums_generate)
201
+ audio_stream = generate_audio_data(snac, self.snacmodel, self.device)
202
+ if is_text_end:
203
+ text_stream = ""
204
+ else:
205
+ text_stream, text_index, is_text_end = get_text_stream(list_output, text_index, self.text_tokenizer)
206
+
207
+ yield (audio_stream, text_stream)
208
+
209
+ if warm_up:
210
+ break
211
+
212
+ input_pos = input_pos.add_(1)
213
+ model_input_ids = [[] for i in range(8)]
214
+ for i in range(7):
215
+ tokens_A[i] = tokens_A[i].clone() + padded_text_vocabsize+ i * 4160
216
+ model_input_ids[i].append(tokens_A[i].clone().to(self.device).to(torch.int32))
217
+ model_input_ids[i].append(torch.tensor([layershift(4097,i)],device=self.device))
218
+ model_input_ids[i] = torch.stack(model_input_ids[i])
219
+
220
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
221
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
222
+ model_input_ids[-1] = torch.stack(model_input_ids[-1])
223
+
224
+ index += 1
225
+
226
+ text_tokens = list_output[-1]
227
+ if text_vocabsize in text_tokens:
228
+ text_tokens = text_tokens[:text_tokens.index(text_vocabsize)]
229
+ res_text = self.text_tokenizer.decode(torch.tensor(text_tokens))
230
+ print(f"text output: {res_text}")
231
+
232
+ if save_path is not None:
233
+ audiolist = reconscruct_snac(list_output)
234
+ audio = reconstruct_tensors(audiolist)
235
+ with torch.inference_mode():
236
+ audio_hat = self.snacmodel.decode(audio)
237
+ sf.write(save_path, audio_hat.squeeze().cpu().numpy(), 24000)
238
+
239
+ model.clear_kv_cache()
240
+
241
+
242
+ def test_vision_infer():
243
+ client = OmniVisionInference()
244
+ client.warm_up()
245
+ input_audio_path = './data/samples/vision_qa_audio.wav'
246
+ input_image_path = './data/samples/vision_qa_image.jpg'
247
+
248
+ res_text = ""
249
+ for audio_stream, text_stream in client.run_vision_AA_batch_stream(
250
+ input_audio_path,
251
+ input_image_path,
252
+ save_path="./vision_qa_output.wav"
253
+ ):
254
+ res_text += text_stream
255
+ print(f"text_output: {res_text}")
256
+
257
+
258
+ if __name__ == "__main__":
259
+ test_vision_infer()
litgpt/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2
+
3
+ import logging
4
+ import re
5
+ from litgpt.model import GPT # needs to be imported before config
6
+ from litgpt.config import Config
7
+ from litgpt.tokenizer import Tokenizer
8
+
9
+ # Suppress excessive warnings, see https://github.com/pytorch/pytorch/issues/111632
10
+ pattern = re.compile(".*Profiler function .* will be ignored")
11
+ logging.getLogger("torch._dynamo.variables.torch").addFilter(
12
+ lambda record: not pattern.search(record.getMessage())
13
+ )
14
+
15
+ # Avoid printing state-dict profiling output at the WARNING level when saving a checkpoint
16
+ logging.getLogger("torch.distributed.fsdp._optim_utils").disabled = True
17
+ logging.getLogger("torch.distributed.fsdp._debug_utils").disabled = True
18
+
19
+ __all__ = ["GPT", "Config", "Tokenizer"]
litgpt/config.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2
+
3
+ from copy import deepcopy
4
+ from dataclasses import dataclass, field
5
+ from pathlib import Path
6
+ from typing import Any, Literal, Optional, Type, Union
7
+
8
+ import torch
9
+ import yaml
10
+ from typing_extensions import Self
11
+
12
+ import litgpt.model
13
+ from litgpt.utils import find_multiple
14
+
15
+
16
+ @dataclass
17
+ class Config:
18
+ name: str = ""
19
+ hf_config: dict = field(default_factory=dict)
20
+ scale_embeddings: bool = False
21
+ block_size: int = 4096
22
+ vocab_size: int = 50254
23
+ padding_multiple: int = 512
24
+ padded_vocab_size: Optional[int] = None
25
+ n_layer: int = 16
26
+ n_head: int = 32
27
+ head_size: Optional[int] = None
28
+ n_embd: int = 4096
29
+ rotary_percentage: float = 0.25
30
+ parallel_residual: bool = True
31
+ bias: bool = True
32
+ lm_head_bias: bool = False
33
+ # to use multi-head attention (MHA), set this to `n_head` (default)
34
+ # to use multi-query attention (MQA), set this to 1
35
+ # to use grouped-query attention (GQA), set this to a value in between
36
+ # Example with `n_head=4`
37
+ # ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐
38
+ # │ v ││ v ││ v ││ v │ │ v │ │ v │ │ v │
39
+ # └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘
40
+ # │ │ │ │ │ │ │
41
+ # ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐
42
+ # │ k ││ k ││ k ││ k │ │ k │ │ k │ │ k │
43
+ # └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘
44
+ # │ │ │ │ ┌──┴──┐ ┌──┴──┐ ┌────┬──┴─┬────┐
45
+ # ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐
46
+ # │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │
47
+ # └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘
48
+ # ◀──────────────────▶ ◀──────────────────▶ ◀──────────────────▶
49
+ # MHA GQA MQA
50
+ # n_query_groups=4 n_query_groups=2 n_query_groups=1
51
+ #
52
+ # credit https://arxiv.org/pdf/2305.13245.pdf
53
+ n_query_groups: Optional[int] = None
54
+ shared_attention_norm: bool = False
55
+ norm_class_name: Literal["LayerNorm", "RMSNorm"] = "LayerNorm"
56
+ norm_eps: float = 1e-5
57
+ mlp_class_name: Literal["GptNeoxMLP", "LLaMAMLP", "GemmaMLP", "LLaMAMoE"] = (
58
+ "GptNeoxMLP"
59
+ )
60
+ gelu_approximate: str = "none"
61
+ intermediate_size: Optional[int] = None
62
+ rope_condense_ratio: int = 1
63
+ rope_base: int = 10000
64
+ n_expert: int = 0
65
+ n_expert_per_token: int = 0
66
+
67
+ add_qkv_bias: Optional[bool] = None
68
+ prompt_vocab_size: Optional[int] = None
69
+ attn_dropout: float = 0.0
70
+ pos_type: str = "rope"
71
+ force_align: bool = False
72
+ use_pretrain_phoneme_emb: bool = False
73
+ tie_word_embeddings: bool = False
74
+
75
+ # setting for mini-omni
76
+ text_vocab_size:int = 152000
77
+ cat_audio_vocab_size: int = 29120
78
+ audio_vocab_size: int = 4160
79
+ whisper_adapter_dim: int = 768
80
+ vision_adapter_dim: int = 512
81
+
82
+ post_adapter: bool = False
83
+ post_adapter_layers: int = 6
84
+ asr_adapter: str = "llamamlp"
85
+
86
+ def __post_init__(self):
87
+ if not self.name:
88
+ self.name = self.hf_config.get("name", self.name)
89
+
90
+ if self.head_size is None:
91
+ assert self.n_embd % self.n_head == 0
92
+ self.head_size = self.n_embd // self.n_head
93
+
94
+ # vocab size should be a power of 2 to be optimal on hardware. compute the closest value
95
+ if self.padded_vocab_size is None:
96
+ self.padded_vocab_size = find_multiple(
97
+ self.vocab_size, self.padding_multiple
98
+ )
99
+ else:
100
+ # vocab size shouldn't be larger than padded vocab size
101
+ self.vocab_size = min(self.vocab_size, self.padded_vocab_size)
102
+
103
+ # compute the number of query groups
104
+ if self.n_query_groups is not None:
105
+ assert self.n_head % self.n_query_groups == 0
106
+ else:
107
+ self.n_query_groups = self.n_head
108
+
109
+ # compute the intermediate size for MLP if not set
110
+ if self.intermediate_size is None:
111
+ if self.mlp_class_name == "LLaMAMLP":
112
+ raise ValueError(
113
+ f"The config {self.name!r}, needs to set the `intermediate_size`"
114
+ )
115
+ self.intermediate_size = 4 * self.n_embd
116
+
117
+ self.rope_n_elem = int(self.rotary_percentage * self.head_size)
118
+
119
+ if self.add_qkv_bias is None:
120
+ self.add_qkv_bias = self.bias
121
+
122
+ @classmethod
123
+ def from_name(cls, name: str, **kwargs: Any) -> Optional[Self]:
124
+ if name not in name_to_config:
125
+ # search through all `config['hf_config']['name']`
126
+ try:
127
+ conf_dict = next(
128
+ config
129
+ for config in configs
130
+ if name == config["hf_config"]["name"]
131
+ or config["hf_config"]["org"] + "/" + config["hf_config"]["name"]
132
+ == name
133
+ )
134
+ except StopIteration:
135
+ raise ValueError(f"{name!r} is not a supported config name")
136
+ else:
137
+ conf_dict = name_to_config[name]
138
+
139
+ conf_dict = conf_dict.copy()
140
+ conf_dict.update(kwargs)
141
+ return cls(**conf_dict)
142
+
143
+ @classmethod
144
+ def from_file(cls, path: Union[str, Path], **kwargs: Any) -> Self:
145
+ with open(path, encoding="utf-8") as fp:
146
+ file_kwargs = yaml.safe_load(fp)
147
+ if file_kwargs is None:
148
+ raise ValueError(f"{path} is empty which is likely unexpected.")
149
+ file_kwargs.update(kwargs)
150
+ return cls(**file_kwargs)
151
+
152
+ @classmethod
153
+ def from_checkpoint(cls, path: Path, **kwargs: Any) -> Self:
154
+ """Automatically load `model_config.yaml` and if it doesn't exist - a matching config from `litgpt/config.py`."""
155
+ if (config_path := path / "model_config.yaml").is_file():
156
+ return cls.from_file(config_path, **kwargs)
157
+ if (model_name := path.name) in name_to_config:
158
+ return cls.from_name(model_name, **kwargs)
159
+ raise FileNotFoundError(
160
+ f"For {str(path)!r} neither 'model_config.yaml' nor matching config exists."
161
+ )
162
+
163
+ @property
164
+ def mlp_class(self) -> Type:
165
+ # `self.mlp_class_name` cannot be the type to keep the config serializable
166
+ return getattr(litgpt.model, self.mlp_class_name)
167
+
168
+ @property
169
+ def norm_class(self) -> Type:
170
+ # `self.norm_class_name` cannot be the type to keep the config serializable
171
+ if self.norm_class_name == "RMSNorm":
172
+ from functools import partial
173
+
174
+ from litgpt.model import RMSNorm
175
+
176
+ return partial(RMSNorm, add_unit_offset="Gemma" in self.name)
177
+ return getattr(torch.nn, self.norm_class_name)
178
+
179
+
180
+ configs = []
181
+ name_to_config = {config["name"]: config for config in configs}
litgpt/generate/__init__.py ADDED
File without changes
litgpt/generate/base.py ADDED
@@ -0,0 +1,795 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2
+
3
+ from typing import Any, Literal, Optional
4
+
5
+ import torch
6
+ # import torch._dynamo.config
7
+ # import torch._inductor.config
8
+
9
+ from litgpt.model import GPT
10
+ from utils.snac_utils import layershift, snac_config
11
+ from tqdm import tqdm
12
+
13
+
14
+ def multinomial_num_samples_1(probs: torch.Tensor) -> torch.Tensor:
15
+ if torch._dynamo.is_compiling():
16
+ # Faster alternative to `torch.multinomial(probs, num_samples=1)` that is also CUDAGraph friendly
17
+ distribution = torch.empty_like(probs).exponential_(1)
18
+ return torch.argmax(probs / distribution, dim=-1, keepdim=True)
19
+ return torch.multinomial(probs, num_samples=1)
20
+
21
+
22
+ def sample_top_p(logits: torch.Tensor, top_p: float) -> torch.Tensor:
23
+ sorted_logits, sorted_indices = torch.sort(logits, descending=False)
24
+ cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
25
+ # Example:
26
+ # sorted_probs=[0.1, 0.15, 0.2, 0.25, 0.3] -> sorted_cumprobs=[0.1, 0.25, 0.45, 0.7, 1.0]
27
+ # sorted_indices_to_remove = [1, 1, 0, 0, 0] if top_p=0.7
28
+ sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
29
+ # Keep at least 1 token always to prevent the case where no token is selected
30
+ # In this case the most probable one is always kept
31
+ sorted_indices_to_remove[-1:] = 0
32
+ indices_to_remove = sorted_indices_to_remove.scatter(
33
+ 0, sorted_indices, sorted_indices_to_remove
34
+ )
35
+ logits = logits.masked_fill(indices_to_remove, float("-inf"))
36
+ return logits
37
+
38
+
39
+ def sample(
40
+ logits: torch.Tensor,
41
+ temperature: float = 1.0,
42
+ top_k: Optional[int] = None,
43
+ top_p: float = 1.0,
44
+ ) -> torch.Tensor:
45
+ if top_p < 0.0 or top_p > 1.0:
46
+ raise ValueError(f"top_p must be in [0, 1], got {top_p}")
47
+ logits = logits[0, -1]
48
+ # optionally crop the logits to only the top k options
49
+ if top_k is not None:
50
+ v, i = torch.topk(logits, min(top_k, logits.size(-1)))
51
+ # do not use `torch.where` as in nanogpt because it will repeat top-k collisions
52
+ logits = torch.full_like(logits, float("-inf")).scatter_(-1, i, v)
53
+ # optionally scale the logits and sample from a probability distribution
54
+ if temperature > 0.0 or top_p > 0.0:
55
+ if temperature > 0.0:
56
+ logits = logits / temperature
57
+ # optionally crop the logits to smallest set of logits with a cumulative probability above top_p
58
+ if top_p < 1.0:
59
+ logits = sample_top_p(logits, top_p)
60
+ probs = torch.nn.functional.softmax(logits, dim=-1)
61
+ return multinomial_num_samples_1(probs)
62
+ return torch.argmax(logits, dim=-1, keepdim=True)
63
+
64
+
65
+ def next_token(
66
+ model: GPT, input_pos: torch.Tensor, x: list, **kwargs: Any
67
+ ) -> torch.Tensor:
68
+ input_pos = input_pos.to(model.device)
69
+ logits_a, logit_t = model(None, x, None, input_pos)
70
+
71
+ next_audio_tokens = []
72
+ for logit_a in logits_a:
73
+ next_a = sample(logit_a, **kwargs).to(dtype=x[0].dtype)
74
+ next_audio_tokens.append(next_a)
75
+ next_t = sample(logit_t, **kwargs).to(dtype=x[0].dtype)
76
+ return next_audio_tokens, next_t
77
+
78
+
79
+ def next_token_asr(
80
+ model: GPT,
81
+ input_pos: torch.Tensor,
82
+ audio_features: torch.tensor,
83
+ lens: int,
84
+ input_ids: list,
85
+ **kwargs: Any,
86
+ ) -> torch.Tensor:
87
+ input_pos = input_pos.to(model.device)
88
+ input_ids = [input_id.to(model.device) for input_id in input_ids]
89
+ logits_a, logit_t = model(audio_features, input_ids, None, input_pos, whisper_lens=lens)
90
+
91
+ next_audio_tokens = []
92
+ for logit_a in logits_a:
93
+ next_a = sample(logit_a, **kwargs).to(dtype=input_ids[0].dtype)
94
+ next_audio_tokens.append(next_a)
95
+ next_t = sample(logit_t, **kwargs).to(dtype=input_ids[0].dtype)
96
+ return next_audio_tokens, next_t
97
+
98
+
99
+ def next_token_A1T2(
100
+ model: GPT,
101
+ audio_features: torch.tensor,
102
+ input_ids: list,
103
+ whisper_lens: int,
104
+ task: list,
105
+ input_pos: torch.Tensor,
106
+ **kwargs: Any,
107
+ ) -> torch.Tensor:
108
+ input_pos = input_pos.to(model.device)
109
+ input_ids = [input_id.to(model.device) for input_id in input_ids]
110
+ logits_a, logit_t = model(
111
+ audio_features, input_ids, None, input_pos, whisper_lens=whisper_lens, task=task
112
+ )
113
+
114
+ next_audio_tokens = []
115
+ for logit_a in logits_a:
116
+ next_a = sample(logit_a, **kwargs).to(dtype=input_ids[0].dtype)
117
+ next_audio_tokens.append(next_a)
118
+ next_t = sample(logit_t, **kwargs).to(dtype=input_ids[0].dtype)
119
+ return next_audio_tokens, next_t
120
+
121
+
122
+ def next_token_A1T1(
123
+ model: GPT,
124
+ audio_features: torch.tensor,
125
+ input_ids: list,
126
+ whisper_lens: int,
127
+ task: list,
128
+ input_pos: torch.Tensor,
129
+ **kwargs: Any,
130
+ ) -> torch.Tensor:
131
+ input_pos = input_pos.to(model.device)
132
+ input_ids = [input_id.to(model.device) for input_id in input_ids]
133
+ logits_a, logit_t = model(
134
+ audio_features, input_ids, None, input_pos, whisper_lens=whisper_lens, task=task
135
+ )
136
+ next_t = sample(logit_t, **kwargs).to(dtype=input_ids[0].dtype)
137
+ return next_t
138
+
139
+
140
+ def next_token_image_batch(model: GPT,
141
+ audio_features: torch.tensor,
142
+ clip_features: torch.tensor,
143
+ input_ids: list,
144
+ whisper_lens: int,
145
+ task: list,
146
+ input_pos: torch.Tensor,
147
+ **kwargs: Any) -> torch.Tensor:
148
+ input_pos = input_pos.to(model.device)
149
+ input_ids = [input_id.to(model.device) for input_id in input_ids]
150
+ logits_a,logit_t = model(audio_features, input_ids, clip_features,
151
+ input_pos, whisper_lens=whisper_lens, task=task)
152
+
153
+ for i in range(7):
154
+ logits_a[i] = logits_a[i][0].unsqueeze(0)
155
+ logit_t = logit_t[1].unsqueeze(0)
156
+
157
+ next_audio_tokens = []
158
+ for logit_a in logits_a:
159
+ next_a = sample(logit_a, **kwargs).to(dtype=input_ids[0].dtype)
160
+ next_audio_tokens.append(next_a)
161
+ next_t = sample(logit_t, **kwargs).to(dtype=input_ids[0].dtype)
162
+ return next_audio_tokens, next_t
163
+
164
+
165
+ # torch._dynamo.config.automatic_dynamic_shapes = True
166
+ # torch._inductor.config.triton.unique_kernel_names = True
167
+ # torch._inductor.config.coordinate_descent_tuning = True
168
+ # next_token = torch.compile(next_token, mode="reduce-overhead")
169
+
170
+
171
+ @torch.inference_mode()
172
+ def generate(
173
+ model: GPT,
174
+ input_ids: list,
175
+ max_returned_tokens: int,
176
+ *,
177
+ temperature: float = 1.0,
178
+ top_k: Optional[int] = None,
179
+ top_p: float = 1.0,
180
+ eos_id_a: Optional[int] = None,
181
+ eos_id_t: Optional[int] = None,
182
+ pad_id: Optional[int] = None,
183
+ shift: Optional[int] = None,
184
+ include_prompt: bool = True,
185
+ generate_text=False,
186
+ ) -> torch.Tensor:
187
+ # print("eos_id_a:", eos_id_a)
188
+ # print("eos_id_t:", eos_id_t)
189
+ # print("pad_id:", pad_id)
190
+ """
191
+ Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
192
+ The implementation of this function is modified from A. Karpathy's nanoGPT.
193
+
194
+ Args:
195
+ model: The model to use.
196
+ prompt: Tensor of shape (T) with indices of the prompt sequence.
197
+ max_returned_tokens: The maximum number of tokens to return (given plus generated).
198
+ temperature: Scales the predicted logits by 1 / temperature.
199
+ top_k: If specified, only sample among the tokens with the k highest probabilities.
200
+ top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process.
201
+ In top-p sampling, the next token is sampled from the highest probability tokens
202
+ whose cumulative probability exceeds the threshold `top_p`. When specified,
203
+ it must be `0 <= top_p <= 1`. Here, `top_p=0` is equivalent
204
+ to sampling the most probable token, while `top_p=1` samples from the whole distribution.
205
+ It can be used in conjunction with `top_k` and `temperature` with the following order
206
+ of application:
207
+
208
+ 1. `top_k` sampling
209
+ 2. `temperature` scaling
210
+ 3. `top_p` sampling
211
+
212
+ For more details, see https://arxiv.org/abs/1904.09751
213
+ or https://huyenchip.com/2024/01/16/sampling.html#top_p
214
+ eos_id: If specified, stop generating any more token once the <eos> token is triggered.
215
+ include_prompt: If true (default) prepends the prompt (after applying the prompt style) to the output.
216
+ """
217
+ T = input_ids[0].size(0)
218
+ device = input_ids[0].device
219
+ assert max_returned_tokens > T
220
+ if model.max_seq_length < max_returned_tokens - 1:
221
+ # rolling the kv cache based on the `input_pos` value would be necessary. However, doing so would introduce a
222
+ # data dependency on the `input_pos` tensor and impact model compilation. Since this setting is uncommon, we do
223
+ # not support it to avoid negatively impacting the overall speed
224
+ raise NotImplementedError(
225
+ f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}"
226
+ )
227
+
228
+ for input_id in input_ids:
229
+ input_id = [input_id]
230
+ (
231
+ tokens_A1,
232
+ tokens_A2,
233
+ tokens_A3,
234
+ tokens_A4,
235
+ tokens_A5,
236
+ tokens_A6,
237
+ tokens_A7,
238
+ tokens_T,
239
+ ) = input_ids
240
+
241
+ tokens_A1_output = [tokens_A1]
242
+ tokens_A2_output = [tokens_A2]
243
+ tokens_A3_output = [tokens_A3]
244
+ tokens_A4_output = [tokens_A4]
245
+ tokens_A5_output = [tokens_A5]
246
+ tokens_A6_output = [tokens_A6]
247
+ tokens_A7_output = [tokens_A7]
248
+ tokens_T_output = [tokens_T]
249
+
250
+ list_output = [
251
+ tokens_A1_output,
252
+ tokens_A2_output,
253
+ tokens_A3_output,
254
+ tokens_A4_output,
255
+ tokens_A5_output,
256
+ tokens_A6_output,
257
+ tokens_A7_output,
258
+ tokens_T_output,
259
+ ]
260
+
261
+ input_pos = torch.tensor([T], device=device)
262
+ model_input_ids = [
263
+ tokens_A1.view(1, -1),
264
+ tokens_A2.view(1, -1),
265
+ tokens_A3.view(1, -1),
266
+ tokens_A4.view(1, -1),
267
+ tokens_A5.view(1, -1),
268
+ tokens_A6.view(1, -1),
269
+ tokens_A7.view(1, -1),
270
+ tokens_T.view(1, -1),
271
+ ]
272
+
273
+ tokens_A, token_T = next_token(
274
+ model,
275
+ torch.arange(0, T, device=device),
276
+ model_input_ids,
277
+ temperature=temperature,
278
+ top_k=top_k,
279
+ top_p=top_p,
280
+ )
281
+ for i in range(7):
282
+ list_output[i].append(tokens_A[i].clone())
283
+ list_output[7].append(token_T.clone())
284
+
285
+ # prepare the input for the next iteration
286
+ for i in range(7):
287
+ tokens_A[i] = tokens_A[i].clone() + shift + i * snac_config.padded_vocab_size
288
+ token_T = token_T.clone()
289
+
290
+ text_end = False
291
+ max_returned_tokens = 1000
292
+ for _ in tqdm(range(2, max_returned_tokens - T + 1)):
293
+ model_input_ids = [
294
+ token_a.view(1, -1).to(torch.int32) for token_a in tokens_A
295
+ ] + [token_T.view(1, -1).to(torch.int32)]
296
+ tokens_A, token_T = next_token(
297
+ model,
298
+ input_pos,
299
+ model_input_ids,
300
+ temperature=temperature,
301
+ top_k=top_k,
302
+ top_p=top_p,
303
+ )
304
+ if text_end:
305
+ token_T = torch.tensor([pad_id], device=device)
306
+
307
+ for i in range(7):
308
+ list_output[i].append(tokens_A[i].clone())
309
+ list_output[7].append(token_T.clone())
310
+
311
+ if tokens_A[-1] == eos_id_a:
312
+ break
313
+ if token_T == eos_id_t:
314
+ if generate_text:
315
+ break
316
+ text_end = True
317
+
318
+ for i in range(7):
319
+ tokens_A[i] = tokens_A[i].clone() + shift + i * snac_config.padded_vocab_size
320
+ token_T = token_T.clone()
321
+ input_pos = input_pos.add_(1)
322
+
323
+ for i in range(len(list_output)):
324
+ list_output[i] = torch.cat(list_output[i])
325
+ return list_output
326
+
327
+
328
+ @torch.inference_mode()
329
+ def generate_TA_BATCH(
330
+ model: GPT,
331
+ audio_features: torch.Tensor,
332
+ input_ids: list,
333
+ leng,
334
+ task,
335
+ max_returned_tokens: int = 1000,
336
+ *,
337
+ temperature: float = 1.0,
338
+ top_k: Optional[int] = None,
339
+ top_p: float = 1.0,
340
+ eos_id_a: Optional[int] = None,
341
+ eos_id_t: Optional[int] = None,
342
+ pad_id_t: Optional[int] = None,
343
+ shift: Optional[int] = None,
344
+ include_prompt: bool = True,
345
+ generate_text=False,
346
+ ) -> torch.Tensor:
347
+
348
+ T = input_ids[0].size(1)
349
+ device = input_ids[0].device
350
+ assert max_returned_tokens > T
351
+ if model.max_seq_length < max_returned_tokens - 1:
352
+ raise NotImplementedError(
353
+ f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}"
354
+ )
355
+
356
+ input_pos = torch.tensor([T], device=device)
357
+ model_input_ids = input_ids
358
+
359
+ list_output = [[] for i in range(8)]
360
+
361
+ tokens_A, token_T = next_token_image_batch(
362
+ model,
363
+ audio_features.to(torch.float32).to(model.device),
364
+ None,
365
+ input_ids,
366
+ [T - 3, T - 3],
367
+ ["A1T2", "A1T2"],
368
+ input_pos=torch.arange(0, T, device=device),
369
+ temperature=temperature,
370
+ top_k=top_k,
371
+ top_p=top_p,
372
+ )
373
+
374
+ for i in range(7):
375
+ list_output[i].append(tokens_A[i].tolist()[0])
376
+ list_output[7].append(token_T.tolist()[0])
377
+
378
+ model_input_ids = [[] for i in range(8)]
379
+ for i in range(7):
380
+ tokens_A[i] = tokens_A[i].clone() + shift + i * snac_config.padded_vocab_size
381
+ model_input_ids[i].append(tokens_A[i].clone().to(device).to(torch.int32))
382
+ model_input_ids[i].append(torch.tensor([layershift(snac_config.end_of_audio, i)], device=device))
383
+ model_input_ids[i] = torch.stack(model_input_ids[i])
384
+
385
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
386
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
387
+ model_input_ids[-1] = torch.stack(model_input_ids[-1])
388
+
389
+ text_end = False
390
+
391
+ for _ in range(2, max_returned_tokens - T + 1):
392
+ tokens_A, token_T = next_token_image_batch(
393
+ model,
394
+ None,
395
+ None,
396
+ model_input_ids,
397
+ None,
398
+ None,
399
+ input_pos=input_pos,
400
+ temperature=temperature,
401
+ top_k=top_k,
402
+ top_p=top_p,
403
+ )
404
+
405
+ if text_end:
406
+ token_T = torch.tensor([pad_id_t], device=device)
407
+
408
+ if tokens_A[-1] == eos_id_a:
409
+ break
410
+ if token_T == eos_id_t:
411
+ text_end = True
412
+
413
+ for i in range(7):
414
+ list_output[i].append(tokens_A[i].tolist()[0])
415
+ list_output[7].append(token_T.tolist()[0])
416
+
417
+ model_input_ids = [[] for i in range(8)]
418
+ for i in range(7):
419
+ tokens_A[i] = tokens_A[i].clone() + shift + i * snac_config.padded_vocab_size
420
+ model_input_ids[i].append(tokens_A[i].clone().to(device).to(torch.int32))
421
+ model_input_ids[i].append(
422
+ torch.tensor([layershift(snac_config.end_of_audio, i)], device=device)
423
+ )
424
+ model_input_ids[i] = torch.stack(model_input_ids[i])
425
+
426
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
427
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
428
+ model_input_ids[-1] = torch.stack(model_input_ids[-1])
429
+
430
+ input_pos = input_pos.add_(1)
431
+
432
+ return list_output
433
+
434
+
435
+ @torch.inference_mode()
436
+ def generate_TT(
437
+ model: GPT,
438
+ audio_features: torch.Tensor,
439
+ input_ids: list,
440
+ leng,
441
+ task,
442
+ max_returned_tokens: int = 2048,
443
+ *,
444
+ temperature: float = 1.0,
445
+ top_k: Optional[int] = None,
446
+ top_p: float = 1.0,
447
+ eos_id_a: Optional[int] = None,
448
+ eos_id_t: Optional[int] = None,
449
+ pad_id_t: Optional[int] = None,
450
+ shift: Optional[int] = None,
451
+ include_prompt: bool = True,
452
+ generate_text=False,
453
+ ) -> torch.Tensor:
454
+
455
+ T = input_ids[0].size(1)
456
+ device = input_ids[0].device
457
+
458
+ output = []
459
+ token_T = next_token_A1T1(
460
+ model,
461
+ None,
462
+ input_ids,
463
+ None,
464
+ None,
465
+ input_pos=torch.arange(0, T, device=device),
466
+ temperature=temperature,
467
+ top_k=top_k,
468
+ top_p=top_p,
469
+ )
470
+
471
+ output.append(token_T.clone().tolist()[0])
472
+ input_pos = torch.tensor([T], device=device)
473
+
474
+ for _ in tqdm(range(2, max_returned_tokens - T + 1)):
475
+ model_input_ids = []
476
+ for i in range(7):
477
+ model_input_ids.append(
478
+ torch.tensor([layershift(snac_config.end_of_audio, i)])
479
+ .view(1, -1)
480
+ .to(torch.int32)
481
+ .to(device)
482
+ )
483
+ model_input_ids.append(token_T.clone().view(1, -1).to(torch.int32).to(device))
484
+ token_T = next_token_A1T1(
485
+ model,
486
+ None,
487
+ model_input_ids,
488
+ None,
489
+ None,
490
+ input_pos=input_pos,
491
+ temperature=temperature,
492
+ top_k=top_k,
493
+ top_p=top_p,
494
+ )
495
+ if token_T == eos_id_t:
496
+ break
497
+ output.append(token_T.clone().tolist()[0])
498
+ input_pos = input_pos.add_(1)
499
+ return output
500
+
501
+
502
+ @torch.inference_mode()
503
+ def generate_AT(
504
+ model: GPT,
505
+ audio_features: torch.Tensor,
506
+ input_ids: list,
507
+ leng,
508
+ task,
509
+ max_returned_tokens: int = 2048,
510
+ *,
511
+ temperature: float = 1.0,
512
+ top_k: Optional[int] = None,
513
+ top_p: float = 1.0,
514
+ eos_id_a: Optional[int] = None,
515
+ eos_id_t: Optional[int] = None,
516
+ pad_id_t: Optional[int] = None,
517
+ shift: Optional[int] = None,
518
+ include_prompt: bool = True,
519
+ generate_text=False,
520
+ ) -> torch.Tensor:
521
+
522
+ T = input_ids[0].size(1)
523
+ device = input_ids[0].device
524
+
525
+ output = []
526
+ token_T = next_token_A1T1(
527
+ model,
528
+ audio_features.to(torch.float32).to(model.device),
529
+ input_ids,
530
+ [T - 3],
531
+ ["AT"],
532
+ input_pos=torch.arange(0, T, device=device),
533
+ temperature=temperature,
534
+ top_k=top_k,
535
+ top_p=top_p,
536
+ )
537
+ output.append(token_T.clone().tolist()[0])
538
+ input_pos = torch.tensor([T], device=device)
539
+ text_end = False
540
+ for _ in tqdm(range(2, max_returned_tokens - T + 1)):
541
+ model_input_ids = []
542
+ for i in range(7):
543
+ model_input_ids.append(
544
+ torch.tensor([layershift(snac_config.end_of_audio, i)])
545
+ .view(1, -1)
546
+ .to(torch.int32)
547
+ .to(device)
548
+ )
549
+ model_input_ids.append(token_T.clone().view(1, -1).to(torch.int32).to(device))
550
+ token_T = next_token_A1T1(
551
+ model,
552
+ None,
553
+ model_input_ids,
554
+ None,
555
+ None,
556
+ input_pos=input_pos,
557
+ temperature=temperature,
558
+ top_k=top_k,
559
+ top_p=top_p,
560
+ )
561
+ if token_T == eos_id_t:
562
+ break
563
+ output.append(token_T.clone().tolist()[0])
564
+ input_pos = input_pos.add_(1)
565
+ return output
566
+
567
+
568
+ @torch.inference_mode()
569
+ def generate_TA(
570
+ model: GPT,
571
+ audio_features: torch.Tensor,
572
+ input_ids: list,
573
+ leng,
574
+ task,
575
+ max_returned_tokens: int = 2048,
576
+ *,
577
+ temperature: float = 1.0,
578
+ top_k: Optional[int] = None,
579
+ top_p: float = 1.0,
580
+ eos_id_a: Optional[int] = None,
581
+ eos_id_t: Optional[int] = None,
582
+ pad_id_t: Optional[int] = None,
583
+ shift: Optional[int] = None,
584
+ include_prompt: bool = True,
585
+ generate_text=False,
586
+ ) -> torch.Tensor:
587
+
588
+ T = input_ids[0].size(1)
589
+ device = input_ids[0].device
590
+
591
+ output = [[] for _ in range(8)]
592
+ tokens_A, token_T = next_token_A1T2(
593
+ model,
594
+ None,
595
+ input_ids,
596
+ None,
597
+ None,
598
+ input_pos=torch.arange(0, T, device=device),
599
+ temperature=temperature,
600
+ top_k=top_k,
601
+ top_p=top_p,
602
+ )
603
+ for i in range(7):
604
+ output[i].append(tokens_A[i].clone().tolist()[0])
605
+ output[7].append(token_T.clone().tolist()[0])
606
+
607
+ input_pos = torch.tensor([T], device=device)
608
+ text_end = False
609
+ for _ in tqdm(range(2, max_returned_tokens - T + 1)):
610
+
611
+ model_input_ids = []
612
+ for i in range(7):
613
+ model_input_ids.append(
614
+ layershift(tokens_A[i].clone(), i)
615
+ .view(1, -1)
616
+ .to(torch.int32)
617
+ .to(device)
618
+ )
619
+ model_input_ids.append(token_T.clone().view(1, -1).to(torch.int32).to(device))
620
+
621
+ tokens_A, token_T = next_token_A1T2(
622
+ model,
623
+ None,
624
+ model_input_ids,
625
+ None,
626
+ None,
627
+ input_pos=input_pos,
628
+ temperature=temperature,
629
+ top_k=top_k,
630
+ top_p=top_p,
631
+ )
632
+
633
+ if text_end:
634
+ token_T = torch.tensor([pad_id_t], device=device)
635
+
636
+ if tokens_A[-1] == eos_id_a:
637
+ break
638
+
639
+ if token_T == eos_id_t:
640
+ text_end = True
641
+
642
+ for i in range(7):
643
+ output[i].append(tokens_A[i].clone().tolist()[0])
644
+ output[7].append(token_T.clone().tolist()[0])
645
+ input_pos = input_pos.add_(1)
646
+
647
+ return output
648
+
649
+
650
+ @torch.inference_mode()
651
+ def generate_AA(
652
+ model: GPT,
653
+ audio_features: torch.Tensor,
654
+ input_ids: list,
655
+ leng,
656
+ task,
657
+ max_returned_tokens: int = 2048,
658
+ *,
659
+ temperature: float = 1.0,
660
+ top_k: Optional[int] = None,
661
+ top_p: float = 1.0,
662
+ eos_id_a: Optional[int] = None,
663
+ eos_id_t: Optional[int] = None,
664
+ pad_id_t: Optional[int] = None,
665
+ shift: Optional[int] = None,
666
+ include_prompt: bool = True,
667
+ generate_text=False,
668
+ ) -> torch.Tensor:
669
+
670
+ T = input_ids[0].size(1)
671
+ device = input_ids[0].device
672
+
673
+ output = [[] for _ in range(8)]
674
+ tokens_A, token_T = next_token_A1T2(
675
+ model,
676
+ audio_features.to(torch.float32).to(model.device),
677
+ input_ids,
678
+ [T - 3],
679
+ ["A1T2"],
680
+ input_pos=torch.arange(0, T, device=device),
681
+ temperature=temperature,
682
+ top_k=top_k,
683
+ top_p=top_p,
684
+ )
685
+ for i in range(7):
686
+ output[i].append(tokens_A[i].clone().tolist()[0])
687
+ output[7].append(token_T.clone().tolist()[0])
688
+
689
+ input_pos = torch.tensor([T], device=device)
690
+
691
+ text_end = False
692
+ for _ in tqdm(range(2, max_returned_tokens - T + 1)):
693
+
694
+ model_input_ids = []
695
+ for i in range(7):
696
+ model_input_ids.append(
697
+ layershift(tokens_A[i].clone(), i)
698
+ .view(1, -1)
699
+ .to(torch.int32)
700
+ .to(device)
701
+ )
702
+ model_input_ids.append(token_T.clone().view(1, -1).to(torch.int32).to(device))
703
+
704
+ tokens_A, token_T = next_token_A1T2(
705
+ model,
706
+ None,
707
+ model_input_ids,
708
+ None,
709
+ None,
710
+ input_pos=input_pos,
711
+ temperature=temperature,
712
+ top_k=top_k,
713
+ top_p=top_p,
714
+ )
715
+
716
+ if text_end:
717
+ token_T = torch.tensor([pad_id_t], device=device)
718
+
719
+ if tokens_A[-1] == eos_id_a:
720
+ break
721
+ if token_T == eos_id_t:
722
+ # print("text_end")
723
+ text_end = True
724
+
725
+ for i in range(7):
726
+ output[i].append(tokens_A[i].clone().tolist()[0])
727
+ output[7].append(token_T.clone().tolist()[0])
728
+ input_pos = input_pos.add_(1)
729
+
730
+ return output
731
+
732
+
733
+ @torch.inference_mode()
734
+ def generate_ASR(
735
+ model: GPT,
736
+ audio_features: torch.Tensor,
737
+ input_ids: list,
738
+ leng,
739
+ task,
740
+ max_returned_tokens: int = 1200,
741
+ *,
742
+ temperature: float = 1.0,
743
+ top_k: Optional[int] = None,
744
+ top_p: float = 1.0,
745
+ eos_id_a: Optional[int] = None,
746
+ eos_id_t: Optional[int] = None,
747
+ pad_id_t: Optional[int] = None,
748
+ shift: Optional[int] = None,
749
+ include_prompt: bool = True,
750
+ generate_text=False,
751
+ ) -> torch.Tensor:
752
+
753
+ T = input_ids[0].size(1)
754
+ device = input_ids[0].device
755
+ output = []
756
+ token_T = next_token_A1T1(
757
+ model,
758
+ audio_features.to(torch.float32).to(model.device),
759
+ input_ids,
760
+ [T - 3],
761
+ ["asr"],
762
+ input_pos=torch.arange(0, T, device=device),
763
+ temperature=temperature,
764
+ top_k=top_k,
765
+ top_p=top_p,
766
+ )
767
+ output.append(token_T.clone().tolist()[0])
768
+ input_pos = torch.tensor([T], device=device)
769
+ text_end = False
770
+ for _ in tqdm(range(2, max_returned_tokens - T + 1)):
771
+ model_input_ids = []
772
+ for i in range(7):
773
+ model_input_ids.append(
774
+ torch.tensor([layershift(snac_config.end_of_audio, i)])
775
+ .view(1, -1)
776
+ .to(torch.int32)
777
+ .to(device)
778
+ )
779
+ model_input_ids.append(token_T.clone().view(1, -1).to(torch.int32).to(device))
780
+ token_T = next_token_A1T1(
781
+ model,
782
+ None,
783
+ model_input_ids,
784
+ None,
785
+ None,
786
+ input_pos=input_pos,
787
+ temperature=temperature,
788
+ top_k=top_k,
789
+ top_p=top_p,
790
+ )
791
+ if token_T == eos_id_t:
792
+ break
793
+ output.append(token_T.clone().tolist()[0])
794
+ input_pos = input_pos.add_(1)
795
+ return output
litgpt/model.py ADDED
@@ -0,0 +1,654 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2
+
3
+ """Full definition of a decoder-only transformer-based language model, all of it in this single file.
4
+
5
+ Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT and
6
+ https://github.com/EleutherAI/gpt-neox/tree/main/megatron/model.
7
+ """
8
+
9
+ import math
10
+ from typing import Any, Optional, Tuple
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from typing_extensions import Self
15
+ from litgpt.config import Config
16
+
17
+
18
+ class GPT(nn.Module):
19
+ def __init__(self, config: Config) -> None:
20
+ super().__init__()
21
+ assert config.padded_vocab_size is not None
22
+ self.config = config
23
+ if self.config.asr_adapter == "mlp":
24
+ print("Using MLP adapter for ASR feature")
25
+ self.whisper_adapter = nn.Linear(config.whisper_adapter_dim, config.n_embd)
26
+ elif self.config.asr_adapter == "llamamlp":
27
+ print("using LLAMA MLP adapter for ASR feature")
28
+ self.whisper_adapter = whisperMLP(config=config)
29
+ else:
30
+ raise ValueError("asr_adapter should be mlp or llamamlp")
31
+ self.lm_head = nn.Linear(
32
+ config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias
33
+ )
34
+
35
+ self.vision_adapter = visionMLP(config = config)
36
+ if config.post_adapter:
37
+ self.transformer = nn.ModuleDict(
38
+ dict(
39
+ wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
40
+ h=nn.ModuleList(Block(config) for _ in range(config.n_layer)),
41
+ post_adapter=nn.ModuleList(
42
+ Block(config) for _ in range(config.post_adapter_layers)
43
+ ),
44
+ ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
45
+ post_adapter_audio_ln=config.norm_class(
46
+ config.n_embd, eps=config.norm_eps
47
+ ),
48
+ post_adapter_audio_lm_head=nn.Linear(
49
+ config.n_embd, config.cat_audio_vocab_size, bias=config.lm_head_bias
50
+ ),
51
+ )
52
+ )
53
+ else:
54
+ self.transformer = nn.ModuleDict(
55
+ dict(
56
+ wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
57
+ h=nn.ModuleList(Block(config) for _ in range(config.n_layer)),
58
+ ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
59
+ )
60
+ )
61
+ self.max_seq_length = self.config.block_size
62
+ self.mask_cache: Optional[torch.Tensor] = None
63
+ if config.tie_word_embeddings:
64
+ self.lm_head.weight = self.transformer.wte.weight
65
+
66
+ @property
67
+ def max_seq_length(self) -> int:
68
+ return self._max_seq_length
69
+
70
+ @max_seq_length.setter
71
+ def max_seq_length(self, value: int) -> None:
72
+ """
73
+ When doing inference, the sequences used might be shorter than the model's context length.
74
+ This allows setting a smaller number to avoid allocating unused memory
75
+ """
76
+ if value > self.config.block_size:
77
+ raise ValueError(
78
+ f"Cannot attend to {value}, block size is only {self.config.block_size}"
79
+ )
80
+ self._max_seq_length = value
81
+ if not hasattr(self, "cos"):
82
+ # first call
83
+ cos, sin = self.rope_cache()
84
+ self.register_buffer("cos", cos, persistent=False)
85
+ self.register_buffer("sin", sin, persistent=False)
86
+ # override
87
+ elif value != self.cos.size(0):
88
+ self.cos, self.sin = self.rope_cache(device=self.cos.device)
89
+ # the mask and kv cache size will get updated on `set_kv_cache`. we cannot update it here because we don't know
90
+ # if the kv cache is expected
91
+
92
+ def reset_parameters(self) -> None:
93
+ # Trigger resetting the rope-cache
94
+ self.cos, self.sin = self.rope_cache(device=self.cos.device)
95
+
96
+ def _init_weights(self, module: nn.Module) -> None:
97
+ """Meant to be used with `gpt.apply(gpt._init_weights)`."""
98
+ if isinstance(module, nn.Linear):
99
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
100
+ if module.bias is not None:
101
+ torch.nn.init.zeros_(module.bias)
102
+ elif isinstance(module, nn.Embedding):
103
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
104
+
105
+ def concat_feat(self, audio_feature, clip_feature, input_ids, T, task):
106
+
107
+ for j in range(len(T)):
108
+ if task[j] != 'T1T2' and task[j] != 'T1A2' and task[j]!='ImageQA_T' and not task[j] == 'ImageCAP' and not task[j] == 'ImageQA_A' and not task[j] == 'ImageQA_AT':
109
+ for i in range(7):
110
+ input_ids[i][j,1:T[j]+1,:] = audio_feature[j][:T[j]].clone()
111
+ assert task[j] != 'ImageQ', "ImageQ should be concat with audio feature"
112
+
113
+ elif task[j] == 'ImageQA_A' or task[j] == 'ImageQA_AT':
114
+ print("concat ImageQA_A feature")
115
+ for i in range(7):
116
+ input_ids[i][j,1:51,:] = clip_feature[j].clone()
117
+
118
+ input_ids[i][j,52 : 52 + T[j],:] = audio_feature[j][:T[j]].clone()
119
+
120
+ elif task[j] == 'ImageQA_T' or task[j] =='ImageCAP':
121
+ for i in range(7):
122
+ input_ids[i][j,1:51,:] = clip_feature[j].clone()
123
+
124
+ return input_ids
125
+
126
+ def forward(
127
+ self,
128
+ audio_features: torch.Tensor,
129
+ input_ids: torch.Tensor,
130
+ clip_features: torch.Tensor,
131
+ input_pos: Optional[torch.Tensor] = None,
132
+ whisper_lens: Optional[list] = None,
133
+ task: Optional[str] = None,
134
+ ) -> torch.Tensor:
135
+
136
+ show = False
137
+ T = input_ids[0].size(1)
138
+ if self.max_seq_length < T:
139
+ raise ValueError(
140
+ f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}."
141
+ )
142
+
143
+ if input_pos is not None: # use the kv cache
144
+ cos = self.cos.index_select(0, input_pos)
145
+ sin = self.sin.index_select(0, input_pos)
146
+ if self.mask_cache is None:
147
+ raise TypeError("You need to call `gpt.set_kv_cache()`")
148
+ mask = self.mask_cache.index_select(2, input_pos)
149
+ else:
150
+ cos = self.cos[:T]
151
+ sin = self.sin[:T]
152
+ mask = None
153
+
154
+ if audio_features is not None:
155
+ # get whisper feature
156
+ x_a = self.whisper_adapter(audio_features)
157
+ if clip_features is not None:
158
+ x_v = self.vision_adapter(clip_features)
159
+ else:
160
+ x_v = None
161
+ # get input_ids embedding
162
+ x0, x1, x2, x3, x4, x5, x6, x7 = input_ids
163
+
164
+ x0 = self.transformer.wte(x0)
165
+ x1 = self.transformer.wte(x1)
166
+ x2 = self.transformer.wte(x2)
167
+ x3 = self.transformer.wte(x3)
168
+ x4 = self.transformer.wte(x4)
169
+ x5 = self.transformer.wte(x5)
170
+ x6 = self.transformer.wte(x6)
171
+ x7 = self.transformer.wte(x7)
172
+
173
+ # concat whisper feature
174
+ input_emb = self.concat_feat(
175
+ x_a, x_v, [x0, x1, x2, x3, x4, x5, x6, x7], whisper_lens, task
176
+ )
177
+ x0, x1, x2, x3, x4, x5, x6, x7 = input_emb
178
+
179
+ else:
180
+ x0, x1, x2, x3, x4, x5, x6, x7 = input_ids
181
+
182
+ x0 = self.transformer.wte(x0)
183
+ x1 = self.transformer.wte(x1)
184
+ x2 = self.transformer.wte(x2)
185
+ x3 = self.transformer.wte(x3)
186
+ x4 = self.transformer.wte(x4)
187
+ x5 = self.transformer.wte(x5)
188
+ x6 = self.transformer.wte(x6)
189
+ x7 = self.transformer.wte(x7)
190
+
191
+ x = (x0 + x1 + x2 + x3 + x4 + x5 + x6 + x7) / 8
192
+
193
+ if self.config.scale_embeddings:
194
+ x = x * (self.config.n_embd**0.5)
195
+
196
+ for block in self.transformer.h:
197
+ x = block(x, cos, sin, mask, input_pos)
198
+
199
+
200
+ text_vocab_size = self.config.text_vocab_size
201
+ audio_vocab_size = self.config.audio_vocab_size
202
+
203
+ x_ori = x
204
+ x_ori = self.transformer.ln_f(x_ori)
205
+ x_ori = self.lm_head(x_ori) # (b, t, vocab_size)
206
+ xt = x_ori[..., :text_vocab_size]
207
+
208
+ if self.config.post_adapter:
209
+ for block in self.transformer.post_adapter:
210
+ x = block(x, cos, sin, mask, input_pos)
211
+ x = self.transformer.post_adapter_audio_ln(x)
212
+ x = self.transformer.post_adapter_audio_lm_head(x) # (b, t, vocab_size)
213
+ xa = []
214
+ for i in range(7):
215
+ xa.append(x[..., audio_vocab_size * i : audio_vocab_size * (i + 1)])
216
+ else:
217
+ xa = []
218
+ for i in range(7):
219
+ xa.append(x_ori[..., text_vocab_size + audio_vocab_size * i : text_vocab_size + audio_vocab_size * (i + 1)])
220
+
221
+ return xa, xt
222
+
223
+ @classmethod
224
+ def from_name(cls, name: str, **kwargs: Any) -> Self:
225
+ return cls(Config.from_name(name, **kwargs))
226
+
227
+ def rope_cache(
228
+ self, device: Optional[torch.device] = None
229
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
230
+ return build_rope_cache(
231
+ seq_len=self.max_seq_length,
232
+ n_elem=self.config.rope_n_elem,
233
+ device=device,
234
+ condense_ratio=self.config.rope_condense_ratio,
235
+ base=self.config.rope_base,
236
+ )
237
+
238
+ def set_kv_cache(
239
+ self,
240
+ batch_size: int,
241
+ rope_cache_length: Optional[int] = None,
242
+ device: Optional[torch.device] = None,
243
+ dtype: Optional[torch.dtype] = None,
244
+ ) -> None:
245
+ if rope_cache_length is None:
246
+ rope_cache_length = self.cos.size(-1)
247
+ max_seq_length = self.max_seq_length
248
+
249
+ # initialize the kv cache for all blocks
250
+ for block in self.transformer.h:
251
+ block.attn.kv_cache = block.attn.build_kv_cache(
252
+ batch_size, max_seq_length, rope_cache_length, device, dtype
253
+ )
254
+ if self.config.post_adapter:
255
+ for block in self.transformer.post_adapter:
256
+ block.attn.kv_cache = block.attn.build_kv_cache(
257
+ batch_size, max_seq_length, rope_cache_length, device, dtype
258
+ )
259
+
260
+ if self.mask_cache is None or self.mask_cache.size(3) != max_seq_length:
261
+ # passing `attn_mask` to SDPA disables the flash implementation. since we only need the mask
262
+ # for the kv-cache support (only during inference), we only create it in that situation
263
+ self.mask_cache = build_mask_cache(max_seq_length, device)
264
+
265
+ def clear_kv_cache(self) -> None:
266
+ self.mask_cache = None
267
+ for block in self.transformer.h:
268
+ block.attn.kv_cache = None
269
+
270
+
271
+ class visionMLP(nn.Module):
272
+ def __init__(self, config: Config) -> None:
273
+ super().__init__()
274
+ vision_adapter_dim = config.vision_adapter_dim
275
+ self.fc_1 = nn.Linear(vision_adapter_dim, config.intermediate_size, bias=config.bias)
276
+ self.fc_2 = nn.Linear(vision_adapter_dim, config.intermediate_size, bias=config.bias)
277
+ self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
278
+
279
+ self.config = config
280
+
281
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
282
+ x_fc_1 = self.fc_1(x)
283
+ x_fc_2 = self.fc_2(x)
284
+ x = torch.nn.functional.silu(x_fc_1) * x_fc_2
285
+ return self.proj(x)
286
+
287
+
288
+ class Block(nn.Module):
289
+
290
+ def __init__(self, config: Config) -> None:
291
+ super().__init__()
292
+ if not config.parallel_residual and config.shared_attention_norm:
293
+ raise NotImplementedError(
294
+ "No checkpoint amongst the ones we support uses this configuration"
295
+ " (non-parallel residual and shared attention norm)."
296
+ )
297
+
298
+ self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps)
299
+ self.attn = CausalSelfAttention(config)
300
+ self.norm_2 = (
301
+ None
302
+ if config.shared_attention_norm
303
+ else config.norm_class(config.n_embd, eps=config.norm_eps)
304
+ )
305
+ self.mlp = config.mlp_class(config)
306
+
307
+ self.config = config
308
+
309
+ def forward(
310
+ self,
311
+ x: torch.Tensor,
312
+ cos: torch.Tensor,
313
+ sin: torch.Tensor,
314
+ mask: Optional[torch.Tensor] = None,
315
+ input_pos: Optional[torch.Tensor] = None,
316
+ ) -> torch.Tensor:
317
+ """
318
+ Non-parallel residual Parallel residual
319
+ ┌─ x ┌─ x ────────────┐ Note: if `shared_attention_norm` is True,
320
+ │ ↓ │ ↓ ↓ the output from `norm_1` is reused
321
+ │ norm_1 │ norm_1 ───► norm_2
322
+ │ ↓ │ ↓ ↓
323
+ │ attn │ attn mlp
324
+ │ ↓ │ ↓ │
325
+ ┌─ └► + └► + ◄───────────┘
326
+ │ norm_2
327
+ │ ↓
328
+ │ mlp
329
+ │ ↓
330
+ └───► +
331
+ """
332
+
333
+ x_normed = self.norm_1(x)
334
+ attention_output = self.attn(x_normed, cos, sin, mask, input_pos)
335
+
336
+ if self.config.parallel_residual:
337
+ x_normed = x_normed if self.config.shared_attention_norm else self.norm_2(x)
338
+ x = self.mlp(x_normed) + attention_output + x
339
+ else:
340
+ x = attention_output + x
341
+ x = self.mlp(self.norm_2(x)) + x
342
+ return x
343
+
344
+
345
+ class CausalSelfAttention(nn.Module):
346
+ def __init__(self, config: Config) -> None:
347
+ super().__init__()
348
+ shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
349
+ # key, query, value projections for all heads, but in a batch
350
+ self.attn = nn.Linear(config.n_embd, shape, bias=config.add_qkv_bias)
351
+ # output projection
352
+ # if `head_size` is explicitly specified in the config, `n_emd` might not be equal to `head_size * n_head`
353
+ self.proj = nn.Linear(
354
+ config.head_size * config.n_head, config.n_embd, bias=config.bias
355
+ )
356
+ # disabled by default
357
+ self.kv_cache: Optional[KVCache] = None
358
+
359
+ self.config = config
360
+
361
+ def forward(
362
+ self,
363
+ x: torch.Tensor,
364
+ cos: torch.Tensor,
365
+ sin: torch.Tensor,
366
+ mask: Optional[torch.Tensor] = None,
367
+ input_pos: Optional[torch.Tensor] = None,
368
+ ) -> torch.Tensor:
369
+ B, T, C = (
370
+ x.size()
371
+ ) # batch size, sequence length, embedding dimensionality (n_embd)
372
+
373
+ qkv = self.attn(x)
374
+
375
+ # assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`)
376
+ q_per_kv = self.config.n_head // self.config.n_query_groups
377
+ total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value
378
+ qkv = qkv.view(
379
+ B, T, self.config.n_query_groups, total_qkv, self.config.head_size
380
+ )
381
+ qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs)
382
+
383
+ # split batched computation into three
384
+ q, k, v = qkv.split((q_per_kv, 1, 1), dim=2)
385
+
386
+ # maybe repeat k and v if for the non multi-head attention cases
387
+ # training: flash attention requires it
388
+ # inference: multi-query would require a full kv cache so avoid it to limit its memory usage
389
+ if self.config.n_query_groups != self.config.n_head and (
390
+ input_pos is None or self.config.n_query_groups != 1
391
+ ):
392
+ k = k.expand(
393
+ B, self.config.n_query_groups, q_per_kv, T, self.config.head_size
394
+ )
395
+ v = v.expand(
396
+ B, self.config.n_query_groups, q_per_kv, T, self.config.head_size
397
+ )
398
+
399
+ q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs)
400
+ k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs)
401
+ v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs)
402
+
403
+ q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin)
404
+ k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin)
405
+ q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1)
406
+ k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1)
407
+
408
+ if input_pos is not None:
409
+ if not isinstance(self.kv_cache, KVCache):
410
+ raise TypeError("You need to call `gpt.set_kv_cache()`")
411
+ k, v = self.kv_cache(input_pos, k, v)
412
+
413
+ y = self.scaled_dot_product_attention(q, k, v, mask)
414
+
415
+ y = y.reshape(
416
+ B, T, self.config.head_size * self.config.n_head
417
+ ) # re-assemble all head outputs side by side
418
+
419
+ # output projection
420
+ return self.proj(y)
421
+
422
+ def scaled_dot_product_attention(
423
+ self,
424
+ q: torch.Tensor,
425
+ k: torch.Tensor,
426
+ v: torch.Tensor,
427
+ mask: Optional[torch.Tensor] = None,
428
+ ) -> torch.Tensor:
429
+ scale = 1.0 / math.sqrt(self.config.head_size)
430
+ y = torch.nn.functional.scaled_dot_product_attention(
431
+ q, k, v, attn_mask=mask, dropout_p=0.0, scale=scale, is_causal=mask is None
432
+ )
433
+ return y.transpose(1, 2)
434
+
435
+ def build_kv_cache(
436
+ self,
437
+ batch_size: int,
438
+ max_seq_length: int,
439
+ rope_cache_length: Optional[int] = None,
440
+ device: Optional[torch.device] = None,
441
+ dtype: Optional[torch.dtype] = None,
442
+ ) -> "KVCache":
443
+ heads = 1 if self.config.n_query_groups == 1 else self.config.n_head
444
+ v_shape = (batch_size, heads, max_seq_length, self.config.head_size)
445
+ if rope_cache_length is None:
446
+ if self.config.rotary_percentage != 1.0:
447
+ raise TypeError(
448
+ "Please pass the `rope_cache_length=gpt.cos.size(-1)` value"
449
+ )
450
+ k_shape = v_shape
451
+ else:
452
+ k_shape = (
453
+ batch_size,
454
+ heads,
455
+ max_seq_length,
456
+ rope_cache_length + self.config.head_size - self.config.rope_n_elem,
457
+ )
458
+ return KVCache(k_shape, v_shape, device=device, dtype=dtype)
459
+
460
+
461
+ class GptNeoxMLP(nn.Module):
462
+ def __init__(self, config: Config) -> None:
463
+ super().__init__()
464
+ self.fc = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
465
+ self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
466
+
467
+ self.config = config
468
+
469
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
470
+ x = self.fc(x)
471
+ x = torch.nn.functional.gelu(x, approximate=self.config.gelu_approximate)
472
+ return self.proj(x)
473
+
474
+
475
+ class LLaMAMLP(nn.Module):
476
+ def __init__(self, config: Config) -> None:
477
+ super().__init__()
478
+ self.fc_1 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
479
+ self.fc_2 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
480
+ self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
481
+
482
+ self.config = config
483
+
484
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
485
+ x_fc_1 = self.fc_1(x)
486
+ x_fc_2 = self.fc_2(x)
487
+ x = torch.nn.functional.silu(x_fc_1) * x_fc_2
488
+ return self.proj(x)
489
+
490
+
491
+ class whisperMLP(nn.Module):
492
+ def __init__(self, config: Config) -> None:
493
+ super().__init__()
494
+ self.fc_1 = nn.Linear(config.whisper_adapter_dim, config.intermediate_size, bias=config.bias)
495
+ self.fc_2 = nn.Linear(config.whisper_adapter_dim, config.intermediate_size, bias=config.bias)
496
+ self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
497
+
498
+ self.config = config
499
+
500
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
501
+ x_fc_1 = self.fc_1(x)
502
+ x_fc_2 = self.fc_2(x)
503
+ x = torch.nn.functional.silu(x_fc_1) * x_fc_2
504
+ return self.proj(x)
505
+
506
+
507
+ class GemmaMLP(LLaMAMLP):
508
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
509
+ x_fc_1 = self.fc_1(x)
510
+ x_fc_2 = self.fc_2(x)
511
+ x = (
512
+ torch.nn.functional.gelu(x_fc_1, approximate=self.config.gelu_approximate)
513
+ * x_fc_2
514
+ )
515
+ return self.proj(x)
516
+
517
+
518
+ class LLaMAMoE(nn.Module):
519
+ def __init__(self, config: Config) -> None:
520
+ super().__init__()
521
+ self.gate = nn.Linear(config.n_embd, config.n_expert, bias=False)
522
+ self.experts = nn.ModuleList(LLaMAMLP(config) for _ in range(config.n_expert))
523
+
524
+ self.config = config
525
+
526
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
527
+ """
528
+ Derived from: https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219
529
+ See also figure 1 in https://arxiv.org/abs/2211.15841
530
+ """
531
+ B, T, C = (
532
+ x.size()
533
+ ) # batch size, sequence length, embedding dimensionality (n_embd)
534
+ x = x.view(-1, C) # (B*T, C)
535
+ router = self.gate(x) # (B*T, n_expert)
536
+ probs, indices = torch.topk(
537
+ router, self.config.n_expert_per_token
538
+ ) # (B*T, n_expert_per_token)
539
+ probs = probs.softmax(dim=1, dtype=torch.float).to(dtype=x.dtype)
540
+ masks = indices.unsqueeze(-1) == torch.arange(
541
+ self.config.n_expert, device=x.device
542
+ )
543
+ masks = masks.permute(2, 0, 1) # (n_expert, B*T, n_expert_per_token)
544
+ y = torch.zeros_like(x) # (B*T, C)
545
+ for mask, expert in zip(masks, self.experts):
546
+ token_idx, expert_idx = torch.where(mask)
547
+ y[token_idx] += probs[token_idx, expert_idx, None] * expert(x[token_idx])
548
+ return y.view(B, T, C)
549
+
550
+
551
+ def build_rope_cache(
552
+ seq_len: int,
553
+ n_elem: int,
554
+ device: Optional[torch.device] = None,
555
+ base: int = 10000,
556
+ condense_ratio: int = 1,
557
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
558
+ """Enhanced Transformer with Rotary Position Embedding.
559
+
560
+ Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
561
+ transformers/rope/__init__.py. MIT License:
562
+ https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
563
+ """
564
+ # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
565
+ theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem))
566
+
567
+ # Create position indexes `[0, 1, ..., seq_len - 1]`
568
+ seq_idx = torch.arange(seq_len, device=device) / condense_ratio
569
+
570
+ # Calculate the product of position index and $\theta_i$
571
+ idx_theta = torch.outer(seq_idx, theta).repeat(1, 2)
572
+
573
+ return torch.cos(idx_theta), torch.sin(idx_theta)
574
+
575
+
576
+ def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
577
+ head_size = x.size(-1)
578
+ x1 = x[..., : head_size // 2] # (B, nh, T, hs/2)
579
+ x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2)
580
+ rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs)
581
+ roped = (x * cos) + (rotated * sin)
582
+ return roped.to(dtype=x.dtype)
583
+
584
+
585
+ class KVCache(nn.Module):
586
+ def __init__(
587
+ self,
588
+ k_shape: Tuple[int, int, int, int],
589
+ v_shape: Tuple[int, int, int, int],
590
+ device: Optional[torch.device] = None,
591
+ dtype: Optional[torch.dtype] = None,
592
+ ) -> None:
593
+ super().__init__()
594
+ self.register_buffer(
595
+ "k", torch.zeros(k_shape, device=device, dtype=dtype), persistent=False
596
+ )
597
+ self.register_buffer(
598
+ "v", torch.zeros(v_shape, device=device, dtype=dtype), persistent=False
599
+ )
600
+
601
+ def forward(
602
+ self, input_pos: torch.Tensor, k: torch.Tensor, v: torch.Tensor
603
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
604
+ # move the buffer to the activation dtype for when AMP is used
605
+ self.k = self.k.to(k.dtype)
606
+ self.v = self.v.to(v.dtype)
607
+ # update the cache
608
+ k = self.k.index_copy_(2, input_pos, k)
609
+ v = self.v.index_copy_(2, input_pos, v)
610
+ return k, v
611
+
612
+ def reset_parameters(self) -> None:
613
+ torch.nn.init.zeros_(self.k)
614
+ torch.nn.init.zeros_(self.v)
615
+
616
+
617
+ def build_mask_cache(
618
+ max_seq_length: int, device: Optional[torch.device] = None
619
+ ) -> torch.Tensor:
620
+ ones = torch.ones((max_seq_length, max_seq_length), device=device, dtype=torch.bool)
621
+ return torch.tril(ones).unsqueeze(0).unsqueeze(0)
622
+
623
+
624
+ class RMSNorm(torch.nn.Module):
625
+ """Root Mean Square Layer Normalization.
626
+
627
+ Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:
628
+ https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.
629
+ """
630
+
631
+ def __init__(
632
+ self, size: int, dim: int = -1, eps: float = 1e-6, add_unit_offset: bool = False
633
+ ) -> None:
634
+ super().__init__()
635
+ self.weight = torch.nn.Parameter(torch.ones(size))
636
+ self.eps = eps
637
+ self.dim = dim
638
+ self.add_unit_offset = add_unit_offset
639
+
640
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
641
+ dtype = x.dtype
642
+ x = x.float()
643
+ # NOTE: the original RMSNorm paper implementation is not equivalent
644
+ norm_x = torch.mean(x * x, dim=self.dim, keepdim=True)
645
+ x_normed = x * torch.rsqrt(norm_x + self.eps)
646
+ x_normed = x_normed.to(dtype=dtype)
647
+ if self.add_unit_offset:
648
+ # Gemma model requires a unit offset
649
+ # https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L176
650
+ return x_normed * (1 + self.weight)
651
+ return x_normed * self.weight
652
+
653
+ def reset_parameters(self) -> None:
654
+ torch.nn.init.ones_(self.weight)
litgpt/tokenizer.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2
+
3
+ import json
4
+ from pathlib import Path
5
+ from typing import Optional, Union
6
+
7
+ import torch
8
+
9
+
10
+ class Tokenizer:
11
+ def __init__(self, checkpoint_dir: Union[Path, str]) -> None:
12
+ checkpoint_dir = Path(checkpoint_dir)
13
+ if not checkpoint_dir.exists():
14
+ raise NotADirectoryError(
15
+ f"The checkpoint directory does not exist: {str(checkpoint_dir)}"
16
+ )
17
+
18
+ self.use_bos = self.check_if_bos_token_used(checkpoint_dir)
19
+ self.bos_id = None
20
+ self.eos_id = None
21
+
22
+ # some checkpoints have both files, `.json` takes precedence
23
+ if (vocabulary_path := checkpoint_dir / "tokenizer.json").is_file():
24
+ from tokenizers import Tokenizer as HFTokenizer
25
+
26
+ self.processor = HFTokenizer.from_file(str(vocabulary_path))
27
+ self.backend = "huggingface"
28
+
29
+ if (
30
+ special_tokens_path := checkpoint_dir / "tokenizer_config.json"
31
+ ).is_file():
32
+ with open(special_tokens_path, encoding="utf-8") as fp:
33
+ config = json.load(fp)
34
+ bos_token = config.get("bos_token")
35
+ eos_token = config.get("eos_token")
36
+ if bos_token is not None and isinstance(bos_token, dict):
37
+ bos_token = bos_token.get("content")
38
+ if eos_token is not None and isinstance(eos_token, dict):
39
+ eos_token = eos_token.get("content")
40
+ self.bos_id = (
41
+ self.token_to_id(bos_token) if bos_token is not None else None
42
+ )
43
+ self.eos_id = (
44
+ self.token_to_id(eos_token) if eos_token is not None else None
45
+ )
46
+ if (
47
+ special_tokens_path := checkpoint_dir / "generation_config.json"
48
+ ).is_file():
49
+ with open(special_tokens_path, encoding="utf-8") as fp:
50
+ config = json.load(fp)
51
+ if self.bos_id is None:
52
+ self.bos_id = config.get("bos_token_id")
53
+ if self.eos_id is None:
54
+ self.eos_id = config.get("eos_token_id")
55
+
56
+ elif (vocabulary_path := checkpoint_dir / "tokenizer.model").is_file():
57
+ from sentencepiece import SentencePieceProcessor
58
+
59
+ self.processor = SentencePieceProcessor(model_file=str(vocabulary_path))
60
+ self.backend = "sentencepiece"
61
+ self.bos_id = self.processor.bos_id()
62
+ self.eos_id = self.processor.eos_id()
63
+ else:
64
+ raise NotImplementedError
65
+
66
+ @property
67
+ def vocab_size(self) -> int:
68
+ if self.backend == "huggingface":
69
+ return self.processor.get_vocab_size(with_added_tokens=False)
70
+ if self.backend == "sentencepiece":
71
+ return self.processor.vocab_size()
72
+ raise RuntimeError
73
+
74
+ def token_to_id(self, token: str) -> int:
75
+ if self.backend == "huggingface":
76
+ id_ = self.processor.token_to_id(token)
77
+ elif self.backend == "sentencepiece":
78
+ id_ = self.processor.piece_to_id(token)
79
+ else:
80
+ raise RuntimeError
81
+ if id_ is None:
82
+ raise ValueError(f"token {token!r} not found in the collection.")
83
+ return id_
84
+
85
+ def check_if_bos_token_used(self, checkpoint_dir: Path) -> bool:
86
+ if not (
87
+ tokenizer_config_path := checkpoint_dir / "tokenizer_config.json"
88
+ ).is_file():
89
+ return False
90
+ with open(tokenizer_config_path, encoding="utf-8") as fp:
91
+ config = json.load(fp)
92
+ if "add_bos_token" in config:
93
+ return config["add_bos_token"]
94
+ # if `add_bos_token` isn't in the config file, but LLaMA tokenizer is used - return True.
95
+ # ex: https://huggingface.co/stabilityai/StableBeluga2/blob/main/tokenizer_config.json#L2
96
+ return config.get("tokenizer_class") == "LlamaTokenizer"
97
+
98
+ def encode(
99
+ self,
100
+ string: str,
101
+ device: Optional[torch.device] = None,
102
+ bos: Optional[bool] = None,
103
+ eos: bool = False,
104
+ max_length: int = -1,
105
+ ) -> torch.Tensor:
106
+ if self.backend == "huggingface":
107
+ tokens = self.processor.encode(string).ids
108
+ elif self.backend == "sentencepiece":
109
+ tokens = self.processor.encode(string)
110
+ else:
111
+ raise RuntimeError
112
+ if bos or (bos is None and self.use_bos):
113
+ bos_id = self.bos_id
114
+ if bos_id is None:
115
+ raise NotImplementedError(
116
+ "This tokenizer does not have a defined a bos token"
117
+ )
118
+ if tokens[0] != bos_id:
119
+ tokens = [bos_id] + tokens
120
+ if tokens is None:
121
+ raise ValueError("`tokens` is None")
122
+
123
+ if eos and (not tokens or tokens[-1] != self.eos_id):
124
+ tokens = tokens + [self.eos_id]
125
+ if max_length > 0:
126
+ tokens = tokens[:max_length]
127
+ return torch.tensor(tokens, dtype=torch.int, device=device)
128
+
129
+ def decode(self, tensor: torch.Tensor) -> str:
130
+ tokens = [tensor.item()] if tensor.ndim == 0 else tensor.tolist()
131
+ return self.processor.decode(tokens)
litgpt/utils.py ADDED
@@ -0,0 +1,641 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2
+
3
+ """Utility functions for training and inference."""
4
+ import inspect
5
+ import math
6
+ import os
7
+ import pickle
8
+ import shutil
9
+ import sys
10
+ from dataclasses import asdict, is_dataclass
11
+ from io import BytesIO
12
+ from pathlib import Path
13
+ from typing import (
14
+ TYPE_CHECKING,
15
+ Any,
16
+ Dict,
17
+ Iterable,
18
+ List,
19
+ Literal,
20
+ Mapping,
21
+ Optional,
22
+ TypeVar,
23
+ Union,
24
+ )
25
+
26
+ import lightning as L
27
+ import torch
28
+ import torch.nn as nn
29
+ import torch.utils._device
30
+ import yaml
31
+ from lightning.fabric.loggers import CSVLogger, TensorBoardLogger
32
+ from lightning.fabric.strategies import FSDPStrategy
33
+ from lightning.fabric.utilities.load import _lazy_load as lazy_load
34
+ from lightning.pytorch.loggers import WandbLogger
35
+ from lightning.pytorch.cli import instantiate_class
36
+ from torch.serialization import normalize_storage_type
37
+ from typing_extensions import Self
38
+
39
+ if TYPE_CHECKING:
40
+ from litgpt import GPT, Config
41
+
42
+
43
+ def init_out_dir(out_dir: Path) -> Path:
44
+ if not out_dir.is_absolute() and "LIGHTNING_ARTIFACTS_DIR" in os.environ:
45
+ return Path(os.getenv("LIGHTNING_ARTIFACTS_DIR")) / out_dir
46
+ return out_dir
47
+
48
+
49
+ def find_resume_path(
50
+ resume: Union[bool, Literal["auto"], Path], out_dir: Path
51
+ ) -> Optional[Path]:
52
+ if not resume or isinstance(resume, Path):
53
+ return resume
54
+
55
+ resume_path = max(
56
+ out_dir.rglob("step-*/*.pth"),
57
+ key=(lambda p: int(p.parent.name.split("-")[1])),
58
+ default=None,
59
+ )
60
+ if resume == "auto":
61
+ return resume_path
62
+ if resume is True and resume_path is None:
63
+ raise FileNotFoundError(
64
+ f"You passed `--resume=True`, but no checkpont file was found in `--out_dir={out_dir}`."
65
+ )
66
+ return resume_path
67
+
68
+
69
+ def find_multiple(n: int, k: int) -> int:
70
+ assert k > 0
71
+ if n % k == 0:
72
+ return n
73
+ return n + k - (n % k)
74
+
75
+
76
+ def num_parameters(module: nn.Module, requires_grad: Optional[bool] = None) -> int:
77
+ total = 0
78
+ for p in module.parameters():
79
+ if requires_grad is None or p.requires_grad == requires_grad:
80
+ if hasattr(p, "quant_state"):
81
+ # bitsandbytes 4bit layer support
82
+ total += math.prod(p.quant_state.shape)
83
+ else:
84
+ total += p.numel()
85
+ return total
86
+
87
+
88
+ def reset_parameters(module: nn.Module) -> None:
89
+ """Calls `reset_parameters` on the module and all its submodules."""
90
+ for mod in module.modules():
91
+ if callable(getattr(mod, "reset_parameters", None)):
92
+ mod.reset_parameters()
93
+
94
+
95
+ def check_valid_checkpoint_dir(
96
+ checkpoint_dir: Path,
97
+ model_filename: str = "lit_model.pth",
98
+ verbose: bool = True,
99
+ raise_error: bool = False,
100
+ ) -> None:
101
+ files = {
102
+ model_filename: (checkpoint_dir / model_filename).is_file(),
103
+ "model_config.yaml": (checkpoint_dir / "model_config.yaml").is_file(),
104
+ "tokenizer.json OR tokenizer.model": (
105
+ checkpoint_dir / "tokenizer.json"
106
+ ).is_file()
107
+ or (checkpoint_dir / "tokenizer.model").is_file(),
108
+ "tokenizer_config.json": (checkpoint_dir / "tokenizer_config.json").is_file(),
109
+ }
110
+ if checkpoint_dir.is_dir():
111
+ if all(files.values()):
112
+ # we're good
113
+ return
114
+ problem = f" is missing the files: {[f for f, exists in files.items() if not exists]!r}"
115
+ else:
116
+ problem = " is not a checkpoint directory"
117
+
118
+ # list locally available checkpoints
119
+ available = list(Path("checkpoints").glob("*/*"))
120
+ if available:
121
+ options = "\n".join([""] + [repr(str(p.resolve())) for p in available])
122
+ extra = f"\nYou have downloaded locally:{options}\n"
123
+ else:
124
+ extra = ""
125
+
126
+ if verbose:
127
+ error_message = (
128
+ f"checkpoint_dir {str(checkpoint_dir.absolute())!r}{problem}."
129
+ "\nFind download instructions at https://github.com/Lightning-AI/litgpt/blob/main/tutorials\n"
130
+ f"{extra}\nSee all download options by running:\n litgpt download"
131
+ )
132
+ print(error_message, file=sys.stderr)
133
+
134
+ if raise_error:
135
+ raise FileNotFoundError(
136
+ f"checkpoint_dir {str(checkpoint_dir.absolute())!r}{problem}."
137
+ )
138
+ else:
139
+ raise SystemExit(1)
140
+
141
+
142
+ class SavingProxyForStorage:
143
+ def __init__(self, obj, saver, protocol_version=5):
144
+ self.protocol_version = protocol_version
145
+ self.saver = saver
146
+ if not (isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj)):
147
+ raise TypeError(f"expected storage, not {type(obj)}")
148
+
149
+ # this logic is taken from PyTorch 2.0+ torch/serialization.py
150
+ if isinstance(obj, torch.storage.TypedStorage):
151
+ # PT upstream wants to deprecate this eventually...
152
+ storage = obj._untyped_storage
153
+ storage_type_str = obj._pickle_storage_type()
154
+ storage_type = getattr(torch, storage_type_str)
155
+ storage_numel = obj._size()
156
+ else:
157
+ storage = obj
158
+ storage_type = normalize_storage_type(type(obj))
159
+ storage_numel = storage.nbytes()
160
+
161
+ storage_key = saver._write_storage_and_return_key(storage)
162
+ location = torch.serialization.location_tag(storage)
163
+
164
+ self.storage_info = (
165
+ "storage",
166
+ storage_type,
167
+ storage_key,
168
+ location,
169
+ storage_numel,
170
+ )
171
+
172
+ def __reduce_ex__(self, protocol_version):
173
+ assert False, "this should be handled with out of band"
174
+
175
+
176
+ class SavingProxyForTensor:
177
+ def __init__(self, tensor, saver, protocol_version=5):
178
+ self.protocol_version = protocol_version
179
+ self.reduce_ret_fn, reduce_args = tensor.__reduce_ex__(protocol_version)
180
+ if reduce_args[0] == torch._utils._rebuild_tensor_v2:
181
+ # for Tensors with Python attributes
182
+ (a0, a1, (storage, *a2_other), *other_reduce_args) = reduce_args
183
+ assert isinstance(
184
+ storage, torch.storage.TypedStorage
185
+ ), "Please check for updates"
186
+ storage_proxy = SavingProxyForStorage(
187
+ storage, saver, protocol_version=protocol_version
188
+ )
189
+ self.reduce_args = (a0, a1, (storage_proxy, *a2_other), *other_reduce_args)
190
+ else:
191
+ (storage, *other_reduce_args) = reduce_args
192
+ assert isinstance(
193
+ storage, torch.storage.TypedStorage
194
+ ), "Please check for updates"
195
+ storage_proxy = SavingProxyForStorage(
196
+ storage, saver, protocol_version=protocol_version
197
+ )
198
+ self.reduce_args = (storage_proxy, *other_reduce_args)
199
+
200
+ def __reduce_ex__(self, protocol_version):
201
+ if protocol_version != self.protocol_version:
202
+ raise RuntimeError(
203
+ f"Unexpected protocol version: expected {self.protocol_version}, got {protocol_version}"
204
+ )
205
+ return self.reduce_ret_fn, self.reduce_args
206
+
207
+
208
+ class IncrementalPyTorchPickler(pickle.Pickler):
209
+ def __init__(self, saver, *args, **kwargs):
210
+ super().__init__(*args, **kwargs)
211
+ self.storage_dtypes = {}
212
+ self.saver = saver
213
+ self.id_map = {}
214
+
215
+ # this logic is taken from PyTorch 2.0+ torch/serialization.py
216
+ def persistent_id(self, obj):
217
+ # FIXME: the docs say that persistent_id should only return a string
218
+ # but torch store returns tuples. This works only in the binary protocol
219
+ # see
220
+ # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
221
+ # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
222
+ if isinstance(obj, SavingProxyForStorage):
223
+ return obj.storage_info
224
+
225
+ if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
226
+ if isinstance(obj, torch.storage.TypedStorage):
227
+ # TODO: Once we decide to break serialization FC, this case
228
+ # can be deleted
229
+ storage = obj._untyped_storage
230
+ storage_dtype = obj.dtype
231
+ storage_type_str = obj._pickle_storage_type()
232
+ storage_type = getattr(torch, storage_type_str)
233
+ storage_numel = obj._size()
234
+
235
+ else:
236
+ storage = obj
237
+ storage_dtype = torch.uint8
238
+ storage_type = normalize_storage_type(type(obj))
239
+ storage_numel = storage.nbytes()
240
+
241
+ # If storage is allocated, ensure that any other saved storages
242
+ # pointing to the same data all have the same dtype. If storage is
243
+ # not allocated, don't perform this check
244
+ if storage.data_ptr() != 0:
245
+ if storage.data_ptr() in self.storage_dtypes:
246
+ if storage_dtype != self.storage_dtypes[storage.data_ptr()]:
247
+ raise RuntimeError(
248
+ "Cannot save multiple tensors or storages that view the same data as different types"
249
+ )
250
+ else:
251
+ self.storage_dtypes[storage.data_ptr()] = storage_dtype
252
+
253
+ storage_key = self.id_map.get(storage._cdata)
254
+ if storage_key is None:
255
+ storage_key = self.saver._write_storage_and_return_key(storage)
256
+ self.id_map[storage._cdata] = storage_key
257
+ location = torch.serialization.location_tag(storage)
258
+
259
+ return ("storage", storage_type, storage_key, location, storage_numel)
260
+
261
+ return None
262
+
263
+
264
+ class incremental_save:
265
+ def __init__(self, name):
266
+ self.name = name
267
+ self.zipfile = torch._C.PyTorchFileWriter(str(name))
268
+ self.has_saved = False
269
+ self.next_key = 0
270
+
271
+ def __enter__(self):
272
+ return self
273
+
274
+ def store_early(self, tensor):
275
+ if isinstance(tensor, torch.Tensor):
276
+ return SavingProxyForTensor(tensor, self)
277
+ raise TypeError(f"can only store tensors early, not {type(tensor)}")
278
+
279
+ def save(self, obj):
280
+ if self.has_saved:
281
+ raise RuntimeError("have already saved")
282
+ # Write the pickle data for `obj`
283
+ data_buf = BytesIO()
284
+ pickler = IncrementalPyTorchPickler(self, data_buf, protocol=5)
285
+ pickler.dump(obj)
286
+ data_value = data_buf.getvalue()
287
+ self.zipfile.write_record("data.pkl", data_value, len(data_value))
288
+ self.has_saved = True
289
+
290
+ def _write_storage_and_return_key(self, storage):
291
+ if self.has_saved:
292
+ raise RuntimeError("have already saved")
293
+ key = self.next_key
294
+ self.next_key += 1
295
+ name = f"data/{key}"
296
+ if storage.device.type != "cpu":
297
+ storage = storage.cpu()
298
+ num_bytes = storage.nbytes()
299
+ self.zipfile.write_record(name, storage.data_ptr(), num_bytes)
300
+ return key
301
+
302
+ def __exit__(self, type, value, traceback):
303
+ self.zipfile.write_end_of_file()
304
+
305
+
306
+ T = TypeVar("T")
307
+
308
+
309
+ def chunked_cross_entropy(
310
+ logits: Union[torch.Tensor, List[torch.Tensor]],
311
+ targets: torch.Tensor,
312
+ chunk_size: int = 128,
313
+ ignore_index: int = -100,
314
+ ) -> torch.Tensor:
315
+ # with large max_sequence_lengths, the beginning of `backward` allocates a large memory chunk which can dominate
316
+ # the memory usage in fine-tuning settings with low number of parameters.
317
+ # as a workaround hack, the cross entropy computation is chunked to force it to deallocate on the go, reducing
318
+ # the memory spike's magnitude
319
+
320
+ # lm_head was chunked (we are fine-tuning)
321
+ if isinstance(logits, list):
322
+ # don't want to chunk cross entropy
323
+ if chunk_size == 0:
324
+ logits = torch.cat(logits, dim=1)
325
+ logits = logits.reshape(-1, logits.size(-1))
326
+ targets = targets.reshape(-1)
327
+ return torch.nn.functional.cross_entropy(
328
+ logits, targets, ignore_index=ignore_index
329
+ )
330
+
331
+ # chunk cross entropy
332
+ logit_chunks = [
333
+ logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits
334
+ ]
335
+ target_chunks = [
336
+ target_chunk.reshape(-1)
337
+ for target_chunk in targets.split(logits[0].size(1), dim=1)
338
+ ]
339
+ loss_chunks = [
340
+ torch.nn.functional.cross_entropy(
341
+ logit_chunk, target_chunk, ignore_index=ignore_index, reduction="none"
342
+ )
343
+ for logit_chunk, target_chunk in zip(logit_chunks, target_chunks)
344
+ ]
345
+ non_masked_elems = (targets != ignore_index).sum()
346
+ # See [non_masked_elems div note]
347
+ return torch.cat(loss_chunks).sum() / non_masked_elems.maximum(
348
+ torch.ones_like(non_masked_elems)
349
+ )
350
+
351
+ # no chunking at all
352
+ logits = logits.reshape(-1, logits.size(-1))
353
+ targets = targets.reshape(-1)
354
+ if chunk_size == 0:
355
+ return torch.nn.functional.cross_entropy(
356
+ logits, targets, ignore_index=ignore_index
357
+ )
358
+
359
+ # lm_head wasn't chunked, chunk cross entropy
360
+ logit_chunks = logits.split(chunk_size)
361
+ target_chunks = targets.split(chunk_size)
362
+ loss_chunks = [
363
+ torch.nn.functional.cross_entropy(
364
+ logit_chunk, target_chunk, ignore_index=ignore_index, reduction="none"
365
+ )
366
+ for logit_chunk, target_chunk in zip(logit_chunks, target_chunks)
367
+ ]
368
+ non_masked_elems = (targets != ignore_index).sum()
369
+ # [non_masked_elems div note]:
370
+ # max(1, non_masked_elems) would be more ergonomic to avoid a division by zero. However that
371
+ # results in a python int which is then passed back to torch division. By using the
372
+ # `x.maximum(torch.ones_like(x))` pattern we avoid a cudaStreamSynchronize.
373
+ return torch.cat(loss_chunks).sum() / non_masked_elems.maximum(
374
+ torch.ones_like(non_masked_elems)
375
+ )
376
+
377
+
378
+ def map_old_state_dict_weights(state_dict: Dict, mapping: Mapping, prefix: str) -> Dict:
379
+ for checkpoint_name, attribute_name in mapping.items():
380
+ full_checkpoint_name = prefix + checkpoint_name
381
+ if full_checkpoint_name in state_dict:
382
+ full_attribute_name = prefix + attribute_name
383
+ state_dict[full_attribute_name] = state_dict.pop(full_checkpoint_name)
384
+ return state_dict
385
+
386
+
387
+ def get_default_supported_precision(training: bool) -> str:
388
+ """Return default precision that is supported by the hardware: either `bf16` or `16`.
389
+
390
+ Args:
391
+ training: `-mixed` or `-true` version of the precision to use
392
+
393
+ Returns:
394
+ default precision that is suitable for the task and is supported by the hardware
395
+ """
396
+ from lightning.fabric.accelerators import MPSAccelerator
397
+
398
+ if MPSAccelerator.is_available() or (
399
+ torch.cuda.is_available() and not torch.cuda.is_bf16_supported()
400
+ ):
401
+ return "16-mixed" if training else "16-true"
402
+ return "bf16-mixed" if training else "bf16-true"
403
+
404
+
405
+ def load_checkpoint(
406
+ fabric: L.Fabric, model: nn.Module, checkpoint_path: Path, strict: bool = True
407
+ ) -> None:
408
+ if isinstance(fabric.strategy, FSDPStrategy):
409
+ fabric.load_raw(checkpoint_path, model, strict=strict)
410
+ else:
411
+ state_dict = lazy_load(checkpoint_path)
412
+ state_dict = state_dict.get("model", state_dict)
413
+ model.load_state_dict(state_dict, strict=strict)
414
+
415
+
416
+ def flops_per_param(
417
+ max_seq_length: int, n_layer: int, n_embd: int, n_params: int
418
+ ) -> int:
419
+ flops_per_token = (
420
+ 2 * n_params
421
+ ) # each parameter is used for a MAC (2 FLOPS) per network operation
422
+ # this assumes that all samples have a fixed length equal to the block size
423
+ # which is most likely false during finetuning
424
+ flops_per_seq = flops_per_token * max_seq_length
425
+ attn_flops_per_seq = n_layer * 2 * 2 * (n_embd * (max_seq_length**2))
426
+ return flops_per_seq + attn_flops_per_seq
427
+
428
+
429
+ def estimate_flops(model: "GPT", training: bool) -> int:
430
+ """Measures estimated FLOPs for MFU.
431
+
432
+ Refs:
433
+ * https://ar5iv.labs.arxiv.org/html/2205.05198#A1
434
+ * https://ar5iv.labs.arxiv.org/html/2204.02311#A2
435
+ """
436
+ # using all parameters for this is a naive over estimation because not all model parameters actually contribute to
437
+ # this FLOP computation (e.g. embedding, norm). For this reason, the result will be higher by a fixed percentage
438
+ # (~10%) compared to the measured FLOPs, making those lower but more realistic.
439
+ # For a proper estimate, this needs a more fine-grained calculation as in Appendix A of the paper.
440
+ n_trainable_params = num_parameters(model, requires_grad=True)
441
+ trainable_flops = flops_per_param(
442
+ model.max_seq_length,
443
+ model.config.n_layer,
444
+ model.config.n_embd,
445
+ n_trainable_params,
446
+ )
447
+ # forward + backward + gradients (assumes no gradient accumulation)
448
+ ops_per_step = 3 if training else 1
449
+ n_frozen_params = num_parameters(model, requires_grad=False)
450
+ frozen_flops = flops_per_param(
451
+ model.max_seq_length, model.config.n_layer, model.config.n_embd, n_frozen_params
452
+ )
453
+ # forward + backward
454
+ frozen_ops_per_step = 2 if training else 1
455
+ return ops_per_step * trainable_flops + frozen_ops_per_step * frozen_flops
456
+
457
+
458
+ class CycleIterator:
459
+ """An iterator that cycles through an iterable indefinitely.
460
+
461
+ Example:
462
+ >>> iterator = CycleIterator([1, 2, 3])
463
+ >>> [next(iterator) for _ in range(5)]
464
+ [1, 2, 3, 1, 2]
465
+
466
+ Note:
467
+ Unlike ``itertools.cycle``, this iterator does not cache the values of the iterable.
468
+ """
469
+
470
+ def __init__(self, iterable: Iterable) -> None:
471
+ self.iterable = iterable
472
+ self.epoch = 0
473
+ self._iterator = None
474
+
475
+ def __next__(self) -> Any:
476
+ if self._iterator is None:
477
+ self._iterator = iter(self.iterable)
478
+ try:
479
+ return next(self._iterator)
480
+ except StopIteration:
481
+ self._iterator = iter(self.iterable)
482
+ self.epoch += 1
483
+ return next(self._iterator)
484
+
485
+ def __iter__(self) -> Self:
486
+ return self
487
+
488
+
489
+ def copy_config_files(source_dir: Path, out_dir: Path) -> None:
490
+ """Copies the specified configuration and tokenizer files into the output directory."""
491
+
492
+ config_files = ["config.json", "generation_config.json", "model_config.yaml"]
493
+ tokenizer_files = ["tokenizer.json", "tokenizer.model", "tokenizer_config.json"]
494
+
495
+ for file_name in config_files + tokenizer_files:
496
+ src_path = source_dir / file_name
497
+ if src_path.exists():
498
+ shutil.copy(src_path, out_dir)
499
+
500
+
501
+ def CLI(*args: Any, **kwargs: Any) -> Any:
502
+ from jsonargparse import CLI, set_config_read_mode, set_docstring_parse_options
503
+
504
+ set_docstring_parse_options(attribute_docstrings=True)
505
+ set_config_read_mode(urls_enabled=True)
506
+
507
+ return CLI(*args, **kwargs)
508
+
509
+
510
+ def capture_hparams() -> Dict[str, Any]:
511
+ """Captures the local variables ('hyperparameters') from where this function gets called."""
512
+ caller_frame = inspect.currentframe().f_back
513
+ locals_of_caller = caller_frame.f_locals
514
+ hparams = {}
515
+ for name, value in locals_of_caller.items():
516
+ if value is None or isinstance(value, (int, float, str, bool, Path)):
517
+ hparams[name] = value
518
+ elif is_dataclass(value):
519
+ hparams[name] = asdict(value)
520
+ else:
521
+ hparams[name] = str(value)
522
+ return hparams
523
+
524
+
525
+ def save_hyperparameters(function: callable, checkpoint_dir: Path) -> None:
526
+ """Captures the CLI parameters passed to `function` without running `function` and saves them to the checkpoint."""
527
+ from jsonargparse import capture_parser
528
+
529
+ # TODO: Make this more robust
530
+ # This hack strips away the subcommands from the top-level CLI
531
+ # to parse the file as if it was called as a script
532
+ known_commands = [
533
+ ("finetune_full",), # For subcommands, use `("finetune", "full")` etc
534
+ ("finetune_lora",),
535
+ ("finetune_adapter",),
536
+ ("finetune_adapter_v2",),
537
+ ("finetune",),
538
+ ("pretrain",),
539
+ ]
540
+ for known_command in known_commands:
541
+ unwanted = slice(1, 1 + len(known_command))
542
+ if tuple(sys.argv[unwanted]) == known_command:
543
+ sys.argv[unwanted] = []
544
+
545
+ parser = capture_parser(lambda: CLI(function))
546
+ config = parser.parse_args()
547
+ parser.save(config, checkpoint_dir / "hyperparameters.yaml", overwrite=True)
548
+
549
+
550
+ def save_config(config: "Config", checkpoint_dir: Path) -> None:
551
+ config_dict = asdict(config)
552
+ with open(checkpoint_dir / "model_config.yaml", "w", encoding="utf-8") as fp:
553
+ yaml.dump(config_dict, fp)
554
+
555
+
556
+ def parse_devices(devices: Union[str, int]) -> int:
557
+ if devices in (-1, "auto"):
558
+ return torch.cuda.device_count() or 1
559
+ if isinstance(devices, int) and devices > 0:
560
+ return devices
561
+ raise ValueError(f"Devices must be 'auto' or a positive integer, got: {devices!r}")
562
+
563
+
564
+ def choose_logger(
565
+ logger_name: Literal["csv", "tensorboard", "wandb"],
566
+ out_dir: Path,
567
+ name: str,
568
+ log_interval: int = 1,
569
+ resume: Optional[bool] = None,
570
+ **kwargs: Any,
571
+ ):
572
+ if logger_name == "csv":
573
+ return CSVLogger(
574
+ root_dir=(out_dir / "logs"),
575
+ name="csv",
576
+ flush_logs_every_n_steps=log_interval,
577
+ **kwargs,
578
+ )
579
+ if logger_name == "tensorboard":
580
+ return TensorBoardLogger(
581
+ root_dir=(out_dir / "logs"), name="tensorboard", **kwargs
582
+ )
583
+ if logger_name == "wandb":
584
+ return WandbLogger(project=name, resume=resume, **kwargs)
585
+ raise ValueError(
586
+ f"`--logger_name={logger_name}` is not a valid option. Choose from 'csv', 'tensorboard', 'wandb'."
587
+ )
588
+
589
+
590
+ def get_argument_names(cls):
591
+ sig = inspect.signature(cls.__init__)
592
+ return {
593
+ name
594
+ for name, param in sig.parameters.items()
595
+ if param.kind
596
+ in [inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY]
597
+ }
598
+
599
+
600
+ def instantiate_bnb_optimizer(optimizer, model_parameters):
601
+ if (isinstance(optimizer, str) and "AdamW" not in optimizer) or (
602
+ isinstance(optimizer, dict) and "AdamW" not in optimizer.get("class_path", "")
603
+ ):
604
+ raise ValueError(
605
+ "The chosen quantization format only supports the AdamW optimizer."
606
+ )
607
+
608
+ import bitsandbytes as bnb
609
+
610
+ if isinstance(optimizer, str):
611
+ optimizer = bnb.optim.PagedAdamW(model_parameters)
612
+ else:
613
+ optim_args = get_argument_names(bnb.optim.PagedAdamW)
614
+ allowed_kwargs = {
615
+ key: optimizer["init_args"][key]
616
+ for key in optim_args & optimizer["init_args"].keys()
617
+ }
618
+ optimizer = bnb.optim.PagedAdamW(model_parameters, **allowed_kwargs)
619
+ return optimizer
620
+
621
+
622
+ def instantiate_torch_optimizer(optimizer, model_parameters, **kwargs):
623
+ if isinstance(optimizer, str):
624
+ optimizer_cls = getattr(torch.optim, optimizer)
625
+ optimizer = optimizer_cls(model_parameters, **kwargs)
626
+ else:
627
+ optimizer = dict(optimizer) # copy
628
+ optimizer["init_args"].update(kwargs)
629
+ optimizer = instantiate_class(model_parameters, optimizer)
630
+ return optimizer
631
+
632
+
633
+ def extend_checkpoint_dir(checkpoint_dir: Path) -> Path:
634
+ new_checkpoint_dir = "checkpoints" / checkpoint_dir
635
+ should_return_new_dir = (
636
+ not checkpoint_dir.is_dir()
637
+ and checkpoint_dir.parts[0] != "checkpoints"
638
+ and not checkpoint_dir.is_absolute()
639
+ and new_checkpoint_dir.exists()
640
+ )
641
+ return new_checkpoint_dir if should_return_new_dir else checkpoint_dir
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.3.1
2
+ torchvision==0.18.1
3
+ torchaudio==2.3.1
4
+ litgpt==0.4.3
5
+ snac==1.2.0
6
+ soundfile==0.12.1
7
+ openai-whisper
8
+ tokenizers==0.19.1
9
+ streamlit==1.37.1
10
+ streamlit-webrtc
11
+ # PyAudio==0.2.14
12
+ pydub==0.25.1
13
+ onnxruntime==1.19.0
14
+ # numpy==1.26.3
15
+ librosa==0.10.2.post1
16
+ flask==3.0.3
17
+ fire
18
+ git+https://github.com/mini-omni/CLIP.git
19
+ gradio_webrtc[vad]==0.0.11
20
+ twilio
server.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import flask
2
+ import base64
3
+ import tempfile
4
+ import traceback
5
+ from flask import Flask, Response, stream_with_context
6
+ from inference_vision import OmniVisionInference
7
+
8
+
9
+ class OmniChatServer(object):
10
+ def __init__(self, ip='0.0.0.0', port=60808, run_app=True,
11
+ ckpt_dir='./checkpoint', device='cuda:0') -> None:
12
+ server = Flask(__name__)
13
+ # CORS(server, resources=r"/*")
14
+ # server.config["JSON_AS_ASCII"] = False
15
+
16
+ self.client = OmniVisionInference(ckpt_dir, device)
17
+ self.client.warm_up()
18
+
19
+ server.route("/chat", methods=["POST"])(self.chat)
20
+
21
+ if run_app:
22
+ server.run(host=ip, port=port, threaded=False)
23
+ else:
24
+ self.server = server
25
+
26
+ def chat(self) -> Response:
27
+
28
+ req_data = flask.request.get_json()
29
+ try:
30
+ audio_data_buf = req_data["audio"].encode("utf-8")
31
+ audio_data_buf = base64.b64decode(audio_data_buf)
32
+ stream_stride = req_data.get("stream_stride", 4)
33
+ max_tokens = req_data.get("max_tokens", 2048)
34
+
35
+ image_data_buf = req_data.get("image", None)
36
+ if image_data_buf:
37
+ image_data_buf = image_data_buf.encode("utf-8")
38
+ image_data_buf = base64.b64decode(image_data_buf)
39
+
40
+ audio_path, img_path = None, None
41
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as audio_f, \
42
+ tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as img_f:
43
+ audio_f.write(audio_data_buf)
44
+ audio_path = audio_f.name
45
+
46
+ if image_data_buf:
47
+ img_f.write(image_data_buf)
48
+ img_path = img_f.name
49
+ else:
50
+ img_path = None
51
+
52
+ if img_path is not None:
53
+ resp_generator = self.client.run_vision_AA_batch_stream(audio_f.name, img_f.name,
54
+ stream_stride, max_tokens,
55
+ save_path='./vision_qa_out_cache.wav')
56
+ else:
57
+ resp_generator = self.client.run_AT_batch_stream(audio_f.name, stream_stride,
58
+ max_tokens,
59
+ save_path='./audio_qa_out_cache.wav')
60
+ return Response(stream_with_context(self.generator(resp_generator)),
61
+ mimetype='multipart/x-mixed-replace; boundary=frame')
62
+ except Exception as e:
63
+ print(traceback.format_exc())
64
+ return Response("An error occurred", status=500)
65
+
66
+ def generator(self, resp_generator):
67
+ for audio_stream, text_stream in resp_generator:
68
+ yield b'\r\n--frame\r\n'
69
+ yield b'Content-Type: audio/wav\r\n\r\n'
70
+ yield audio_stream
71
+ yield b'\r\n--frame\r\n'
72
+ yield b'Content-Type: text/plain\r\n\r\n'
73
+ yield text_stream.encode()
74
+
75
+
76
+ # CUDA_VISIBLE_DEVICES=1 gunicorn -w 2 -b 0.0.0.0:60808 'server:create_app()'
77
+ def create_app():
78
+ server = OmniChatServer(run_app=False)
79
+ return server.server
80
+
81
+
82
+ def serve(ip='0.0.0.0', port=60808, device='cuda:0'):
83
+
84
+ OmniChatServer(ip, port=port,run_app=True, device=device)
85
+
86
+
87
+ if __name__ == "__main__":
88
+ import fire
89
+ fire.Fire(serve)
90
+
utils/assets/silero_vad.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:591f853590d11ddde2f2a54f9e7ccecb2533a8af7716330e8adfa6f3849787a9
3
+ size 1807524
utils/snac_utils.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import time
3
+ import numpy as np
4
+
5
+
6
+ class SnacConfig:
7
+ audio_vocab_size = 4096
8
+ padded_vocab_size = 4160
9
+ end_of_audio = 4097
10
+
11
+
12
+ snac_config = SnacConfig()
13
+
14
+
15
+ def get_time_str():
16
+ time_str = time.strftime("%Y%m%d_%H%M%S", time.localtime())
17
+ return time_str
18
+
19
+
20
+ def layershift(input_id, layer, stride=4160, shift=152000):
21
+ return input_id + shift + layer * stride
22
+
23
+
24
+ def generate_audio_data(snac_tokens, snacmodel, device=None):
25
+ audio = reconstruct_tensors(snac_tokens, device)
26
+ with torch.inference_mode():
27
+ audio_hat = snacmodel.decode(audio)
28
+ audio_data = audio_hat.cpu().numpy().astype(np.float64) * 32768.0
29
+ audio_data = audio_data.astype(np.int16)
30
+ audio_data = audio_data.tobytes()
31
+ return audio_data
32
+
33
+
34
+ def get_snac(list_output, index, nums_generate):
35
+
36
+ snac = []
37
+ start = index
38
+ for i in range(nums_generate):
39
+ snac.append("#")
40
+ for j in range(7):
41
+ snac.append(list_output[j][start - nums_generate - 5 + j + i])
42
+ return snac
43
+
44
+
45
+ def reconscruct_snac(output_list):
46
+ if len(output_list) == 8:
47
+ output_list = output_list[:-1]
48
+ output = []
49
+ for i in range(7):
50
+ output_list[i] = output_list[i][i + 1 :]
51
+ for i in range(len(output_list[-1])):
52
+ output.append("#")
53
+ for j in range(7):
54
+ output.append(output_list[j][i])
55
+ return output
56
+
57
+
58
+ def reconstruct_tensors(flattened_output, device=None):
59
+ """Reconstructs the list of tensors from the flattened output."""
60
+
61
+ if device is None:
62
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
63
+
64
+ def count_elements_between_hashes(lst):
65
+ try:
66
+ # Find the index of the first '#'
67
+ first_index = lst.index("#")
68
+ # Find the index of the second '#' after the first
69
+ second_index = lst.index("#", first_index + 1)
70
+ # Count the elements between the two indices
71
+ return second_index - first_index - 1
72
+ except ValueError:
73
+ # Handle the case where there aren't enough '#' symbols
74
+ return "List does not contain two '#' symbols"
75
+
76
+ def remove_elements_before_hash(flattened_list):
77
+ try:
78
+ # Find the index of the first '#'
79
+ first_hash_index = flattened_list.index("#")
80
+ # Return the list starting from the first '#'
81
+ return flattened_list[first_hash_index:]
82
+ except ValueError:
83
+ # Handle the case where there is no '#'
84
+ return "List does not contain the symbol '#'"
85
+
86
+ def list_to_torch_tensor(tensor1):
87
+ # Convert the list to a torch tensor
88
+ tensor = torch.tensor(tensor1)
89
+ # Reshape the tensor to have size (1, n)
90
+ tensor = tensor.unsqueeze(0)
91
+ return tensor
92
+
93
+ flattened_output = remove_elements_before_hash(flattened_output)
94
+ codes = []
95
+ tensor1 = []
96
+ tensor2 = []
97
+ tensor3 = []
98
+ tensor4 = []
99
+
100
+ n_tensors = count_elements_between_hashes(flattened_output)
101
+ if n_tensors == 7:
102
+ for i in range(0, len(flattened_output), 8):
103
+
104
+ tensor1.append(flattened_output[i + 1])
105
+ tensor2.append(flattened_output[i + 2])
106
+ tensor3.append(flattened_output[i + 3])
107
+ tensor3.append(flattened_output[i + 4])
108
+
109
+ tensor2.append(flattened_output[i + 5])
110
+ tensor3.append(flattened_output[i + 6])
111
+ tensor3.append(flattened_output[i + 7])
112
+ codes = [
113
+ list_to_torch_tensor(tensor1).to(device),
114
+ list_to_torch_tensor(tensor2).to(device),
115
+ list_to_torch_tensor(tensor3).to(device),
116
+ ]
117
+
118
+ if n_tensors == 15:
119
+ for i in range(0, len(flattened_output), 16):
120
+
121
+ tensor1.append(flattened_output[i + 1])
122
+ tensor2.append(flattened_output[i + 2])
123
+ tensor3.append(flattened_output[i + 3])
124
+ tensor4.append(flattened_output[i + 4])
125
+ tensor4.append(flattened_output[i + 5])
126
+ tensor3.append(flattened_output[i + 6])
127
+ tensor4.append(flattened_output[i + 7])
128
+ tensor4.append(flattened_output[i + 8])
129
+
130
+ tensor2.append(flattened_output[i + 9])
131
+ tensor3.append(flattened_output[i + 10])
132
+ tensor4.append(flattened_output[i + 11])
133
+ tensor4.append(flattened_output[i + 12])
134
+ tensor3.append(flattened_output[i + 13])
135
+ tensor4.append(flattened_output[i + 14])
136
+ tensor4.append(flattened_output[i + 15])
137
+
138
+ codes = [
139
+ list_to_torch_tensor(tensor1).to(device),
140
+ list_to_torch_tensor(tensor2).to(device),
141
+ list_to_torch_tensor(tensor3).to(device),
142
+ list_to_torch_tensor(tensor4).to(device),
143
+ ]
144
+
145
+ return codes
146
+
utils/vad.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import bisect
2
+ import functools
3
+ import os
4
+ import warnings
5
+
6
+ from typing import List, NamedTuple, Optional
7
+
8
+ import numpy as np
9
+
10
+
11
+ # The code below is adapted from https://github.com/snakers4/silero-vad.
12
+ class VadOptions(NamedTuple):
13
+ """VAD options.
14
+
15
+ Attributes:
16
+ threshold: Speech threshold. Silero VAD outputs speech probabilities for each audio chunk,
17
+ probabilities ABOVE this value are considered as SPEECH. It is better to tune this
18
+ parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
19
+ min_speech_duration_ms: Final speech chunks shorter min_speech_duration_ms are thrown out.
20
+ max_speech_duration_s: Maximum duration of speech chunks in seconds. Chunks longer
21
+ than max_speech_duration_s will be split at the timestamp of the last silence that
22
+ lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise, they will be
23
+ split aggressively just before max_speech_duration_s.
24
+ min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms
25
+ before separating it
26
+ window_size_samples: Audio chunks of window_size_samples size are fed to the silero VAD model.
27
+ WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 sample rate.
28
+ Values other than these may affect model performance!!
29
+ speech_pad_ms: Final speech chunks are padded by speech_pad_ms each side
30
+ """
31
+
32
+ threshold: float = 0.5
33
+ min_speech_duration_ms: int = 250
34
+ max_speech_duration_s: float = float("inf")
35
+ min_silence_duration_ms: int = 2000
36
+ window_size_samples: int = 1024
37
+ speech_pad_ms: int = 400
38
+
39
+
40
+ def get_speech_timestamps(
41
+ audio: np.ndarray,
42
+ vad_options: Optional[VadOptions] = None,
43
+ **kwargs,
44
+ ) -> List[dict]:
45
+ """This method is used for splitting long audios into speech chunks using silero VAD.
46
+
47
+ Args:
48
+ audio: One dimensional float array.
49
+ vad_options: Options for VAD processing.
50
+ kwargs: VAD options passed as keyword arguments for backward compatibility.
51
+
52
+ Returns:
53
+ List of dicts containing begin and end samples of each speech chunk.
54
+ """
55
+ if vad_options is None:
56
+ vad_options = VadOptions(**kwargs)
57
+
58
+ threshold = vad_options.threshold
59
+ min_speech_duration_ms = vad_options.min_speech_duration_ms
60
+ max_speech_duration_s = vad_options.max_speech_duration_s
61
+ min_silence_duration_ms = vad_options.min_silence_duration_ms
62
+ window_size_samples = vad_options.window_size_samples
63
+ speech_pad_ms = vad_options.speech_pad_ms
64
+
65
+ if window_size_samples not in [512, 1024, 1536]:
66
+ warnings.warn(
67
+ "Unusual window_size_samples! Supported window_size_samples:\n"
68
+ " - [512, 1024, 1536] for 16000 sampling_rate"
69
+ )
70
+
71
+ sampling_rate = 16000
72
+ min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
73
+ speech_pad_samples = sampling_rate * speech_pad_ms / 1000
74
+ max_speech_samples = (
75
+ sampling_rate * max_speech_duration_s
76
+ - window_size_samples
77
+ - 2 * speech_pad_samples
78
+ )
79
+ min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
80
+ min_silence_samples_at_max_speech = sampling_rate * 98 / 1000
81
+
82
+ audio_length_samples = len(audio)
83
+
84
+ model = get_vad_model()
85
+ state = model.get_initial_state(batch_size=1)
86
+
87
+ speech_probs = []
88
+ for current_start_sample in range(0, audio_length_samples, window_size_samples):
89
+ chunk = audio[current_start_sample : current_start_sample + window_size_samples]
90
+ if len(chunk) < window_size_samples:
91
+ chunk = np.pad(chunk, (0, int(window_size_samples - len(chunk))))
92
+ speech_prob, state = model(chunk, state, sampling_rate)
93
+ speech_probs.append(speech_prob)
94
+
95
+ triggered = False
96
+ speeches = []
97
+ current_speech = {}
98
+ neg_threshold = threshold - 0.15
99
+
100
+ # to save potential segment end (and tolerate some silence)
101
+ temp_end = 0
102
+ # to save potential segment limits in case of maximum segment size reached
103
+ prev_end = next_start = 0
104
+
105
+ for i, speech_prob in enumerate(speech_probs):
106
+ if (speech_prob >= threshold) and temp_end:
107
+ temp_end = 0
108
+ if next_start < prev_end:
109
+ next_start = window_size_samples * i
110
+
111
+ if (speech_prob >= threshold) and not triggered:
112
+ triggered = True
113
+ current_speech["start"] = window_size_samples * i
114
+ continue
115
+
116
+ if (
117
+ triggered
118
+ and (window_size_samples * i) - current_speech["start"] > max_speech_samples
119
+ ):
120
+ if prev_end:
121
+ current_speech["end"] = prev_end
122
+ speeches.append(current_speech)
123
+ current_speech = {}
124
+ # previously reached silence (< neg_thres) and is still not speech (< thres)
125
+ if next_start < prev_end:
126
+ triggered = False
127
+ else:
128
+ current_speech["start"] = next_start
129
+ prev_end = next_start = temp_end = 0
130
+ else:
131
+ current_speech["end"] = window_size_samples * i
132
+ speeches.append(current_speech)
133
+ current_speech = {}
134
+ prev_end = next_start = temp_end = 0
135
+ triggered = False
136
+ continue
137
+
138
+ if (speech_prob < neg_threshold) and triggered:
139
+ if not temp_end:
140
+ temp_end = window_size_samples * i
141
+ # condition to avoid cutting in very short silence
142
+ if (window_size_samples * i) - temp_end > min_silence_samples_at_max_speech:
143
+ prev_end = temp_end
144
+ if (window_size_samples * i) - temp_end < min_silence_samples:
145
+ continue
146
+ else:
147
+ current_speech["end"] = temp_end
148
+ if (
149
+ current_speech["end"] - current_speech["start"]
150
+ ) > min_speech_samples:
151
+ speeches.append(current_speech)
152
+ current_speech = {}
153
+ prev_end = next_start = temp_end = 0
154
+ triggered = False
155
+ continue
156
+
157
+ if (
158
+ current_speech
159
+ and (audio_length_samples - current_speech["start"]) > min_speech_samples
160
+ ):
161
+ current_speech["end"] = audio_length_samples
162
+ speeches.append(current_speech)
163
+
164
+ for i, speech in enumerate(speeches):
165
+ if i == 0:
166
+ speech["start"] = int(max(0, speech["start"] - speech_pad_samples))
167
+ if i != len(speeches) - 1:
168
+ silence_duration = speeches[i + 1]["start"] - speech["end"]
169
+ if silence_duration < 2 * speech_pad_samples:
170
+ speech["end"] += int(silence_duration // 2)
171
+ speeches[i + 1]["start"] = int(
172
+ max(0, speeches[i + 1]["start"] - silence_duration // 2)
173
+ )
174
+ else:
175
+ speech["end"] = int(
176
+ min(audio_length_samples, speech["end"] + speech_pad_samples)
177
+ )
178
+ speeches[i + 1]["start"] = int(
179
+ max(0, speeches[i + 1]["start"] - speech_pad_samples)
180
+ )
181
+ else:
182
+ speech["end"] = int(
183
+ min(audio_length_samples, speech["end"] + speech_pad_samples)
184
+ )
185
+
186
+ return speeches
187
+
188
+
189
+ def collect_chunks(audio: np.ndarray, chunks: List[dict]) -> np.ndarray:
190
+ """Collects and concatenates audio chunks."""
191
+ if not chunks:
192
+ return np.array([], dtype=np.float32)
193
+
194
+ return np.concatenate([audio[chunk["start"] : chunk["end"]] for chunk in chunks])
195
+
196
+
197
+ class SpeechTimestampsMap:
198
+ """Helper class to restore original speech timestamps."""
199
+
200
+ def __init__(self, chunks: List[dict], sampling_rate: int, time_precision: int = 2):
201
+ self.sampling_rate = sampling_rate
202
+ self.time_precision = time_precision
203
+ self.chunk_end_sample = []
204
+ self.total_silence_before = []
205
+
206
+ previous_end = 0
207
+ silent_samples = 0
208
+
209
+ for chunk in chunks:
210
+ silent_samples += chunk["start"] - previous_end
211
+ previous_end = chunk["end"]
212
+
213
+ self.chunk_end_sample.append(chunk["end"] - silent_samples)
214
+ self.total_silence_before.append(silent_samples / sampling_rate)
215
+
216
+ def get_original_time(
217
+ self,
218
+ time: float,
219
+ chunk_index: Optional[int] = None,
220
+ ) -> float:
221
+ if chunk_index is None:
222
+ chunk_index = self.get_chunk_index(time)
223
+
224
+ total_silence_before = self.total_silence_before[chunk_index]
225
+ return round(total_silence_before + time, self.time_precision)
226
+
227
+ def get_chunk_index(self, time: float) -> int:
228
+ sample = int(time * self.sampling_rate)
229
+ return min(
230
+ bisect.bisect(self.chunk_end_sample, sample),
231
+ len(self.chunk_end_sample) - 1,
232
+ )
233
+
234
+
235
+ @functools.lru_cache
236
+ def get_vad_model():
237
+ """Returns the VAD model instance."""
238
+ asset_dir = os.path.join(os.path.dirname(__file__), "assets")
239
+ path = os.path.join(asset_dir, "silero_vad.onnx")
240
+ return SileroVADModel(path)
241
+
242
+
243
+ class SileroVADModel:
244
+ def __init__(self, path):
245
+ try:
246
+ import onnxruntime
247
+ except ImportError as e:
248
+ raise RuntimeError(
249
+ "Applying the VAD filter requires the onnxruntime package"
250
+ ) from e
251
+
252
+ opts = onnxruntime.SessionOptions()
253
+ opts.inter_op_num_threads = 1
254
+ opts.intra_op_num_threads = 1
255
+ opts.log_severity_level = 4
256
+
257
+ self.session = onnxruntime.InferenceSession(
258
+ path,
259
+ providers=["CPUExecutionProvider"],
260
+ sess_options=opts,
261
+ )
262
+
263
+ def get_initial_state(self, batch_size: int):
264
+ h = np.zeros((2, batch_size, 64), dtype=np.float32)
265
+ c = np.zeros((2, batch_size, 64), dtype=np.float32)
266
+ return h, c
267
+
268
+ def __call__(self, x, state, sr: int):
269
+ if len(x.shape) == 1:
270
+ x = np.expand_dims(x, 0)
271
+ if len(x.shape) > 2:
272
+ raise ValueError(
273
+ f"Too many dimensions for input audio chunk {len(x.shape)}"
274
+ )
275
+ if sr / x.shape[1] > 31.25:
276
+ raise ValueError("Input audio chunk is too short")
277
+
278
+ h, c = state
279
+
280
+ ort_inputs = {
281
+ "input": x,
282
+ "h": h,
283
+ "c": c,
284
+ "sr": np.array(sr, dtype="int64"),
285
+ }
286
+
287
+ out, h, c = self.session.run(None, ort_inputs)
288
+ state = (h, c)
289
+
290
+ return out, state