|
import gradio as gr |
|
import torch as T |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torchaudio |
|
import matplotlib.pyplot as plt |
|
from utils import load_ckpt, print_colored |
|
from tokenizer import make_tokenizer |
|
from model import get_hertz_dev_config |
|
from typing import Tuple |
|
import numpy as np |
|
import os |
|
|
|
|
|
global_generator = None |
|
global_tokenizer = None |
|
default_audio_path = "testingtesting.wav" |
|
|
|
def init_model(use_pure_audio_ablation: bool = False) -> Tuple[nn.Module, object]: |
|
"""Initialize the model and tokenizer""" |
|
global global_generator, global_tokenizer |
|
|
|
if global_generator is not None and global_tokenizer is not None: |
|
return global_generator, global_tokenizer |
|
|
|
device = 'cuda' if T.cuda.is_available() else 'cpu' |
|
T.cuda.set_device(0) if device == 'cuda' else None |
|
|
|
print_colored("Initializing model and tokenizer...", "blue") |
|
global_tokenizer = make_tokenizer(device) |
|
model_config = get_hertz_dev_config(is_split=False, use_pure_audio_ablation=use_pure_audio_ablation) |
|
|
|
global_generator = model_config() |
|
global_generator = global_generator.eval().to(T.bfloat16).to(device) |
|
print_colored("Model initialization complete!", "green") |
|
|
|
return global_generator, global_tokenizer |
|
|
|
def process_audio(audio_path: str, sr: int) -> T.Tensor: |
|
"""Load and preprocess audio file""" |
|
audio_tensor, sr = torchaudio.load(audio_path) |
|
|
|
|
|
if audio_tensor.shape[0] == 2: |
|
audio_tensor = audio_tensor.mean(dim=0).unsqueeze(0) |
|
|
|
if sr != 16000: |
|
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000) |
|
audio_tensor = resampler(audio_tensor) |
|
|
|
max_samples = 16000 * 60 * 5 |
|
if audio_tensor.shape[1] > max_samples: |
|
audio_tensor = audio_tensor[:, :max_samples] |
|
|
|
return audio_tensor.unsqueeze(0) |
|
|
|
def generate_completion( |
|
audio_file, |
|
prompt_len_seconds: float = 3.0, |
|
num_completions: int = 5, |
|
generation_seconds: float = 20.0, |
|
token_temp: float = 0.8, |
|
categorical_temp: float = 0.5, |
|
gaussian_temp: float = 0.1, |
|
progress=gr.Progress(track_tqdm=True) |
|
) -> list: |
|
"""Generate audio completions from the input audio""" |
|
device = 'cuda' if T.cuda.is_available() else 'cpu' |
|
|
|
|
|
generator, audio_tokenizer = global_generator, global_tokenizer |
|
|
|
progress(0, desc="Processing input audio...") |
|
|
|
prompt_audio = process_audio(audio_file, sr=16000) |
|
prompt_len = int(prompt_len_seconds * 8) |
|
|
|
progress(0.2, desc="Encoding prompt...") |
|
|
|
with T.autocast(device_type='cuda', dtype=T.bfloat16): |
|
encoded_prompt_audio = audio_tokenizer.latent_from_data(prompt_audio.to(device)) |
|
|
|
completions = [] |
|
for i in range(num_completions): |
|
progress((i + 1) / num_completions, desc=f"Generating completion {i+1}/{num_completions}") |
|
|
|
|
|
encoded_prompt = encoded_prompt_audio[:, :prompt_len] |
|
with T.autocast(device_type='cuda', dtype=T.bfloat16): |
|
completed_audio_batch = generator.completion( |
|
encoded_prompt, |
|
temps=(token_temp, (categorical_temp, gaussian_temp)), |
|
use_cache=True, |
|
gen_len=int(generation_seconds * 8) |
|
) |
|
|
|
decoded_completion = audio_tokenizer.data_from_latent(completed_audio_batch.bfloat16()) |
|
|
|
|
|
audio_tensor = decoded_completion.cpu().squeeze() |
|
if audio_tensor.ndim == 1: |
|
audio_tensor = audio_tensor.unsqueeze(0) |
|
audio_tensor = audio_tensor.float() |
|
|
|
if audio_tensor.abs().max() > 1: |
|
audio_tensor = audio_tensor / audio_tensor.abs().max() |
|
|
|
|
|
output_audio = audio_tensor[:, max(prompt_len*2000 - 16000, 0):] |
|
completions.append((16000, output_audio.numpy().T)) |
|
|
|
progress(1.0, desc="Generation complete!") |
|
return completions |
|
|
|
def create_interface(): |
|
|
|
init_model() |
|
|
|
with gr.Blocks(title="Audio Completion Generator") as app: |
|
gr.Markdown(""" |
|
# Audio Completion Generator |
|
Upload an audio file (or use the default) and generate AI completions based on the prompt. |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
|
|
default_value = default_audio_path if os.path.exists(default_audio_path) else None |
|
|
|
audio_input = gr.Audio( |
|
label="Input Audio", |
|
type="filepath", |
|
sources=["microphone", "upload"], |
|
value=default_value |
|
) |
|
|
|
with gr.Row(): |
|
prompt_len = gr.Slider( |
|
minimum=1, |
|
maximum=10, |
|
value=3, |
|
step=0.5, |
|
label="Prompt Length (seconds)" |
|
) |
|
default_num_completions = 5 |
|
num_completions = gr.Slider( |
|
minimum=1, |
|
maximum=10, |
|
value=default_num_completions, |
|
step=1, |
|
label="Number of Completions" |
|
) |
|
gen_length = gr.Slider( |
|
minimum=5, |
|
maximum=60, |
|
value=20, |
|
step=5, |
|
label="Generation Length (seconds)" |
|
) |
|
|
|
with gr.Row(): |
|
token_temp = gr.Slider( |
|
minimum=0.1, |
|
maximum=1.0, |
|
value=0.8, |
|
step=0.1, |
|
label="Token Temperature" |
|
) |
|
cat_temp = gr.Slider( |
|
minimum=0.1, |
|
maximum=1.0, |
|
value=0.5, |
|
step=0.1, |
|
label="Categorical Temperature" |
|
) |
|
gauss_temp = gr.Slider( |
|
minimum=0.1, |
|
maximum=1.0, |
|
value=0.1, |
|
step=0.1, |
|
label="Gaussian Temperature" |
|
) |
|
|
|
generate_btn = gr.Button("Generate Completions") |
|
status_text = gr.Markdown("Ready") |
|
|
|
with gr.Column(): |
|
output_audios = [] |
|
for i in range(10): |
|
output_audios.append(gr.Audio( |
|
label=f"Generated Completion {i+1}", |
|
type="numpy", |
|
visible=False |
|
)) |
|
|
|
def update_visibility(num): |
|
return [gr.update(visible=(i < num)) for i in range(10)] |
|
|
|
def generate_with_status(*args): |
|
status_text.value = "Processing input audio..." |
|
completions = generate_completion(*args) |
|
status_text.value = "Generation complete!" |
|
|
|
|
|
outputs = [] |
|
for i in range(10): |
|
if i < len(completions): |
|
outputs.append(completions[i]) |
|
else: |
|
outputs.append(None) |
|
return outputs |
|
|
|
|
|
app.load( |
|
fn=update_visibility, |
|
inputs=[num_completions], |
|
outputs=output_audios |
|
) |
|
|
|
|
|
num_completions.change( |
|
fn=update_visibility, |
|
inputs=[num_completions], |
|
outputs=output_audios |
|
) |
|
|
|
generate_btn.click( |
|
fn=generate_with_status, |
|
inputs=[ |
|
audio_input, |
|
prompt_len, |
|
num_completions, |
|
gen_length, |
|
token_temp, |
|
cat_temp, |
|
gauss_temp |
|
], |
|
outputs=output_audios |
|
) |
|
|
|
return app |
|
|
|
if __name__ == "__main__": |
|
app = create_interface() |
|
app.launch(share=True) |