skytnt commited on
Commit
ed10990
1 Parent(s): 942f170

update to onnx

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.mid filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -1,48 +1,72 @@
1
  import argparse
2
  import glob
3
-
4
- import PIL
 
 
5
  import gradio as gr
6
  import numpy as np
7
- import torch
8
-
9
- import torch.nn.functional as F
10
  import tqdm
 
11
 
12
  import MIDI
13
- from midi_model import MIDIModel
14
- from midi_tokenizer import MIDITokenizer
15
  from midi_synthesizer import synthesis
16
- from huggingface_hub import hf_hub_download
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- @torch.inference_mode()
19
  def generate(prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
20
- disable_patch_change=False, disable_control_change=False, disable_channels=None, amp=True):
21
  if disable_channels is not None:
22
  disable_channels = [tokenizer.parameter_ids["channel"][c] for c in disable_channels]
23
  else:
24
  disable_channels = []
25
  max_token_seq = tokenizer.max_token_seq
26
  if prompt is None:
27
- input_tensor = torch.full((1, max_token_seq), tokenizer.pad_id, dtype=torch.long, device=model.device)
28
  input_tensor[0, 0] = tokenizer.bos_id # bos
29
  else:
30
  prompt = prompt[:, :max_token_seq]
31
  if prompt.shape[-1] < max_token_seq:
32
  prompt = np.pad(prompt, ((0, 0), (0, max_token_seq - prompt.shape[-1])),
33
  mode="constant", constant_values=tokenizer.pad_id)
34
- input_tensor = torch.from_numpy(prompt).to(dtype=torch.long, device=model.device)
35
- input_tensor = input_tensor.unsqueeze(0)
36
  cur_len = input_tensor.shape[1]
37
  bar = tqdm.tqdm(desc="generating", total=max_len - cur_len)
38
- with bar, torch.cuda.amp.autocast(enabled=amp):
39
  while cur_len < max_len:
40
  end = False
41
- hidden = model.forward(input_tensor)[0, -1].unsqueeze(0)
42
- next_token_seq = None
43
  event_name = ""
44
  for i in range(max_token_seq):
45
- mask = torch.zeros(tokenizer.vocab_size, dtype=torch.int64, device=model.device)
46
  if i == 0:
47
  mask_ids = list(tokenizer.event_ids.values()) + [tokenizer.eos_id]
48
  if disable_patch_change:
