skytnt commited on
Commit
209f9d6
1 Parent(s): fe820fd
Files changed (2) hide show
  1. README.md +1 -1
  2. midi_tokenizer.py +22 -2
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Midi Composer
3
  emoji: 🎼🎶
4
  colorFrom: red
5
  colorTo: indigo
 
1
  ---
2
+ title: Midi Music Generator
3
  emoji: 🎼🎶
4
  colorFrom: red
5
  colorTo: indigo
midi_tokenizer.py CHANGED
@@ -169,16 +169,36 @@ class MIDITokenizer:
169
  img = PIL.Image.fromarray(np.flip(img, 0))
170
  return img
171
 
172
- def augment(self, midi_seq, max_pitch_shift=4, max_vel_shift=10, max_cc_val_shift=10, max_bpm_shift=10):
 
173
  pitch_shift = random.randint(-max_pitch_shift, max_pitch_shift)
174
  vel_shift = random.randint(-max_vel_shift, max_vel_shift)
175
  cc_val_shift = random.randint(-max_cc_val_shift, max_cc_val_shift)
176
  bpm_shift = random.randint(-max_bpm_shift, max_bpm_shift)
 
 
177
  midi_seq_new = []
178
  for tokens in midi_seq:
179
  tokens_new = [*tokens]
180
  if tokens[0] in self.id_events:
181
  name = self.id_events[tokens[0]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  if name == "note":
183
  c = tokens[5] - self.parameter_ids["channel"][0]
184
  p = tokens[6] - self.parameter_ids["pitch"][0]
@@ -206,7 +226,7 @@ class MIDITokenizer:
206
  midi_seq_new.append(tokens_new)
207
  return midi_seq_new
208
 
209
- def check_alignment(self, midi_seq, threshold=0.4):
210
  total = 0
211
  hist = [0] * 16
212
  for tokens in midi_seq:
 
169
  img = PIL.Image.fromarray(np.flip(img, 0))
170
  return img
171
 
172
+ def augment(self, midi_seq, max_pitch_shift=4, max_vel_shift=10, max_cc_val_shift=10, max_bpm_shift=10,
173
+ max_track_shift=128, max_channel_shift=16):
174
  pitch_shift = random.randint(-max_pitch_shift, max_pitch_shift)
175
  vel_shift = random.randint(-max_vel_shift, max_vel_shift)
176
  cc_val_shift = random.randint(-max_cc_val_shift, max_cc_val_shift)
177
  bpm_shift = random.randint(-max_bpm_shift, max_bpm_shift)
178
+ track_shift = random.randint(0, max_track_shift)
179
+ channel_shift = random.randint(0, max_channel_shift)
180
  midi_seq_new = []
181
  for tokens in midi_seq:
182
  tokens_new = [*tokens]
183
  if tokens[0] in self.id_events:
184
  name = self.id_events[tokens[0]]
185
+ for i, pn in enumerate(self.events[name]):
186
+ if pn == "track":
187
+ tr = tokens[1 + i] - self.parameter_ids[pn][0]
188
+ tr += track_shift
189
+ tr = tr % self.event_parameters[pn]
190
+ tokens_new[1 + i] = self.parameter_ids[pn][tr]
191
+ elif pn == "channel":
192
+ c = tokens[1 + i] - self.parameter_ids[pn][0]
193
+ c0 = c
194
+ c += channel_shift
195
+ c = c % self.event_parameters[pn]
196
+ if c0 == 9:
197
+ c = 9
198
+ elif c == 9:
199
+ c = (9 + channel_shift) % self.event_parameters[pn]
200
+ tokens_new[1 + i] = self.parameter_ids[pn][c]
201
+
202
  if name == "note":
203
  c = tokens[5] - self.parameter_ids["channel"][0]
204
  p = tokens[6] - self.parameter_ids["pitch"][0]
 
226
  midi_seq_new.append(tokens_new)
227
  return midi_seq_new
228
 
229
+ def check_alignment(self, midi_seq, threshold=0.3):
230
  total = 0
231
  hist = [0] * 16
232
  for tokens in midi_seq: