|
from fastapi import FastAPI, HTTPException |
|
from pydantic import BaseModel |
|
import torch |
|
from torch.nn.functional import softmax |
|
import re |
|
from predictor import predict, batch_predict |
|
|
|
app = FastAPI( |
|
title="Contact Information Detection API", |
|
description="API for detecting contact information in text, great thanks to xxparthparekhxx/ContactShieldAI for the model", |
|
version="1.0.0", |
|
docs_url="/" |
|
) |
|
|
|
def preprocess_text(text): |
|
|
|
return re.sub(r'[^\w\s@.]', '', text) |
|
|
|
class TextInput(BaseModel): |
|
text: str |
|
|
|
class BatchTextInput(BaseModel): |
|
texts: list[str] |
|
|
|
def check_regex_patterns(text): |
|
patterns = [ |
|
r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', |
|
r'\b\d{3}[-.]?\d{3}[-.]?\d{4}\b', |
|
r'\b\d{5}(?:[-\s]\d{4})?\b', |
|
r'\b\d+\s+[\w\s]+(?:street|st|avenue|ave|road|rd|highway|hwy|square|sq|trail|trl|drive|dr|court|ct|park|parkway|pkwy|circle|cir|boulevard|blvd)\b\s*(?:[a-z]+\s*\d{1,3})?(?:,\s*(?:apt|bldg|dept|fl|hngr|lot|pier|rm|ste|unit|#)\s*[a-z0-9-]+)?(?:,\s*[a-z]+\s*[a-z]{2}\s*\d{5}(?:-\d{4})?)?', |
|
r'(?:http|https)://(?:www\.)?[a-zA-Z0-9-]+\.[a-zA-Z]{2,}(?:/[^\s]*)?' |
|
] |
|
|
|
for pattern in patterns: |
|
if re.search(pattern, text, re.IGNORECASE): |
|
return True |
|
return False |
|
|
|
@app.post("/detect_contact", summary="Detect contact information in text") |
|
async def detect_contact(input: TextInput): |
|
try: |
|
preprocessed_text = preprocess_text(input.text) |
|
|
|
|
|
if check_regex_patterns(preprocessed_text): |
|
return { |
|
"text": input.text, |
|
"is_contact_info": True, |
|
"method": "regex" |
|
} |
|
|
|
|
|
is_contact = predict(preprocessed_text) |
|
return { |
|
"text": input.text, |
|
"is_contact_info": is_contact == 1, |
|
"method": "model" |
|
} |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
@app.post("/batch_detect_contact", summary="Detect contact information in batch of texts") |
|
async def batch_detect_contact(inputs: BatchTextInput): |
|
try: |
|
|
|
preprocessed_texts = [preprocess_text(text) for text in inputs.texts] |
|
|
|
|
|
regex_results = [check_regex_patterns(text) for text in preprocessed_texts] |
|
|
|
|
|
|
|
texts_for_model = [text for text, regex_match in zip(preprocessed_texts, regex_results) if not regex_match] |
|
if texts_for_model: |
|
model_results = batch_predict(texts_for_model) |
|
else: |
|
model_results = [] |
|
|
|
|
|
results = [] |
|
model_idx = 0 |
|
for i, text in enumerate(preprocessed_texts): |
|
if regex_results[i]: |
|
results.append({ |
|
"text": inputs.texts[i], |
|
"is_contact_info": True, |
|
"method": "regex" |
|
}) |
|
else: |
|
is_contact = model_results[model_idx] |
|
results.append({ |
|
"text": inputs.texts[i], |
|
"is_contact_info": bool(is_contact), |
|
"method": "model" |
|
}) |
|
model_idx += 1 |
|
|
|
return results |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |