Spaces:
Runtime error
Runtime error
update
Browse files- README.md +1 -1
- midi_tokenizer.py +22 -2
README.md
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
---
|
2 |
-
title: Midi
|
3 |
emoji: 🎼🎶
|
4 |
colorFrom: red
|
5 |
colorTo: indigo
|
|
|
1 |
---
|
2 |
+
title: Midi Music Generator
|
3 |
emoji: 🎼🎶
|
4 |
colorFrom: red
|
5 |
colorTo: indigo
|
midi_tokenizer.py
CHANGED
@@ -169,16 +169,36 @@ class MIDITokenizer:
|
|
169 |
img = PIL.Image.fromarray(np.flip(img, 0))
|
170 |
return img
|
171 |
|
172 |
-
def augment(self, midi_seq, max_pitch_shift=4, max_vel_shift=10, max_cc_val_shift=10, max_bpm_shift=10
|
|
|
173 |
pitch_shift = random.randint(-max_pitch_shift, max_pitch_shift)
|
174 |
vel_shift = random.randint(-max_vel_shift, max_vel_shift)
|
175 |
cc_val_shift = random.randint(-max_cc_val_shift, max_cc_val_shift)
|
176 |
bpm_shift = random.randint(-max_bpm_shift, max_bpm_shift)
|
|
|
|
|
177 |
midi_seq_new = []
|
178 |
for tokens in midi_seq:
|
179 |
tokens_new = [*tokens]
|
180 |
if tokens[0] in self.id_events:
|
181 |
name = self.id_events[tokens[0]]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
if name == "note":
|
183 |
c = tokens[5] - self.parameter_ids["channel"][0]
|
184 |
p = tokens[6] - self.parameter_ids["pitch"][0]
|
@@ -206,7 +226,7 @@ class MIDITokenizer:
|
|
206 |
midi_seq_new.append(tokens_new)
|
207 |
return midi_seq_new
|
208 |
|
209 |
-
def check_alignment(self, midi_seq, threshold=0.
|
210 |
total = 0
|
211 |
hist = [0] * 16
|
212 |
for tokens in midi_seq:
|
|
|
169 |
img = PIL.Image.fromarray(np.flip(img, 0))
|
170 |
return img
|
171 |
|
172 |
+
def augment(self, midi_seq, max_pitch_shift=4, max_vel_shift=10, max_cc_val_shift=10, max_bpm_shift=10,
|
173 |
+
max_track_shift=128, max_channel_shift=16):
|
174 |
pitch_shift = random.randint(-max_pitch_shift, max_pitch_shift)
|
175 |
vel_shift = random.randint(-max_vel_shift, max_vel_shift)
|
176 |
cc_val_shift = random.randint(-max_cc_val_shift, max_cc_val_shift)
|
177 |
bpm_shift = random.randint(-max_bpm_shift, max_bpm_shift)
|
178 |
+
track_shift = random.randint(0, max_track_shift)
|
179 |
+
channel_shift = random.randint(0, max_channel_shift)
|
180 |
midi_seq_new = []
|
181 |
for tokens in midi_seq:
|
182 |
tokens_new = [*tokens]
|
183 |
if tokens[0] in self.id_events:
|
184 |
name = self.id_events[tokens[0]]
|
185 |
+
for i, pn in enumerate(self.events[name]):
|
186 |
+
if pn == "track":
|
187 |
+
tr = tokens[1 + i] - self.parameter_ids[pn][0]
|
188 |
+
tr += track_shift
|
189 |
+
tr = tr % self.event_parameters[pn]
|
190 |
+
tokens_new[1 + i] = self.parameter_ids[pn][tr]
|
191 |
+
elif pn == "channel":
|
192 |
+
c = tokens[1 + i] - self.parameter_ids[pn][0]
|
193 |
+
c0 = c
|
194 |
+
c += channel_shift
|
195 |
+
c = c % self.event_parameters[pn]
|
196 |
+
if c0 == 9:
|
197 |
+
c = 9
|
198 |
+
elif c == 9:
|
199 |
+
c = (9 + channel_shift) % self.event_parameters[pn]
|
200 |
+
tokens_new[1 + i] = self.parameter_ids[pn][c]
|
201 |
+
|
202 |
if name == "note":
|
203 |
c = tokens[5] - self.parameter_ids["channel"][0]
|
204 |
p = tokens[6] - self.parameter_ids["pitch"][0]
|
|
|
226 |
midi_seq_new.append(tokens_new)
|
227 |
return midi_seq_new
|
228 |
|
229 |
+
def check_alignment(self, midi_seq, threshold=0.3):
|
230 |
total = 0
|
231 |
hist = [0] * 16
|
232 |
for tokens in midi_seq:
|