animikhaich commited on
Commit
8a2882e
1 Parent(s): d50bd1e

Added Audio Generator - Working, Tested

Browse files
Files changed (2) hide show
  1. .gitignore +4 -1
  2. engine/audio_generator.py +71 -112
.gitignore CHANGED
@@ -167,4 +167,7 @@ cython_debug/
167
  *.mp3
168
  *.mp4
169
 
170
- creds.json
 
 
 
 
167
  *.mp3
168
  *.mp4
169
 
170
+ creds.json
171
+
172
+ # Ignore the test file
173
+ test.py
engine/audio_generator.py CHANGED
@@ -3,6 +3,7 @@ import warnings
3
 
4
  warnings.simplefilter("ignore")
5
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
 
6
  import torch
7
  import numpy as np
8
  from audiocraft.models import musicgen
@@ -19,6 +20,8 @@ class GenerateAudio:
19
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
20
  self.model_name = self.get_model_name(model)
21
  self.model = self.get_model(self.model_name, self.device)
 
 
22
 
23
  @staticmethod
24
  def get_model(model, device):
@@ -36,127 +39,83 @@ class GenerateAudio:
36
  if model_name.startswith("facebook/"):
37
  return model_name
38
  return f"facebook/{model_name}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- def generate_audio(self, prompts, duration=30):
41
  try:
42
  self.model.set_generation_params(duration=duration)
43
  result = self.model.generate(prompts, progress=False)
44
- result = result.squeeze().cpu().numpy().T
45
- sample_rate = self.model.sample_rate
 
46
  logging.info(
47
- f"Generated audio with shape: {result.shape}, sample rate: {sample_rate} Hz"
48
  )
49
- return sample_rate, result
 
50
  except Exception as e:
51
  logging.error(f"Failed to generate audio: {e}")
52
  raise ValueError(f"Failed to generate audio: {e}")
53
-
54
-
55
-
56
-
57
- # Parse command line arguments
58
- parser = argparse.ArgumentParser(description="Music Generation Server")
59
- parser.add_argument(
60
- "--model", type=str, default="musicgen-stereo-small", help="Pretrained model name"
61
- )
62
- parser.add_argument(
63
- "--device", type=str, default="cuda", help="Device to load the model on"
64
- )
65
- parser.add_argument(
66
- "--duration", type=int, default=10, help="Duration of generated music in seconds"
67
- )
68
- parser.add_argument(
69
- "--host", type=str, default="0.0.0.0", help="Host to run the server on"
70
- )
71
- parser.add_argument("--port", type=int, default=8000, help="Port to run the server on")
72
-
73
- args = parser.parse_args()
74
-
75
-
76
- # Initialize the FastAPI app
77
- app = FastAPI()
78
-
79
- # Build the model name based on the provided arguments
80
- if args.model.startswith("facebook/"):
81
- args.model_name = args.model
82
- else:
83
- args.model_name = f"facebook/{args.model}"
84
-
85
-
86
- logging.info(f"Initializing Model Server with Settings: {args}")
87
-
88
- # Load the model with the provided arguments
89
- try:
90
- musicgen_model = musicgen.MusicGen.get_pretrained(
91
- args.model_name, device=args.device
92
- )
93
- model_loaded = True
94
- logging.info(f"Model Loaded: {args.model_name}")
95
- except Exception as e:
96
- logging.error(f"Failed to load model: {e}")
97
- musicgen_model = None
98
- model_loaded = False
99
-
100
-
101
- class MusicRequest(BaseModel):
102
- prompts: List[str]
103
- duration: Optional[int] = 10 # Default duration is 10 seconds if not provided
104
-
105
-
106
- @app.get("/generate_music")
107
- def generate_music(request: MusicRequest):
108
-
109
- if not model_loaded:
110
- raise HTTPException(status_code=500, detail="Model is not loaded.")
111
-
112
- try:
113
- logging.info(
114
- f"Generating music with prompts: {request.prompts}, duration: {request.duration} seconds"
115
- )
116
-
117
- musicgen_model.set_generation_params(duration=request.duration)
118
- result = musicgen_model.generate(request.prompts, progress=False)
119
- result = result.squeeze().cpu().numpy().T
120
-
121
- sample_rate = musicgen_model.sample_rate
122
-
123
- logging.info(
124
- f"Music generated with shape: {result.shape}, sample rate: {sample_rate} Hz"
125
- )
126
-
127
- buffer = io.BytesIO()
128
- wav_write(buffer, sample_rate, result)
129
- buffer.seek(0)
130
- return StreamingResponse(buffer, media_type="audio/wav")
131
- except Exception as e:
132
- logging.error(f"Failed to generate music: {e}")
133
- raise HTTPException(status_code=500, detail=str(e))
134
-
135
-
136
- @app.get("/health")
137
- def health_check():
138
- cpu_usage = psutil.cpu_percent(interval=1)
139
- ram_usage = psutil.virtual_memory().percent
140
- stats = {
141
- "server_running": True,
142
- "model_loaded": model_loaded,
143
- "cpu_usage_percent": cpu_usage,
144
- "ram_usage_percent": ram_usage,
145
- }
146
- if args.device == "cuda" and torch.cuda.is_available():
147
- gpu_memory_allocated = memory_allocated()
148
- gpu_memory_reserved = memory_reserved()
149
- stats.update(
150
- {
151
- "gpu_memory_allocated": gpu_memory_allocated,
152
- "gpu_memory_reserved": gpu_memory_reserved,
153
- }
154
- )
155
-
156
- logging.info(f"Health Check: {stats}")
157
-
158
- return JSONResponse(content=stats)
159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
  if __name__ == "__main__":
