gaunernst commited on
Commit
cafc237
1 Parent(s): 5b04966
Files changed (1) hide show
  1. app.py +36 -5
app.py CHANGED
@@ -1,6 +1,8 @@
1
  import json
2
 
3
  import gradio as gr
 
 
4
  import requests
5
  import timm
6
  import torch
@@ -27,7 +29,7 @@ def preprocess(x: torch.Tensor):
27
  else:
28
  melspec = melspec[:1024]
29
  melspec = (melspec - MEAN) / (STD * 2)
30
- return melspec.view(1, 1024, 128)
31
 
32
 
33
  def predict(audio, start):
@@ -43,15 +45,44 @@ def predict(audio, start):
43
  x = preprocess(x)
44
 
45
  with torch.inference_mode():
46
- logits = MODEL(x.unsqueeze(0)).squeeze(0)
47
 
48
  topk_probs, topk_classes = logits.sigmoid().topk(10)
49
- return [[AUDIOSET_LABELS[cls], prob.item() * 100] for cls, prob in zip(topk_classes, topk_probs)]
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  gr.Interface(
 
 
53
  fn=predict,
54
  inputs=["audio", "number"],
55
- outputs="dataframe",
56
- examples=[["LS_female_1462-170138-0008.flac", 0], ["LS_male_3170-137482-0005.flac", 0]],
 
 
 
 
 
 
57
  ).launch()
 
1
  import json
2
 
3
  import gradio as gr
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
  import requests
7
  import timm
8
  import torch
 
29
  else:
30
  melspec = melspec[:1024]
31
  melspec = (melspec - MEAN) / (STD * 2)
32
+ return melspec
33
 
34
 
35
  def predict(audio, start):
 
45
  x = preprocess(x)
46
 
47
  with torch.inference_mode():
48
+ logits = MODEL(x.view(1, 1, 1024, 128)).squeeze(0)
49
 
50
  topk_probs, topk_classes = logits.sigmoid().topk(10)
51
+ preds = [[AUDIOSET_LABELS[cls], prob.item() * 100] for cls, prob in zip(topk_classes, topk_probs)]
52
 
53
+ fig = plt.figure()
54
+ plt.imshow(x.T, origin="lower")
55
+ plt.title("Log mel-spectrogram")
56
+ plt.xlabel("Time (s)")
57
+ plt.xticks(np.arange(11) * 100, np.arange(11))
58
+ plt.yticks([0, 64, 128])
59
+ plt.tight_layout()
60
+
61
+ return preds, fig
62
+
63
+
64
+ DESCRIPTION = """
65
+ Classify audio into AudioSet classes with ViT-B/16 pre-trained using AudioMAE objective.
66
+
67
+ - For more information about AudioMAE, visit https://github.com/facebookresearch/AudioMAE.
68
+ - For how to use AudioMAE model in timm, visit https://huggingface.co/gaunernst/vit_base_patch16_1024_128.audiomae_as2m_ft_as20k.
69
+
70
+ Input audio is converted to log Mel-spectrogram and treated as a grayscale image. The model is a vanilla ViT-B/16.
71
+
72
+ NOTE: AudioMAE model only accepts 10s audio (10.24 to be exact). Longer audio will be cropped. Shorted audio will be zero-padded.
73
+ """
74
 
75
  gr.Interface(
76
+ title="AudioSet classification with AudioMAE (ViT-B/16)",
77
+ description=DESCRIPTION,
78
  fn=predict,
79
  inputs=["audio", "number"],
80
+ outputs=[
81
+ gr.Dataframe(headers=["class", "score"], row_count=10, label="prediction"),
82
+ gr.Plot(label="spectrogram"),
83
+ ],
84
+ examples=[
85
+ ["LS_female_1462-170138-0008.flac", 0],
86
+ ["LS_male_3170-137482-0005.flac", 0],
87
+ ],
88
  ).launch()