Spaces:
Running
Running
import warnings | |
warnings.filterwarnings('ignore') | |
warnings.filterwarnings("ignore", category=UserWarning) | |
import torchvision | |
torchvision.disable_beta_transforms_warning() | |
import os | |
import re | |
from transformers import BertTokenizer, BertForSequenceClassification | |
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer | |
from sklearn.model_selection import train_test_split | |
from sklearn.metrics import classification_report | |
import torch | |
import torch.nn as nn | |
import numpy as np | |
from alive_progress import alive_bar | |
class Preprocessor: | |
def __init__(self, modelName='bert-base-uncased'): | |
self.tokenizer = BertTokenizer.from_pretrained(modelName) | |
self.labelMap = { | |
0: 'Anxiety', | |
1: 'Depression', | |
2: 'Stress', | |
3: 'Happiness', | |
4: 'Relationship Issues', | |
5: 'Self-Harm', | |
6: 'Substance Abuse', | |
7: 'Trauma', | |
8: 'Obsessive Compulsive Disorder', | |
9: 'Eating Disorders', | |
10: 'Grief', | |
11: 'Phobias', | |
12: 'Bipolar Disorder', | |
13: 'Post-Traumatic Stress Disorder', | |
14: 'Mental Fatigue', | |
15: 'Mood Swings', | |
16: 'Anger Management', | |
17: 'Social Isolation', | |
18: 'Perfectionism', | |
19: 'Low Self-Esteem', | |
20: 'Family Issues' | |
} | |
self.keywords = { | |
'anxiety': 0, | |
'depressed': 1, | |
'sad': 1, | |
'stress': 2, | |
'happy': 3, | |
'relationship': 4, | |
'self-harm': 5, | |
'substance': 6, | |
'trauma': 7, | |
'ocd': 8, | |
'eating': 9, | |
'grief': 10, | |
'phobia': 11, | |
'bipolar': 12, | |
'ptsd': 13, | |
'fatigue': 14, | |
'mood': 15, | |
'anger': 16, | |
'isolated': 17, | |
'perfectionism': 18, | |
'self-esteem': 19, | |
'family': 20 | |
} | |
def tokenizeText(self, text, maxLength=128): | |
return self.tokenizer( | |
text, | |
padding='max_length', | |
truncation=True, | |
max_length=maxLength, | |
return_tensors='pt' | |
) | |
def preprocessDataset(self, texts): | |
inputIds, attentionMasks = [], [] | |
for text in texts: | |
encodedDict = self.tokenizeText(text) | |
inputIds.append(encodedDict['input_ids']) | |
attentionMasks.append(encodedDict['attention_mask']) | |
return torch.cat(inputIds, dim=0), torch.cat(attentionMasks, dim=0) | |
def labelContext(self, context): | |
context = context.lower() | |
pattern = r'\b(?:' + '|'.join(re.escape(keyword) for keyword in self.keywords.keys()) + r')\b' | |
match = re.search(pattern, context) | |
return self.keywords[match.group(0)] if match else None | |
class etal(Preprocessor): | |
def __init__(self, modelName='bert-base-uncased', numLabels=21): | |
super().__init__(modelName) | |
self.model = BertForSequenceClassification.from_pretrained(modelName, num_labels=numLabels) | |
self.criterion = nn.CrossEntropyLoss() | |
def train(self, texts, labels, epochs=3, batchSize=8, learningRate=2e-5): | |
inputIds, attentionMasks = self.preprocessDataset(texts) | |
labels = torch.tensor(labels, dtype=torch.long) | |
trainIdx, valIdx = train_test_split(np.arange(len(labels)), test_size=0.2, random_state=42) | |
trainIds, valIds = inputIds[trainIdx], inputIds[valIdx] | |
trainMasks, valMasks = attentionMasks[trainIdx], attentionMasks[valIdx] | |
trainLabels, valLabels = labels[trainIdx], labels[valIdx] | |
trainData = torch.utils.data.TensorDataset(trainIds, trainMasks, trainLabels) | |
valData = torch.utils.data.TensorDataset(valIds, valMasks, valLabels) | |
trainLoader = torch.utils.data.DataLoader(trainData, batch_size=batchSize, shuffle=True) | |
valLoader = torch.utils.data.DataLoader(valData, batch_size=batchSize) | |
optimizer = torch.optim.AdamW(self.model.parameters(), lr=learningRate) | |
bestValLoss = float('inf') | |
with alive_bar(epochs, title='Training Progress') as bar: | |
for epoch in range(epochs): | |
totalLoss = 0 | |
self.model.train() | |
for i, batch in enumerate(trainLoader): | |
batchIds, batchMasks, batchLabels = batch | |
self.model.zero_grad() | |
outputs = self.model(input_ids=batchIds, attention_mask=batchMasks, labels=batchLabels) | |
loss = outputs.loss | |
totalLoss += loss.item() | |
loss.backward() | |
optimizer.step() | |
print(f"Epoch {epoch + 1}/{epochs}, Batch {i + 1}/{len(trainLoader)}, Loss: {loss.item()}") | |
avgTrainLoss = totalLoss / len(trainLoader) | |
valLoss = self.evaluate(valLoader) | |
if valLoss < bestValLoss: | |
bestValLoss = valLoss | |
self.save('models', f'e{epoch}l{valLoss}.pt') | |
print(f"Model State Dict Saved at: {os.path.join(os.getcwd(), 'models', f'e{epoch}l{valLoss}.pt')}") | |
print(f'Epoch {epoch + 1}, Train Loss: {avgTrainLoss}, Validation Loss: {valLoss}') | |
bar() | |
def evaluate(self, dataLoader): | |
self.model.eval() | |
predictions, trueLabels = [], [] | |
totalLoss = 0 | |
with torch.no_grad(): | |
for batch in dataLoader: | |
batchIds, batchMasks, batchLabels = batch | |
outputs = self.model(input_ids=batchIds, attention_mask=batchMasks, labels=batchLabels) | |
logits = outputs.logits | |
loss = outputs.loss | |
totalLoss += loss.item() | |
predictions.extend(torch.argmax(logits, axis=1).cpu().numpy()) | |
trueLabels.extend(batchLabels.cpu().numpy()) | |
print(classification_report(trueLabels, predictions)) | |
return totalLoss / len(dataLoader) | |
def predict(self, text): | |
self.model.eval() | |
tokens = self.tokenizeText(text) | |
with torch.no_grad(): | |
outputs = self.model(input_ids=tokens['input_ids'], attention_mask=tokens['attention_mask']) | |
prediction = torch.argmax(outputs.logits, axis=1).item() | |
return self.labelMap.get(prediction) | |
def save(self, folder, filename): | |
if not os.path.exists(folder): | |
os.makedirs(folder) | |
filepath = os.path.join(folder, filename) | |
torch.save(self.model.state_dict(), filepath) | |
def load(self, filePath, best = True): | |
if best: | |
modelFiles = [f for f in os.listdir(filePath) if f.endswith('.pt')] | |
if not modelFiles: | |
print('No model files found in the specified folder.') | |
return | |
modelFiles.sort(key=lambda x: (int(x.split('e')[1].split('l')[0]), float(x.split('l')[1].split('.')[0]))) | |
bestModelFile = modelFiles[-1] | |
modelPath = os.path.join(filePath, bestModelFile) | |
self.model.load_state_dict(torch.load(modelPath)) | |
else: | |
self.model.load_state_dict(torch.load(filePath)) | |
print(f'Loaded model state dict') | |
self.model.eval() |