skytnt commited on
Commit
743aa2c
1 Parent(s): 0b1d3cf

update sdk

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +14 -14
  3. javascript/app.js +81 -14
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🎼🎶
4
  colorFrom: red
5
  colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 3.41.1
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
4
  colorFrom: red
5
  colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 4.36.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
app.py CHANGED
@@ -121,6 +121,8 @@ def run(model_name, tab, instruments, drum_kit, mid, midi_events, gen_events, te
121
  i = 0
122
  mid = [[tokenizer.bos_id] + [tokenizer.pad_id] * (tokenizer.max_token_seq - 1)]
123
  patches = {}
 
 
124
  for instr in instruments:
125
  patches[i] = patch2number[instr]
126
  i = (i + 1) if i != 8 else 10
@@ -187,20 +189,17 @@ def load_javascript(dir="javascript"):
187
 
188
  gr.routes.templates.TemplateResponse = template_response
189
 
 
 
190
 
191
- class JSMsgReceiver(gr.HTML):
192
 
193
- def __init__(self, **kwargs):
194
- super().__init__(elem_id="msg_receiver", visible=False, **kwargs)
 
 
195
 
196
- def postprocess(self, y):
197
- if y:
198
- y = f"<p>{json.dumps(y)}</p>"
199
- return super().postprocess(y)
200
-
201
- def get_block_name(self) -> str:
202
- return "html"
203
 
 
204
 
205
  number2drum_kits = {-1: "None", 0: "Standard", 8: "Room", 16: "Power", 24: "Electric", 25: "TR-808", 32: "Jazz",
206
  40: "Blush", 48: "Orchestra"}
@@ -216,7 +215,8 @@ if __name__ == "__main__":
216
  soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2")
217
  models_info = {"generic pretrain model": ["skytnt/midi-model", ""],
218
  "j-pop finetune model": ["skytnt/midi-model-ft", "jpop/"],
219
- "touhou finetune model": ["skytnt/midi-model-ft", "touhou/"]}
 
220
  models = {}
221
  tokenizer = MIDITokenizer()
222
  providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
@@ -238,10 +238,10 @@ if __name__ == "__main__":
238
  "(https://colab.research.google.com/github/SkyTNT/midi-model/blob/main/demo.ipynb)"
239
  " for faster running and longer generation"
240
  )
241
- js_msg = JSMsgReceiver()
242
  input_model = gr.Dropdown(label="select model", choices=list(models.keys()),
243
  type="value", value=list(models.keys())[0])
244
- tab_select = gr.Variable(value=0)
245
  with gr.Tabs():
246
  with gr.TabItem("instrument prompt") as tab1:
247
  input_instruments = gr.Dropdown(label="instruments (auto if empty)", choices=list(patch2number.keys()),
@@ -279,7 +279,7 @@ if __name__ == "__main__":
279
  example3 = gr.Examples([[1, 0.98, 12], [1.2, 0.95, 8]], [input_temp, input_top_p, input_top_k])
280
  run_btn = gr.Button("generate", variant="primary")
281
  stop_btn = gr.Button("stop and output")
282
- output_midi_seq = gr.Variable()
283
  output_midi_visualizer = gr.HTML(elem_id="midi_visualizer_container")
284
  output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
285
  output_midi = gr.File(label="output midi", file_types=[".mid"])
 
121
  i = 0
122
  mid = [[tokenizer.bos_id] + [tokenizer.pad_id] * (tokenizer.max_token_seq - 1)]
123
  patches = {}
124
+ if instruments is None:
125
+ instruments = []
126
  for instr in instruments:
127
  patches[i] = patch2number[instr]
128
  i = (i + 1) if i != 8 else 10
 
189
 
190
  gr.routes.templates.TemplateResponse = template_response
191
 
192
+ # JSMsgReceiver
193
+ HTML_postprocess_ori = gr.HTML.postprocess
194
 
 
195
 
196
+ def JSMsgReceiver_postprocess(self, y):
197
+ if self.elem_id == "msg_receiver" and y:
198
+ y = f"<p>{json.dumps(y)}</p>"
199
+ return HTML_postprocess_ori(self, y)
200
 
 
 
 
 
 
 
 
201
 
202
+ gr.HTML.postprocess = JSMsgReceiver_postprocess
203
 
204
  number2drum_kits = {-1: "None", 0: "Standard", 8: "Room", 16: "Power", 24: "Electric", 25: "TR-808", 32: "Jazz",
205
  40: "Blush", 48: "Orchestra"}
 
215
  soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2")
216
  models_info = {"generic pretrain model": ["skytnt/midi-model", ""],
217
  "j-pop finetune model": ["skytnt/midi-model-ft", "jpop/"],
218
+ "touhou finetune model": ["skytnt/midi-model-ft", "touhou/"],
219
+ }
220
  models = {}
221
  tokenizer = MIDITokenizer()
222
  providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
 
238
  "(https://colab.research.google.com/github/SkyTNT/midi-model/blob/main/demo.ipynb)"
239
  " for faster running and longer generation"
240
  )
