from functools import partial import gradio as gr import torch from torchaudio.functional import resample from transformers import AutoModel, PreTrainedTokenizerFast def load_model(model_name, device): if model_name == "AudioCaps": model = AutoModel.from_pretrained( "wsntxxn/effb2-trm-audiocaps-captioning", trust_remote_code=True ).to(device) tokenizer = PreTrainedTokenizerFast.from_pretrained( "wsntxxn/audiocaps-simple-tokenizer" ) elif model_name == "Clotho": model = AutoModel.from_pretrained( "wsntxxn/effb2-trm-clotho-captioning", trust_remote_code=True ).to(device) tokenizer = PreTrainedTokenizerFast.from_pretrained( "wsntxxn/clotho-simple-tokenizer" ) return model, tokenizer def infer(file, runner): sr, wav = file wav = torch.as_tensor(wav) if wav.dtype == torch.short: wav = wav / 2 ** 15 elif wav.dtype == torch.int: wav = wav / 2 ** 31 if wav.ndim > 1: wav = wav.mean(1) wav = resample(wav, sr, runner.target_sr) wav_len = len(wav) wav = wav.float().unsqueeze(0) with torch.no_grad(): word_idx = runner.model( audio=wav, audio_length=[wav_len] )[0] cap = runner.tokenizer.decode(word_idx, skip_special_tokens=True) return cap # def input_toggle(input_type): # if input_type == "file": # return gr.update(visible=True), gr.update(visible=False) # elif input_type == "mic": # return gr.update(visible=False), gr.update(visible=True) class InferRunner: def __init__(self, model_name): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model, self.tokenizer = load_model(model_name, self.device) self.target_sr = self.model.config.sample_rate def change_model(self, model_name): self.model, self.tokenizer = load_model(model_name, self.device) self.target_sr = self.model.config.sample_rate def change_model(radio): global infer_runner infer_runner.change_model(radio) with gr.Blocks() as demo: with gr.Row(): gr.Markdown("# Lightweight EfficientNetB2-Transformer Audio Captioning") with gr.Row(): gr.Markdown(""" [![arXiv](https://img.shields.io/badge/arXiv-2407.14329-brightgreen.svg?style=flat-square)](https://arxiv.org/abs/2407.14329) [![github](https://img.shields.io/badge/GitHub-Code-blue?logo=Github&style=flat-square)](https://github.com/wsntxxn/AudioCaption?tab=readme-ov-file#lightweight-effb2-transformer-model) """) with gr.Row(): with gr.Column(): radio = gr.Radio( ["AudioCaps", "Clotho"], value="AudioCaps", label="Select model" ) infer_runner = InferRunner(radio.value) file = gr.Audio(label="Input", visible=True) radio.change(fn=change_model, inputs=[radio,],) btn = gr.Button("Run") with gr.Column(): output = gr.Textbox(label="Output") btn.click( fn=partial(infer, runner=infer_runner), inputs=[file,], outputs=output ) demo.launch()