File size: 3,742 Bytes
7e63028 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
from torch.nn.functional import softmax
import re
from predictor import predict, batch_predict # Assuming batch_predict is in predictor module
app = FastAPI(
title="Contact Information Detection API",
description="API for detecting contact information in text, great thanks to xxparthparekhxx/ContactShieldAI for the model",
version="1.0.0",
docs_url="/"
)
def preprocess_text(text):
# Remove all punctuation except for @ and . which are often used in email addresses
return re.sub(r'[^\w\s@.]', '', text)
class TextInput(BaseModel):
text: str
class BatchTextInput(BaseModel):
texts: list[str]
def check_regex_patterns(text):
patterns = [
r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', # Email
r'\b\d{3}[-.]?\d{3}[-.]?\d{4}\b', # Phone number
r'\b\d{5}(?:[-\s]\d{4})?\b', # ZIP code
r'\b\d+\s+[\w\s]+(?:street|st|avenue|ave|road|rd|highway|hwy|square|sq|trail|trl|drive|dr|court|ct|park|parkway|pkwy|circle|cir|boulevard|blvd)\b\s*(?:[a-z]+\s*\d{1,3})?(?:,\s*(?:apt|bldg|dept|fl|hngr|lot|pier|rm|ste|unit|#)\s*[a-z0-9-]+)?(?:,\s*[a-z]+\s*[a-z]{2}\s*\d{5}(?:-\d{4})?)?', # Street address
r'(?:http|https)://(?:www\.)?[a-zA-Z0-9-]+\.[a-zA-Z]{2,}(?:/[^\s]*)?' # Website URL
]
for pattern in patterns:
if re.search(pattern, text, re.IGNORECASE):
return True
return False
@app.post("/detect_contact", summary="Detect contact information in text")
async def detect_contact(input: TextInput):
try:
preprocessed_text = preprocess_text(input.text)
# First, check with regex patterns
if check_regex_patterns(preprocessed_text):
return {
"text": input.text,
"is_contact_info": True,
"method": "regex"
}
# If no regex patterns match, use the model
is_contact = predict(preprocessed_text)
return {
"text": input.text,
"is_contact_info": is_contact == 1,
"method": "model"
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/batch_detect_contact", summary="Detect contact information in batch of texts")
async def batch_detect_contact(inputs: BatchTextInput):
try:
# Preprocess all texts
preprocessed_texts = [preprocess_text(text) for text in inputs.texts]
# First, use regex to check patterns
regex_results = [check_regex_patterns(text) for text in preprocessed_texts]
# For texts where regex doesn't detect anything, use the model
texts_for_model = [text for text, regex_match in zip(preprocessed_texts, regex_results) if not regex_match]
if texts_for_model:
model_results = batch_predict(texts_for_model)
else:
model_results = []
# Prepare final results
results = []
model_idx = 0
for i, text in enumerate(preprocessed_texts):
if regex_results[i]:
results.append({
"text": inputs.texts[i],
"is_contact_info": True,
"method": "regex"
})
else:
is_contact = model_results[model_idx]
results.append({
"text": inputs.texts[i],
"is_contact_info": bool(is_contact), # Convert numpy bool
"method": "model"
})
model_idx += 1
return results
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) |