ContactShieldAI / raw /tester.py
parth parekh
added readme and raw training files
1d54b01
raw
history blame
5.16 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchtext.vocab import build_vocab_from_iterator, GloVe
from torchtext.data.utils import get_tokenizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class ContactSharingClassifier(nn.Module):
def __init__(self, vocab_size, embed_dim, num_filters, filter_sizes, lstm_hidden_dim, output_dim, dropout, pad_idx):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
self.lstm = nn.LSTM(embed_dim, lstm_hidden_dim, bidirectional=True, batch_first=True)
self.convs = nn.ModuleList([
nn.Conv1d(in_channels=lstm_hidden_dim*2, out_channels=num_filters, kernel_size=fs)
for fs in filter_sizes
])
self.fc1 = nn.Linear(len(filter_sizes) * num_filters, len(filter_sizes) * num_filters // 2)
self.fc2 = nn.Linear(len(filter_sizes) * num_filters // 2, output_dim)
self.dropout = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm(len(filter_sizes) * num_filters)
def forward(self, text):
embedded = self.embedding(text)
lstm_out, _ = self.lstm(embedded)
lstm_out = lstm_out.permute(0, 2, 1)
conved = [F.relu(conv(lstm_out)) for conv in self.convs]
pooled = [F.max_pool1d(conv, conv.shape[2]).squeeze(2) for conv in conved]
cat = self.dropout(torch.cat(pooled, dim=1))
cat = self.layer_norm(cat)
x = F.relu(self.fc1(cat))
x = self.dropout(x)
return self.fc2(x)
# Initialize tokenizer and vocabulary
tokenizer = get_tokenizer("spacy", language="en_core_web_sm")
vocab = torch.load('vocab.pth') # Assuming you've saved the vocabulary
# Define text pipeline
def text_pipeline(x):
return [vocab[token] for token in tokenizer(x)]
# Model parameters
VOCAB_SIZE = len(vocab)
EMBED_DIM = 600
NUM_FILTERS = 600
FILTER_SIZES = [3, 4, 5, 6, 7, 8, 9, 10]
LSTM_HIDDEN_DIM = 768
OUTPUT_DIM = 2
DROPOUT = 0.5
PAD_IDX = vocab["<pad>"]
# Load the model
model = ContactSharingClassifier(VOCAB_SIZE, EMBED_DIM, NUM_FILTERS, FILTER_SIZES, LSTM_HIDDEN_DIM, OUTPUT_DIM, DROPOUT, PAD_IDX)
model.load_state_dict(torch.load('contact_sharing_epoch_1.pth', map_location=device))
model.to(device)
model.eval()
# Test sentences
test_sentences = [
"You can reach me at my electronic mail address, it's my first name dot last name at that popular search engine company's mail service.",
"Call me on my cellular device, the digits are the same as the year the Declaration of Independence was signed, followed by my birth year, twice.",
"Visit my online presence at triple w dot my full name without spaces or punctuation dot com.",
"Send a message to username 'not_my_real_name' on that instant messaging platform that starts with 'disc' and ends with 'ord'.",
"My contact info is hidden in this sentence: Eight Six Seven Five Three Oh Nine.",
"Find me on the professional networking site, just search for my name plus 'software engineer in San Francisco'.",
"My handle on the bird-themed social media platform is at symbol followed by 'definitely_not_my_email_address'.",
"You know that video sharing site? My channel is there, just add 'cool_coder_' before my full name, all lowercase.",
"I'm listed in the phone book under 'Smith, John' but replace 'Smith' with my actual last name and 'John' with my first name.",
"My contact details are encrypted: Rot13('[email protected]')",
# New non-contact sharing examples
"The weather today is absolutely beautiful, perfect for a picnic in the park.",
"I'm really excited about the new sci-fi movie coming out next month.",
"Did you hear about the latest advancements in artificial intelligence? It's fascinating!",
"I'm planning to go hiking this weekend in the nearby mountains.",
"The recipe calls for two cups of flour and a pinch of salt.",
"The annual tech conference will be held virtually this year due to ongoing health concerns.",
"I've been learning to play the guitar for the past six months. It's challenging but rewarding.",
"The local farmer's market has the freshest produce every Saturday morning.",
"Did you catch the game last night? It was an incredible comeback in the final quarter!",
"Lets do '42069' tonight it will be really fun what do you say ?"
]
# Function to predict
def predict(text):
with torch.no_grad():
inputs = torch.tensor([text_pipeline(text)])
if inputs.size(1) < max(FILTER_SIZES):
# Pad the input if it's shorter than the largest filter size
padding = torch.zeros(1, max(FILTER_SIZES) - inputs.size(1), dtype=torch.long)
inputs = torch.cat([inputs, padding], dim=1)
inputs = inputs.to(device)
outputs = model(inputs)
return torch.argmax(outputs, dim=1).item()
# Test the sentences
for i, sentence in enumerate(test_sentences, 1):
prediction = predict(sentence)
result = "Contains contact info" if prediction == 1 else "No contact info"
print(f"Sentence {i}: {result}")
print(f"Text: {sentence}\n")