parth parekh
working demo
7e63028
raw
history blame
6.24 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchtext.vocab import build_vocab_from_iterator, GloVe
from torchtext.data.utils import get_tokenizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class ContactSharingClassifier(nn.Module):
def __init__(self, vocab_size, embed_dim, num_filters, filter_sizes, lstm_hidden_dim, output_dim, dropout, pad_idx):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
self.lstm = nn.LSTM(embed_dim, lstm_hidden_dim, bidirectional=True, batch_first=True)
self.convs = nn.ModuleList([
nn.Conv1d(in_channels=lstm_hidden_dim*2, out_channels=num_filters, kernel_size=fs)
for fs in filter_sizes
])
self.fc1 = nn.Linear(len(filter_sizes) * num_filters, len(filter_sizes) * num_filters // 2)
self.fc2 = nn.Linear(len(filter_sizes) * num_filters // 2, output_dim)
self.dropout = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm(len(filter_sizes) * num_filters)
def forward(self, text):
embedded = self.embedding(text)
lstm_out, _ = self.lstm(embedded)
lstm_out = lstm_out.permute(0, 2, 1)
conved = [F.relu(conv(lstm_out)) for conv in self.convs]
pooled = [F.max_pool1d(conv, conv.shape[2]).squeeze(2) for conv in conved]
cat = self.dropout(torch.cat(pooled, dim=1))
cat = self.layer_norm(cat)
x = F.relu(self.fc1(cat))
x = self.dropout(x)
return self.fc2(x)
# Initialize tokenizer and vocabulary
tokenizer = get_tokenizer("spacy", language="en_core_web_sm")
vocab = torch.load('vocab.pth') # Assuming you've saved the vocabulary
# Define text pipeline
def text_pipeline(x):
return [vocab[token] for token in tokenizer(x)]
# Model parameters
VOCAB_SIZE = len(vocab)
EMBED_DIM = 600
NUM_FILTERS = 600
FILTER_SIZES = [3, 4, 5, 6, 7, 8, 9, 10]
LSTM_HIDDEN_DIM = 768
OUTPUT_DIM = 2
DROPOUT = 0.5
PAD_IDX = vocab["<pad>"]
# Load the model
model = ContactSharingClassifier(VOCAB_SIZE, EMBED_DIM, NUM_FILTERS, FILTER_SIZES, LSTM_HIDDEN_DIM, OUTPUT_DIM, DROPOUT, PAD_IDX)
model.load_state_dict(torch.load('contact_sharing_epoch_1.pth', map_location=device))
model.to(device)
model.eval()
# Test sentences
test_sentences = [
"You can reach me at my electronic mail address, it's my first name dot last name at that popular search engine company's mail service.",
"Call me on my cellular device, the digits are the same as the year the Declaration of Independence was signed, followed by my birth year, twice.",
"Visit my online presence at triple w dot my full name without spaces or punctuation dot com.",
"Send a message to username 'not_my_real_name' on that instant messaging platform that starts with 'disc' and ends with 'ord'.",
"My contact info is hidden in this sentence: Eight Six Seven Five Three Oh Nine.",
"Find me on the professional networking site, just search for my name plus 'software engineer in San Francisco'.",
"My handle on the bird-themed social media platform is at symbol followed by 'definitely_not_my_email_address'.",
"You know that video sharing site? My channel is there, just add 'cool_coder_' before my full name, all lowercase.",
"I'm listed in the phone book under 'Smith, John' but replace 'Smith' with my actual last name and 'John' with my first name.",
"My contact details are encrypted: Rot13('[email protected]')",
# New non-contact sharing examples
"The weather today is absolutely beautiful, perfect for a picnic in the park.",
"I'm really excited about the new sci-fi movie coming out next month.",
"Did you hear about the latest advancements in artificial intelligence? It's fascinating!",
"I'm planning to go hiking this weekend in the nearby mountains.",
"The recipe calls for two cups of flour and a pinch of salt.",
"The annual tech conference will be held virtually this year due to ongoing health concerns.",
"I've been learning to play the guitar for the past six months. It's challenging but rewarding.",
"The local farmer's market has the freshest produce every Saturday morning.",
"Did you catch the game last night? It was an incredible comeback in the final quarter!",
"Lets do '42069' tonight it will be really fun what do you say ?"
]
# JIT Script the model for faster inference
scripted_model = torch.jit.script(model)
# Preallocate padding tensor to avoid repeated memory allocation
MAX_LEN = max(FILTER_SIZES)
padding_tensor = torch.zeros(1, MAX_LEN, dtype=torch.long).to(device)
# Prediction function using JIT and inference optimizations
def predict(text):
with torch.inference_mode(): # Use inference mode instead of no_grad
inputs = torch.tensor([text_pipeline(text)]).to(device)
# Perform padding if necessary
if inputs.size(1) < MAX_LEN:
inputs = torch.cat([inputs, padding_tensor[:, :MAX_LEN - inputs.size(1)]], dim=1)
# Pass inputs through the scripted model
outputs = scripted_model(inputs)
# Return predicted class
return torch.argmax(outputs, dim=1).item()
def batch_predict(texts):
with torch.inference_mode(): # Use inference mode for better performance
# Tokenize and convert to tensors
inputs = [torch.tensor(text_pipeline(text)) for text in texts]
# Pad all sequences to the length of the longest one in the batch
max_len = max(len(seq) for seq in inputs)
padded_inputs = torch.stack([torch.cat([seq, torch.zeros(max_len - len(seq), dtype=torch.long)]) for seq in inputs]).to(device)
# Pass the batch through the scripted model
outputs = scripted_model(padded_inputs)
# Return predicted classes for each sentence
predictions = torch.argmax(outputs, dim=1).cpu().numpy()
return predictions
# Test the sentences
for i, sentence in enumerate(test_sentences, 1):
prediction = predict(sentence)
result = "Contains contact info" if prediction == 1 else "No contact info"
print(f"Sentence {i}: {result}")
print(f"Text: {sentence}\n")