import warnings warnings.filterwarnings('ignore') warnings.filterwarnings("ignore", category=UserWarning) import torchvision torchvision.disable_beta_transforms_warning() import os import re from transformers import BertTokenizer, BertForSequenceClassification from transformers import DistilBertForSequenceClassification, DistilBertTokenizer from sklearn.model_selection import train_test_split from sklearn.metrics import classification_report import torch import torch.nn as nn import numpy as np from alive_progress import alive_bar class Preprocessor: def __init__(self, modelName='bert-base-uncased'): self.tokenizer = BertTokenizer.from_pretrained(modelName) self.labelMap = { 0: 'Anxiety', 1: 'Depression', 2: 'Stress', 3: 'Happiness', 4: 'Relationship Issues', 5: 'Self-Harm', 6: 'Substance Abuse', 7: 'Trauma', 8: 'Obsessive Compulsive Disorder', 9: 'Eating Disorders', 10: 'Grief', 11: 'Phobias', 12: 'Bipolar Disorder', 13: 'Post-Traumatic Stress Disorder', 14: 'Mental Fatigue', 15: 'Mood Swings', 16: 'Anger Management', 17: 'Social Isolation', 18: 'Perfectionism', 19: 'Low Self-Esteem', 20: 'Family Issues' } self.keywords = { 'anxiety': 0, 'depressed': 1, 'sad': 1, 'stress': 2, 'happy': 3, 'relationship': 4, 'self-harm': 5, 'substance': 6, 'trauma': 7, 'ocd': 8, 'eating': 9, 'grief': 10, 'phobia': 11, 'bipolar': 12, 'ptsd': 13, 'fatigue': 14, 'mood': 15, 'anger': 16, 'isolated': 17, 'perfectionism': 18, 'self-esteem': 19, 'family': 20 } def tokenizeText(self, text, maxLength=128): return self.tokenizer( text, padding='max_length', truncation=True, max_length=maxLength, return_tensors='pt' ) def preprocessDataset(self, texts): inputIds, attentionMasks = [], [] for text in texts: encodedDict = self.tokenizeText(text) inputIds.append(encodedDict['input_ids']) attentionMasks.append(encodedDict['attention_mask']) return torch.cat(inputIds, dim=0), torch.cat(attentionMasks, dim=0) def labelContext(self, context): context = context.lower() pattern = r'\b(?:' + '|'.join(re.escape(keyword) for keyword in self.keywords.keys()) + r')\b' match = re.search(pattern, context) return self.keywords[match.group(0)] if match else None class etal(Preprocessor): def __init__(self, modelName='bert-base-uncased', numLabels=21): super().__init__(modelName) self.model = BertForSequenceClassification.from_pretrained(modelName, num_labels=numLabels) self.criterion = nn.CrossEntropyLoss() def train(self, texts, labels, epochs=3, batchSize=8, learningRate=2e-5): inputIds, attentionMasks = self.preprocessDataset(texts) labels = torch.tensor(labels, dtype=torch.long) trainIdx, valIdx = train_test_split(np.arange(len(labels)), test_size=0.2, random_state=42) trainIds, valIds = inputIds[trainIdx], inputIds[valIdx] trainMasks, valMasks = attentionMasks[trainIdx], attentionMasks[valIdx] trainLabels, valLabels = labels[trainIdx], labels[valIdx] trainData = torch.utils.data.TensorDataset(trainIds, trainMasks, trainLabels) valData = torch.utils.data.TensorDataset(valIds, valMasks, valLabels) trainLoader = torch.utils.data.DataLoader(trainData, batch_size=batchSize, shuffle=True) valLoader = torch.utils.data.DataLoader(valData, batch_size=batchSize) optimizer = torch.optim.AdamW(self.model.parameters(), lr=learningRate) bestValLoss = float('inf') with alive_bar(epochs, title='Training Progress') as bar: for epoch in range(epochs): totalLoss = 0 self.model.train() for i, batch in enumerate(trainLoader): batchIds, batchMasks, batchLabels = batch self.model.zero_grad() outputs = self.model(input_ids=batchIds, attention_mask=batchMasks, labels=batchLabels) loss = outputs.loss totalLoss += loss.item() loss.backward() optimizer.step() print(f"Epoch {epoch + 1}/{epochs}, Batch {i + 1}/{len(trainLoader)}, Loss: {loss.item()}") avgTrainLoss = totalLoss / len(trainLoader) valLoss = self.evaluate(valLoader) if valLoss < bestValLoss: bestValLoss = valLoss self.save('models', f'e{epoch}l{valLoss}.pt') print(f"Model State Dict Saved at: {os.path.join(os.getcwd(), 'models', f'e{epoch}l{valLoss}.pt')}") print(f'Epoch {epoch + 1}, Train Loss: {avgTrainLoss}, Validation Loss: {valLoss}') bar() def evaluate(self, dataLoader): self.model.eval() predictions, trueLabels = [], [] totalLoss = 0 with torch.no_grad(): for batch in dataLoader: batchIds, batchMasks, batchLabels = batch outputs = self.model(input_ids=batchIds, attention_mask=batchMasks, labels=batchLabels) logits = outputs.logits loss = outputs.loss totalLoss += loss.item() predictions.extend(torch.argmax(logits, axis=1).cpu().numpy()) trueLabels.extend(batchLabels.cpu().numpy()) print(classification_report(trueLabels, predictions)) return totalLoss / len(dataLoader) def predict(self, text): self.model.eval() tokens = self.tokenizeText(text) with torch.no_grad(): outputs = self.model(input_ids=tokens['input_ids'], attention_mask=tokens['attention_mask']) prediction = torch.argmax(outputs.logits, axis=1).item() return self.labelMap.get(prediction) def save(self, folder, filename): if not os.path.exists(folder): os.makedirs(folder) filepath = os.path.join(folder, filename) torch.save(self.model.state_dict(), filepath) def load(self, filePath, best = True): if best: modelFiles = [f for f in os.listdir(filePath) if f.endswith('.pt')] if not modelFiles: print('No model files found in the specified folder.') return modelFiles.sort(key=lambda x: (int(x.split('e')[1].split('l')[0]), float(x.split('l')[1].split('.')[0]))) bestModelFile = modelFiles[-1] modelPath = os.path.join(filePath, bestModelFile) self.model.load_state_dict(torch.load(modelPath)) else: self.model.load_state_dict(torch.load(filePath)) print(f'Loaded model state dict') self.model.eval()