asigalov61 commited on
Commit
d26af00
1 Parent(s): b7b4e7f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -145
app.py CHANGED
@@ -1,4 +1,4 @@
1
- import os.path
2
 
3
  import time as reqtime
4
  import datetime
@@ -23,15 +23,15 @@ in_space = os.getenv("SYSTEM") == "spaces"
23
  # =================================================================================================
24
 
25
  @spaces.GPU
26
- def GenerateAccompaniment(input_midi, input_num_tokens, input_conditioning_type, input_strip_notes):
27
  print('=' * 70)
28
  print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
29
  start_time = reqtime.time()
30
 
31
  print('Loading model...')
32
 
33
- SEQ_LEN = 8192 # Models seq len
34
- PAD_IDX = 707 # Models pad index
35
  DEVICE = 'cuda' # 'cuda'
36
 
37
  # instantiate the model
@@ -39,7 +39,7 @@ def GenerateAccompaniment(input_midi, input_num_tokens, input_conditioning_type,
39
  model = TransformerWrapper(
40
  num_tokens = PAD_IDX+1,
41
  max_seq_len = SEQ_LEN,
42
- attn_layers = Decoder(dim = 2048, depth = 4, heads = 16, attn_flash = True)
43
  )
44
 
45
  model = AutoregressiveWrapper(model, ignore_index = PAD_IDX)
@@ -50,7 +50,7 @@ def GenerateAccompaniment(input_midi, input_num_tokens, input_conditioning_type,
50
  print('Loading model checkpoint...')
51
 
52
  model.load_state_dict(
53
- torch.load('Chords_Progressions_Transformer_Small_2048_Trained_Model_12947_steps_0.9316_loss_0.7386_acc.pth',
54
  map_location=DEVICE))
55
  print('=' * 70)
56
 
@@ -59,145 +59,15 @@ def GenerateAccompaniment(input_midi, input_num_tokens, input_conditioning_type,
59
  if DEVICE == 'cpu':
60
  dtype = torch.bfloat16
61
  else:
62
- dtype = torch.float16
63
 
64
  ctx = torch.amp.autocast(device_type=DEVICE, dtype=dtype)
65
 
66
  print('Done!')
67
  print('=' * 70)
68
-
69
- fn = os.path.basename(input_midi.name)
70
- fn1 = fn.split('.')[0]
71
-
72
- input_num_tokens = max(4, min(128, input_num_tokens))
73
-
74
- print('-' * 70)
75
- print('Input file name:', fn)
76
- print('Req num toks:', input_num_tokens)
77
- print('Conditioning type:', input_conditioning_type)
78
- print('Strip notes:', input_strip_notes)
79
  print('-' * 70)
80
 
81
- #===============================================================================
82
- raw_score = TMIDIX.midi2single_track_ms_score(input_midi.name)
83
-
84
- #===============================================================================
85
- # Enhanced score notes
86
-
87
- escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True)[0]
88
-
89
- no_drums_escore_notes = [e for e in escore_notes if e[6] < 80]
90
-
91
- if len(no_drums_escore_notes) > 0:
92
-
93
- #=======================================================
94
- # PRE-PROCESSING
95
-
96
- #===============================================================================
97
- # Augmented enhanced score notes
98
-
99
- no_drums_escore_notes = TMIDIX.augment_enhanced_score_notes(no_drums_escore_notes)
100
-
101
- cscore = TMIDIX.chordify_score([1000, no_drums_escore_notes])
102
-
103
- clean_cscore = []
104
-
105
- for c in cscore:
106
- pitches = []
107
- cho = []
108
- for cc in c:
109
- if cc[4] not in pitches:
110
- cho.append(cc)
111
- pitches.append(cc[4])
112
-
113
- clean_cscore.append(cho)
114
-
115
- #=======================================================
116
- # FINAL PROCESSING
117
-
118
- melody_chords = []
119
- chords = []
120
- times = [0]
121
- durs = []
122
-
123
- #=======================================================
124
- # MAIN PROCESSING CYCLE
125
- #=======================================================
126
-
127
- pe = clean_cscore[0][0]
128
-
129
- first_chord = True
130
-
131
- for c in clean_cscore:
132
-
133
- # Chords
134
-
135
- c.sort(key=lambda x: x[4], reverse=True)
136
-
137
- tones_chord = sorted(set([cc[4] % 12 for cc in c]))
138
-
139
- try:
140
- chord_token = TMIDIX.ALL_CHORDS_SORTED.index(tones_chord)
141
- except:
142
- checked_tones_chord = TMIDIX.check_and_fix_tones_chord(tones_chord)
143
- chord_token = TMIDIX.ALL_CHORDS_SORTED.index(checked_tones_chord)
144
-
145
- melody_chords.extend([chord_token+384])
146
-
147
- if input_strip_notes:
148
- if len(tones_chord) > 1:
149
- chords.extend([chord_token+384])
150
-
151
- else:
152
- chords.extend([chord_token+384])
153
-
154
- if first_chord:
155
- melody_chords.extend([0])
156
- first_chord = False
157
-
158
- for e in c:
159
-
160
- #=======================================================
161
- # Timings...
162
-
163
- time = e[1]-pe[1]
164
-
165
- dur = e[2]
166
-
167
- if time != 0 and time % 2 != 0:
168
- time += 1
169
- if dur % 2 != 0:
170
- dur += 1
171
-
172
- delta_time = int(max(0, min(255, time)) / 2)
173
-
174
- # Durations
175
-
176
- dur = int(max(0, min(255, dur)) / 2)
177
-
178
- # Pitches
179
-
180
- ptc = max(1, min(127, e[4]))
181
-
182
- #=======================================================
183
- # FINAL NOTE SEQ
184
-
185
- # Writing final note asynchronously
186
-
187
- if delta_time != 0:
188
- melody_chords.extend([delta_time, dur+128, ptc+256])
189
- if input_strip_notes:
190
- if len(c) > 1:
191
- times.append(delta_time)
192
- durs.append(dur+128)
193
- else:
194
- times.append(delta_time)
195
- durs.append(dur+128)
196
- else:
197
- melody_chords.extend([dur+128, ptc+256])
198
-
199
- pe = e
200
-
201
  #==================================================================
202
 
203
  print('=' * 70)
@@ -368,11 +238,8 @@ if __name__ == "__main__":
368
  gr.Markdown(
369
  "![Visitors](https://api.visitorbadge.io/api/visitors?path=asigalov61.Melody2Song-Seq2Seq-Music-Transformer&style=flat)\n\n")
370
 
371
- input_midi = gr.File(label="Input MIDI", file_types=[".midi", ".mid", ".kar"])
372
- input_num_tokens = gr.Slider(4, 128, value=32, step=1, label="Number of composition chords to generate progression for")
373
- input_conditioning_type = gr.Radio(["Chords", "Chords-Times", "Chords-Times-Durations"], label="Conditioning type")
374
- input_strip_notes = gr.Checkbox(label="Strip notes from the composition")
375
-
376
  run_btn = gr.Button("generate", variant="primary")
377
 
378
  gr.Markdown("## Generation results")
@@ -383,8 +250,7 @@ if __name__ == "__main__":
383
  output_plot = gr.Plot(label="Output MIDI score plot")
384
  output_midi = gr.File(label="Output MIDI file", file_types=[".mid"])
385
 
386
-
387
- run_event = run_btn.click(GenerateAccompaniment, [input_midi, input_num_tokens, input_conditioning_type, input_strip_notes],
388
  [output_midi_title, output_midi_summary, output_midi, output_audio, output_plot])
389
 
390
  app.queue().launch()
 
1
+ # https://huggingface.co/spaces/asigalov61/Melody2Song-Seq2Seq-Music-Transformer
2
 
3
  import time as reqtime
4
  import datetime
 
23
  # =================================================================================================
24
 
25
  @spaces.GPU
26
+ def GenerateSong(input_melody_seed_number):
27
  print('=' * 70)
28
  print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
29
  start_time = reqtime.time()
30
 
31
  print('Loading model...')
32
 
33
+ SEQ_LEN = 2560
34
+ PAD_IDX = 514
35
  DEVICE = 'cuda' # 'cuda'
36
 
37
  # instantiate the model
 
39
  model = TransformerWrapper(
40
  num_tokens = PAD_IDX+1,
41
  max_seq_len = SEQ_LEN,
42
+ attn_layers = Decoder(dim = 1024, depth = 24, heads = 16, attn_flash = True)
43
  )
44
 
45
  model = AutoregressiveWrapper(model, ignore_index = PAD_IDX)
 
50
  print('Loading model checkpoint...')
51
 
52
  model.load_state_dict(
53
+ torch.load('Melody2Song_Seq2Seq_Music_Transformer_Trained_Model_28482_steps_0.719_loss_0.7865_acc.pth',
54
  map_location=DEVICE))
55
  print('=' * 70)
56
 
 
59
  if DEVICE == 'cpu':
60
  dtype = torch.bfloat16
61
  else:
62
+ dtype = torch.bfloat16
63
 
64
  ctx = torch.amp.autocast(device_type=DEVICE, dtype=dtype)
65
 
66
  print('Done!')
67
  print('=' * 70)
68
+ print('Input melody seed number:', input_melody_seed_number)
 
 
 
 
 
 
 
 
 
 
69
  print('-' * 70)
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  #==================================================================
72
 
73
  print('=' * 70)
 
238
  gr.Markdown(
239
  "![Visitors](https://api.visitorbadge.io/api/visitors?path=asigalov61.Melody2Song-Seq2Seq-Music-Transformer&style=flat)\n\n")
240
 
241
+ input_melody_seed_number = gr.Slider(0, 200000, value=0, step=1, label="Select seed melody number")
242
+
 
 
 
243
  run_btn = gr.Button("generate", variant="primary")
244
 
245
  gr.Markdown("## Generation results")
 
250
  output_plot = gr.Plot(label="Output MIDI score plot")
251
  output_midi = gr.File(label="Output MIDI file", file_types=[".mid"])
252
 
253
+ run_event = run_btn.click(GenerateSong, [input_melody_seed_number],
 
254
  [output_midi_title, output_midi_summary, output_midi, output_audio, output_plot])
255
 
256
  app.queue().launch()