Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import json | |
import random | |
import argparse | |
import soundfile as sf | |
import numpy as np | |
import torch | |
from diffusers import DDPMScheduler | |
from pico_model import PicoDiffusion, build_pretrained_models | |
from llm_preprocess import get_event, preprocess_gemini, preprocess_gpt | |
class dotdict(dict): | |
"""dot.notation access to dictionary attributes""" | |
__getattr__ = dict.get | |
__setattr__ = dict.__setitem__ | |
__delattr__ = dict.__delitem__ | |
def parse_args(): | |
parser = argparse.ArgumentParser(description="Inference for text to audio generation task.") | |
parser.add_argument( | |
"--text", '-t', type=str, default="spraying two times then gunshot three times.", | |
help="free-text caption." | |
) | |
parser.add_argument( | |
"--timestamp_caption", '-c', type=str, | |
default=None, | |
#default="spraying at 0.38-1.176_3.06-3.856 and gunshot at 1.729-3.729_4.367-6.367_7.031-9.031.", | |
help="timestamp caption, formatted as 'event1 at onset1-offset1_onset2-offset2 and event2 at onset1-offset1'." | |
) | |
parser.add_argument( | |
"--exp_path", '-exp', type=str, default="/hpc_stor03/sjtu_home/zeyu.xie/workspace/controllable_audio_generation/huggingface/ckpts/pico_model", | |
help="Path for experiment." | |
) | |
parser.add_argument( | |
"--freeze_text_encoder_ckpt", type=str, default='/hpc_stor03/sjtu_home/zeyu.xie/workspace/controllable_audio_generation/huggingface/ckpts/laion_clap/630k-audioset-best.pt', | |
help="Path for clap." | |
) | |
parser.add_argument( | |
"--seed", type=int, default=0, | |
help="seed.", | |
) | |
args = parser.parse_args() | |
args.original_args = os.path.join(args.exp_path, "summary.jsonl") | |
args.diffusion_pt = os.path.join(args.exp_path, "diffusion.pt") | |
return args | |
def main(): | |
args = parse_args() | |
train_args = dotdict(json.loads(open(args.original_args).readlines()[0])) | |
seed = args.seed | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
if torch.cuda.is_available(): | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
# Step1: preprocess via llm | |
if args.timestamp_caption == None: | |
#args.timestamp_caption = preprocess_gpt(args.text) | |
args.timestamp_caption = preprocess_gemini(args.text) | |
# Load Models # | |
print("------Load model") | |
name = "audioldm-s-full" | |
vae, stft = build_pretrained_models(name) | |
vae, stft = vae.cuda(), stft.cuda() | |
model = PicoDiffusion( | |
scheduler_name=train_args.scheduler_name, | |
unet_model_config_path=train_args.unet_model_config, | |
snr_gamma=train_args.snr_gamma, | |
freeze_text_encoder_ckpt=args.freeze_text_encoder_ckpt, | |
diffusion_pt=args.diffusion_pt, | |
).cuda().eval() | |
scheduler = DDPMScheduler.from_pretrained(train_args.scheduler_name, subfolder="scheduler") | |
# Generate # | |
num_steps, guidance, num_samples, audio_len = 200, 3.0, 1, 16000 * 10 | |
output_dir = os.path.join("/hpc_stor03/sjtu_home/zeyu.xie/workspace/controllable_audio_generation/synthesized", | |
f"huggingface_demo_steps-{num_steps}_guidance-{guidance}_samples-{num_samples}") | |
os.makedirs(output_dir, exist_ok=True) | |
print("------Diffusion begin!") | |
with torch.no_grad(): | |
latents = model.demo_inference(args.timestamp_caption, scheduler, num_steps, guidance, num_samples, disable_progress=True) | |
mel = vae.decode_first_stage(latents) | |
wave = vae.decode_to_waveform(mel) | |
sf.write(f"{output_dir}/{args.timestamp_caption}.wav", wave[0][:audio_len], samplerate=16000, subtype='PCM_16') | |
print(f"------Write to files to {output_dir}/{args.timestamp_caption}.wav") | |
if __name__ == "__main__": | |
main() |