|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torchtext.vocab import build_vocab_from_iterator, GloVe |
|
from torchtext.data.utils import get_tokenizer |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
class ContactSharingClassifier(nn.Module): |
|
def __init__(self, vocab_size, embed_dim, num_filters, filter_sizes, lstm_hidden_dim, output_dim, dropout, pad_idx): |
|
super().__init__() |
|
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx) |
|
self.lstm = nn.LSTM(embed_dim, lstm_hidden_dim, bidirectional=True, batch_first=True) |
|
self.convs = nn.ModuleList([ |
|
nn.Conv1d(in_channels=lstm_hidden_dim*2, out_channels=num_filters, kernel_size=fs) |
|
for fs in filter_sizes |
|
]) |
|
self.fc1 = nn.Linear(len(filter_sizes) * num_filters, len(filter_sizes) * num_filters // 2) |
|
self.fc2 = nn.Linear(len(filter_sizes) * num_filters // 2, output_dim) |
|
self.dropout = nn.Dropout(dropout) |
|
self.layer_norm = nn.LayerNorm(len(filter_sizes) * num_filters) |
|
|
|
def forward(self, text): |
|
embedded = self.embedding(text) |
|
lstm_out, _ = self.lstm(embedded) |
|
lstm_out = lstm_out.permute(0, 2, 1) |
|
conved = [F.relu(conv(lstm_out)) for conv in self.convs] |
|
pooled = [F.max_pool1d(conv, conv.shape[2]).squeeze(2) for conv in conved] |
|
cat = self.dropout(torch.cat(pooled, dim=1)) |
|
cat = self.layer_norm(cat) |
|
x = F.relu(self.fc1(cat)) |
|
x = self.dropout(x) |
|
return self.fc2(x) |
|
|
|
|
|
tokenizer = get_tokenizer("spacy", language="en_core_web_sm") |
|
vocab = torch.load('vocab.pth') |
|
|
|
|
|
def text_pipeline(x): |
|
return [vocab[token] for token in tokenizer(x)] |
|
|
|
|
|
VOCAB_SIZE = len(vocab) |
|
EMBED_DIM = 600 |
|
NUM_FILTERS = 600 |
|
FILTER_SIZES = [3, 4, 5, 6, 7, 8, 9, 10] |
|
LSTM_HIDDEN_DIM = 768 |
|
OUTPUT_DIM = 2 |
|
DROPOUT = 0.5 |
|
PAD_IDX = vocab["<pad>"] |
|
|
|
|
|
|
|
model = ContactSharingClassifier(VOCAB_SIZE, EMBED_DIM, NUM_FILTERS, FILTER_SIZES, LSTM_HIDDEN_DIM, OUTPUT_DIM, DROPOUT, PAD_IDX) |
|
model.load_state_dict(torch.load('contact_sharing_epoch_1.pth', map_location=device)) |
|
model.to(device) |
|
model.eval() |
|
|
|
|
|
test_sentences = [ |
|
"You can reach me at my electronic mail address, it's my first name dot last name at that popular search engine company's mail service.", |
|
"Call me on my cellular device, the digits are the same as the year the Declaration of Independence was signed, followed by my birth year, twice.", |
|
"Visit my online presence at triple w dot my full name without spaces or punctuation dot com.", |
|
"Send a message to username 'not_my_real_name' on that instant messaging platform that starts with 'disc' and ends with 'ord'.", |
|
"My contact info is hidden in this sentence: Eight Six Seven Five Three Oh Nine.", |
|
"Find me on the professional networking site, just search for my name plus 'software engineer in San Francisco'.", |
|
"My handle on the bird-themed social media platform is at symbol followed by 'definitely_not_my_email_address'.", |
|
"You know that video sharing site? My channel is there, just add 'cool_coder_' before my full name, all lowercase.", |
|
"I'm listed in the phone book under 'Smith, John' but replace 'Smith' with my actual last name and 'John' with my first name.", |
|
"My contact details are encrypted: Rot13('[email protected]')", |
|
|
|
|
|
"The weather today is absolutely beautiful, perfect for a picnic in the park.", |
|
"I'm really excited about the new sci-fi movie coming out next month.", |
|
"Did you hear about the latest advancements in artificial intelligence? It's fascinating!", |
|
"I'm planning to go hiking this weekend in the nearby mountains.", |
|
"The recipe calls for two cups of flour and a pinch of salt.", |
|
"The annual tech conference will be held virtually this year due to ongoing health concerns.", |
|
"I've been learning to play the guitar for the past six months. It's challenging but rewarding.", |
|
"The local farmer's market has the freshest produce every Saturday morning.", |
|
"Did you catch the game last night? It was an incredible comeback in the final quarter!", |
|
"Lets do '42069' tonight it will be really fun what do you say ?" |
|
] |
|
|
|
|
|
scripted_model = torch.jit.script(model) |
|
|
|
|
|
MAX_LEN = max(FILTER_SIZES) |
|
padding_tensor = torch.zeros(1, MAX_LEN, dtype=torch.long).to(device) |
|
|
|
|
|
def predict(text): |
|
with torch.inference_mode(): |
|
inputs = torch.tensor([text_pipeline(text)]).to(device) |
|
|
|
|
|
if inputs.size(1) < MAX_LEN: |
|
inputs = torch.cat([inputs, padding_tensor[:, :MAX_LEN - inputs.size(1)]], dim=1) |
|
|
|
|
|
outputs = scripted_model(inputs) |
|
|
|
|
|
return torch.argmax(outputs, dim=1).item() |
|
|
|
def batch_predict(texts): |
|
with torch.inference_mode(): |
|
|
|
inputs = [torch.tensor(text_pipeline(text)) for text in texts] |
|
|
|
|
|
max_len = max(len(seq) for seq in inputs) |
|
padded_inputs = torch.stack([torch.cat([seq, torch.zeros(max_len - len(seq), dtype=torch.long)]) for seq in inputs]).to(device) |
|
|
|
|
|
outputs = scripted_model(padded_inputs) |
|
|
|
|
|
predictions = torch.argmax(outputs, dim=1).cpu().numpy() |
|
return predictions |
|
|
|
|
|
for i, sentence in enumerate(test_sentences, 1): |
|
prediction = predict(sentence) |
|
result = "Contains contact info" if prediction == 1 else "No contact info" |
|
print(f"Sentence {i}: {result}") |
|
print(f"Text: {sentence}\n") |
|
|