parth parekh
working demo
7e63028
raw
history blame
3.74 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
from torch.nn.functional import softmax
import re
from predictor import predict, batch_predict # Assuming batch_predict is in predictor module
app = FastAPI(
title="Contact Information Detection API",
description="API for detecting contact information in text, great thanks to xxparthparekhxx/ContactShieldAI for the model",
version="1.0.0",
docs_url="/"
)
def preprocess_text(text):
# Remove all punctuation except for @ and . which are often used in email addresses
return re.sub(r'[^\w\s@.]', '', text)
class TextInput(BaseModel):
text: str
class BatchTextInput(BaseModel):
texts: list[str]
def check_regex_patterns(text):
patterns = [
r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', # Email
r'\b\d{3}[-.]?\d{3}[-.]?\d{4}\b', # Phone number
r'\b\d{5}(?:[-\s]\d{4})?\b', # ZIP code
r'\b\d+\s+[\w\s]+(?:street|st|avenue|ave|road|rd|highway|hwy|square|sq|trail|trl|drive|dr|court|ct|park|parkway|pkwy|circle|cir|boulevard|blvd)\b\s*(?:[a-z]+\s*\d{1,3})?(?:,\s*(?:apt|bldg|dept|fl|hngr|lot|pier|rm|ste|unit|#)\s*[a-z0-9-]+)?(?:,\s*[a-z]+\s*[a-z]{2}\s*\d{5}(?:-\d{4})?)?', # Street address
r'(?:http|https)://(?:www\.)?[a-zA-Z0-9-]+\.[a-zA-Z]{2,}(?:/[^\s]*)?' # Website URL
]
for pattern in patterns:
if re.search(pattern, text, re.IGNORECASE):
return True
return False
@app.post("/detect_contact", summary="Detect contact information in text")
async def detect_contact(input: TextInput):
try:
preprocessed_text = preprocess_text(input.text)
# First, check with regex patterns
if check_regex_patterns(preprocessed_text):
return {
"text": input.text,
"is_contact_info": True,
"method": "regex"
}
# If no regex patterns match, use the model
is_contact = predict(preprocessed_text)
return {
"text": input.text,
"is_contact_info": is_contact == 1,
"method": "model"
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/batch_detect_contact", summary="Detect contact information in batch of texts")
async def batch_detect_contact(inputs: BatchTextInput):
try:
# Preprocess all texts
preprocessed_texts = [preprocess_text(text) for text in inputs.texts]
# First, use regex to check patterns
regex_results = [check_regex_patterns(text) for text in preprocessed_texts]
# For texts where regex doesn't detect anything, use the model
texts_for_model = [text for text, regex_match in zip(preprocessed_texts, regex_results) if not regex_match]
if texts_for_model:
model_results = batch_predict(texts_for_model)
else:
model_results = []
# Prepare final results
results = []
model_idx = 0
for i, text in enumerate(preprocessed_texts):
if regex_results[i]:
results.append({
"text": inputs.texts[i],
"is_contact_info": True,
"method": "regex"
})
else:
is_contact = model_results[model_idx]
results.append({
"text": inputs.texts[i],
"is_contact_info": bool(is_contact), # Convert numpy bool
"method": "model"
})
model_idx += 1
return results
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))