thecollabagepatch commited on
Commit
3eec7a7
1 Parent(s): 98b2108

passing the midi audio properly

Browse files
Files changed (1) hide show
  1. app.py +16 -19
app.py CHANGED
@@ -117,27 +117,27 @@ def generate_midi(seed, use_chords, chord_progression, bpm):
117
  # Clean up temporary MIDI file
118
  os.remove(midi_filename)
119
 
120
- return wav_filename
 
 
 
 
 
 
 
121
 
122
  @spaces.GPU(duration=120)
123
- def generate_music(midi_audio, prompt_duration, musicgen_model, num_iterations, bpm):
124
- if isinstance(midi_audio, tuple):
125
- wav_filename, sample_rate = midi_audio
126
- song, sr = torchaudio.load(wav_filename)
127
- else:
128
- # Assuming midi_audio is a numpy array
129
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
130
- temp_filename = temp_file.name
131
- torchaudio.save(temp_filename, torch.from_numpy(midi_audio), sample_rate=44100)
132
- song, sr = torchaudio.load(temp_filename)
133
 
134
- song = song.to(device)
 
135
 
136
  # Use the user-provided BPM value for duration calculation
137
  duration = calculate_duration(bpm)
138
 
139
  # Create slices from the song using the user-provided BPM value
140
- slices = create_slices(song, sr, 35, bpm, num_slices=5)
141
 
142
  # Load the model
143
  model_name = musicgen_model.split(" ")[0]
@@ -160,10 +160,10 @@ def generate_music(midi_audio, prompt_duration, musicgen_model, num_iterations,
160
 
161
  print(f"Running iteration {i + 1} using slice {slice_idx}...")
162
 
163
- prompt_waveform = slices[slice_idx][..., :int(prompt_duration * sr)]
164
  prompt_waveform = preprocess_audio(prompt_waveform)
165
 
166
- output = model_continue.generate_continuation(prompt_waveform, prompt_sample_rate=sr, progress=True)
167
  output = output.cpu() # Move the output tensor back to CPU
168
 
169
  # Make sure the output tensor has at most 2 dimensions
@@ -184,10 +184,7 @@ def generate_music(midi_audio, prompt_duration, musicgen_model, num_iterations,
184
  combined_audio_filename = f"combined_audio_{random.randint(1, 10000)}.mp3"
185
  combined_audio.export(combined_audio_filename, format="mp3")
186
 
187
-
188
  # Clean up temporary files
189
- if not isinstance(midi_audio, tuple):
190
- os.remove(temp_filename)
191
  for filename in all_audio_files:
192
  os.remove(filename)
193
 
@@ -253,6 +250,6 @@ with gr.Blocks() as iface:
253
  output_audio = gr.Audio(label="Generated Music")
254
 
255
  generate_midi_button.click(generate_midi, inputs=[seed, use_chords, chord_progression, bpm], outputs=midi_audio)
256
- generate_music_button.click(generate_music, inputs=[midi_audio, prompt_duration, musicgen_model, num_iterations, bpm], outputs=output_audio)
257
 
258
  iface.launch()
 
117
  # Clean up temporary MIDI file
118
  os.remove(midi_filename)
119
 
120
+ # Load the generated audio
121
+ song, sr = torchaudio.load(wav_filename)
122
+
123
+ # Clean up temporary MIDI file
124
+ os.remove(midi_filename)
125
+ os.remove(wav_filename)
126
+
127
+ return song.numpy(), sr
128
 
129
  @spaces.GPU(duration=120)
130
+ def generate_music(midi_data, prompt_duration, musicgen_model, num_iterations, bpm):
131
+ audio_data, sample_rate = midi_data
 
 
 
 
 
 
 
 
132
 
133
+ # Convert the audio data to a PyTorch tensor
134
+ song = torch.from_numpy(audio_data).to(device)
135
 
136
  # Use the user-provided BPM value for duration calculation
137
  duration = calculate_duration(bpm)
138
 
139
  # Create slices from the song using the user-provided BPM value
140
+ slices = create_slices(song, sample_rate, 35, bpm, num_slices=5)
141
 
142
  # Load the model
143
  model_name = musicgen_model.split(" ")[0]
 
160
 
161
  print(f"Running iteration {i + 1} using slice {slice_idx}...")
162
 
163
+ prompt_waveform = slices[slice_idx][..., :int(prompt_duration * sample_rate)]
164
  prompt_waveform = preprocess_audio(prompt_waveform)
165
 
166
+ output = model_continue.generate_continuation(prompt_waveform, prompt_sample_rate=sample_rate, progress=True)
167
  output = output.cpu() # Move the output tensor back to CPU
168
 
169
  # Make sure the output tensor has at most 2 dimensions
 
184
  combined_audio_filename = f"combined_audio_{random.randint(1, 10000)}.mp3"
185
  combined_audio.export(combined_audio_filename, format="mp3")
186
 
 
187
  # Clean up temporary files
 
 
188
  for filename in all_audio_files:
189
  os.remove(filename)
190
 
 
250
  output_audio = gr.Audio(label="Generated Music")
251
 
252
  generate_midi_button.click(generate_midi, inputs=[seed, use_chords, chord_progression, bpm], outputs=midi_audio)
253
+ generate_music_button.click(generate_music, inputs=[midi_audio[0], prompt_duration, musicgen_model, num_iterations, bpm], outputs=output_audio)
254
 
255
  iface.launch()