ContactShieldAI / raw /trainer.py
parth parekh
added readme and raw training files
1d54b01
raw
history blame
7.01 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchtext.vocab import build_vocab_from_iterator, GloVe
from torchtext.data.utils import get_tokenizer
import json
from sklearn.model_selection import KFold
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class ContactSharingDataset(Dataset):
def __init__(self, data, text_pipeline, label_pipeline):
self.data = data
self.text_pipeline = text_pipeline
self.label_pipeline = label_pipeline
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
text, label = self.data[idx]
return self.text_pipeline(text), self.label_pipeline(label)
class EnhancedContactSharingModel(nn.Module):
def __init__(self, vocab_size, embed_dim, num_filters, filter_sizes, lstm_hidden_dim, output_dim, dropout, pad_idx):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
self.lstm = nn.LSTM(embed_dim, lstm_hidden_dim, bidirectional=True, batch_first=True)
self.convs = nn.ModuleList([
nn.Conv1d(in_channels=lstm_hidden_dim*2, out_channels=num_filters, kernel_size=fs)
for fs in filter_sizes
])
self.fc1 = nn.Linear(len(filter_sizes) * num_filters, len(filter_sizes) * num_filters // 2)
self.fc2 = nn.Linear(len(filter_sizes) * num_filters // 2, output_dim)
self.dropout = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm(len(filter_sizes) * num_filters)
def forward(self, text):
embedded = self.embedding(text)
lstm_out, _ = self.lstm(embedded)
lstm_out = lstm_out.permute(0, 2, 1)
conved = [F.relu(conv(lstm_out)) for conv in self.convs]
pooled = [F.max_pool1d(conv, conv.shape[2]).squeeze(2) for conv in conved]
cat = self.dropout(torch.cat(pooled, dim=1))
cat = self.layer_norm(cat)
x = F.relu(self.fc1(cat))
x = self.dropout(x)
return self.fc2(x)
def load_data(filename='contacts_data.json'):
with open(filename, 'r') as f:
data = json.load(f)
return [(item['text'], item['label']) for item in data]
tokenizer = get_tokenizer("spacy", language="en_core_web_sm")
def yield_tokens(data_iter):
for text, _ in data_iter:
yield tokenizer(text)
data = load_data()
vocab = build_vocab_from_iterator(yield_tokens(data), specials=["<unk>", "<pad>"])
vocab.set_default_index(vocab["<unk>"])
glove = GloVe(name="6B", dim=300)
pretrained_embedding = torch.zeros(len(vocab), 300)
for token, index in vocab.get_stoi().items():
if token in glove.stoi:
pretrained_embedding[index] = glove[token]
def text_pipeline(x):
return [vocab[token] for token in tokenizer(x)]
def label_pipeline(x):
return int(x)
def collate_batch(batch):
label_list, text_list = [], []
for (_text, _label) in batch:
label_list.append(_label)
processed_text = torch.tensor(_text, dtype=torch.int64)
text_list.append(processed_text)
label_list = torch.tensor(label_list, dtype=torch.int64)
text_list = nn.utils.rnn.pad_sequence(text_list, batch_first=True, padding_value=vocab["<pad>"])
return text_list, label_list
VOCAB_SIZE = len(vocab)
EMBED_DIM = 600
NUM_FILTERS = 600
FILTER_SIZES = [3, 4, 5, 6, 7, 8, 9, 10]
LSTM_HIDDEN_DIM = 768
OUTPUT_DIM = 2
DROPOUT = 0.5
PAD_IDX = vocab["<pad>"]
model = EnhancedContactSharingModel(VOCAB_SIZE, EMBED_DIM, NUM_FILTERS, FILTER_SIZES, LSTM_HIDDEN_DIM, OUTPUT_DIM, DROPOUT, PAD_IDX).to(device)
pretrained_embedding_padded = torch.zeros(VOCAB_SIZE, EMBED_DIM)
pretrained_embedding_padded[:, :300] = pretrained_embedding
model.embedding.weight.data.copy_(pretrained_embedding_padded)
def train_model(model, train_loader, val_loader, optimizer, criterion, scheduler, num_epochs=15):
best_val_loss = float('inf')
for epoch in range(num_epochs):
model.train()
total_loss = 0
for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
text, labels = batch
text, labels = text.to(device), labels.to(device)
optimizer.zero_grad()
predictions = model(text)
loss = criterion(predictions, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_train_loss = total_loss / len(train_loader)
val_loss = evaluate(model, val_loader, criterion)
scheduler.step(val_loss)
print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {val_loss:.4f}")
if val_loss < best_val_loss:
best_val_loss = val_loss
torch.save(model.state_dict(), 'best_model.pth')
def evaluate(model, data_loader, criterion):
model.eval()
total_loss = 0
with torch.no_grad():
for batch in data_loader:
text, labels = batch
text, labels = text.to(device), labels.to(device)
predictions = model(text)
loss = criterion(predictions, labels)
total_loss += loss.item()
return total_loss / len(data_loader)
def k_fold_cross_validation(model, dataset, k=5, batch_size=128, num_epochs=4):
kf = KFold(n_splits=k, shuffle=True, random_state=42)
for fold, (train_idx, val_idx) in enumerate(kf.split(dataset)):
print(f"Fold {fold+1}/{k}")
train_subsampler = torch.utils.data.SubsetRandomSampler(train_idx)
val_subsampler = torch.utils.data.SubsetRandomSampler(val_idx)
train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_subsampler, collate_fn=collate_batch)
val_loader = DataLoader(dataset, batch_size=batch_size, sampler=val_subsampler, collate_fn=collate_batch)
model.apply(lambda m: m.reset_parameters() if hasattr(m, 'reset_parameters') else None)
model.embedding.weight.data.copy_(pretrained_embedding_padded)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-5)
criterion = nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3)
train_model(model, train_loader, val_loader, optimizer, criterion, scheduler, num_epochs)
def predict(text):
model.eval()
with torch.no_grad():
text = torch.tensor(text_pipeline(text)).unsqueeze(0).to(device)
output = model(text)
return output.argmax(1).item()
if __name__ == "__main__":
dataset = ContactSharingDataset(data, text_pipeline, label_pipeline)
k_fold_cross_validation(model, dataset)
sample_text = "Please contact me at [email protected] or call 555-1234."
prediction = predict(sample_text)
print(f"Prediction for '{sample_text}': {'Contains contact info' if prediction == 1 else 'No contact info'}")