from pathlib import Path import argparse from functools import partial import gradio as gr import torch from torchaudio.functional import resample import utils.train_util as train_util def load_model(cfg, ckpt_path, device): model = train_util.init_model_from_config(cfg["model"]) ckpt = torch.load(ckpt_path, "cpu") train_util.load_pretrained_model(model, ckpt) model.eval() model = model.to(device) tokenizer = train_util.init_obj_from_dict(cfg["tokenizer"]) if not tokenizer.loaded: tokenizer.load_state_dict(ckpt["tokenizer"]) model.set_index(tokenizer.bos, tokenizer.eos, tokenizer.pad) return model, tokenizer def infer(file, runner): sr, wav = file wav = torch.as_tensor(wav) if wav.dtype == torch.short: wav = wav / 2 ** 15 elif wav.dtype == torch.int: wav = wav / 2 ** 31 if wav.ndim > 1: wav = wav.mean(1) wav = resample(wav, sr, runner.target_sr) wav_len = len(wav) wav = wav.float().unsqueeze(0).to(runner.device) input_dict = { "mode": "inference", "wav": wav, "wav_len": [wav_len], "specaug": False, "sample_method": "beam", "beam_size": 3, } with torch.no_grad(): output_dict = runner.model(input_dict) seq = output_dict["seq"].cpu().numpy() cap = runner.tokenizer.decode(seq)[0] return cap # def input_toggle(input_type): # if input_type == "file": # return gr.update(visible=True), gr.update(visible=False) # elif input_type == "mic": # return gr.update(visible=False), gr.update(visible=True) class InferRunner: def __init__(self, model_name): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") exp_dir = Path(f"./checkpoints/{model_name.lower()}") cfg = train_util.load_config(exp_dir / "config.yaml") self.model, self.tokenizer = load_model(cfg, exp_dir / "ckpt.pth", self.device) self.target_sr = cfg["target_sr"] def change_model(self, model_name): exp_dir = Path(f"./checkpoints/{model_name.lower()}") cfg = train_util.load_config(exp_dir / "config.yaml") self.model, self.tokenizer = load_model(cfg, exp_dir / "ckpt.pth", self.device) self.target_sr = cfg["target_sr"] def change_model(radio): global infer_runner infer_runner.change_model(radio) with gr.Blocks() as demo: with gr.Row(): gr.Markdown("# Lightweight EfficientNetB2-Transformer Audio Captioning") with gr.Row(): gr.Markdown(""" [![arXiv](https://img.shields.io/badge/arXiv-2407.14329-brightgreen.svg?style=flat-square)](https://arxiv.org/abs/2407.14329) [![github](https://img.shields.io/badge/GitHub-Code-blue?logo=Github&style=flat-square)](https://github.com/wsntxxn/AudioCaption?tab=readme-ov-file#lightweight-effb2-transformer-model) """) with gr.Row(): with gr.Column(): radio = gr.Radio( ["AudioCaps", "Clotho"], value="AudioCaps", label="Select model" ) infer_runner = InferRunner(radio.value) file = gr.Audio(label="Input", visible=True) radio.change(fn=change_model, inputs=[radio,],) btn = gr.Button("Run") with gr.Column(): output = gr.Textbox(label="Output") btn.click( fn=partial(infer, runner=infer_runner), inputs=[file,], outputs=output ) demo.launch()