nakas commited on
Commit
a834046
1 Parent(s): 7c993b8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -261
app.py CHANGED
@@ -1,269 +1,67 @@
1
  import os
2
  import gradio as gr
3
  from scipy.io.wavfile import write
4
- import subprocess
5
- import argparse
6
- from concurrent.futures import ProcessPoolExecutor
7
- import time
8
- import typing as tp
9
- import warnings
10
- from pathlib import Path
11
  import torch
12
- import gradio as gr
13
 
14
  from audiocraft.data.audio_utils import convert_audio
15
- from audiocraft.data.audio import audio_write
16
- from audiocraft.models import MusicGen
17
-
18
-
19
- MODEL = None # Last used model
20
- IS_BATCHED = "facebook/MusicGen" in os.environ.get('SPACE_ID', '')
21
- MAX_BATCH_SIZE = 6
22
- BATCHED_DURATION = 15
23
- INTERRUPTING = False
24
-
25
- def interrupt():
26
- global INTERRUPTING
27
- INTERRUPTING = True
28
-
29
-
30
- class FileCleaner:
31
- def __init__(self, file_lifetime: float = 3600):
32
- self.file_lifetime = file_lifetime
33
- self.files = []
34
-
35
- def add(self, path: tp.Union[str, Path]):
36
- self._cleanup()
37
- self.files.append((time.time(), Path(path)))
38
 
39
- def _cleanup(self):
40
- now = time.time()
41
- for time_added, path in list(self.files):
42
- if now - time_added > self.file_lifetime:
43
- if path.exists():
44
- path.unlink()
45
- self.files.pop(0)
46
- else:
47
- break
48
-
49
-
50
- file_cleaner = FileCleaner()
51
-
52
-
53
- def make_waveform(*args, **kwargs):
54
- be = time.time()
55
- with warnings.catch_warnings():
56
- warnings.simplefilter('ignore')
57
- out = gr.make_waveform(*args, **kwargs)
58
- print("Make a video took", time.time() - be)
59
- return out
60
-
61
-
62
- def load_model(version='melody'):
63
- global MODEL
64
- print("Loading model", version)
65
- if MODEL is None or MODEL.name != version:
66
- MODEL = MusicGen.get_pretrained(version)
67
-
68
-
69
- def _do_predictions(texts, melodies, duration, progress=False, **gen_kwargs):
70
- MODEL.set_generation_params(duration=duration, **gen_kwargs)
71
- print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies])
72
- be = time.time()
73
- processed_melodies = []
74
- target_sr = 32000
75
- target_ac = 1
76
- for melody in melodies:
77
- if melody is None:
78
- processed_melodies.append(None)
79
  else:
80
- sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t()
81
- if melody.dim() == 1:
82
- melody = melody[None]
83
- melody = melody[..., :int(sr * duration)]
84
- melody = convert_audio(melody, sr, target_sr, target_ac)
85
- processed_melodies.append(melody)
86
-
87
- if any(m is not None for m in processed_melodies):
88
- outputs = MODEL.generate_with_chroma(
89
- descriptions=texts,
90
- melody_wavs=processed_melodies,
91
- melody_sample_rate=target_sr,
92
- progress=progress,
93
- )
94
- else:
95
- outputs = MODEL.generate(texts, progress=progress)
96
-
97
- outputs = outputs.detach().cpu().float()
98
- out_files = []
99
- for output in outputs:
100
- with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
101
- audio_write(
102
- file.name, output, MODEL.sample_rate, strategy="loudness",
103
- loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
104
- out_files.append(pool.submit(make_waveform, file.name))
105
- file_cleaner.add(file.name)
106
- res = [out_file.result() for out_file in out_files]
107
- for file in res:
108
- file_cleaner.add(file)
109
- print("batch finished", len(texts), time.time() - be)
110
- print("Tempfiles currently stored: ", len(file_cleaner.files))
111
- return res
112
-
113
-
114
- def predict_batched(texts, melodies):
115
- max_text_length = 512
116
- texts = [text[:max_text_length] for text in texts]
117
- load_model('melody')
118
- res = _do_predictions(texts, melodies, BATCHED_DURATION)
119
- return [res]
120
-
121
-
122
- def predict_full(model, text, melody, duration, topk, topp, temperature, cfg_coef, progress=gr.Progress()):
123
- global INTERRUPTING
124
- INTERRUPTING = False
125
- if temperature < 0:
126
- raise gr.Error("Temperature must be >= 0.")
127
- if topk < 0:
128
- raise gr.Error("Topk must be non-negative.")
129
- if topp < 0:
130
- raise gr.Error("Topp must be non-negative.")
131
-
132
- topk = int(topk)
133
- load_model(model)
134
-
135
- def _progress(generated, to_generate):
136
- progress((generated, to_generate))
137
- if INTERRUPTING:
138
- raise gr.Error("Interrupted.")
139
- MODEL.set_custom_progress_callback(_progress)
140
-
141
- outs = _do_predictions(
142
- [text], [melody], duration, progress=True,
143
- top_k=topk, top_p=topp, temperature=temperature, cfg_coef=cfg_coef)
144
- return outs[0]
145
-
146
-
147
- def toggle_audio_src(choice):
148
- if choice == "mic":
149
- return gr.update(source="microphone", value=None, label="Microphone")
150
- else:
151
- return gr.update(source="upload", value=None, label="File")
152
-
153
-
154
- def ui_full(launch_kwargs):
155
- with gr.Blocks() as interface:
156
- gr.Markdown(
157
- """
158
- # MusicGen and Demucs Combination
159
- This is a combined demo of MusicGen and Demucs.
160
- MusicGen is a model for music generation based on text prompts,
161
- and Demucs is a model for music source separation.
162
- """
163
- )
164
- with gr.Row():
165
- with gr.Column():
166
- with gr.Row():
167
- text = gr.Text(label="Input Text", interactive=True)
168
- with gr.Column():
169
- radio = gr.Radio(["file", "mic"], value="file",
170
- label="Condition on a Melody (optional) File or Mic")
171
- melody = gr.Audio(source="upload", type="numpy", label="Melody File",
172
- interactive=True, elem_id="melody-input")
173
- with gr.Row():
174
- submit = gr.Button("Generate Music")
175
- with gr.Row():
176
- audio_output = gr.Audio(type="numpy", label="Generated Music")
177
- vocals_output = gr.Audio(type="filepath", label="Vocals")
178
- bass_output = gr.Audio(type="filepath", label="Bass")
179
- drums_output = gr.Audio(type="filepath", label="Drums")
180
- other_output = gr.Audio(type="filepath", label="Other")
181
- submit.click(predict_full,
182
- inputs=[text, melody, None, 10, 250, 0, 1.0, 3.0],
183
- outputs=[audio_output, vocals_output, bass_output, drums_output, other_output])
184
- radio.change(toggle_audio_src, radio, [melody], queue=False, show_progress=False)
185
- gr.Examples(
186
- fn=predict_full,
187
- examples=[
188
- [
189
- "An 80s driving pop song with heavy drums and synth pads in the background",
190
- "./assets/bach.mp3",
191
- ],
192
- [
193
- "A cheerful country song with acoustic guitars",
194
- "./assets/bolero_ravel.mp3",
195
- ],
196
- [
197
- "90s rock song with electric guitar and heavy drums",
198
- None,
199
- ],
200
- [
201
- "a light and cheerful EDM track, with syncopated drums, airy pads, and strong emotions",
202
- "./assets/bach.mp3",
203
- ],
204
- [
205
- "lofi slow bpm electro chill with organic samples",
206
- None,
207
- ],
208
- ],
209
- inputs=[text, melody],
210
- outputs=[audio_output, vocals_output, bass_output, drums_output, other_output]
211
- )
212
-
213
- gr.Interface(
214
- fn=inference,
215
- inputs=gr.inputs.Audio(type="numpy", label="Input Audio"),
216
- outputs=[
217
- gr.outputs.Audio(type="filepath", label="Vocals"),
218
- gr.outputs.Audio(type="filepath", label="Bass"),
219
- gr.outputs.Audio(type="filepath", label="Drums"),
220
- gr.outputs.Audio(type="filepath", label="Other"),
221
- ],
222
- title="MusicGen and Demucs Combination",
223
- description="A combined demo of MusicGen and Demucs",
224
- article="",
225
- ).launch(enable_queue=True)
226
-
227
- if __name__ == "__main__":
228
- parser = argparse.ArgumentParser()
229
- parser.add_argument(
230
- '--listen',
231
- type=str,
232
- default='0.0.0.0' if 'SPACE_ID' in os.environ else '127.0.0.1',
233
- help='IP to listen on for connections to Gradio',
234
- )
235
- parser.add_argument(
236
- '--username', type=str, default='', help='Username for authentication'
237
- )
238
- parser.add_argument(
239
- '--password', type=str, default='', help='Password for authentication'
240
- )
241
- parser.add_argument(
242
- '--server_port',
243
- type=int,
244
- default=0,
245
- help='Port to run the server listener on',
246
- )
247
- parser.add_argument(
248
- '--inbrowser', action='store_true', help='Open in browser'
249
- )
250
- parser.add_argument(
251
- '--share', action='store_true', help='Share the gradio UI'
252
- )
253
-
254
- args = parser.parse_args()
255
-
256
- launch_kwargs = {}
257
- launch_kwargs['server_name'] = args.listen
258
-
259
- if args.username and args.password:
260
- launch_kwargs['auth'] = (args.username, args.password)
261
- if args.server_port:
262
- launch_kwargs['server_port'] = args.server_port
263
- if args.inbrowser:
264
- launch_kwargs['inbrowser'] = args.inbrowser
265
- if args.share:
266
- launch_kwargs['share'] = args.share
267
-
268
- # Show the interface
269
- ui_full(launch_kwargs)
 
1
  import os
2
  import gradio as gr
3
  from scipy.io.wavfile import write
4
+ import subprocess
 
 
 
 
 
 
5
  import torch
6
+ import typing as tp
7
 
8
  from audiocraft.data.audio_utils import convert_audio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ # Import the necessary MusicGen code here
11
+
12
+ def load_model():
13
+ # Load the MusicGen model here
14
+
15
+ def music_gen_and_separation(audio):
16
+ # Perform music generation with the loaded MusicGen model
17
+ texts = [...] # Provide the desired texts for music generation
18
+ melodies = [(audio[1], audio[0])] # Convert audio to melody format for MusicGen
19
+
20
+ # Perform music generation using the loaded MusicGen model
21
+ generated_music = predict_full(model, texts, melodies, duration, topk, topp, temperature, cfg_coef)
22
+
23
+ # Perform source separation using Demucs
24
+ # Save the generated music to a temporary file
25
+ temp_file = "generated_music.wav"
26
+ write(temp_file, generated_music, 32000)
27
+
28
+ # Run Demucs for source separation
29
+ command = "python3 -m demucs.separate -n mdx_extra_q -d cpu " + temp_file + " -o out"
30
+ process = subprocess.run(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
31
+ print("Demucs script output:", process.stdout.decode())
32
+
33
+ # Check if files exist before returning
34
+ files = ["./out/mdx_extra_q/test/vocals.wav",
35
+ "./out/mdx_extra_q/test/bass.wav",
36
+ "./out/mdx_extra_q/test/drums.wav",
37
+ "./out/mdx_extra_q/test/other.wav"]
38
+ for file in files:
39
+ if not os.path.isfile(file):
40
+ print(f"File not found: {file}")
 
 
 
 
 
 
 
 
 
41
  else:
42
+ print(f"File exists: {file}")
43
+
44
+ # Convert the separated audio files to numpy arrays
45
+ separated_audio = []
46
+ for file in files:
47
+ _, audio = read(file)
48
+ separated_audio.append(audio)
49
+
50
+ return separated_audio
51
+
52
+
53
+ title = "MusicGen with Demucs"
54
+ description = "Combine MusicGen with Demucs for music generation and source separation."
55
+ article = "<p>Article content goes here.</p>"
56
+
57
+ gr.Interface(
58
+ music_gen_and_separation,
59
+ gr.inputs.Audio(label="Input"),
60
+ [gr.outputs.Audio(label="Vocals"),
61
+ gr.outputs.Audio(label="Bass"),
62
+ gr.outputs.Audio(label="Drums"),
63
+ gr.outputs.Audio(label="Other")],
64
+ title=title,
65
+ description=description,
66
+ article=article
67
+ ).launch()