241
+ js_msg = gr.HTML(elem_id="msg_receiver", visible=False)
242
  input_model = gr.Dropdown(label="select model", choices=list(models.keys()),
243
  type="value", value=list(models.keys())[0])
244
+ tab_select = gr.State(value=0)
245
  with gr.Tabs():
246
  with gr.TabItem("instrument prompt") as tab1:
247
  input_instruments = gr.Dropdown(label="instruments (auto if empty)", choices=list(patch2number.keys()),
 
279
  example3 = gr.Examples([[1, 0.98, 12], [1.2, 0.95, 8]], [input_temp, input_top_p, input_top_k])
280
  run_btn = gr.Button("generate", variant="primary")
281
  stop_btn = gr.Button("stop and output")
282
+ output_midi_seq = gr.State()
283
  output_midi_visualizer = gr.HTML(elem_id="midi_visualizer_container")
284
  output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
285
  output_midi = gr.File(label="output midi", file_types=[".mid"])
javascript/app.js CHANGED
@@ -1,3 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  function gradioApp() {
2
  const elems = document.getElementsByTagName('gradio-app')
3
  const gradioShadowRoot = elems.length == 0 ? null : elems[0].shadowRoot
@@ -98,6 +141,7 @@ class MidiVisualizer extends HTMLElement{
98
  this.timePreBeat = 16
99
  this.svgWidth = 0;
100
  this.t1 = 0;
 
101
  this.playTime = 0
102
  this.playTimeMs = 0
103
  this.colorMap = new Map();
@@ -137,6 +181,7 @@ class MidiVisualizer extends HTMLElement{
137
  this.t1 = 0
138
  this.colorMap.clear()
139
  this.setPlayTime(0);
 
140
  this.playTimeMs = 0
141
  this.svgWidth = 0
142
  this.svg.innerHTML = ''
@@ -215,6 +260,9 @@ class MidiVisualizer extends HTMLElement{
215
  tempo = (60 / midiEvent[3]) * 10 ** 3
216
  this.midiTimes.push({ms:ms, t: t, tempo: tempo})
217
  }
 
 
 
218
  lastT = t
219
  })
220
  }
@@ -277,16 +325,10 @@ class MidiVisualizer extends HTMLElement{
277
 
278
  play(){
279
  this.playing = true;
280
- this.timer = setInterval(() => {
281
- this.setPlayTimeMs(this.playTimeMs + 10)
282
- }, 10);
283
  }
284
 
285
  pause(){
286
- if(!!this.timer)
287
- clearInterval(this.timer)
288
  this.removeActiveNotes(this.activeNotes)
289
- this.timer = null;
290
  this.playing = false;
291
  }
292
 
@@ -299,9 +341,25 @@ class MidiVisualizer extends HTMLElement{
299
  audio.addEventListener("pause", (event)=>{
300
  this.pause()
301
  })
302
- audio.addEventListener("timeupdate", (event)=>{
303
- this.setPlayTimeMs(event.target.currentTime*10**3)
304
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
  }
306
  }
307
 
@@ -309,7 +367,8 @@ customElements.define('midi-visualizer', MidiVisualizer);
309
 
310
  (()=>{
311
  let midi_visualizer_container_inited = null
312
- let midi_audio_inited = null;
 
313
  let midi_visualizer = document.createElement('midi-visualizer')
314
  onUiUpdate((m)=>{
315
  let app = gradioApp()
@@ -318,10 +377,18 @@ customElements.define('midi-visualizer', MidiVisualizer);
318
  midi_visualizer_container.appendChild(midi_visualizer)
319
  midi_visualizer_container_inited = midi_visualizer_container;
320
  }
321
- let midi_audio = app.querySelector("#midi_audio > audio");
322
- if(!!midi_audio && midi_audio_inited!==midi_audio){
323
- midi_visualizer.bindAudioPlayer(midi_audio)
324
- midi_audio_inited = midi_audio
 
 
 
 
 
 
 
 
325
  }
326
  })
327
 
 
1
+ /**
2
+ * 自动绕过 shadowRoot 的 querySelector
3
+ * @param {string} selector - 要查询的 CSS 选择器
4
+ * @returns {Element|null} - 匹配的元素或 null 如果未找到
5
+ */
6
+ function deepQuerySelector(selector) {
7
+ /**
8
+ * 在指定的根元素或文档对象下深度查询元素
9
+ * @param {Element|Document} root - 要开始搜索的根元素或文档对象
10
+ * @param {string} selector - 要查询的 CSS 选择器
11
+ * @returns {Element|null} - 匹配的元素或 null 如果未找到
12
+ */
13
+ function deepSearch(root, selector) {
14
+ // 在当前根元素下查找
15
+ let element = root.querySelector(selector);
16
+ if (element) {
17
+ return element;
18
+ }
19
+
20
+ // 如果未找到,递归检查 shadow DOM
21
+ const shadowHosts = root.querySelectorAll('*');
22
+
23
+ for (let i = 0; i < shadowHosts.length; i++) {
24
+ const host = shadowHosts[i];
25
+
26
+ // 检查当前元素是否有 shadowRoot
27
+ if (host.shadowRoot) {
28
+ element = deepSearch(host.shadowRoot, selector);
29
+ if (element) {
30
+ return element;
31
+ }
32
+ }
33
+ }
34
+ // 未找到元素
35
+ return null;
36
+ }
37
+
38
+ return deepSearch(this, selector);
39
+ }
40
+
41
+ Element.prototype.deepQuerySelector = deepQuerySelector;
42
+ Document.prototype.deepQuerySelector = deepQuerySelector;
43
+
44
  function gradioApp() {
45
  const elems = document.getElementsByTagName('gradio-app')
46
  const gradioShadowRoot = elems.length == 0 ? null : elems[0].shadowRoot
 
141
  this.timePreBeat = 16
142
  this.svgWidth = 0;
143
  this.t1 = 0;
144
+ this.totalTimeMs = 0
145
  this.playTime = 0
146
  this.playTimeMs = 0
147
  this.colorMap = new Map();
 
181
  this.t1 = 0
182
  this.colorMap.clear()
183
  this.setPlayTime(0);
184
+ this.totalTimeMs = 0;
185
  this.playTimeMs = 0
186
  this.svgWidth = 0
187
  this.svg.innerHTML = ''
 
260
  tempo = (60 / midiEvent[3]) * 10 ** 3
261
  this.midiTimes.push({ms:ms, t: t, tempo: tempo})
262
  }
263
+ if(midiEvent[0]==="note"){
264
+ this.totalTimeMs = ms + (midiEvent[3]/ this.timePreBeat)*tempo
265
+ }
266
  lastT = t
267
  })
268
  }
 
325
 
326
  play(){
327
  this.playing = true;
 
 
 
328
  }
329
 
330
  pause(){
 
 
331
  this.removeActiveNotes(this.activeNotes)
 
332
  this.playing = false;
333
  }
334
 
 
341
  audio.addEventListener("pause", (event)=>{
342
  this.pause()
343
  })
344
+ }
345
+
346
+ bindWaveformCursor(cursor){
347
+ let self = this;
348
+ const callback = function(mutationsList, observer) {
349
+ for(let mutation of mutationsList) {
350
+ if (mutation.type === 'attributes' && mutation.attributeName === 'style') {
351
+ let progress = parseFloat(mutation.target.style.left.slice(0,-1))*0.01;
352
+ if(!isNaN(progress)){
353
+ self.setPlayTimeMs(progress*self.totalTimeMs);
354
+ }
355
+ }
356
+ }
357
+ };
358
+ const observer = new MutationObserver(callback);
359
+ observer.observe(cursor, {
360
+ attributes: true,
361
+ attributeFilter: ['style']
362
+ });
363
  }
