Commit
9639ae1
1 Parent(s): 1e126d1

Reuse a working space to make it work (#1)

Browse files

- Reuse a working space to make it work (d004e30cb351fcdc74a023fd09c588884b21f0de)


Co-authored-by: Fabrice TIERCELIN <[email protected]>

Files changed (1) hide show
  1. app.py +107 -43
app.py CHANGED
@@ -1,44 +1,59 @@
1
- import spaces
2
  import torch
3
  import torchaudio
4
  from einops import rearrange
 
 
 
 
 
 
5
  from stable_audio_tools import get_pretrained_model
6
  from stable_audio_tools.inference.generation import generate_diffusion_cond
7
- import os
8
 
9
- # Load model config from stable-audio-tools
10
- model, model_config = get_pretrained_model(
11
- "stabilityai/stable-audio-open-1.0", config_filename="model_config.json"
12
- )
13
- sample_rate = model_config["sample_rate"]
14
- sample_size = model_config["sample_size"]
 
 
 
 
 
 
 
 
 
15
 
16
- # Load the model using the transformers library
17
- token = os.environ.get("TOKEN")
18
- model = AutoModelForAudioClassification.from_pretrained(
19
- "stabilityai/stable-audio-open-1.0", use_auth_token=token, cache_dir=None
20
- )
21
 
22
- device = "cuda" if torch.cuda.is_available() else "cpu"
23
- model = model.to(device)
 
 
24
 
25
- # --- Gradio App ---
26
 
27
- def generate_music(prompt, seconds_total, bpm, genre):
28
- """Generates music from a prompt using Stable Diffusion."""
29
 
30
  # Set up text and timing conditioning
31
  conditioning = [{
32
- "prompt": f"{bpm} BPM {genre} {prompt}",
33
  "seconds_start": 0,
34
  "seconds_total": seconds_total
35
  }]
 
36
 
37
  # Generate stereo audio
 
38
  output = generate_diffusion_cond(
39
  model,
40
- steps=100,
41
- cfg_scale=7,
42
  conditioning=conditioning,
43
  sample_size=sample_size,
44
  sigma_min=0.3,
@@ -46,37 +61,86 @@ def generate_music(prompt, seconds_total, bpm, genre):
46
  sampler_type="dpmpp-3m-sde",
47
  device=device
48
  )
 
49
 
50
  # Rearrange audio batch to a single sequence
51
  output = rearrange(output, "b d n -> d (b n)")
 
52
 
53
- # Peak normalize, clip, convert to int16, and save to file
54
  output = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
55
- return output
56
 
57
- @spaces.GPU(duration=120)
58
- def generate_music_and_save(prompt, seconds_total, bpm, genre):
59
- """Generates music, saves it to a file, and returns the file path."""
60
 
61
- output = generate_music(prompt, seconds_total, bpm, genre)
62
- filename = "output.wav"
63
- torchaudio.save(filename, output, sample_rate)
64
- return filename
65
 
66
- # Create Gradio interface
67
- iface = spaces.Interface(
68
- generate_music_and_save,
 
 
 
69
  inputs=[
70
- spaces.Textbox(label="Prompt (e.g., 'upbeat drum loop')", lines=1),
71
- spaces.Slider(label="Duration (seconds)", minimum=1, maximum=60, step=1),
72
- spaces.Slider(label="BPM", minimum=60, maximum=200, step=1),
73
- spaces.Dropdown(label="Genre", choices=["pop", "rock", "hip hop", "electronic", "classical"], value="pop")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  ],
75
- outputs=[
76
- spaces.Audio(label="Generated Music")
 
 
 
 
 
77
  ],
78
- title="Stable Audio Open",
79
- description="Generate music from text prompts using Stable Audio."
80
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
- iface.launch(share=True)
 
 
 
1
  import torch
2
  import torchaudio
3
  from einops import rearrange
4
+ import gradio as gr
5
+ import spaces
6
+ import os
7
+ import uuid
8
+
9
+ # Importing the model-related functions
10
  from stable_audio_tools import get_pretrained_model
11
  from stable_audio_tools.inference.generation import generate_diffusion_cond
 
12
 
13
+ # Load the model outside of the GPU-decorated function
14
+ def load_model():
15
+ print("Loading model...")
16
+ model, model_config = get_pretrained_model("stabilityai/stable-audio-open-1.0")
17
+ print("Model loaded successfully.")
18
+ return model, model_config
19
+
20
+ # Function to set up, generate, and process the audio
21
+ @spaces.GPU(duration=120) # Allocate GPU only when this function is called
22
+ def generate_audio(prompt, seconds_total=30, steps=100, cfg_scale=7):
23
+ print(f"Prompt received: {prompt}")
24
+ print(f"Settings: Duration={seconds_total}s, Steps={steps}, CFG Scale={cfg_scale}")
25
+
26
+ device = "cuda" if torch.cuda.is_available() else "cpu"
27
+ print(f"Using device: {device}")
28
 
29
+ # Fetch the Hugging Face token from the environment variable
30
+ hf_token = os.getenv('HF_TOKEN')
31
+ print(f"Hugging Face token: {hf_token}")
 
 
32
 
33
+ # Use pre-loaded model and configuration
34
+ model, model_config = load_model()
35
+ sample_rate = model_config["sample_rate"]
36
+ sample_size = model_config["sample_size"]
37
 
38
+ print(f"Sample rate: {sample_rate}, Sample size: {sample_size}")
39
 
40
+ model = model.to(device)
41
+ print("Model moved to device.")
42
 
43
  # Set up text and timing conditioning
44
  conditioning = [{
45
+ "prompt": prompt,
46
  "seconds_start": 0,
47
  "seconds_total": seconds_total
48
  }]
49
+ print(f"Conditioning: {conditioning}")
50
 
51
  # Generate stereo audio
52
+ print("Generating audio...")
53
  output = generate_diffusion_cond(
54
  model,
55
+ steps=steps,
56
+ cfg_scale=cfg_scale,
57
  conditioning=conditioning,
58
  sample_size=sample_size,
59
  sigma_min=0.3,
 
61
  sampler_type="dpmpp-3m-sde",
62
  device=device
63
  )
64
+ print("Audio generated.")
65
 
66
  # Rearrange audio batch to a single sequence
67
  output = rearrange(output, "b d n -> d (b n)")
68
+ print("Audio rearranged.")
69
 
70
+ # Peak normalize, clip, convert to int16
71
  output = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
72
+ print("Audio normalized and converted.")
73
 
74
+ # Generate a unique filename for the output
75
+ unique_filename = f"output_{uuid.uuid4().hex}.wav"
76
+ print(f"Saving audio to file: {unique_filename}")
77
 
78
+ # Save to file
79
+ torchaudio.save(unique_filename, output, sample_rate)
80
+ print(f"Audio saved: {unique_filename}")
 
81
 
82
+ # Return the path to the generated audio file
83
+ return unique_filename
84
+
85
+ # Setting up the Gradio Interface
86
+ interface = gr.Interface(
87
+ fn=generate_audio,
88
  inputs=[
89
+ gr.Textbox(label="Prompt", placeholder="Enter your text prompt here"),
90
+ gr.Slider(0, 47, value=30, label="Duration in Seconds"),
91
+ gr.Slider(10, 150, value=100, step=10, label="Number of Diffusion Steps"),
92
+ gr.Slider(1, 15, value=7, step=0.1, label="CFG Scale")
93
+ ],
94
+ outputs=gr.Audio(type="filepath", label="Generated Audio"),
95
+ title="Stable Audio Generator",
96
+ description="Generate variable-length stereo audio at 44.1kHz from text prompts using Stable Audio Open 1.0.",
97
+ examples=[
98
+ [
99
+ "Create a serene soundscape of a quiet beach at sunset.", # Text prompt
100
+
101
+ 45, # Duration in Seconds
102
+ 100, # Number of Diffusion Steps
103
+ 10, # CFG Scale
104
+ ],
105
+ [
106
+ "Generate an energetic and bustling city street scene with distant traffic and close conversations.", # Text prompt
107
+
108
+ 30, # Duration in Seconds
109
+ 120, # Number of Diffusion Steps
110
+ 5, # CFG Scale
111
+ ],
112
+ [
113
+ "Simulate a forest ambiance with birds chirping and wind rustling through the leaves.", # Text prompt
114
+ 60, # Duration in Seconds
115
+ 140, # Number of Diffusion Steps
116
+ 7.5, # CFG Scale
117
  ],
118
+ [
119
+ "Recreate a gentle rainfall with distant thunder.", # Text prompt
120
+
121
+ 35, # Duration in Seconds
122
+ 110, # Number of Diffusion Steps
123
+ 8, # CFG Scale
124
+
125
  ],
126
+ [
127
+ "Imagine a jazz cafe environment with soft music and ambient chatter.", # Text prompt
128
+ 25, # Duration in Seconds
129
+ 90, # Number of Diffusion Steps
130
+ 6, # CFG Scale
131
+
132
+ ],
133
+ ["Rock beat played in a treated studio, session drumming on an acoustic kit.",
134
+ 30, # Duration in Seconds
135
+ 100, # Number of Diffusion Steps
136
+ 7, # CFG Scale
137
+
138
+ ]
139
+ ])
140
+
141
+
142
+ # Pre-load the model to avoid multiprocessing issues
143
+ model, model_config = load_model()
144
 
145
+ # Launch the Interface
146
+ interface.launch()