# -*- coding: UTF-8 -*- import gradio as gr import torch, torchaudio from timeit import default_timer as timer from torchaudio.transforms import Resample from models.model import HarmonicCNN device = "cuda" if torch.cuda.is_available() else "cpu" SAMPLE_RATE = 16000 AUDIO_LEN = 2.90 model = HarmonicCNN() S = torch.load('models/best_model.pth', map_location=torch.device('cpu')) model.load_state_dict(S) LABELS = [ "alternative", "ambient", "atmospheric", "chillout", "classical", "dance", "downtempo", "easylistening", "electronic", "experimental", "folk", "funk", "hiphop", "house", "indie", "instrumentalpop", "jazz", "lounge", "metal", "newage", "orchestral", "pop", "popfolk", "poprock", "reggae", "rock", "soundtrack", "techno", "trance", "triphop", "world", "acousticguitar", "bass", "computer", "drummachine", "drums", "electricguitar", "electricpiano", "guitar", "keyboard", "piano", "strings", "synthesizer", "violin", "voice", "emotional", "energetic", "film", "happy", "relaxing" ] example_list = [ "samples/guitar_acoustic.wav", "samples/guitar_electric.wav", "samples/piano.wav", "samples/violin.wav", "samples/flute.wav" ] def predict(audio_path): start_time = timer() wav, sample_rate = torchaudio.load(audio_path) if sample_rate > SAMPLE_RATE: resampler = Resample(sample_rate, SAMPLE_RATE) wav = resampler(wav) if wav.shape[0] >= 2: wav = torch.mean(wav, dim=0) wav = wav.unsqueeze(0) model.eval() with torch.inference_mode(): pred_probs = model(wav) pred_labels_and_probs = {LABELS[i]: float(pred_probs[0][i]) for i in range(len(LABELS))} pred_time = round(timer() - start_time, 5) return pred_labels_and_probs, pred_time title = "Music Tagging" demo = gr.Interface(fn=predict, inputs=gr.Audio(type="filepath"), outputs=[gr.Label(num_top_classes=10, label="Predictions"), gr.Number(label="Prediction time (s)")], examples=example_list, title=title) demo.launch(debug=False)