364
  }
365
 
 
367
 
368
  (()=>{
369
  let midi_visualizer_container_inited = null
370
+ let midi_audio_audio_inited = null;
371
+ let midi_audio_cursor_inited = null;
372
  let midi_visualizer = document.createElement('midi-visualizer')
373
  onUiUpdate((m)=>{
374
  let app = gradioApp()
 
377
  midi_visualizer_container.appendChild(midi_visualizer)
378
  midi_visualizer_container_inited = midi_visualizer_container;
379
  }
380
+ let midi_audio = app.querySelector("#midi_audio");
381
+ if (!!midi_audio){
382
+ let midi_audio_cursor = midi_audio.deepQuerySelector(".cursor");
383
+ if(!!midi_audio_cursor && midi_audio_cursor_inited!==midi_audio_cursor){
384
+ midi_visualizer.bindWaveformCursor(midi_audio_cursor)
385
+ midi_audio_cursor_inited = midi_audio_cursor
386
+ }
387
+ let midi_audio_audio = midi_audio.deepQuerySelector("audio");
388
+ if(!!midi_audio_audio && midi_audio_audio_inited!==midi_audio_audio){
389
+ midi_visualizer.bindAudioPlayer(midi_audio_audio)
390
+ midi_audio_audio_inited = midi_audio_audio
391
+ }
392
  }
393
  })
394