import warnings warnings.simplefilter('ignore') import numpy as np import torch import torch.nn as nn from transformers import DistilBertTokenizer, DistilBertModel import logging logging.basicConfig(level=logging.ERROR) from torch import cuda import gradio as gr def classify(sentence): output = "" class DistilBERTClass(nn.Module): def __init__(self, num_intents): super(DistilBERTClass, self).__init__() self.l1 = DistilBertModel.from_pretrained("distilbert-base-uncased") self.fc1 = nn.Sequential( nn.Linear(768, 64), nn.BatchNorm1d(64), nn.ReLU(), ) self.fc2 = nn.Sequential( nn.Linear(64, num_intents) ) def forward(self, input_ids, attention_mask): output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask) hidden_state = output_1[0] pooler = hidden_state[:, 0] pooler = self.fc1(pooler) output = self.fc2(pooler) return output user_intents = ['initial_query', 'greeting', 'add_filter', 'remove_filter', 'continue', 'accept_response', 'reject_response'] musical_attributes = ['track', 'artist', 'year', 'popularity', 'culture', 'similar_track', 'similar_artist', 'user', 'theme', 'mood', 'genre', 'instrument', 'vocal', 'tempo'] intents_dict = {"user": user_intents, "music": musical_attributes} num_intents_dict = {'user': 7, 'music': 14} device = 'cuda:0' if cuda.is_available() else 'cpu' tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") for data_type in ["user", "music"]: num_intents = num_intents_dict[data_type] model = DistilBERTClass(num_intents) model.load_state_dict(torch.load(f"./models/{data_type}_finetune_model.pth", map_location=torch.device('cpu'))) model.to(device) # Tokenize the input sentence inputs = tokenizer.encode_plus( sentence, None, add_special_tokens=True, max_length=128, pad_to_max_length=True, return_token_type_ids=False, return_attention_mask=True, truncation=True ) input_ids = torch.tensor(inputs['input_ids']).unsqueeze(0).to(device) attention_mask = torch.tensor(inputs['attention_mask']).unsqueeze(0).to(device) model.eval() with torch.no_grad(): outputs = model(input_ids, attention_mask) probability_outputs = torch.sigmoid(outputs).cpu().detach().numpy() binary_outputs = (probability_outputs >= 0.5) binary_outputs[np.all(binary_outputs == False, axis=1), -1] = True intents = intents_dict[data_type] predicted_intents = [intent for i, intent in enumerate(intents) if binary_outputs[0][i] == 1] if data_type=="user": output += f"User Intents: {predicted_intents}\n" else: output += f"Musical Attributes: {predicted_intents}\n" return output title = "User Intents and Musical Attributes Classifier" description = """ You can engage in a conversation with the music recommendation system, imagining a situation where it recommends music to you. The model will then predict the intents and musical attributes based on the sentence you provide. """ article = "For more information, visit [Github Repository.](https://github.com/DaeyongKwon98/Intent-Classification/tree/main)" demo = gr.Interface( fn=classify, inputs="text", outputs="text", title=title, description=description, article=article, examples=[["Hi, I need a playlist of rock songs to listen when I exercise."], ["I love Ariana Grande! Give me more."], ["I think these are too fast for me."]], ) demo.launch()