gaunernst's picture
beautify
cafc237
import json
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import requests
import timm
import torch
import torch.nn.functional as F
from torchaudio.compliance import kaldi
from torchaudio.functional import resample
TAG = "gaunernst/vit_base_patch16_1024_128.audiomae_as2m_ft_as20k"
MODEL = timm.create_model(f"hf_hub:{TAG}", pretrained=True).eval()
LABEL_URL = "https://huggingface.co/datasets/huggingface/label-files/raw/main/audioset-id2label.json"
AUDIOSET_LABELS = list(json.loads(requests.get(LABEL_URL).content).values())
SAMPLING_RATE = 16_000
MEAN = -4.2677393
STD = 4.5689974
def preprocess(x: torch.Tensor):
x = x - x.mean()
melspec = kaldi.fbank(x.unsqueeze(0), htk_compat=True, window_type="hanning", num_mel_bins=128)
if melspec.shape[0] < 1024:
melspec = F.pad(melspec, (0, 0, 0, 1024 - melspec.shape[0]))
else:
melspec = melspec[:1024]
melspec = (melspec - MEAN) / (STD * 2)
return melspec
def predict(audio, start):
sr, x = audio
if x.shape[0] < start * sr:
raise gr.Error(f"`start` ({start}) must be smaller than audio duration ({x.shape[0] / sr:.0f}s)")
x = torch.from_numpy(x) / (1 << 15)
if x.ndim > 1:
x = x.mean(-1)
assert x.ndim == 1
x = resample(x[int(start * sr) :], sr, SAMPLING_RATE)
x = preprocess(x)
with torch.inference_mode():
logits = MODEL(x.view(1, 1, 1024, 128)).squeeze(0)
topk_probs, topk_classes = logits.sigmoid().topk(10)
preds = [[AUDIOSET_LABELS[cls], prob.item() * 100] for cls, prob in zip(topk_classes, topk_probs)]
fig = plt.figure()
plt.imshow(x.T, origin="lower")
plt.title("Log mel-spectrogram")
plt.xlabel("Time (s)")
plt.xticks(np.arange(11) * 100, np.arange(11))
plt.yticks([0, 64, 128])
plt.tight_layout()
return preds, fig
DESCRIPTION = """
Classify audio into AudioSet classes with ViT-B/16 pre-trained using AudioMAE objective.
- For more information about AudioMAE, visit https://github.com/facebookresearch/AudioMAE.
- For how to use AudioMAE model in timm, visit https://huggingface.co/gaunernst/vit_base_patch16_1024_128.audiomae_as2m_ft_as20k.
Input audio is converted to log Mel-spectrogram and treated as a grayscale image. The model is a vanilla ViT-B/16.
NOTE: AudioMAE model only accepts 10s audio (10.24 to be exact). Longer audio will be cropped. Shorted audio will be zero-padded.
"""
gr.Interface(
title="AudioSet classification with AudioMAE (ViT-B/16)",
description=DESCRIPTION,
fn=predict,
inputs=["audio", "number"],
outputs=[
gr.Dataframe(headers=["class", "score"], row_count=10, label="prediction"),
gr.Plot(label="spectrogram"),
],
examples=[
["LS_female_1462-170138-0008.flac", 0],
["LS_male_3170-137482-0005.flac", 0],
],
).launch()