chuks-cmu's picture
Fixed the audio length to 30 seconds
5927ed7 verified
raw
history blame
No virus
2.02 kB
import gradio as gr
import torch
import os
import uuid
import torchaudio
from einops import rearrange
from stable_audio_tools import get_pretrained_model
from stable_audio_tools.inference.generation import generate_diffusion_cond
def gen_music(description):
device = "cuda" if torch.cuda.is_available() else "cpu"
# Fetch the Hugging Face token from the environment variable
hf_token = os.getenv('HF_TOKEN')
print(f"Hugging Face token: {hf_token}")
# Download model
model, model_config = get_pretrained_model("stabilityai/stable-audio-open-1.0")
sample_rate = model_config["sample_rate"]
sample_size = model_config["sample_size"]
model = model.to(device)
# Set up text and timing conditioning
conditioning = [{
"prompt": f"{description}",
"seconds_start": 0,
"seconds_total": 30
}]
# Generate stereo audio
output = generate_diffusion_cond(
model,
conditioning=conditioning,
sample_size=sample_size,
device=device
)
# Rearrange audio batch to a single sequence
output = rearrange(output, "b d n -> d (b n)")
# Peak normalize, clip, convert to int16, and save to file
output = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
# Generate a unique filename for the output
unique_filename = f"output_{uuid.uuid4().hex}.wav"
print(f"Saving audio to file: {unique_filename}")
# Save to file
torchaudio.save(unique_filename, output, sample_rate)
print(f"Audio saved: {unique_filename}")
# Return the path to the generated audio file
return unique_filename
# Define a interface Gradio
description = gr.Textbox(label="Description", placeholder="128 BPM tech house drum loop")
output_path = gr.Audio(label="Generated Music", type="filepath")
gr.Interface(
fn=gen_music,
inputs=[description],
outputs=output_path,
title="StableAudio Music Generation Demo",
).launch()