hertz-dev / app.py
calculating
committing...
824afbf
raw
history blame
8.63 kB
import gradio as gr
import torch as T
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import matplotlib.pyplot as plt
from utils import load_ckpt, print_colored
from tokenizer import make_tokenizer
from model import get_hertz_dev_config
from typing import Tuple
import numpy as np
import os
# Global variables for model and tokenizer
global_generator = None
global_tokenizer = None
default_audio_path = "testingtesting.wav" # Your default audio file
def init_model(use_pure_audio_ablation: bool = False) -> Tuple[nn.Module, object]:
"""Initialize the model and tokenizer"""
global global_generator, global_tokenizer
if global_generator is not None and global_tokenizer is not None:
return global_generator, global_tokenizer
device = 'cuda' if T.cuda.is_available() else 'cpu'
T.cuda.set_device(0) if device == 'cuda' else None
print_colored("Initializing model and tokenizer...", "blue")
global_tokenizer = make_tokenizer(device)
model_config = get_hertz_dev_config(is_split=False, use_pure_audio_ablation=use_pure_audio_ablation)
global_generator = model_config()
global_generator = global_generator.eval().to(T.bfloat16).to(device)
print_colored("Model initialization complete!", "green")
return global_generator, global_tokenizer
def process_audio(audio_path: str, sr: int) -> T.Tensor:
"""Load and preprocess audio file"""
audio_tensor, sr = torchaudio.load(audio_path)
if audio_tensor.shape[0] == 2:
audio_tensor = audio_tensor.mean(dim=0).unsqueeze(0)
if sr != 16000:
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)
audio_tensor = resampler(audio_tensor)
max_samples = 16000 * 60 * 5 # 5 minutes
if audio_tensor.shape[1] > max_samples:
audio_tensor = audio_tensor[:, :max_samples]
return audio_tensor.unsqueeze(0)
def generate_completion(
audio_file,
prompt_len_seconds: float = 3.0,
num_completions: int = 5,
generation_seconds: float = 20.0,
token_temp: float = 0.8,
categorical_temp: float = 0.5,
gaussian_temp: float = 0.1,
progress=gr.Progress(track_tqdm=True)
) -> list:
"""Generate audio completions from the input audio"""
device = 'cuda' if T.cuda.is_available() else 'cpu'
# Use existing model and tokenizer
generator, audio_tokenizer = global_generator, global_tokenizer
progress(0, desc="Processing input audio...")
# Process input audio
prompt_audio = process_audio(audio_file, sr=16000)
prompt_len = int(prompt_len_seconds * 8)
progress(0.2, desc="Encoding prompt...")
# Encode prompt
with T.autocast(device_type='cuda', dtype=T.bfloat16):
encoded_prompt_audio = audio_tokenizer.latent_from_data(prompt_audio.to(device))
completions = []
for i in range(num_completions):
progress((i + 1) / num_completions, desc=f"Generating completion {i+1}/{num_completions}")
# Generate completion
encoded_prompt = encoded_prompt_audio[:, :prompt_len]
with T.autocast(device_type='cuda', dtype=T.bfloat16):
completed_audio_batch = generator.completion(
encoded_prompt,
temps=(token_temp, (categorical_temp, gaussian_temp)),
use_cache=True,
gen_len=int(generation_seconds * 8)
)
decoded_completion = audio_tokenizer.data_from_latent(completed_audio_batch.bfloat16())
# Process audio for output
audio_tensor = decoded_completion.cpu().squeeze()
if audio_tensor.ndim == 1:
audio_tensor = audio_tensor.unsqueeze(0)
audio_tensor = audio_tensor.float()
if audio_tensor.abs().max() > 1:
audio_tensor = audio_tensor / audio_tensor.abs().max()
# Trim to include only the generated portion
output_audio = audio_tensor[:, max(prompt_len*2000 - 16000, 0):]
completions.append((16000, output_audio.numpy().T))
progress(1.0, desc="Generation complete!")
return completions
def create_interface():
# Initialize model at startup
init_model()
with gr.Blocks(title="Audio Completion Generator") as app:
gr.Markdown("""
# Audio Completion Generator
Upload an audio file (or use the default) and generate AI completions based on the prompt.
""")
with gr.Row():
with gr.Column():
# Load the default audio if it exists
default_value = default_audio_path if os.path.exists(default_audio_path) else None
audio_input = gr.Audio(
label="Input Audio",
type="filepath",
sources=["microphone", "upload"],
value=default_value
)
with gr.Row():
prompt_len = gr.Slider(
minimum=1,
maximum=10,
value=3,
step=0.5,
label="Prompt Length (seconds)"
)
default_num_completions = 5
num_completions = gr.Slider(
minimum=1,
maximum=10,
value=default_num_completions,
step=1,
label="Number of Completions"
)
gen_length = gr.Slider(
minimum=5,
maximum=60,
value=20,
step=5,
label="Generation Length (seconds)"
)
with gr.Row():
token_temp = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.8,
step=0.1,
label="Token Temperature"
)
cat_temp = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.5,
step=0.1,
label="Categorical Temperature"
)
gauss_temp = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.1,
step=0.1,
label="Gaussian Temperature"
)
generate_btn = gr.Button("Generate Completions")
status_text = gr.Markdown("Ready")
with gr.Column():
output_audios = []
for i in range(10): # Create 10 audio components
output_audios.append(gr.Audio(
label=f"Generated Completion {i+1}",
type="numpy",
visible=False
))
def update_visibility(num):
return [gr.update(visible=(i < num)) for i in range(10)]
def generate_with_status(*args):
status_text.value = "Processing input audio..."
completions = generate_completion(*args)
status_text.value = "Generation complete!"
# Prepare outputs for all audio components
outputs = []
for i in range(10):
if i < len(completions):
outputs.append(completions[i])
else:
outputs.append(None)
return outputs
# Set initial visibility on load
app.load(
fn=update_visibility,
inputs=[num_completions],
outputs=output_audios
)
# Update visibility when slider changes
num_completions.change(
fn=update_visibility,
inputs=[num_completions],
outputs=output_audios
)
generate_btn.click(
fn=generate_with_status,
inputs=[
audio_input,
prompt_len,
num_completions,
gen_length,
token_temp,
cat_temp,
gauss_temp
],
outputs=output_audios
)
return app
if __name__ == "__main__":
app = create_interface()
app.launch(share=True)