thecollabagepatch commited on
Commit
58fc3d4
1 Parent(s): 5098605

continuing continuation attempt 1

Browse files
Files changed (1) hide show
  1. app.py +56 -1
app.py CHANGED
@@ -182,6 +182,58 @@ def generate_music(wav_filename, prompt_duration, musicgen_model, num_iterations
182
 
183
  return combined_audio_filename
184
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  # Define the expandable sections
186
  musiclang_blurb = """
187
  ## musiclang
@@ -234,12 +286,15 @@ with gr.Blocks() as iface:
234
  "thepatch/bleeps-medium (medium)",
235
  "thepatch/hoenn_lofi (large)"
236
  ], value="thepatch/vanya_ai_dnb_0.1 (small)")
237
- num_iterations = gr.Slider(label="Number of Iterations", minimum=1, maximum=10, step=1, value=3)
238
  generate_music_button = gr.Button("Generate Music")
239
  output_audio = gr.Audio(label="Generated Music")
 
 
240
 
241
  # Connecting the components
242
  generate_midi_button.click(generate_midi, inputs=[seed, use_chords, chord_progression, bpm], outputs=[midi_audio])
243
  generate_music_button.click(generate_music, inputs=[midi_audio, prompt_duration, musicgen_model, num_iterations, bpm], outputs=[output_audio])
 
244
 
245
  iface.launch()
 
182
 
183
  return combined_audio_filename
184
 
185
+ def continue_music(input_audio_path, prompt_duration, musicgen_model, num_iterations, bpm):
186
+ # Load the audio from the given file path
187
+ song, sr = torchaudio.load(input_audio_path)
188
+ song = song.to(device)
189
+
190
+ # Calculate the slice from the end of the song based on prompt_duration
191
+ num_samples = int(prompt_duration * sr)
192
+ if song.shape[-1] < num_samples:
193
+ raise ValueError("The prompt_duration is longer than the audio length.")
194
+ start_idx = song.shape[-1] - num_samples
195
+ prompt_waveform = song[..., start_idx:]
196
+
197
+ # Prepare the audio slice for generation
198
+ prompt_waveform = preprocess_audio(prompt_waveform)
199
+
200
+ # Load the model and set generation parameters as before
201
+ model_continue = MusicGen.get_pretrained(musicgen_model.split(" ")[0])
202
+ model_continue.set_generation_params(
203
+ use_sampling=True,
204
+ top_k=250,
205
+ top_p=0.0,
206
+ temperature=1.0,
207
+ duration=calculate_duration(bpm),
208
+ cfg_coef=3
209
+ )
210
+
211
+ all_audio_files = []
212
+ for i in range(num_iterations):
213
+ output = model_continue.generate_continuation(prompt_waveform, prompt_sample_rate=sr, progress=True)
214
+ output = output.cpu() # Ensure the output is on CPU for further processing
215
+ if len(output.size()) > 2:
216
+ output = output.squeeze()
217
+
218
+ filename_without_extension = f'continue_{i}'
219
+ filename_with_extension = f'{filename_without_extension}.wav'
220
+ audio_write(filename_with_extension, output, model_continue.sample_rate, strategy="loudness", loudness_compressor=True)
221
+ all_audio_files.append(filename_with_extension)
222
+
223
+ # Combine all audio files as before
224
+ combined_audio = AudioSegment.empty()
225
+ for filename in all_audio_files:
226
+ combined_audio += AudioSegment.from_wav(filename)
227
+
228
+ combined_audio_filename = f"combined_audio_{random.randint(1, 10000)}.mp3"
229
+ combined_audio.export(combined_audio_filename, format="mp3")
230
+
231
+ # Clean up temporary files
232
+ for filename in all_audio_files:
233
+ os.remove(filename)
234
+
235
+ return combined_audio_filename
236
+
237
  # Define the expandable sections
238
  musiclang_blurb = """
239
  ## musiclang
 
286
  "thepatch/bleeps-medium (medium)",
287
  "thepatch/hoenn_lofi (large)"
288
  ], value="thepatch/vanya_ai_dnb_0.1 (small)")
289
+ num_iterations = gr.Slider(label="Number of Iterations", minimum=1, maximum=3, step=1, value=3)
290
  generate_music_button = gr.Button("Generate Music")
291
  output_audio = gr.Audio(label="Generated Music")
292
+ continue_button = gr.Button("Continue Generating Music")
293
+ continue_output_audio = gr.Audio(label="Continued Music Output")
294
 
295
  # Connecting the components
296
  generate_midi_button.click(generate_midi, inputs=[seed, use_chords, chord_progression, bpm], outputs=[midi_audio])
297
  generate_music_button.click(generate_music, inputs=[midi_audio, prompt_duration, musicgen_model, num_iterations, bpm], outputs=[output_audio])
298
+ continue_button.click(continue_music, inputs=[output_audio, prompt_duration, musicgen_model, num_iterations, bpm], outputs=continue_output_audio)
299
 
300
  iface.launch()