PicoAudio / inference.py
ZeyuXie's picture
Upload 5 files
cb0c99a verified
import os
import json
import random
import argparse
import soundfile as sf
import numpy as np
import torch
from diffusers import DDPMScheduler
from pico_model import PicoDiffusion, build_pretrained_models
from llm_preprocess import get_event, preprocess_gemini, preprocess_gpt
class dotdict(dict):
"""dot.notation access to dictionary attributes"""
__getattr__ = dict.get
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
def parse_args():
parser = argparse.ArgumentParser(description="Inference for text to audio generation task.")
parser.add_argument(
"--text", '-t', type=str, default="spraying two times then gunshot three times.",
help="free-text caption."
)
parser.add_argument(
"--timestamp_caption", '-c', type=str,
default=None,
#default="spraying at 0.38-1.176_3.06-3.856 and gunshot at 1.729-3.729_4.367-6.367_7.031-9.031.",
help="timestamp caption, formatted as 'event1 at onset1-offset1_onset2-offset2 and event2 at onset1-offset1'."
)
parser.add_argument(
"--exp_path", '-exp', type=str, default="/hpc_stor03/sjtu_home/zeyu.xie/workspace/controllable_audio_generation/huggingface/ckpts/pico_model",
help="Path for experiment."
)
parser.add_argument(
"--freeze_text_encoder_ckpt", type=str, default='/hpc_stor03/sjtu_home/zeyu.xie/workspace/controllable_audio_generation/huggingface/ckpts/laion_clap/630k-audioset-best.pt',
help="Path for clap."
)
parser.add_argument(
"--seed", type=int, default=0,
help="seed.",
)
args = parser.parse_args()
args.original_args = os.path.join(args.exp_path, "summary.jsonl")
args.diffusion_pt = os.path.join(args.exp_path, "diffusion.pt")
return args
def main():
args = parse_args()
train_args = dotdict(json.loads(open(args.original_args).readlines()[0]))
seed = args.seed
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Step1: preprocess via llm
if args.timestamp_caption == None:
#args.timestamp_caption = preprocess_gpt(args.text)
args.timestamp_caption = preprocess_gemini(args.text)
# Load Models #
print("------Load model")
name = "audioldm-s-full"
vae, stft = build_pretrained_models(name)
vae, stft = vae.cuda(), stft.cuda()
model = PicoDiffusion(
scheduler_name=train_args.scheduler_name,
unet_model_config_path=train_args.unet_model_config,
snr_gamma=train_args.snr_gamma,
freeze_text_encoder_ckpt=args.freeze_text_encoder_ckpt,
diffusion_pt=args.diffusion_pt,
).cuda().eval()
scheduler = DDPMScheduler.from_pretrained(train_args.scheduler_name, subfolder="scheduler")
# Generate #
num_steps, guidance, num_samples, audio_len = 200, 3.0, 1, 16000 * 10
output_dir = os.path.join("/hpc_stor03/sjtu_home/zeyu.xie/workspace/controllable_audio_generation/synthesized",
f"huggingface_demo_steps-{num_steps}_guidance-{guidance}_samples-{num_samples}")
os.makedirs(output_dir, exist_ok=True)
print("------Diffusion begin!")
with torch.no_grad():
latents = model.demo_inference(args.timestamp_caption, scheduler, num_steps, guidance, num_samples, disable_progress=True)
mel = vae.decode_first_stage(latents)
wave = vae.decode_to_waveform(mel)
sf.write(f"{output_dir}/{args.timestamp_caption}.wav", wave[0][:audio_len], samplerate=16000, subtype='PCM_16')
print(f"------Write to files to {output_dir}/{args.timestamp_caption}.wav")
if __name__ == "__main__":
main()