PicoAudio / inference.py
ZeyuXie's picture
Upload 167 files
8c1bf05 verified
raw
history blame
3.19 kB
import os
import json
import random
import argparse
import soundfile as sf
import numpy as np
import torch
from diffusers import DDPMScheduler
from pico_model import PicoDiffusion, build_pretrained_models
class dotdict(dict):
"""dot.notation access to dictionary attributes"""
__getattr__ = dict.get
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
def parse_args():
parser = argparse.ArgumentParser(description="Inference for text to audio generation task.")
parser.add_argument(
"--text", '-t', type=str, default="spraying at 0.38-1.176_3.06-3.856 and gunshot at 1.729-3.729_4.367-6.367_7.031-9.031.",
help="Path for experiment."
)
parser.add_argument(
"--exp_path", '-exp', type=str, default="/hpc_stor03/sjtu_home/zeyu.xie/workspace/controllable_audio_generation/huggingface/ckpts/pico_model",
help="Path for experiment."
)
parser.add_argument(
"--freeze_text_encoder_ckpt", type=str, default='/hpc_stor03/sjtu_home/zeyu.xie/workspace/controllable_audio_generation/huggingface/ckpts/laion_clap/630k-audioset-best.pt',
help="Path for clap."
)
parser.add_argument(
"--seed", type=int, default=0,
help="seed.",
)
args = parser.parse_args()
args.original_args = os.path.join(args.exp_path, "summary.jsonl")
args.diffusion_pt = os.path.join(args.exp_path, "diffusion.pt")
return args
def main():
args = parse_args()
train_args = dotdict(json.loads(open(args.original_args).readlines()[0]))
seed = args.seed
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Load Models #
print("------Load model")
name = "audioldm-s-full"
vae, stft = build_pretrained_models(name)
vae, stft = vae.cuda(), stft.cuda()
model = PicoDiffusion(
scheduler_name=train_args.scheduler_name,
unet_model_config_path=train_args.unet_model_config,
snr_gamma=train_args.snr_gamma,
freeze_text_encoder_ckpt=args.freeze_text_encoder_ckpt,
diffusion_pt=args.diffusion_pt,
).cuda().eval()
scheduler = DDPMScheduler.from_pretrained(train_args.scheduler_name, subfolder="scheduler")
# Generate #
num_steps, guidance, num_samples, audio_len = 200, 3.0, 1, 16000 * 10
output_dir = os.path.join("/hpc_stor03/sjtu_home/zeyu.xie/workspace/controllable_audio_generation/synthesized",
f"huggingface_demo_steps-{num_steps}_guidance-{guidance}_samples-{num_samples}")
os.makedirs(output_dir, exist_ok=True)
print("------Diffusion begin!")
with torch.no_grad():
latents = model.demo_inference(args.text, scheduler, num_steps, guidance, num_samples, disable_progress=True)
mel = vae.decode_first_stage(latents)
wave = vae.decode_to_waveform(mel)
sf.write(f"{output_dir}/{args.text}.wav", wave[0][:audio_len], samplerate=16000, subtype='PCM_16')
print(f"------Write to files to {output_dir}/{args.text}.wav")
if __name__ == "__main__":
main()