Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as TNN | |
import pandas as pd | |
from tqdm import tqdm | |
from torch.utils.data import Dataset as set, DataLoader as DL | |
from torch import cuda | |
import streamlit as st | |
from transformers import BertTokenizer as BT, BertModel as BM | |
device = 'cuda' if cuda.is_available() else 'cpu' | |
# Defined variables for later use | |
MAX_LEN = 128 | |
TRAIN_BATCH_SIZE = 4 | |
VALID_BATCH_SIZE = 4 | |
LEARNING_RATE = 5e-05 | |
modName = 'bert-base-uncased' # Pre-trained model | |
categories = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] # Labels | |
data = pd.read_csv('./train.csv') | |
data.drop(['id'], inplace=True, axis=1) | |
new = pd.DataFrame() | |
new['text'] = data['comment_text'] | |
new['labels'] = data.iloc[:,1].values.tolist() | |
tokenizer = BT.from_pretrained(modName, truncation=True, do_lower_case=True) | |
class MultiLabelDataset(set): | |
def __init__(self, df, tokenizer, max_len): | |
self.tokenizer = tokenizer | |
self.data = df | |
self.text = df.text | |
self.targets = self.data.labels | |
self.max_len = max_len | |
def __len__(self): | |
return len(self.targets) | |
def __getitem__(self, idx): | |
text = str(self.text[idx]) | |
text = " ".join(text.split()) | |
ins = self.tokenizer.encode_plus( | |
text, | |
None, | |
add_special_tokens=True, | |
max_length=self.max_len, | |
pad_to_max_length=True, | |
return_token_type_ids=True | |
) | |
input_ids = ins['input_ids'] | |
attention_mask = ins['attention_mask'] | |
token_type_ids = ins["token_type_ids"] | |
#st.write("Input Keys: ", ins.keys()) # was used for debugging | |
return { | |
'input_ids': torch.tensor(input_ids, dtype=torch.long), | |
'attention_mask': torch.tensor(attention_mask, dtype=torch.long), | |
'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long), | |
'targets': torch.tensor(self.targets[idx], dtype=torch.float) | |
} | |
trainSize = 0.8 | |
trainData = new.sample(frac=trainSize,random_state=200) | |
testData = new.drop(trainData.index).reset_index(drop=True) | |
trainData = trainData.reset_index(drop=True) | |
trainSet = MultiLabelDataset(trainData, tokenizer, MAX_LEN) | |
testSet = MultiLabelDataset(testData, tokenizer, MAX_LEN) | |
training_loader = DL(trainSet, batch_size=TRAIN_BATCH_SIZE, shuffle=True) | |
testing_loader = DL(testSet, batch_size=VALID_BATCH_SIZE, shuffle=True) | |
# neural network | |
class BERTClass(TNN.Module): | |
def __init__(self): | |
super(BERTClass, self).__init__() | |
self.l1 = BM.from_pretrained(modName) | |
self.pre_classifier = TNN.Linear(768, 768) | |
self.dropout = TNN.Dropout(0.1) | |
self.classifier = TNN.Linear(768, 6) | |
def forward(self, input_ids, attention_mask, token_type_ids): | |
out = self.l1(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) | |
hidden_state = out[0] | |
po = hidden_state[:, 0] | |
po = self.pre_classifier(po) | |
po = TNN.Tanh()(po) | |
po = self.dropout(po) | |
outs = self.classifier(po) | |
return outs | |
mod = BERTClass() | |
mod.to(device) | |
# Loss function and Optimizer | |
def lossFN(outs, targets): | |
targets = targets.unsqueeze(1).expand_as(outs) | |
return TNN.BCEWithLogitsLoss()(outs, targets) | |
opt = torch.optim.Adam(mod.parameters(), lr=LEARNING_RATE) | |
# Training and Finetuning | |
def train(mod, training_loader): | |
mod.train() | |
for _, data in tqdm(enumerate(training_loader, 0)): | |
input_ids = data['input_ids'].to(device, dtype=torch.long) | |
attention_mask = data['attention_mask'].to(device, dtype=torch.long) | |
token_type_ids = data['token_type_ids'].to(device, dtype=torch.long) | |
targets = data['targets'].to(device, dtype=torch.float) | |
outs = mod(input_ids, attention_mask, token_type_ids) | |
opt.zero_grad() | |
loss = lossFN(outs, targets) | |
loss.backward() | |
opt.step() | |
# StreamLit Table of Results | |
st.title("Finetuned Model for Toxicity") | |
st.subheader("Model: bert-base-uncased") | |
def predict(tweets): | |
mod.eval() | |
res = [] | |
with torch.no_grad(): | |
for ins in tweets: | |
outs = mod(input_ids=ins['input_ids'].to(device), attention_mask=ins['attention_mask'].to(device), token_type_ids=ins['token_type_ids'].to(device)) | |
probs = torch.softmax(outs[0], dim=-1) | |
preds = torch.argmax(probs, dim=-1) | |
for i in range(len(tweets)): | |
res.append({'TWEETS': tweets, 'LABEL': preds[i].item(), 'PROBABILITY': probs[i][preds[i].item()].item()}) | |
return res | |
res = predict(testing_loader) | |
st.table(res) # table | |