snaramirez872's picture
Update app.py
24e04ee
import torch
import torch.nn as TNN
import pandas as pd
from tqdm import tqdm
from torch.utils.data import Dataset as set, DataLoader as DL
from torch import cuda
import streamlit as st
from transformers import BertTokenizer as BT, BertModel as BM
device = 'cuda' if cuda.is_available() else 'cpu'
# Defined variables for later use
MAX_LEN = 128
TRAIN_BATCH_SIZE = 4
VALID_BATCH_SIZE = 4
LEARNING_RATE = 5e-05
modName = 'bert-base-uncased' # Pre-trained model
categories = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] # Labels
data = pd.read_csv('./train.csv')
data.drop(['id'], inplace=True, axis=1)
new = pd.DataFrame()
new['text'] = data['comment_text']
new['labels'] = data.iloc[:,1].values.tolist()
tokenizer = BT.from_pretrained(modName, truncation=True, do_lower_case=True)
class MultiLabelDataset(set):
def __init__(self, df, tokenizer, max_len):
self.tokenizer = tokenizer
self.data = df
self.text = df.text
self.targets = self.data.labels
self.max_len = max_len
def __len__(self):
return len(self.targets)
def __getitem__(self, idx):
text = str(self.text[idx])
text = " ".join(text.split())
ins = self.tokenizer.encode_plus(
text,
None,
add_special_tokens=True,
max_length=self.max_len,
pad_to_max_length=True,
return_token_type_ids=True
)
input_ids = ins['input_ids']
attention_mask = ins['attention_mask']
token_type_ids = ins["token_type_ids"]
#st.write("Input Keys: ", ins.keys()) # was used for debugging
return {
'input_ids': torch.tensor(input_ids, dtype=torch.long),
'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
'targets': torch.tensor(self.targets[idx], dtype=torch.float)
}
trainSize = 0.8
trainData = new.sample(frac=trainSize,random_state=200)
testData = new.drop(trainData.index).reset_index(drop=True)
trainData = trainData.reset_index(drop=True)
trainSet = MultiLabelDataset(trainData, tokenizer, MAX_LEN)
testSet = MultiLabelDataset(testData, tokenizer, MAX_LEN)
training_loader = DL(trainSet, batch_size=TRAIN_BATCH_SIZE, shuffle=True)
testing_loader = DL(testSet, batch_size=VALID_BATCH_SIZE, shuffle=True)
# neural network
class BERTClass(TNN.Module):
def __init__(self):
super(BERTClass, self).__init__()
self.l1 = BM.from_pretrained(modName)
self.pre_classifier = TNN.Linear(768, 768)
self.dropout = TNN.Dropout(0.1)
self.classifier = TNN.Linear(768, 6)
def forward(self, input_ids, attention_mask, token_type_ids):
out = self.l1(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
hidden_state = out[0]
po = hidden_state[:, 0]
po = self.pre_classifier(po)
po = TNN.Tanh()(po)
po = self.dropout(po)
outs = self.classifier(po)
return outs
mod = BERTClass()
mod.to(device)
# Loss function and Optimizer
def lossFN(outs, targets):
targets = targets.unsqueeze(1).expand_as(outs)
return TNN.BCEWithLogitsLoss()(outs, targets)
opt = torch.optim.Adam(mod.parameters(), lr=LEARNING_RATE)
# Training and Finetuning
def train(mod, training_loader):
mod.train()
for _, data in tqdm(enumerate(training_loader, 0)):
input_ids = data['input_ids'].to(device, dtype=torch.long)
attention_mask = data['attention_mask'].to(device, dtype=torch.long)
token_type_ids = data['token_type_ids'].to(device, dtype=torch.long)
targets = data['targets'].to(device, dtype=torch.float)
outs = mod(input_ids, attention_mask, token_type_ids)
opt.zero_grad()
loss = lossFN(outs, targets)
loss.backward()
opt.step()
# StreamLit Table of Results
st.title("Finetuned Model for Toxicity")
st.subheader("Model: bert-base-uncased")
def predict(tweets):
mod.eval()
res = []
with torch.no_grad():
for ins in tweets:
outs = mod(input_ids=ins['input_ids'].to(device), attention_mask=ins['attention_mask'].to(device), token_type_ids=ins['token_type_ids'].to(device))
probs = torch.softmax(outs[0], dim=-1)
preds = torch.argmax(probs, dim=-1)
for i in range(len(tweets)):
res.append({'TWEETS': tweets, 'LABEL': preds[i].item(), 'PROBABILITY': probs[i][preds[i].item()].item()})
return res
res = predict(testing_loader)
st.table(res) # table