wsntxxn
Update arxiv and code link
f729a94
from pathlib import Path
import argparse
from functools import partial
import gradio as gr
import torch
from torchaudio.functional import resample
import utils.train_util as train_util
def load_model(cfg,
ckpt_path,
device):
model = train_util.init_model_from_config(cfg["model"])
ckpt = torch.load(ckpt_path, "cpu")
train_util.load_pretrained_model(model, ckpt)
model.eval()
model = model.to(device)
tokenizer = train_util.init_obj_from_dict(cfg["tokenizer"])
if not tokenizer.loaded:
tokenizer.load_state_dict(ckpt["tokenizer"])
model.set_index(tokenizer.bos, tokenizer.eos, tokenizer.pad)
return model, tokenizer
def infer(file, runner):
sr, wav = file
wav = torch.as_tensor(wav)
if wav.dtype == torch.short:
wav = wav / 2 ** 15
elif wav.dtype == torch.int:
wav = wav / 2 ** 31
if wav.ndim > 1:
wav = wav.mean(1)
wav = resample(wav, sr, runner.target_sr)
wav_len = len(wav)
wav = wav.float().unsqueeze(0).to(runner.device)
input_dict = {
"mode": "inference",
"wav": wav,
"wav_len": [wav_len],
"specaug": False,
"sample_method": "beam",
"beam_size": 3,
}
with torch.no_grad():
output_dict = runner.model(input_dict)
seq = output_dict["seq"].cpu().numpy()
cap = runner.tokenizer.decode(seq)[0]
return cap
# def input_toggle(input_type):
# if input_type == "file":
# return gr.update(visible=True), gr.update(visible=False)
# elif input_type == "mic":
# return gr.update(visible=False), gr.update(visible=True)
class InferRunner:
def __init__(self, model_name):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
exp_dir = Path(f"./checkpoints/{model_name.lower()}")
cfg = train_util.load_config(exp_dir / "config.yaml")
self.model, self.tokenizer = load_model(cfg, exp_dir / "ckpt.pth", self.device)
self.target_sr = cfg["target_sr"]
def change_model(self, model_name):
exp_dir = Path(f"./checkpoints/{model_name.lower()}")
cfg = train_util.load_config(exp_dir / "config.yaml")
self.model, self.tokenizer = load_model(cfg, exp_dir / "ckpt.pth", self.device)
self.target_sr = cfg["target_sr"]
def change_model(radio):
global infer_runner
infer_runner.change_model(radio)
with gr.Blocks() as demo:
with gr.Row():
gr.Markdown("# Lightweight EfficientNetB2-Transformer Audio Captioning")
with gr.Row():
gr.Markdown("""
[![arXiv](https://img.shields.io/badge/arXiv-2407.14329-brightgreen.svg?style=flat-square)](https://arxiv.org/abs/2407.14329)
[![github](https://img.shields.io/badge/GitHub-Code-blue?logo=Github&style=flat-square)](https://github.com/wsntxxn/AudioCaption?tab=readme-ov-file#lightweight-effb2-transformer-model)
""")
with gr.Row():
with gr.Column():
radio = gr.Radio(
["AudioCaps", "Clotho"],
value="AudioCaps",
label="Select model"
)
infer_runner = InferRunner(radio.value)
file = gr.Audio(label="Input", visible=True)
radio.change(fn=change_model, inputs=[radio,],)
btn = gr.Button("Run")
with gr.Column():
output = gr.Textbox(label="Output")
btn.click(
fn=partial(infer,
runner=infer_runner),
inputs=[file,],
outputs=output
)
demo.launch()