162
- uvicorn.run("main:app", host=args.host, port=args.port, reload=False, workers=1)
 
 
 
 
 
 
3
 
4
  warnings.simplefilter("ignore")
5
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
6
+ import io
7
  import torch
8
  import numpy as np
9
  from audiocraft.models import musicgen
 
20
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
21
  self.model_name = self.get_model_name(model)
22
  self.model = self.get_model(self.model_name, self.device)
23
+ self.generated_audio = None
24
+ self.sampling_rate = None
25
 
26
  @staticmethod
27
  def get_model(model, device):
 
39
  if model_name.startswith("facebook/"):
40
  return model_name
41
  return f"facebook/{model_name}"
42
+
43
+ @staticmethod
44
+ def duration_sanity_check(duration):
45
+ if duration < 1:
46
+ logging.warning("Duration is less than 1 second. Setting duration to 1 second.")
47
+ return 1
48
+ elif duration > 30:
49
+ logging.warning("Duration is greater than 30 seconds. Setting duration to 30 seconds.")
50
+ return 30
51
+ return duration
52
+
53
+ @staticmethod
54
+ def prompts_sanity_check(prompts):
55
+ if isinstance(prompts, str):
56
+ prompts = [prompts]
57
+ elif not isinstance(prompts, list):
58
+ raise ValueError("Prompts should be a string or a list of strings.")
59
+ else:
60
+ for prompt in prompts:
61
+ if not isinstance(prompt, str):
62
+ raise ValueError("Prompts should be a string or a list of strings.")
63
+ if len(prompts) > 8: # Too many prompts will cause OOM error
64
+ raise ValueError("Maximum number of prompts allowed is 8.")
65
+ return prompts
66
+
67
+
68
+ def generate_audio(self, prompts, duration=10):
69
+ duration = self.duration_sanity_check(duration)
70
+ prompts = self.prompts_sanity_check(prompts)
71
 
 
72
  try:
73
  self.model.set_generation_params(duration=duration)
74
  result = self.model.generate(prompts, progress=False)
75
+ self.result = result.cpu().numpy().T
76
+ self.result = self.result.transpose((2, 0, 1))
77
+ self.sampling_rate = self.model.sample_rate
78
  logging.info(
79
+ f"Generated audio with shape: {self.result.shape}, sample rate: {self.sampling_rate} Hz"
80
  )
81
+ print(f"Generated audio with shape: {self.result.shape}, sample rate: {self.sampling_rate} Hz")
82
+ return self.sampling_rate, self.result
83
  except Exception as e:
84
  logging.error(f"Failed to generate audio: {e}")
85
  raise ValueError(f"Failed to generate audio: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
+ def save_audio(self, audio_dir="generated_audio"):
88
+ if self.result is None:
89
+ raise ValueError("Audio is not generated yet.")
90
+ if self.sampling_rate is None:
91
+ raise ValueError("Sampling rate is not available.")
92
+
93
+ paths = []
94
+ os.makedirs(audio_dir, exist_ok=True)
95
+ for i, audio in enumerate(self.result):
96
+ path = os.path.join(audio_dir, f"audio_{i}.wav")
97
+ wav_write(path, self.sampling_rate, audio)
98
+ paths.append(path)
99
+ return paths
100
+
101
+ def get_audio_buffer(self):
102
+ if self.result is None:
103
+ raise ValueError("Audio is not generated yet.")
104
+ if self.sampling_rate is None:
105
+ raise ValueError("Sampling rate is not available.")
106
+
107
+ buffers = []
108
+ for audio in self.result:
109
+ buffer = io.BytesIO()
110
+ wav_write(buffer, self.sampling_rate, audio)
111
+ buffer.seek(0)
112
+ buffers.append(buffer)
113
+ return buffers
114
 
115
  if __name__ == "__main__":
116
+ audio_gen = GenerateAudio()
117
+ sample_rate, result = audio_gen.generate_audio(["A piano playing a jazz melody", "A guitar playing a rock riff", "A LoFi music for coding"], duration=10)
118
+ paths = audio_gen.save_audio()
119
+ print(f"Saved audio to: {paths}")
120
+ buffers = audio_gen.get_audio_buffer()
121
+ print(f"Audio buffers: {buffers}")