@@ -56,9 +80,9 @@ def generate(prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
56
  if param_name == "channel":
57
  mask_ids = [i for i in mask_ids if i not in disable_channels]
58
  mask[mask_ids] = 1
59
- logits = model.forward_token(hidden, next_token_seq)[:, -1:]
60
- scores = torch.softmax(logits / temp, dim=-1) * mask
61
- sample = model.sample_top_p_k(scores, top_p, top_k)
62
  if i == 0:
63
  next_token_seq = sample
64
  eid = sample.item()
@@ -67,29 +91,30 @@ def generate(prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
67
  break
68
  event_name = tokenizer.id_events[eid]
69
  else:
70
- next_token_seq = torch.cat([next_token_seq, sample], dim=1)
71
  if len(tokenizer.events[event_name]) == i:
72
  break
73
  if next_token_seq.shape[1] < max_token_seq:
74
- next_token_seq = F.pad(next_token_seq, (0, max_token_seq - next_token_seq.shape[1]),
75
- "constant", value=tokenizer.pad_id)
76
- next_token_seq = next_token_seq.unsqueeze(1)
77
- input_tensor = torch.cat([input_tensor, next_token_seq], dim=1)
78
  cur_len += 1
79
  bar.update(1)
80
- yield next_token_seq.reshape(-1).cpu().numpy()
81
  if end:
82
  break
83
 
84
 
85
- def run(tab, instruments, drum_kit, mid, midi_events, gen_events, temp, top_p, top_k, allow_cc, amp):
86
  mid_seq = []
87
  max_len = int(gen_events)
88
  img_len = 1024
89
  img = np.full((128 * 2, img_len, 3), 255, dtype=np.uint8)
90
  state = {"t1": 0, "t": 0, "cur_pos": 0}
91
- rand = np.random.RandomState(0)
92
- colors = {(i, j): rand.randint(0, 200, 3) for i in range(128) for j in range(16)}
 
93
 
94
  def draw_event(tokens):
95
  if tokens[0] in tokenizer.id_events:
@@ -112,7 +137,7 @@ def run(tab, instruments, drum_kit, mid, midi_events, gen_events, temp, top_p, t
112
  img[:, -shift:] = 255
113
  state["cur_pos"] += shift
114
  t = t - state["cur_pos"]
115
- img[p * 2:(p + 1) * 2, t: t + d] = colors[(tr, c)]
116
 
117
  def get_img():
118
  t = state["t"] - state["cur_pos"]
@@ -135,7 +160,7 @@ def run(tab, instruments, drum_kit, mid, midi_events, gen_events, temp, top_p, t
135
  mid.append(tokenizer.event2tokens(["patch_change", 0, 0, i, c, p]))
136
  mid_seq = mid
137
  mid = np.asarray(mid, dtype=np.int64)
138
- if len(instruments) > 0 or drum_kit != "None":
139
  disable_patch_change = True
140
  disable_channels = [i for i in range(16) if i not in patches]
141
  elif mid is not None:
@@ -148,7 +173,7 @@ def run(tab, instruments, drum_kit, mid, midi_events, gen_events, temp, top_p, t
148
  draw_event(token_seq)
149
  generator = generate(mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
150
  disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
151
- disable_channels=disable_channels, amp=amp)
152
  for token_seq in generator:
153
  mid_seq.append(token_seq)
154
  draw_event(token_seq)
@@ -179,17 +204,16 @@ if __name__ == "__main__":
179
  parser = argparse.ArgumentParser()
180
  parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
181
  parser.add_argument("--port", type=int, default=7860, help="gradio server port")
182
- parser.add_argument("--device", type=str, default="cpu", help="device to run model")
183
- parser.add_argument("--max-gen", type=int, default=512, help="max")
184
- soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2")
185
- model_path = hf_hub_download(repo_id="skytnt/midi-model", filename="model.ckpt")
186
  opt = parser.parse_args()
 
 
 
 
187
  tokenizer = MIDITokenizer()
188
- model = MIDIModel(tokenizer).to(device=opt.device)
189
- ckpt = torch.load(model_path, map_location="cpu")
190
- state_dict = ckpt.get("state_dict", ckpt)
191
- model.load_state_dict(state_dict, strict=False)
192
- model.eval()
193
 
194
  app = gr.Blocks()
195
  with app:
@@ -199,39 +223,52 @@ if __name__ == "__main__":
199
  "Demo for [SkyTNT/midi-model](https://github.com/SkyTNT/midi-model)\n\n"
200
  "[Open In Colab]"
201
  "(https://colab.research.google.com/github/SkyTNT/midi-model/blob/main/demo.ipynb)"
202
- " for faster running")
 
203
 
204
  tab_select = gr.Variable(value=0)
205
  with gr.Tabs():
206
  with gr.TabItem("instrument prompt") as tab1:
207
  input_instruments = gr.Dropdown(label="instruments (auto if empty)", choices=list(patch2number.keys()),
208
- multiselect=True, max_choices=10, type="value")
209
  input_drum_kit = gr.Dropdown(label="drum kit", choices=list(drum_kits2number.keys()), type="value",
210
  value="None")
 
 
 
 
 
 
 
 
 
 
 
211
  with gr.TabItem("midi prompt") as tab2:
212
  input_midi = gr.File(label="input midi", file_types=[".midi", ".mid"], type="binary")
213
  input_midi_events = gr.Slider(label="use first n midi events as prompt", minimum=1, maximum=512,
214
  step=1,
215
  value=128)
 
 
216
 
217
  tab1.select(lambda: 0, None, tab_select, queue=False)
218
  tab2.select(lambda: 1, None, tab_select, queue=False)
219
  input_gen_events = gr.Slider(label="generate n midi events", minimum=1, maximum=opt.max_gen,
220
  step=1, value=opt.max_gen)
221
  input_temp = gr.Slider(label="temperature", minimum=0.1, maximum=1.2, step=0.01, value=1)
222
- input_top_p = gr.Slider(label="top p", minimum=0.1, maximum=1, step=0.01, value=0.97)
223
  input_top_k = gr.Slider(label="top k", minimum=1, maximum=50, step=1, value=20)
224
- input_allow_cc = gr.Checkbox(label="allow control change event", value=True)
225
- input_amp = gr.Checkbox(label="enable amp", value=True)
226
  run_btn = gr.Button("generate", variant="primary")
227
  stop_btn = gr.Button("stop")
228
  output_midi_seq = gr.Variable()
229
  output_midi_img = gr.Image(label="output image")
230
  output_midi = gr.File(label="output midi", file_types=[".mid"])
231
- output_audio = gr.Audio(label="output audio", format="mp3")
232
  run_event = run_btn.click(run, [tab_select, input_instruments, input_drum_kit, input_midi, input_midi_events,
233
  input_gen_events, input_temp, input_top_p, input_top_k,
234
- input_allow_cc, input_amp],
235
  [output_midi_seq, output_midi_img, output_midi, output_audio])
236
  stop_btn.click(cancel_run, output_midi_seq, [output_midi, output_audio], cancels=run_event, queue=False)
237
- app.queue(1).launch(server_port=opt.port, share=opt.share)
 
1
  import argparse
2
  import glob
3
+ import os
4
+ import os.path
5
+ from sys import exit
6
+ import shutil
7
  import gradio as gr
8
  import numpy as np
9
+ import onnxruntime as rt
10
+ import PIL
11
+ import PIL.ImageColor
12
  import tqdm
13
+ from huggingface_hub import hf_hub_download
14
 
15
  import MIDI
 
 
16
  from midi_synthesizer import synthesis
17
+ from midi_tokenizer import MIDITokenizer
18
+
19
+ def softmax(x, axis):
20
+ x_max = np.amax(x, axis=axis, keepdims=True)
21
+ exp_x_shifted = np.exp(x - x_max)
22
+ return exp_x_shifted / np.sum(exp_x_shifted, axis=axis, keepdims=True)
23
+
24
+
25
+ def sample_top_p_k(probs, p, k):
26
+ probs_idx = np.argsort(-probs, axis=-1)
27
+ probs_sort = np.take_along_axis(probs, probs_idx, -1)
28
+ probs_sum = np.cumsum(probs_sort, axis=-1)
29
+ mask = probs_sum - probs_sort > p
30
+ probs_sort[mask] = 0.0
31
+ mask = np.zeros(probs_sort.shape[-1])
32
+ mask[:k] = 1
33
+ probs_sort = probs_sort * mask
34
+ probs_sort /= np.sum(probs_sort, axis=-1, keepdims=True)
35
+ shape = probs_sort.shape
36
+ probs_sort_flat = probs_sort.reshape(-1, shape[-1])
37
+ probs_idx_flat = probs_idx.reshape(-1, shape[-1])
38
+ next_token = np.stack([np.random.choice(idxs, p=pvals) for pvals, idxs in zip(probs_sort_flat, probs_idx_flat)])
39
+ next_token = next_token.reshape(*shape[:-1])
40
+ return next_token
41
+
42
 
 
43
  def generate(prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
44
+ disable_patch_change=False, disable_control_change=False, disable_channels=None):
45
  if disable_channels is not None:
46
  disable_channels = [tokenizer.parameter_ids["channel"][c] for c in disable_channels]
47
  else:
48
  disable_channels = []
49
  max_token_seq = tokenizer.max_token_seq
50
  if prompt is None:
51
+ input_tensor = np.full((1, max_token_seq), tokenizer.pad_id, dtype=np.int64)
52
  input_tensor[0, 0] = tokenizer.bos_id # bos
53
  else:
54
  prompt = prompt[:, :max_token_seq]
55
  if prompt.shape[-1] < max_token_seq:
56
  prompt = np.pad(prompt, ((0, 0), (0, max_token_seq - prompt.shape[-1])),
57
  mode="constant", constant_values=tokenizer.pad_id)
58
+ input_tensor = prompt
59
+ input_tensor = input_tensor[None, :, :]
60
  cur_len = input_tensor.shape[1]
61
  bar = tqdm.tqdm(desc="generating", total=max_len - cur_len)
62
+ with bar:
63
  while cur_len < max_len:
64
  end = False
65
+ hidden = model_base.run(None, {'x': input_tensor})[0][:, -1]
66
+ next_token_seq = np.empty((1, 0), dtype=np.int64)
67
  event_name = ""
68
  for i in range(max_token_seq):
69
+ mask = np.zeros(tokenizer.vocab_size, dtype=np.int64)
70
  if i == 0:
71
  mask_ids = list(tokenizer.event_ids.values()) + [tokenizer.eos_id]
72
  if disable_patch_change:
 
80
  if param_name == "channel":
81
  mask_ids = [i for i in mask_ids if i not in disable_channels]
82
  mask[mask_ids] = 1
83
+ logits = model_token.run(None, {'x': next_token_seq, "hidden": hidden})[0][:, -1:]
84
+ scores = softmax(logits / temp, -1) * mask
85
+ sample = sample_top_p_k(scores, top_p, top_k)
86
  if i == 0:
87
  next_token_seq = sample
88
  eid = sample.item()
 
91
  break
92
  event_name = tokenizer.id_events[eid]
93
  else:
94
+ next_token_seq = np.concatenate([next_token_seq, sample], axis=1)
95
  if len(tokenizer.events[event_name]) == i:
96
  break
97
  if next_token_seq.shape[1] < max_token_seq:
98
+ next_token_seq = np.pad(next_token_seq, ((0, 0), (0, max_token_seq - next_token_seq.shape[-1])),
99
+ mode="constant", constant_values=tokenizer.pad_id)
100
+ next_token_seq = next_token_seq[None, :, :]
101
+ input_tensor = np.concatenate([input_tensor, next_token_seq], axis=1)
102
  cur_len += 1
103
  bar.update(1)
104
+ yield next_token_seq.reshape(-1)
105
  if end:
106
  break
107
 
108
 
109
+ def run(tab, instruments, drum_kit, mid, midi_events, gen_events, temp, top_p, top_k, allow_cc):
110
  mid_seq = []
111
  max_len = int(gen_events)
112
  img_len = 1024
113
  img = np.full((128 * 2, img_len, 3), 255, dtype=np.uint8)
114
  state = {"t1": 0, "t": 0, "cur_pos": 0}
115
+ colors = ['navy', 'blue', 'deepskyblue', 'teal', 'green', 'lightgreen', 'lime', 'orange',
116
+ 'brown', 'grey', 'red', 'pink', 'aqua', 'orchid', 'bisque', 'coral']
117
+ colors = [PIL.ImageColor.getrgb(color) for color in colors]
118
 
119
  def draw_event(tokens):
120
  if tokens[0] in tokenizer.id_events:
 
137
  img[:, -shift:] = 255
138
  state["cur_pos"] += shift
139
  t = t - state["cur_pos"]
140
+ img[p * 2:(p + 1) * 2, t: t + d] = colors[c]
141
 
142
  def get_img():
143
  t = state["t"] - state["cur_pos"]
 
160
  mid.append(tokenizer.event2tokens(["patch_change", 0, 0, i, c, p]))
161
  mid_seq = mid
162
  mid = np.asarray(mid, dtype=np.int64)
163
+ if len(instruments) > 0:
164
  disable_patch_change = True
165
  disable_channels = [i for i in range(16) if i not in patches]
166
  elif mid is not None:
 
173
  draw_event(token_seq)
174
  generator = generate(mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
175
  disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
176
+ disable_channels=disable_channels)
177
  for token_seq in generator:
178
  mid_seq.append(token_seq)
179
  draw_event(token_seq)
 
204
  parser = argparse.ArgumentParser()
205
  parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
206
  parser.add_argument("--port", type=int, default=7860, help="gradio server port")
207
+ parser.add_argument("--max-gen", type=int, default=256, help="max")
 
 
 
208
  opt = parser.parse_args()
209
+ soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2")
210
+ model_base_path = hf_hub_download(repo_id="skytnt/midi-model", filename="onnx/model_base.onnx")
211
+ model_token_path = hf_hub_download(repo_id="skytnt/midi-model", filename="onnx/model_token.onnx")
212
+
213
  tokenizer = MIDITokenizer()
214
+ providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
215
+ model_base = rt.InferenceSession(model_base_path, providers=providers)
216
+ model_token = rt.InferenceSession(model_token_path, providers=providers)
 
 
217
 
218
  app = gr.Blocks()
219
  with app:
 
223
  "Demo for [SkyTNT/midi-model](https://github.com/SkyTNT/midi-model)\n\n"
224
  "[Open In Colab]"
225
  "(https://colab.research.google.com/github/SkyTNT/midi-model/blob/main/demo.ipynb)"
226
+ " for faster running and longer generation"
227
+ )
228
 
229
  tab_select = gr.Variable(value=0)
230
  with gr.Tabs():
231
  with gr.TabItem("instrument prompt") as tab1:
232
  input_instruments = gr.Dropdown(label="instruments (auto if empty)", choices=list(patch2number.keys()),
233
+ multiselect=True, max_choices=15, type="value")
234
  input_drum_kit = gr.Dropdown(label="drum kit", choices=list(drum_kits2number.keys()), type="value",
235
  value="None")
236
+ example1 = gr.Examples([
237
+ [[], "None"],
238
+ [["Acoustic Grand"], "None"],
239
+ [["Acoustic Grand", "Violin", "Viola", "Cello", "Contrabass", "Timpani"], "Orchestra"],
240
+ [["Acoustic Guitar(nylon)", "Acoustic Guitar(steel)", "Electric Guitar(jazz)",
241
+ "Electric Guitar(clean)", "Electric Guitar(muted)", "Overdriven Guitar", "Distortion Guitar",
242
+ "Electric Bass(finger)"], "Standard"],
243
+ [["Acoustic Grand", "String Ensemble 1", "Trombone", "Tuba", "Muted Trumpet", "French Horn", "Oboe",
244
+ "English Horn", "Bassoon", "Clarinet"], "Orchestra"]
245
+
246
+ ], [input_instruments, input_drum_kit])
247
  with gr.TabItem("midi prompt") as tab2:
248
  input_midi = gr.File(label="input midi", file_types=[".midi", ".mid"], type="binary")
249
  input_midi_events = gr.Slider(label="use first n midi events as prompt", minimum=1, maximum=512,
250
  step=1,
251
  value=128)
252
+ example2 = gr.Examples([[file, 128] for file in glob.glob("example/*.mid")],
253
+ [input_midi, input_midi_events])
254
 
255
  tab1.select(lambda: 0, None, tab_select, queue=False)
256
  tab2.select(lambda: 1, None, tab_select, queue=False)
257
  input_gen_events = gr.Slider(label="generate n midi events", minimum=1, maximum=opt.max_gen,
258
  step=1, value=opt.max_gen)
259
  input_temp = gr.Slider(label="temperature", minimum=0.1, maximum=1.2, step=0.01, value=1)
260
+ input_top_p = gr.Slider(label="top p", minimum=0.1, maximum=1, step=0.01, value=0.98)
261
  input_top_k = gr.Slider(label="top k", minimum=1, maximum=50, step=1, value=20)
262
+ input_allow_cc = gr.Checkbox(label="allow midi cc event", value=True)
 
263
  run_btn = gr.Button("generate", variant="primary")
264
  stop_btn = gr.Button("stop")
265
  output_midi_seq = gr.Variable()
266
  output_midi_img = gr.Image(label="output image")
267
  output_midi = gr.File(label="output midi", file_types=[".mid"])
268
+ output_audio = gr.Audio(label="output audio", format="wav")
269
  run_event = run_btn.click(run, [tab_select, input_instruments, input_drum_kit, input_midi, input_midi_events,
270
  input_gen_events, input_temp, input_top_p, input_top_k,
271
+ input_allow_cc],
272
  [output_midi_seq, output_midi_img, output_midi, output_audio])
273
  stop_btn.click(cancel_run, output_midi_seq, [output_midi, output_audio], cancels=run_event, queue=False)
274
+ app.queue(2).launch(server_port=opt.port, share=opt.share, inbrowser=True)
example/Bach--Fugue-in-D-Minor.mid ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1398121eb86a33e73f90ec84be71dac6abc0ddf11372ea7cdd9e01586938a56b
3
+ size 7720
example/Beethoven--Symphony-No5-in-C-Minor-Fate-Opus-67.mid ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:28ff6fdcd644e781d36411bf40ab7a1f4849adddbcd1040eaec22751c5ca99d2
3
+ size 87090
example/Chopin--Nocturne No. 9 in B Major, Opus 32 No.1, Andante Sostenuto.mid ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3a236e647ad9f5d0af680d3ca19d3b60f334c4bde6b4f86310f63405245c476e
3
+ size 13484
example/Mozart--Requiem, No.1..mid ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aa49bf4633401e16777fe47f6f53a494c2166f5101af6dafc60114932a59b9bd
3
+ size 14695
example/castle_in_the_sky.mid ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fa14aec6f1be15c4fddd0decc6d9152204f160d4e07e05d8d1dc9f209c309ff7
3
+ size 7957
example/eva-残酷な天使のテーゼ.mid ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e513487543d7e27ec5dc30f027302d2a3b5a3aaf9af554def1e5cd6a7a8d355a
3
+ size 17671
midi_model.py DELETED
@@ -1,123 +0,0 @@
1
- import numpy as np
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- import tqdm
6
- from transformers import LlamaModel, LlamaConfig
7
- from transformers.modeling_utils import ModuleUtilsMixin
8
-
9
- from midi_tokenizer import MIDITokenizer
10
-
11
-
12
- class MIDIModel(nn.Module, ModuleUtilsMixin):
13
- def __init__(self, tokenizer: MIDITokenizer, n_layer=12, n_head=16, n_embd=1024, n_inner=4096, flash=False,
14
- *args, **kwargs):
15
- super(MIDIModel, self).__init__()
16
- self.tokenizer = tokenizer
17
- self.net = LlamaModel(LlamaConfig(vocab_size=tokenizer.vocab_size,
18
- hidden_size=n_embd, num_attention_heads=n_head,
19
- num_hidden_layers=n_layer, intermediate_size=n_inner,
20
- pad_token_id=tokenizer.pad_id, max_position_embeddings=4096))
21
- self.net_token = LlamaModel(LlamaConfig(vocab_size=tokenizer.vocab_size,
22
- hidden_size=n_embd, num_attention_heads=n_head // 4,
23
- num_hidden_layers=n_layer // 4, intermediate_size=n_inner // 4,
24
- pad_token_id=tokenizer.pad_id, max_position_embeddings=4096))
25
- if flash:
26
- self.net = self.net.to_bettertransformer()
27
- self.net_token = self.net_token.to_bettertransformer()
28
- self.lm_head = nn.Linear(n_embd, tokenizer.vocab_size, bias=False)
29
-
30
- def forward_token(self, hidden_state, x=None):
31
- """
32
-
33
- :param hidden_state: (batch_size, n_embd)
34
- :param x: (batch_size, token_sequence_length)
35
- :return: (batch_size, 1 + token_sequence_length, vocab_size)
36
- """
37
- hidden_state = hidden_state.unsqueeze(1) # (batch_size, 1, n_embd)
38
- if x is not None:
39
- x = self.net_token.embed_tokens(x)
40
- hidden_state = torch.cat([hidden_state, x], dim=1)
41
- hidden_state = self.net_token.forward(inputs_embeds=hidden_state).last_hidden_state
42
- return self.lm_head(hidden_state)
43
-
44
- def forward(self, x):
45
- """
46
- :param x: (batch_size, time_sequence_length, token_sequence_length)
47
- :return: hidden (batch_size, time_sequence_length, n_embd)
48
- """
49
-
50
- # merge token sequence
51
- x = self.net.embed_tokens(x)
52
- x = x.sum(dim=-2)
53
- x = self.net.forward(inputs_embeds=x)
54
- return x.last_hidden_state
55
-
56
- def sample_top_p_k(self, probs, p, k):
57
- probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
58
- probs_sum = torch.cumsum(probs_sort, dim=-1)
59
- mask = probs_sum - probs_sort > p
60
- probs_sort[mask] = 0.0
61
- mask = torch.zeros(probs_sort.shape[-1], device=probs_sort.device)
62
- mask[:k] = 1
63
- probs_sort = probs_sort * mask
64
- probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
65
- shape = probs_sort.shape
66
- next_token = torch.multinomial(probs_sort.reshape(-1, shape[-1]), num_samples=1).reshape(*shape[:-1], 1)
67
- next_token = torch.gather(probs_idx, -1, next_token).reshape(*shape[:-1])
68
- return next_token
69
-
70
- @torch.inference_mode()
71
- def generate(self, prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20, amp=True):
72
- tokenizer = self.tokenizer
73
- max_token_seq = tokenizer.max_token_seq
74
- if prompt is None:
75
- input_tensor = torch.full((1, max_token_seq), tokenizer.pad_id, dtype=torch.long, device=self.device)
76
- input_tensor[0, 0] = tokenizer.bos_id # bos
77
- else:
78
- prompt = prompt[:, :max_token_seq]
79
- if prompt.shape[-1] < max_token_seq:
80
- prompt = np.pad(prompt, ((0, 0), (0, max_token_seq - prompt.shape[-1])),
81
- mode="constant", constant_values=tokenizer.pad_id)
82
- input_tensor = torch.from_numpy(prompt).to(dtype=torch.long, device=self.device)
83
- input_tensor = input_tensor.unsqueeze(0)
84
- cur_len = input_tensor.shape[1]
85
- bar = tqdm.tqdm(desc="generating", total=max_len - cur_len)
86
- with bar, torch.cuda.amp.autocast(enabled=amp):
87
- while cur_len < max_len:
88
- end = False
89
- hidden = self.forward(input_tensor)[0, -1].unsqueeze(0)
90
- next_token_seq = None
91
- event_name = ""
92
- for i in range(max_token_seq):
93
- mask = torch.zeros(tokenizer.vocab_size, dtype=torch.int64, device=self.device)
94
- if i == 0:
95
- mask[list(tokenizer.event_ids.values()) + [tokenizer.eos_id]] = 1
96
- else:
97
- param_name = tokenizer.events[event_name][i - 1]
98
- mask[tokenizer.parameter_ids[param_name]] = 1
99
-
100
- logits = self.forward_token(hidden, next_token_seq)[:, -1:]
101
- scores = torch.softmax(logits / temp, dim=-1) * mask
102
- sample = self.sample_top_p_k(scores, top_p, top_k)
103
- if i == 0:
104
- next_token_seq = sample
105
- eid = sample.item()
106
- if eid == tokenizer.eos_id:
107
- end = True
108
- break
109
- event_name = tokenizer.id_events[eid]
110
- else:
111
- next_token_seq = torch.cat([next_token_seq, sample], dim=1)
112
- if len(tokenizer.events[event_name]) == i:
113
- break
114
- if next_token_seq.shape[1] < max_token_seq:
115
- next_token_seq = F.pad(next_token_seq, (0, max_token_seq - next_token_seq.shape[1]),
116
- "constant", value=tokenizer.pad_id)
117
- next_token_seq = next_token_seq.unsqueeze(1)
118
- input_tensor = torch.cat([input_tensor, next_token_seq], dim=1)
119
- cur_len += 1
120
- bar.update(1)
121
- if end:
122
- break
123
- return input_tensor[0].cpu().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,6 +1,5 @@
1
  Pillow
2
  numpy
3
- torch
4
- transformers
5
  gradio
6
  pyfluidsynth
 
1
  Pillow
2
  numpy
3
+ onnxruntime-gpu
 
4
  gradio
5
  pyfluidsynth