Daeyongkwon98's picture
load model to cpu
4811d2a
raw
history blame
3.54 kB
import warnings
warnings.simplefilter('ignore')
import numpy as np
import torch
import torch.nn as nn
from transformers import DistilBertTokenizer, DistilBertModel
import logging
logging.basicConfig(level=logging.ERROR)
from torch import cuda
import gradio as gr
def classify(sentence):
output = ""
class DistilBERTClass(nn.Module):
def __init__(self, num_intents):
super(DistilBERTClass, self).__init__()
self.l1 = DistilBertModel.from_pretrained("distilbert-base-uncased")
self.fc1 = nn.Sequential(
nn.Linear(768, 64),
nn.BatchNorm1d(64),
nn.ReLU(),
)
self.fc2 = nn.Sequential(
nn.Linear(64, num_intents)
)
def forward(self, input_ids, attention_mask):
output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask)
hidden_state = output_1[0]
pooler = hidden_state[:, 0]
pooler = self.fc1(pooler)
output = self.fc2(pooler)
return output
user_intents = ['initial_query', 'greeting', 'add_filter', 'remove_filter', 'continue', 'accept_response', 'reject_response']
musical_attributes = ['track', 'artist', 'year', 'popularity', 'culture', 'similar_track', 'similar_artist', 'user', 'theme', 'mood', 'genre', 'instrument', 'vocal', 'tempo']
intents_dict = {"user": user_intents, "music": musical_attributes}
num_intents_dict = {'user': 7, 'music': 14}
device = 'cuda:0' if cuda.is_available() else 'cpu'
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
for data_type in ["user", "music"]:
num_intents = num_intents_dict[data_type]
model = DistilBERTClass(num_intents)
model.load_state_dict(torch.load(f"./models/{data_type}_finetune_model.pth", map_location=torch.device('cpu')))
model.to(device)
# Tokenize the input sentence
inputs = tokenizer.encode_plus(
sentence,
None,
add_special_tokens=True,
max_length=128,
pad_to_max_length=True,
return_token_type_ids=False,
return_attention_mask=True,
truncation=True
)
input_ids = torch.tensor(inputs['input_ids']).unsqueeze(0).to(device)
attention_mask = torch.tensor(inputs['attention_mask']).unsqueeze(0).to(device)
model.eval()
with torch.no_grad():
outputs = model(input_ids, attention_mask)
probability_outputs = torch.sigmoid(outputs).cpu().detach().numpy()
binary_outputs = (probability_outputs >= 0.5)
binary_outputs[np.all(binary_outputs == False, axis=1), -1] = True
intents = intents_dict[data_type]
predicted_intents = [intent for i, intent in enumerate(intents) if binary_outputs[0][i] == 1]
if data_type=="user":
output += f"User Intents: {predicted_intents}\n"
else:
output += f"Musical Attributes: {predicted_intents}\n"
return output
title = "User Intents and Musical Attributes Classifier"
description = """
You can engage in a conversation with the music recommendation system, imagining a situation where it recommends music to you. The model will then predict the intents and musical attributes based on the sentence you provide.
<img src="https://github.com/user-attachments/assets/a8bfb1dc-856b-4f85-82dd-510cddcc2aeb" width=400px>
"""
article = "For more information, visit [Github Repository.](https://github.com/DaeyongKwon98/Intent-Classification/tree/main)"
demo = gr.Interface(
fn=classify,
inputs="text",
outputs="text",
title=title,
description=description,
article=article,
examples=[["Hi, I need a playlist of rock songs to listen when I exercise."], ["I love Ariana Grande! Give me more."], ["I think these are too fast for me."]],
)
demo.launch()