music_tagging / app.py
cchaun's picture
quick fix
aa314ec
# -*- coding: UTF-8 -*-
import gradio as gr
import torch, torchaudio
from timeit import default_timer as timer
from torchaudio.transforms import Resample
from models.model import HarmonicCNN
device = "cuda" if torch.cuda.is_available() else "cpu"
SAMPLE_RATE = 16000
AUDIO_LEN = 2.90
model = HarmonicCNN()
S = torch.load('models/best_model.pth', map_location=torch.device('cpu'))
model.load_state_dict(S)
LABELS = [
"alternative",
"ambient",
"atmospheric",
"chillout",
"classical",
"dance",
"downtempo",
"easylistening",
"electronic",
"experimental",
"folk",
"funk",
"hiphop",
"house",
"indie",
"instrumentalpop",
"jazz",
"lounge",
"metal",
"newage",
"orchestral",
"pop",
"popfolk",
"poprock",
"reggae",
"rock",
"soundtrack",
"techno",
"trance",
"triphop",
"world",
"acousticguitar",
"bass",
"computer",
"drummachine",
"drums",
"electricguitar",
"electricpiano",
"guitar",
"keyboard",
"piano",
"strings",
"synthesizer",
"violin",
"voice",
"emotional",
"energetic",
"film",
"happy",
"relaxing"
]
example_list = [
"samples/guitar_acoustic.wav",
"samples/guitar_electric.wav",
"samples/piano.wav",
"samples/violin.wav",
"samples/flute.wav"
]
def predict(audio_path):
start_time = timer()
wav, sample_rate = torchaudio.load(audio_path)
if sample_rate > SAMPLE_RATE:
resampler = Resample(sample_rate, SAMPLE_RATE)
wav = resampler(wav)
if wav.shape[0] >= 2:
wav = torch.mean(wav, dim=0)
wav = wav.unsqueeze(0)
model.eval()
with torch.inference_mode():
pred_probs = model(wav)
pred_labels_and_probs = {LABELS[i]: float(pred_probs[0][i]) for i in range(len(LABELS))}
pred_time = round(timer() - start_time, 5)
return pred_labels_and_probs, pred_time
title = "Music Tagging"
demo = gr.Interface(fn=predict,
inputs=gr.Audio(type="filepath"),
outputs=[gr.Label(num_top_classes=10, label="Predictions"),
gr.Number(label="Prediction time (s)")],
examples=example_list,
title=title)
demo.launch(debug=False)