SunilGopal commited on
Commit
c2a2b08
1 Parent(s): 6c9e559

Upload 2 files

Browse files
Files changed (2) hide show
  1. musicgen_app.py +148 -0
  2. requirements.txt +28 -0
musicgen_app.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from concurrent.futures import ProcessPoolExecutor
3
+ import logging
4
+ import os
5
+ from pathlib import Path
6
+ import subprocess as sp
7
+ import sys
8
+ from tempfile import NamedTemporaryFile
9
+ import time
10
+ import typing as tp
11
+ import warnings
12
+
13
+ import torch
14
+ import gradio as gr
15
+ from audiocraft.data.audio_utils import convert_audio
16
+ from audiocraft.data.audio import audio_write
17
+ from audiocraft.models import MusicGen
18
+
19
+ MODEL = None # Last used model
20
+ INTERRUPTING = False
21
+ pool = ProcessPoolExecutor(4)
22
+ pool.__enter__()
23
+
24
+
25
+ class FileCleaner:
26
+ def __init__(self, file_lifetime: float = 3600):
27
+ self.file_lifetime = file_lifetime
28
+ self.files = []
29
+
30
+ def add(self, path: tp.Union[str, Path]):
31
+ self._cleanup()
32
+ self.files.append((time.time(), Path(path)))
33
+
34
+ def _cleanup(self):
35
+ now = time.time()
36
+ for time_added, path in list(self.files):
37
+ if now - time_added > self.file_lifetime:
38
+ if path.exists():
39
+ path.unlink()
40
+ self.files.pop(0)
41
+ else:
42
+ break
43
+
44
+
45
+ file_cleaner = FileCleaner()
46
+
47
+
48
+ def load_model(version='facebook/musicgen-small'):
49
+ global MODEL
50
+ print("Loading model", version)
51
+ if MODEL is None or MODEL.name != version:
52
+ del MODEL
53
+ torch.cuda.empty_cache()
54
+ MODEL = None
55
+ MODEL = MusicGen.get_pretrained(version)
56
+
57
+
58
+ def _do_predictions(texts, duration):
59
+ MODEL.set_generation_params(duration=duration)
60
+ outputs = MODEL.generate(texts)
61
+ outputs = outputs.detach().cpu().float()
62
+ out_wavs = []
63
+ for output in outputs:
64
+ with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
65
+ audio_write(
66
+ file.name, output, MODEL.sample_rate, strategy="loudness",
67
+ loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
68
+ out_wavs.append(file.name)
69
+ file_cleaner.add(file.name)
70
+ return out_wavs
71
+
72
+
73
+ def predict(text, duration):
74
+ load_model('facebook/musicgen-small')
75
+ wav_files = _do_predictions([text], duration)
76
+ return wav_files[0] # Return the first file in the list
77
+
78
+
79
+ def ui(launch_kwargs):
80
+ with gr.Blocks() as demo:
81
+ gr.Markdown(
82
+ """
83
+ # MusicGen
84
+ This demo uses the MusicGen model to generate music based on a text prompt.
85
+ """
86
+ )
87
+ with gr.Row():
88
+ text = gr.Text(label="Input Text", interactive=True)
89
+ duration = gr.Slider(minimum=1, maximum=120, value=10, label="Duration", interactive=True)
90
+ submit = gr.Button("Submit")
91
+ with gr.Row():
92
+ audio_output = gr.Audio(label="Generated Music", type='filepath')
93
+ submit.click(predict, inputs=[text, duration], outputs=[audio_output])
94
+
95
+ gr.Markdown("""
96
+ ### More details
97
+
98
+ This model generates audio based on a textual description. You can specify the duration of the generated audio.
99
+ """)
100
+
101
+ demo.queue(max_size=8 * 4).launch(**launch_kwargs)
102
+
103
+
104
+ if __name__ == "__main__":
105
+ parser = argparse.ArgumentParser()
106
+ parser.add_argument(
107
+ '--listen',
108
+ type=str,
109
+ default='0.0.0.0' if 'SPACE_ID' in os.environ else '127.0.0.1',
110
+ help='IP to listen on for connections to Gradio',
111
+ )
112
+ parser.add_argument(
113
+ '--username', type=str, default='', help='Username for authentication'
114
+ )
115
+ parser.add_argument(
116
+ '--password', type=str, default='', help='Password for authentication'
117
+ )
118
+ parser.add_argument(
119
+ '--server_port',
120
+ type=int,
121
+ default=0,
122
+ help='Port to run the server listener on',
123
+ )
124
+ parser.add_argument(
125
+ '--inbrowser', action='store_true', help='Open in browser'
126
+ )
127
+ parser.add_argument(
128
+ '--share', action='store_true', help='Share the gradio UI'
129
+ )
130
+
131
+ args = parser.parse_args()
132
+
133
+ launch_kwargs = {}
134
+ launch_kwargs['server_name'] = args.listen
135
+
136
+ if args.username and args.password:
137
+ launch_kwargs['auth'] = (args.username, args.password)
138
+ if args.server_port:
139
+ launch_kwargs['server_port'] = args.server_port
140
+ if args.inbrowser:
141
+ launch_kwargs['inbrowser'] = args.inbrowser
142
+ if args.share:
143
+ launch_kwargs['share'] = args.share
144
+
145
+ logging.basicConfig(level=logging.INFO, stream=sys.stderr)
146
+
147
+ # Show the interface
148
+ ui(launch_kwargs)
requirements.txt ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # please make sure you have already a pytorch install that is cuda enabled!
2
+ av==11.0.0
3
+ einops
4
+ flashy>=0.0.1
5
+ hydra-core>=1.1
6
+ hydra_colorlog
7
+ julius
8
+ num2words
9
+ numpy<2.0.0
10
+ sentencepiece
11
+ spacy>=3.6.1
12
+ torch==2.1.0
13
+ torchaudio>=2.0.0,<2.1.2
14
+ huggingface_hub
15
+ tqdm
16
+ transformers>=4.31.0 # need Encodec there.
17
+ xformers<0.0.23
18
+ demucs
19
+ librosa
20
+ soundfile
21
+ gradio
22
+ torchmetrics
23
+ encodec
24
+ protobuf
25
+ torchvision==0.16.0
26
+ torchtext==0.16.0
27
+ pesq
28
+ pystoi