Spaces:
Build error
Build error
# -*- coding: UTF-8 -*- | |
import gradio as gr | |
import torch, torchaudio | |
from timeit import default_timer as timer | |
from torchaudio.transforms import Resample | |
from models.model import HarmonicCNN | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
SAMPLE_RATE = 16000 | |
AUDIO_LEN = 2.90 | |
model = HarmonicCNN() | |
S = torch.load('models/best_model.pth', map_location=torch.device('cpu')) | |
model.load_state_dict(S) | |
LABELS = [ | |
"alternative", | |
"ambient", | |
"atmospheric", | |
"chillout", | |
"classical", | |
"dance", | |
"downtempo", | |
"easylistening", | |
"electronic", | |
"experimental", | |
"folk", | |
"funk", | |
"hiphop", | |
"house", | |
"indie", | |
"instrumentalpop", | |
"jazz", | |
"lounge", | |
"metal", | |
"newage", | |
"orchestral", | |
"pop", | |
"popfolk", | |
"poprock", | |
"reggae", | |
"rock", | |
"soundtrack", | |
"techno", | |
"trance", | |
"triphop", | |
"world", | |
"acousticguitar", | |
"bass", | |
"computer", | |
"drummachine", | |
"drums", | |
"electricguitar", | |
"electricpiano", | |
"guitar", | |
"keyboard", | |
"piano", | |
"strings", | |
"synthesizer", | |
"violin", | |
"voice", | |
"emotional", | |
"energetic", | |
"film", | |
"happy", | |
"relaxing" | |
] | |
example_list = [ | |
"samples/guitar_acoustic.wav", | |
"samples/guitar_electric.wav", | |
"samples/piano.wav", | |
"samples/violin.wav", | |
"samples/flute.wav" | |
] | |
def predict(audio_path): | |
start_time = timer() | |
wav, sample_rate = torchaudio.load(audio_path) | |
if sample_rate > SAMPLE_RATE: | |
resampler = Resample(sample_rate, SAMPLE_RATE) | |
wav = resampler(wav) | |
if wav.shape[0] >= 2: | |
wav = torch.mean(wav, dim=0) | |
wav = wav.unsqueeze(0) | |
model.eval() | |
with torch.inference_mode(): | |
pred_probs = model(wav) | |
pred_labels_and_probs = {LABELS[i]: float(pred_probs[0][i]) for i in range(len(LABELS))} | |
pred_time = round(timer() - start_time, 5) | |
return pred_labels_and_probs, pred_time | |
title = "Music Tagging" | |
demo = gr.Interface(fn=predict, | |
inputs=gr.Audio(type="filepath"), | |
outputs=[gr.Label(num_top_classes=10, label="Predictions"), | |
gr.Number(label="Prediction time (s)")], | |
examples=example_list, | |
title=title) | |
demo.launch(debug=False) |