Spaces:
Runtime error
Runtime error
use javascript
Browse files- app.py +58 -55
- javascript/app.js +252 -0
- midi_tokenizer.py +13 -0
app.py
CHANGED
@@ -2,12 +2,11 @@ import argparse
|
|
2 |
import glob
|
3 |
import os.path
|
4 |
|
5 |
-
import PIL
|
6 |
-
import PIL.ImageColor
|
7 |
import gradio as gr
|
8 |
import numpy as np
|
9 |
import onnxruntime as rt
|
10 |
import tqdm
|
|
|
11 |
from huggingface_hub import hf_hub_download
|
12 |
|
13 |
import MIDI
|
@@ -107,44 +106,14 @@ def generate(model, prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
|
|
107 |
break
|
108 |
|
109 |
|
110 |
-
def
|
111 |
-
|
112 |
-
max_len = int(gen_events)
|
113 |
-
img_len = 1024
|
114 |
-
img = np.full((128 * 2, img_len, 3), 255, dtype=np.uint8)
|
115 |
-
state = {"t1": 0, "t": 0, "cur_pos": 0}
|
116 |
-
colors = ['navy', 'blue', 'deepskyblue', 'teal', 'green', 'lightgreen', 'lime', 'orange',
|
117 |
-
'brown', 'grey', 'red', 'pink', 'aqua', 'orchid', 'bisque', 'coral']
|
118 |
-
colors = [PIL.ImageColor.getrgb(color) for color in colors]
|
119 |
|
120 |
-
def draw_event(tokens):
|
121 |
-
if tokens[0] in tokenizer.id_events:
|
122 |
-
name = tokenizer.id_events[tokens[0]]
|
123 |
-
if len(tokens) <= len(tokenizer.events[name]):
|
124 |
-
return
|
125 |
-
params = tokens[1:]
|
126 |
-
params = [params[i] - tokenizer.parameter_ids[p][0] for i, p in enumerate(tokenizer.events[name])]
|
127 |
-
if not all([0 <= params[i] < tokenizer.event_parameters[p] for i, p in enumerate(tokenizer.events[name])]):
|
128 |
-
return
|
129 |
-
event = [name] + params
|
130 |
-
state["t1"] += event[1]
|
131 |
-
t = state["t1"] * 16 + event[2]
|
132 |
-
state["t"] = t
|
133 |
-
if name == "note":
|
134 |
-
tr, d, c, p = event[3:7]
|
135 |
-
shift = t + d - (state["cur_pos"] + img_len)
|
136 |
-
if shift > 0:
|
137 |
-
img[:, :-shift] = img[:, shift:]
|
138 |
-
img[:, -shift:] = 255
|
139 |
-
state["cur_pos"] += shift
|
140 |
-
t = t - state["cur_pos"]
|
141 |
-
img[p * 2:(p + 1) * 2, t: t + d] = colors[c]
|
142 |
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
return PIL.Image.fromarray(np.flip(img_new, 0))
|
148 |
|
149 |
disable_patch_change = False
|
150 |
disable_channels = None
|
@@ -170,25 +139,25 @@ def run(model_name, tab, instruments, drum_kit, mid, midi_events, gen_events, te
|
|
170 |
mid = mid[:int(midi_events)]
|
171 |
max_len += len(mid)
|
172 |
for token_seq in mid:
|
173 |
-
mid_seq.append(token_seq)
|
174 |
-
|
|
|
|
|
|
|
175 |
model = models[model_name]
|
176 |
generator = generate(model, mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
|
177 |
disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
|
178 |
disable_channels=disable_channels)
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
except Exception as e:
|
185 |
-
print(e)
|
186 |
-
|
187 |
mid = tokenizer.detokenize(mid_seq)
|
188 |
with open(f"output.mid", 'wb') as f:
|
189 |
f.write(MIDI.score2midi(mid))
|
190 |
-
audio = synthesis(MIDI.score2opus(mid), soundfont_path)
|
191 |
-
yield mid_seq,
|
192 |
|
193 |
|
194 |
def cancel_run(mid_seq):
|
@@ -198,7 +167,39 @@ def cancel_run(mid_seq):
|
|
198 |
with open(f"output.mid", 'wb') as f:
|
199 |
f.write(MIDI.score2midi(mid))
|
200 |
audio = synthesis(MIDI.score2opus(mid), soundfont_path)
|
201 |
-
return "output.mid", (44100, audio)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
|
203 |
|
204 |
number2drum_kits = {-1: "None", 0: "Standard", 8: "Room", 16: "Power", 24: "Electric", 25: "TR-808", 32: "Jazz",
|
@@ -226,6 +227,7 @@ if __name__ == "__main__":
|
|
226 |
model_token = rt.InferenceSession(model_token_path, providers=providers)
|
227 |
models[name] = [model_base, model_token]
|
228 |
|
|
|
229 |
app = gr.Blocks()
|
230 |
with app:
|
231 |
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Midi Composer</h1>")
|
@@ -236,6 +238,7 @@ if __name__ == "__main__":
|
|
236 |
"(https://colab.research.google.com/github/SkyTNT/midi-model/blob/main/demo.ipynb)"
|
237 |
" for faster running and longer generation"
|
238 |
)
|
|
|
239 |
input_model = gr.Dropdown(label="select model", choices=list(models.keys()),
|
240 |
type="value", value=list(models.keys())[0])
|
241 |
tab_select = gr.Variable(value=0)
|
@@ -277,12 +280,12 @@ if __name__ == "__main__":
|
|
277 |
run_btn = gr.Button("generate", variant="primary")
|
278 |
stop_btn = gr.Button("stop and output")
|
279 |
output_midi_seq = gr.Variable()
|
280 |
-
|
281 |
-
output_midi = gr.File(label="output midi", file_types=[".mid"])
|
282 |
output_audio = gr.Audio(label="output audio", format="mp3")
|
|
|
283 |
run_event = run_btn.click(run, [input_model, tab_select, input_instruments, input_drum_kit, input_midi,
|
284 |
input_midi_events, input_gen_events, input_temp, input_top_p, input_top_k,
|
285 |
input_allow_cc],
|
286 |
-
[output_midi_seq,
|
287 |
-
stop_btn.click(cancel_run, output_midi_seq, [output_midi, output_audio], cancels=run_event, queue=False)
|
288 |
app.queue(2).launch(server_port=opt.port, share=opt.share, inbrowser=True)
|
|
|
2 |
import glob
|
3 |
import os.path
|
4 |
|
|
|
|
|
5 |
import gradio as gr
|
6 |
import numpy as np
|
7 |
import onnxruntime as rt
|
8 |
import tqdm
|
9 |
+
import json
|
10 |
from huggingface_hub import hf_hub_download
|
11 |
|
12 |
import MIDI
|
|
|
106 |
break
|
107 |
|
108 |
|
109 |
+
def create_msg(name, data):
|
110 |
+
return {"name": name, "data": data}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
|
113 |
+
def run(model_name, tab, instruments, drum_kit, mid, midi_events, gen_events, temp, top_p, top_k, allow_cc):
|
114 |
+
mid_seq = []
|
115 |
+
gen_events = int(gen_events)
|
116 |
+
max_len = gen_events
|
|
|
117 |
|
118 |
disable_patch_change = False
|
119 |
disable_channels = None
|
|
|
139 |
mid = mid[:int(midi_events)]
|
140 |
max_len += len(mid)
|
141 |
for token_seq in mid:
|
142 |
+
mid_seq.append(token_seq.tolist())
|
143 |
+
init_msgs = [create_msg("visualizer_clear", None)]
|
144 |
+
for tokens in mid_seq:
|
145 |
+
init_msgs.append(create_msg("visualizer_append", tokenizer.tokens2event(tokens)))
|
146 |
+
yield mid_seq, None, None, init_msgs
|
147 |
model = models[model_name]
|
148 |
generator = generate(model, mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
|
149 |
disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
|
150 |
disable_channels=disable_channels)
|
151 |
+
for i, token_seq in enumerate(generator):
|
152 |
+
token_seq = token_seq.tolist()
|
153 |
+
mid_seq.append(token_seq)
|
154 |
+
event = tokenizer.tokens2event(token_seq)
|
155 |
+
yield mid_seq, None, None, [create_msg("visualizer_append", event), create_msg("progress", [i + 1, gen_events])]
|
|
|
|
|
|
|
156 |
mid = tokenizer.detokenize(mid_seq)
|
157 |
with open(f"output.mid", 'wb') as f:
|
158 |
f.write(MIDI.score2midi(mid))
|
159 |
+
audio = synthesis(MIDI.score2opus(mid), opt.soundfont_path)
|
160 |
+
yield mid_seq, "output.mid", (44100, audio), [create_msg("visualizer_end", None)]
|
161 |
|
162 |
|
163 |
def cancel_run(mid_seq):
|
|
|
167 |
with open(f"output.mid", 'wb') as f:
|
168 |
f.write(MIDI.score2midi(mid))
|
169 |
audio = synthesis(MIDI.score2opus(mid), soundfont_path)
|
170 |
+
return "output.mid", (44100, audio), [create_msg("visualizer_end", None)]
|
171 |
+
|
172 |
+
|
173 |
+
def load_javascript(dir="javascript"):
|
174 |
+
scripts_list = glob.glob(f"{dir}/*.js")
|
175 |
+
javascript = ""
|
176 |
+
for path in scripts_list:
|
177 |
+
with open(path, "r", encoding="utf8") as jsfile:
|
178 |
+
javascript += f"\n<!-- {path} --><script>{jsfile.read()}</script>"
|
179 |
+
template_response_ori = gr.routes.templates.TemplateResponse
|
180 |
+
|
181 |
+
def template_response(*args, **kwargs):
|
182 |
+
res = template_response_ori(*args, **kwargs)
|
183 |
+
res.body = res.body.replace(
|
184 |
+
b'</head>', f'{javascript}</head>'.encode("utf8"))
|
185 |
+
res.init_headers()
|
186 |
+
return res
|
187 |
+
|
188 |
+
gr.routes.templates.TemplateResponse = template_response
|
189 |
+
|
190 |
+
|
191 |
+
class JSMsgReceiver(gr.HTML):
|
192 |
+
|
193 |
+
def __init__(self, **kwargs):
|
194 |
+
super().__init__(elem_id="msg_receiver", visible=False, **kwargs)
|
195 |
+
|
196 |
+
def postprocess(self, y):
|
197 |
+
if y:
|
198 |
+
y = f"<p>{json.dumps(y)}</p>"
|
199 |
+
return super().postprocess(y)
|
200 |
+
|
201 |
+
def get_block_name(self) -> str:
|
202 |
+
return "html"
|
203 |
|
204 |
|
205 |
number2drum_kits = {-1: "None", 0: "Standard", 8: "Room", 16: "Power", 24: "Electric", 25: "TR-808", 32: "Jazz",
|
|
|
227 |
model_token = rt.InferenceSession(model_token_path, providers=providers)
|
228 |
models[name] = [model_base, model_token]
|
229 |
|
230 |
+
load_javascript()
|
231 |
app = gr.Blocks()
|
232 |
with app:
|
233 |
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Midi Composer</h1>")
|
|
|
238 |
"(https://colab.research.google.com/github/SkyTNT/midi-model/blob/main/demo.ipynb)"
|
239 |
" for faster running and longer generation"
|
240 |
)
|
241 |
+
js_msg = JSMsgReceiver()
|
242 |
input_model = gr.Dropdown(label="select model", choices=list(models.keys()),
|
243 |
type="value", value=list(models.keys())[0])
|
244 |
tab_select = gr.Variable(value=0)
|
|
|
280 |
run_btn = gr.Button("generate", variant="primary")
|
281 |
stop_btn = gr.Button("stop and output")
|
282 |
output_midi_seq = gr.Variable()
|
283 |
+
output_midi_visualizer = gr.HTML(elem_id="midi_visualizer_container")
|
|
|
284 |
output_audio = gr.Audio(label="output audio", format="mp3")
|
285 |
+
output_midi = gr.File(label="output midi", file_types=[".mid"])
|
286 |
run_event = run_btn.click(run, [input_model, tab_select, input_instruments, input_drum_kit, input_midi,
|
287 |
input_midi_events, input_gen_events, input_temp, input_top_p, input_top_k,
|
288 |
input_allow_cc],
|
289 |
+
[output_midi_seq, output_midi, output_audio, js_msg])
|
290 |
+
stop_btn.click(cancel_run, output_midi_seq, [output_midi, output_audio, js_msg], cancels=run_event, queue=False)
|
291 |
app.queue(2).launch(server_port=opt.port, share=opt.share, inbrowser=True)
|
javascript/app.js
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
function gradioApp() {
|
2 |
+
const elems = document.getElementsByTagName('gradio-app')
|
3 |
+
const gradioShadowRoot = elems.length == 0 ? null : elems[0].shadowRoot
|
4 |
+
return !!gradioShadowRoot ? gradioShadowRoot : document;
|
5 |
+
}
|
6 |
+
|
7 |
+
uiUpdateCallbacks = []
|
8 |
+
msgReceiveCallbacks = []
|
9 |
+
|
10 |
+
function onUiUpdate(callback){
|
11 |
+
uiUpdateCallbacks.push(callback)
|
12 |
+
}
|
13 |
+
|
14 |
+
function onMsgReceive(callback){
|
15 |
+
msgReceiveCallbacks.push(callback)
|
16 |
+
}
|
17 |
+
|
18 |
+
function runCallback(x, m){
|
19 |
+
try {
|
20 |
+
x(m)
|
21 |
+
} catch (e) {
|
22 |
+
(console.error || console.log).call(console, e.message, e);
|
23 |
+
}
|
24 |
+
}
|
25 |
+
function executeCallbacks(queue, m) {
|
26 |
+
queue.forEach(function(x){runCallback(x, m)})
|
27 |
+
}
|
28 |
+
|
29 |
+
document.addEventListener("DOMContentLoaded", function() {
|
30 |
+
var mutationObserver = new MutationObserver(function(m){
|
31 |
+
executeCallbacks(uiUpdateCallbacks, m);
|
32 |
+
});
|
33 |
+
mutationObserver.observe( gradioApp(), { childList:true, subtree:true })
|
34 |
+
});
|
35 |
+
|
36 |
+
(()=>{
|
37 |
+
let mse_receiver_inited = null
|
38 |
+
onUiUpdate(()=>{
|
39 |
+
let app = gradioApp()
|
40 |
+
let msg_receiver = app.querySelector("#msg_receiver");
|
41 |
+
if(!!msg_receiver && mse_receiver_inited !== msg_receiver){
|
42 |
+
let mutationObserver = new MutationObserver(function(ms){
|
43 |
+
ms.forEach((m)=>{
|
44 |
+
m.addedNodes.forEach((node)=>{
|
45 |
+
if(node.nodeName === "P"){
|
46 |
+
let obj = JSON.parse(node.innerText);
|
47 |
+
if(obj instanceof Array){
|
48 |
+
obj.forEach((o)=>{executeCallbacks(msgReceiveCallbacks, o);});
|
49 |
+
}else{
|
50 |
+
executeCallbacks(msgReceiveCallbacks, obj);
|
51 |
+
}
|
52 |
+
}
|
53 |
+
})
|
54 |
+
})
|
55 |
+
});
|
56 |
+
mutationObserver.observe( msg_receiver, {childList:true, subtree:true, characterData:true})
|
57 |
+
console.log("receiver init");
|
58 |
+
mse_receiver_inited = msg_receiver;
|
59 |
+
}
|
60 |
+
})
|
61 |
+
})();
|
62 |
+
|
63 |
+
class MidiVisualizer extends HTMLElement{
|
64 |
+
constructor() {
|
65 |
+
super();
|
66 |
+
this.midiEvents = [];
|
67 |
+
this.wrapper = null;
|
68 |
+
this.svg = null;
|
69 |
+
this.timeLine = null;
|
70 |
+
this.config = {
|
71 |
+
noteHeight : 4,
|
72 |
+
beatWidth: 32
|
73 |
+
}
|
74 |
+
this.svgWidth = 0;
|
75 |
+
this.t1 = 0;
|
76 |
+
this.playTime = 0
|
77 |
+
this.colorMap = new Map();
|
78 |
+
this.init();
|
79 |
+
}
|
80 |
+
|
81 |
+
init(){
|
82 |
+
this.innerHTML=''
|
83 |
+
const shadow = this.attachShadow({mode: 'open'});
|
84 |
+
const style = document.createElement("style");
|
85 |
+
const wrapper = document.createElement('div');
|
86 |
+
style.textContent = ".note.active {stroke: black;stroke-width: 0.75;stroke-opacity: 0.75;}";
|
87 |
+
wrapper.style.overflowX= "scroll"
|
88 |
+
const svg = document.createElementNS('http://www.w3.org/2000/svg', 'svg');
|
89 |
+
svg.style.height = `${this.config.noteHeight*128}px`;
|
90 |
+
svg.style.width = `${this.svgWidth}px`;
|
91 |
+
const timeLine = document.createElementNS('http://www.w3.org/2000/svg', 'line');
|
92 |
+
timeLine.style.stroke = "green"
|
93 |
+
timeLine.style.strokeWidth = 2;
|
94 |
+
shadow.appendChild(style)
|
95 |
+
shadow.appendChild(wrapper);
|
96 |
+
wrapper.appendChild(svg);
|
97 |
+
svg.appendChild(timeLine)
|
98 |
+
this.wrapper = wrapper;
|
99 |
+
this.svg = svg;
|
100 |
+
this.timeLine= timeLine;
|
101 |
+
this.setPlayTime(0);
|
102 |
+
}
|
103 |
+
|
104 |
+
setPlayTime(t){
|
105 |
+
this.playTime = t
|
106 |
+
let x = Math.round((t/16)*this.config.beatWidth)
|
107 |
+
this.timeLine.setAttribute('x1', `${x}`);
|
108 |
+
this.timeLine.setAttribute('y1', '0');
|
109 |
+
this.timeLine.setAttribute('x2', `${x}`);
|
110 |
+
this.timeLine.setAttribute('y2', `${this.config.noteHeight*128}`);
|
111 |
+
}
|
112 |
+
|
113 |
+
clearMidiEvents(){
|
114 |
+
this.midiEvents = [];
|
115 |
+
this.svgWidth = 0
|
116 |
+
this.svg.innerHTML = ''
|
117 |
+
this.t1 = 0
|
118 |
+
this.colorMap.clear()
|
119 |
+
this.svg.style.width = `${this.svgWidth}px`;
|
120 |
+
this.setPlayTime(0);
|
121 |
+
this.svg.appendChild(this.timeLine)
|
122 |
+
}
|
123 |
+
|
124 |
+
appendMidiEvent(midiEvent){
|
125 |
+
if(midiEvent instanceof Array && midiEvent.length > 0){
|
126 |
+
this.midiEvents.push(midiEvent);
|
127 |
+
this.t1 += midiEvent[1]
|
128 |
+
let t = this.t1*16 + midiEvent[2]
|
129 |
+
if(midiEvent[0] === "note"){
|
130 |
+
let track = midiEvent[3]
|
131 |
+
let duration = midiEvent[4]
|
132 |
+
let channel = midiEvent[5]
|
133 |
+
let pitch = midiEvent[6]
|
134 |
+
let velocity = midiEvent[7]
|
135 |
+
let x = (t/16)*this.config.beatWidth
|
136 |
+
let y = (127 - pitch)*this.config.noteHeight
|
137 |
+
let w = (duration/16)*this.config.beatWidth
|
138 |
+
let h = this.config.noteHeight
|
139 |
+
this.svgWidth = Math.ceil(Math.max(x + w, this.svgWidth))
|
140 |
+
let color = this.getColor(track, channel)
|
141 |
+
let opacity = Math.min(1, velocity/127 + 0.1).toFixed(2)
|
142 |
+
this.drawNote(x,y,w,h, `rgba(${color[0]}, ${color[1]}, ${color[2]}, ${opacity})`)
|
143 |
+
this.setPlayTime(t);
|
144 |
+
this.wrapper.scrollTo(this.svgWidth - this.wrapper.offsetWidth, 0)
|
145 |
+
}
|
146 |
+
this.svg.style.width = `${this.svgWidth}px`;
|
147 |
+
}
|
148 |
+
|
149 |
+
}
|
150 |
+
|
151 |
+
getColor(track, channel){
|
152 |
+
let key = `${track},${channel}`;
|
153 |
+
let color = this.colorMap.get(key);
|
154 |
+
if(!!color){
|
155 |
+
return color;
|
156 |
+
}
|
157 |
+
color = [Math.round(Math.random()*240) + 10, Math.round(Math.random()*240)+ 10, Math.round(Math.random()*240)+ 10];
|
158 |
+
this.colorMap.set(key, color);
|
159 |
+
return color;
|
160 |
+
}
|
161 |
+
|
162 |
+
drawNote(x, y, w, h, fill) {
|
163 |
+
if (!this.svg) {
|
164 |
+
return null;
|
165 |
+
}
|
166 |
+
const rect = document.createElementNS('http://www.w3.org/2000/svg', 'rect');
|
167 |
+
rect.classList.add('note');
|
168 |
+
rect.setAttribute('fill', fill);
|
169 |
+
// Round values to the nearest integer to avoid partially filled pixels.
|
170 |
+
rect.setAttribute('x', `${Math.round(x)}`);
|
171 |
+
rect.setAttribute('y', `${Math.round(y)}`);
|
172 |
+
rect.setAttribute('width', `${Math.round(w)}`);
|
173 |
+
rect.setAttribute('height', `${Math.round(h)}`);
|
174 |
+
this.svg.appendChild(rect);
|
175 |
+
return rect
|
176 |
+
}
|
177 |
+
}
|
178 |
+
|
179 |
+
customElements.define('midi-visualizer', MidiVisualizer);
|
180 |
+
|
181 |
+
(()=>{
|
182 |
+
let midi_visualizer_container_inited = null
|
183 |
+
let midi_visualizer = document.createElement('midi-visualizer')
|
184 |
+
onUiUpdate((m)=>{
|
185 |
+
let app = gradioApp()
|
186 |
+
let midi_visualizer_container = app.querySelector("#midi_visualizer_container");
|
187 |
+
if(!!midi_visualizer_container && midi_visualizer_container_inited!== midi_visualizer_container){
|
188 |
+
midi_visualizer_container.appendChild(midi_visualizer)
|
189 |
+
midi_visualizer_container_inited = midi_visualizer_container;
|
190 |
+
}
|
191 |
+
})
|
192 |
+
|
193 |
+
function createProgressBar(progressbarContainer){
|
194 |
+
let parentProgressbar = progressbarContainer.parentNode;
|
195 |
+
let divProgress = document.createElement('div');
|
196 |
+
divProgress.className='progressDiv';
|
197 |
+
let rect = progressbarContainer.getBoundingClientRect();
|
198 |
+
divProgress.style.width = rect.width + "px";
|
199 |
+
divProgress.style.background = "#b4c0cc";
|
200 |
+
divProgress.style.borderRadius = "8px";
|
201 |
+
let divInner = document.createElement('div');
|
202 |
+
divInner.className='progress';
|
203 |
+
divInner.style.color = "white";
|
204 |
+
divInner.style.background = "#0060df";
|
205 |
+
divInner.style.textAlign = "right";
|
206 |
+
divInner.style.fontWeight = "bold";
|
207 |
+
divInner.style.borderRadius = "8px";
|
208 |
+
divInner.style.height = "20px";
|
209 |
+
divInner.style.lineHeight = "20px";
|
210 |
+
divInner.style.paddingRight = "8px"
|
211 |
+
divProgress.appendChild(divInner);
|
212 |
+
parentProgressbar.insertBefore(divProgress, progressbarContainer);
|
213 |
+
}
|
214 |
+
|
215 |
+
function removeProgressBar(progressbarContainer){
|
216 |
+
let parentProgressbar = progressbarContainer.parentNode;
|
217 |
+
let divProgress = parentProgressbar.querySelector(".progressDiv");
|
218 |
+
parentProgressbar.removeChild(divProgress);
|
219 |
+
}
|
220 |
+
|
221 |
+
function setProgressBar(progressbarContainer, progress, total){
|
222 |
+
let parentProgressbar = progressbarContainer.parentNode;
|
223 |
+
let divProgress = parentProgressbar.querySelector(".progressDiv");
|
224 |
+
let divInner = parentProgressbar.querySelector(".progress");
|
225 |
+
if(total===0)
|
226 |
+
total = 1;
|
227 |
+
divInner.style.width = `${(progress/total)*100}%`;
|
228 |
+
divInner.textContent = `${progress}/${total}`;
|
229 |
+
}
|
230 |
+
|
231 |
+
onMsgReceive((msg)=>{
|
232 |
+
switch (msg.name) {
|
233 |
+
case "visualizer_clear":
|
234 |
+
midi_visualizer.clearMidiEvents();
|
235 |
+
createProgressBar(midi_visualizer_container_inited)
|
236 |
+
break;
|
237 |
+
case "visualizer_append":
|
238 |
+
midi_visualizer.appendMidiEvent(msg.data);
|
239 |
+
break;
|
240 |
+
case "progress":
|
241 |
+
let progress = msg.data[0]
|
242 |
+
let total = msg.data[1]
|
243 |
+
setProgressBar(midi_visualizer_container_inited, progress, total)
|
244 |
+
break;
|
245 |
+
case "visualizer_end":
|
246 |
+
midi_visualizer.setPlayTime(0)
|
247 |
+
removeProgressBar(midi_visualizer_container_inited)
|
248 |
+
break;
|
249 |
+
default:
|
250 |
+
}
|
251 |
+
})
|
252 |
+
})();
|
midi_tokenizer.py
CHANGED
@@ -101,6 +101,19 @@ class MIDITokenizer:
|
|
101 |
tokens += [self.pad_id] * (self.max_token_seq - len(tokens))
|
102 |
return tokens
|
103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
def detokenize(self, midi_seq):
|
105 |
ticks_per_beat = 480
|
106 |
tracks_dict = {}
|
|
|
101 |
tokens += [self.pad_id] * (self.max_token_seq - len(tokens))
|
102 |
return tokens
|
103 |
|
104 |
+
def tokens2event(self, tokens):
|
105 |
+
if tokens[0] in self.id_events:
|
106 |
+
name = self.id_events[tokens[0]]
|
107 |
+
if len(tokens) <= len(self.events[name]):
|
108 |
+
return []
|
109 |
+
params = tokens[1:]
|
110 |
+
params = [params[i] - self.parameter_ids[p][0] for i, p in enumerate(self.events[name])]
|
111 |
+
if not all([0 <= params[i] < self.event_parameters[p] for i, p in enumerate(self.events[name])]):
|
112 |
+
return []
|
113 |
+
event = [name] + params
|
114 |
+
return event
|
115 |
+
return []
|
116 |
+
|
117 |
def detokenize(self, midi_seq):
|
118 |
ticks_per_beat = 480
|
119 |
tracks_dict = {}
|