import gradio as gr # from transformers import Wav2Vec2FeatureExtractor from transformers import AutoModel import torch from torch import nn import torchaudio import torchaudio.transforms as T import logging import json import os import re import pandas as pd import importlib modeling_MERT = importlib.import_module("MERT-v1-95M.modeling_MERT") from Prediction_Head.MTGGenre_head import MLPProberBase # input cr: https://huggingface.co/spaces/thealphhamerc/audio-to-text/blob/main/app.py logger = logging.getLogger("MERT-v1-95M-app") logger.setLevel(logging.INFO) ch = logging.StreamHandler() ch.setLevel(logging.INFO) formatter = logging.Formatter( "%(asctime)s;%(levelname)s;%(message)s", "%Y-%m-%d %H:%M:%S") ch.setFormatter(formatter) logger.addHandler(ch) inputs = [ gr.components.Audio(type="filepath", label="Add music audio file"), gr.inputs.Audio(source="microphone", type="filepath"), ] live_inputs = [ gr.Audio(source="microphone",streaming=True, type="filepath"), ] title = "One Model for All Music Understanding Tasks" description = "An example of using the [MERT-v1-95M](https://huggingface.co/m-a-p/MERT-v1-95M) model as backbone to conduct multiple music understanding tasks with the universal represenation." # article = "The tasks include EMO, GS, MTGInstrument, MTGGenre, MTGTop50, MTGMood, NSynthI, NSynthP, VocalSetS, VocalSetT. \n\n More models can be referred at the [map organization page](https://huggingface.co/m-a-p)." with open('./README.md', 'r') as f: # skip the header header_count = 0 for line in f: if '---' in line: header_count += 1 if header_count >= 2: break # read the rest conent article = f.read() audio_examples = [ # ["input/example-1.wav"], # ["input/example-2.wav"], ] df_init = pd.DataFrame(columns=['Task', 'Top 1', 'Top 2', 'Top 3', 'Top 4', 'Top 5']) transcription_df = gr.DataFrame(value=df_init, label="Output Dataframe", row_count=( 0, "dynamic"), max_rows=30, wrap=True, overflow_row_behaviour='paginate') # outputs = [gr.components.Textbox()] outputs = transcription_df df_init_live = pd.DataFrame(columns=['Task', 'Top 1', 'Top 2', 'Top 3', 'Top 4', 'Top 5']) transcription_df_live = gr.DataFrame(value=df_init_live, label="Output Dataframe", row_count=( 0, "dynamic"), max_rows=30, wrap=True, overflow_row_behaviour='paginate') outputs_live = transcription_df_live # Load the model and the corresponding preprocessor config # model = AutoModel.from_pretrained("m-a-p/MERT-v0-public", trust_remote_code=True) # processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v0-public",trust_remote_code=True) model = modeling_MERT.MERTModel.from_pretrained("./MERT-v1-95M") processor = Wav2Vec2FeatureExtractor.from_pretrained("./MERT-v1-95M") device = 'cuda' if torch.cuda.is_available() else 'cpu' MERT_BEST_LAYER_IDX = { 'EMO': 5, 'GS': 8, 'GTZAN': 7, 'MTGGenre': 7, 'MTGInstrument': 'all', 'MTGMood': 6, 'MTGTop50': 6, 'MTT': 'all', 'NSynthI': 6, 'NSynthP': 1, 'VocalSetS': 2, 'VocalSetT': 9, } MERT_BEST_LAYER_IDX = { 'EMO': 5, 'GS': 8, 'GTZAN': 7, 'MTGGenre': 7, 'MTGInstrument': 'all', 'MTGMood': 6, 'MTGTop50': 6, 'MTT': 'all', 'NSynthI': 6, 'NSynthP': 1, 'VocalSetS': 2, 'VocalSetT': 9, } CLASSIFIERS = { } ID2CLASS = { } TASKS = ['GS', 'MTGInstrument', 'MTGGenre', 'MTGTop50', 'MTGMood', 'NSynthI', 'NSynthP', 'VocalSetS', 'VocalSetT','EMO',] Regression_TASKS = ['EMO'] head_dir = './Prediction_Head/best-layer-MERT-v1-95M' for task in TASKS: print('loading', task) with open(os.path.join(head_dir,f'{task}.id2class.json'), 'r') as f: ID2CLASS[task]=json.load(f) num_class = len(ID2CLASS[task].keys()) CLASSIFIERS[task] = MLPProberBase(d=768, layer=MERT_BEST_LAYER_IDX[task], num_outputs=num_class) CLASSIFIERS[task].load_state_dict(torch.load(f'{head_dir}/{task}.ckpt')['state_dict']) CLASSIFIERS[task].to(device) model.to(device) def model_infernce(inputs): waveform, sample_rate = torchaudio.load(inputs) resample_rate = processor.sampling_rate # make sure the sample_rate aligned if resample_rate != sample_rate: # print(f'setting rate from {sample_rate} to {resample_rate}') resampler = T.Resample(sample_rate, resample_rate) waveform = resampler(waveform) waveform = waveform.view(-1,) # make it (n_sample, ) model_inputs = processor(waveform, sampling_rate=resample_rate, return_tensors="pt") model_inputs.to(device) with torch.no_grad(): model_outputs = model(**model_inputs, output_hidden_states=True) # take a look at the output shape, there are 13 layers of representation # each layer performs differently in different downstream tasks, you should choose empirically all_layer_hidden_states = torch.stack(model_outputs.hidden_states).squeeze()[1:,:,:].unsqueeze(0) print(all_layer_hidden_states.shape) # [13 layer, Time steps, 768 feature_dim] all_layer_hidden_states = all_layer_hidden_states.mean(dim=2) task_output_texts = "" df = pd.DataFrame(columns=['Task', 'Top 1', 'Top 2', 'Top 3', 'Top 4', 'Top 5']) df_objects = [] for task in TASKS: num_class = len(ID2CLASS[task].keys()) if MERT_BEST_LAYER_IDX[task] == 'all': logits = CLASSIFIERS[task](all_layer_hidden_states) # [1, 87] else: logits = CLASSIFIERS[task](all_layer_hidden_states[:, MERT_BEST_LAYER_IDX[task]]) # print(f'task {task} logits:', logits.shape, 'num class:', num_class) sorted_idx = torch.argsort(logits, dim = -1, descending=True)[0] # batch =1 sorted_prob,_ = torch.sort(nn.functional.softmax(logits[0], dim=-1), dim=-1, descending=True) # print(sorted_prob) # print(sorted_prob.shape) top_n_show = 5 if num_class >= 5 else num_class # task_output_texts = task_output_texts + f"TASK {task} output:\n" + "\n".join([str(ID2CLASS[task][str(sorted_idx[idx].item())])+f', probability: {sorted_prob[idx].item():.2%}' for idx in range(top_n_show)]) + '\n' # task_output_texts = task_output_texts + '----------------------\n' row_elements = [task] for idx in range(top_n_show): print(ID2CLASS[task]) # print('id', str(sorted_idx[idx].item())) output_class_name = str(ID2CLASS[task][str(sorted_idx[idx].item())]) output_class_name = re.sub(r'^\w+---', '', output_class_name) output_class_name = re.sub(r'^\w+\/\w+---', '', output_class_name) # print('output name', output_class_name) output_prob = f' {sorted_prob[idx].item():.2%}' row_elements.append(output_class_name+output_prob) # fill empty elment for _ in range(5+1 - len(row_elements)): row_elements.append(' ') df_objects.append(row_elements) df = pd.DataFrame(df_objects, columns=['Task', 'Top 1', 'Top 2', 'Top 3', 'Top 4', 'Top 5']) return df def convert_audio(inputs, microphone): if (microphone is not None): inputs = microphone df = model_infernce(inputs) return df def live_convert_audio(microphone): if (microphone is not None): inputs = microphone df = model_infernce(inputs) return df audio_chunked = gr.Interface( fn=convert_audio, inputs=inputs, outputs=outputs, allow_flagging="never", title=title, description=description, article=article, examples=audio_examples, ) live_audio_chunked = gr.Interface( fn=live_convert_audio, inputs=live_inputs, outputs=outputs_live, allow_flagging="never", title=title, description=description, article=article, # examples=audio_examples, live=True, ) demo = gr.Blocks() with demo: gr.TabbedInterface( [ audio_chunked, live_audio_chunked, ], [ "Audio File or Recording", "Live Streaming Music" ] ) # demo.queue(concurrency_count=1, max_size=5) demo.launch(show_api=False)