File size: 5,339 Bytes
e06cbbd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import torch
import torchaudio
from einops import rearrange
from stable_audio_tools import get_pretrained_model
from stable_audio_tools.inference.generation import generate_diffusion_cond
from pydub import AudioSegment
import re
import os
from datetime import datetime
import gradio as gr

# Define the function to generate audio based on a prompt
def generate_audio(prompt, steps, cfg_scale, sigma_min, sigma_max, generation_time, seed, sampler_type, model_half):
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Download model
    model, model_config = get_pretrained_model("audo/stable-audio-open-1.0")
    sample_rate = model_config["sample_rate"]
    sample_size = model_config["sample_size"]

    model = model.to(device)
    
    # Print model data type before conversion
    print("Model data type before conversion:", next(model.parameters()).dtype)

    # Convert model to float16 if model_half is True
    if model_half:
        model = model.to(torch.float16)
    
    # Print model data type after conversion
    print("Model data type after conversion:", next(model.parameters()).dtype)

    # Set up text and timing conditioning
    conditioning = [{
        "prompt": prompt,
        "seconds_start": 0,
        "seconds_total": generation_time
    }]

    # Generate stereo audio
    output = generate_diffusion_cond(
        model,
        steps=steps,
        cfg_scale=cfg_scale,
        conditioning=conditioning,
        sample_size=sample_size,
        sigma_min=sigma_min,
        sigma_max=sigma_max,
        sampler_type=sampler_type,
        device=device,
        seed=seed
    )

    # Print output data type
    print("Output data type:", output.dtype)

    # Rearrange audio batch to a single sequence
    output = rearrange(output, "b d n -> d (b n)")

    # Peak normalize, clip, and convert to int16 directly if model_half is used
    output = output.div(torch.max(torch.abs(output))).clamp(-1, 1).mul(32767)
    if model_half:
        output = output.to(torch.int16).cpu()
    else:
        output = output.to(torch.float32).to(torch.int16).cpu()

    torchaudio.save("temp_output.wav", output, sample_rate)

    # Convert to MP3 format using pydub
    audio = AudioSegment.from_wav("temp_output.wav")

    # Create Output folder and dated subfolder if they do not exist
    output_folder = "Output"
    date_folder = datetime.now().strftime("%Y-%m-%d")
    save_path = os.path.join(output_folder, date_folder)
    os.makedirs(save_path, exist_ok=True)

    # Generate a filename based on the prompt
    filename = re.sub(r'\W+', '_', prompt) + ".mp3"  # Replace non-alphanumeric characters with underscores
    full_path = os.path.join(save_path, filename)

    # Ensure the filename is unique by appending a number if the file already exists
    base_filename = filename
    counter = 1
    while os.path.exists(full_path):
        filename = f"{base_filename[:-4]}_{counter}.mp3"
        full_path = os.path.join(save_path, filename)
        counter += 1

    # Export the audio to MP3 format
    audio.export(full_path, format="mp3")

    return full_path

def audio_generator(prompt, sampler_type, steps, cfg_scale, sigma_min, sigma_max, generation_time, seed, model_half):
    try:
        print("Generating audio with parameters:")
        print("Prompt:", prompt)
        print("Sampler Type:", sampler_type)
        print("Steps:", steps)
        print("CFG Scale:", cfg_scale)
        print("Sigma Min:", sigma_min)
        print("Sigma Max:", sigma_max)
        print("Generation Time:", generation_time)
        print("Seed:", seed)
        print("Model Half Precision:", model_half)
        
        filename = generate_audio(prompt, steps, cfg_scale, sigma_min, sigma_max, generation_time, seed, sampler_type, model_half)
        return gr.Audio(filename), f"Generated: {filename}"
    except Exception as e:
        return str(e)

# Create Gradio interface
prompt_textbox = gr.Textbox(lines=5, label="Prompt")
sampler_dropdown = gr.Dropdown(
    label="Sampler Type",
    choices=[
        "dpmpp-3m-sde",
        "dpmpp-2m-sde",
        "k-heun",
        "k-lms",
        "k-dpmpp-2s-ancestral",
        "k-dpm-2",
        "k-dpm-fast"
    ],
    value="dpmpp-3m-sde"
)
steps_slider = gr.Slider(minimum=0, maximum=200, label="Steps", step=1, value=100)
cfg_scale_slider = gr.Slider(minimum=0, maximum=15, label="CFG Scale", step=0.1, value=7)
sigma_min_slider = gr.Slider(minimum=0, maximum=50, label="Sigma Min", step=0.1, value=0.3)
sigma_max_slider = gr.Slider(minimum=0, maximum=1000, label="Sigma Max", step=0.1, value=500)
generation_time_slider = gr.Slider(minimum=0, maximum=47, label="Generation Time (seconds)", step=1, value=47)
seed_slider = gr.Slider(minimum=-1, maximum=999999, label="Seed", step=1, value=123456)
model_half_checkbox = gr.Checkbox(label="Low VRAM (float16)", value=False)

output_textbox = gr.Textbox(label="Output")

title = "πŸ’€πŸ”Š StableAudioWebUI πŸ’€πŸ”Š"
description = "[Github Repository](https://github.com/Saganaki22/StableAudioWebUI)"

gr.Interface(
    audio_generator,
    [prompt_textbox, sampler_dropdown, steps_slider, cfg_scale_slider, sigma_min_slider, sigma_max_slider, generation_time_slider, seed_slider, model_half_checkbox],
    [gr.Audio(), output_textbox],
    title=title,
    description=description
).launch()