Spaces:
Sleeping
Sleeping
import json | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import requests | |
import timm | |
import torch | |
import torch.nn.functional as F | |
from torchaudio.compliance import kaldi | |
from torchaudio.functional import resample | |
TAG = "gaunernst/vit_base_patch16_1024_128.audiomae_as2m_ft_as20k" | |
MODEL = timm.create_model(f"hf_hub:{TAG}", pretrained=True).eval() | |
LABEL_URL = "https://huggingface.co/datasets/huggingface/label-files/raw/main/audioset-id2label.json" | |
AUDIOSET_LABELS = list(json.loads(requests.get(LABEL_URL).content).values()) | |
SAMPLING_RATE = 16_000 | |
MEAN = -4.2677393 | |
STD = 4.5689974 | |
def preprocess(x: torch.Tensor): | |
x = x - x.mean() | |
melspec = kaldi.fbank(x.unsqueeze(0), htk_compat=True, window_type="hanning", num_mel_bins=128) | |
if melspec.shape[0] < 1024: | |
melspec = F.pad(melspec, (0, 0, 0, 1024 - melspec.shape[0])) | |
else: | |
melspec = melspec[:1024] | |
melspec = (melspec - MEAN) / (STD * 2) | |
return melspec | |
def predict(audio, start): | |
sr, x = audio | |
if x.shape[0] < start * sr: | |
raise gr.Error(f"`start` ({start}) must be smaller than audio duration ({x.shape[0] / sr:.0f}s)") | |
x = torch.from_numpy(x) / (1 << 15) | |
if x.ndim > 1: | |
x = x.mean(-1) | |
assert x.ndim == 1 | |
x = resample(x[int(start * sr) :], sr, SAMPLING_RATE) | |
x = preprocess(x) | |
with torch.inference_mode(): | |
logits = MODEL(x.view(1, 1, 1024, 128)).squeeze(0) | |
topk_probs, topk_classes = logits.sigmoid().topk(10) | |
preds = [[AUDIOSET_LABELS[cls], prob.item() * 100] for cls, prob in zip(topk_classes, topk_probs)] | |
fig = plt.figure() | |
plt.imshow(x.T, origin="lower") | |
plt.title("Log mel-spectrogram") | |
plt.xlabel("Time (s)") | |
plt.xticks(np.arange(11) * 100, np.arange(11)) | |
plt.yticks([0, 64, 128]) | |
plt.tight_layout() | |
return preds, fig | |
DESCRIPTION = """ | |
Classify audio into AudioSet classes with ViT-B/16 pre-trained using AudioMAE objective. | |
- For more information about AudioMAE, visit https://github.com/facebookresearch/AudioMAE. | |
- For how to use AudioMAE model in timm, visit https://huggingface.co/gaunernst/vit_base_patch16_1024_128.audiomae_as2m_ft_as20k. | |
Input audio is converted to log Mel-spectrogram and treated as a grayscale image. The model is a vanilla ViT-B/16. | |
NOTE: AudioMAE model only accepts 10s audio (10.24 to be exact). Longer audio will be cropped. Shorted audio will be zero-padded. | |
""" | |
gr.Interface( | |
title="AudioSet classification with AudioMAE (ViT-B/16)", | |
description=DESCRIPTION, | |
fn=predict, | |
inputs=["audio", "number"], | |
outputs=[ | |
gr.Dataframe(headers=["class", "score"], row_count=10, label="prediction"), | |
gr.Plot(label="spectrogram"), | |
], | |
examples=[ | |
["LS_female_1462-170138-0008.flac", 0], | |
["LS_male_3170-137482-0005.flac", 0], | |
], | |
).launch() | |