from fastapi import FastAPI, HTTPException from pydantic import BaseModel import torch from torch.nn.functional import softmax import re from predictor import predict, batch_predict # Assuming batch_predict is in predictor module app = FastAPI( title="Contact Information Detection API", description="API for detecting contact information in text, great thanks to xxparthparekhxx/ContactShieldAI for the model", version="1.0.0", docs_url="/" ) def preprocess_text(text): # Remove all punctuation except for @ and . which are often used in email addresses return re.sub(r'[^\w\s@.]', '', text) class TextInput(BaseModel): text: str class BatchTextInput(BaseModel): texts: list[str] def check_regex_patterns(text): patterns = [ r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', # Email r'\b\d{3}[-.]?\d{3}[-.]?\d{4}\b', # Phone number r'\b\d{5}(?:[-\s]\d{4})?\b', # ZIP code r'\b\d+\s+[\w\s]+(?:street|st|avenue|ave|road|rd|highway|hwy|square|sq|trail|trl|drive|dr|court|ct|park|parkway|pkwy|circle|cir|boulevard|blvd)\b\s*(?:[a-z]+\s*\d{1,3})?(?:,\s*(?:apt|bldg|dept|fl|hngr|lot|pier|rm|ste|unit|#)\s*[a-z0-9-]+)?(?:,\s*[a-z]+\s*[a-z]{2}\s*\d{5}(?:-\d{4})?)?', # Street address r'(?:http|https)://(?:www\.)?[a-zA-Z0-9-]+\.[a-zA-Z]{2,}(?:/[^\s]*)?' # Website URL ] for pattern in patterns: if re.search(pattern, text, re.IGNORECASE): return True return False @app.post("/detect_contact", summary="Detect contact information in text") async def detect_contact(input: TextInput): try: preprocessed_text = preprocess_text(input.text) # First, check with regex patterns if check_regex_patterns(preprocessed_text): return { "text": input.text, "is_contact_info": True, "method": "regex" } # If no regex patterns match, use the model is_contact = predict(preprocessed_text) return { "text": input.text, "is_contact_info": is_contact == 1, "method": "model" } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/batch_detect_contact", summary="Detect contact information in batch of texts") async def batch_detect_contact(inputs: BatchTextInput): try: # Preprocess all texts preprocessed_texts = [preprocess_text(text) for text in inputs.texts] # First, use regex to check patterns regex_results = [check_regex_patterns(text) for text in preprocessed_texts] # For texts where regex doesn't detect anything, use the model texts_for_model = [text for text, regex_match in zip(preprocessed_texts, regex_results) if not regex_match] if texts_for_model: model_results = batch_predict(texts_for_model) else: model_results = [] # Prepare final results results = [] model_idx = 0 for i, text in enumerate(preprocessed_texts): if regex_results[i]: results.append({ "text": inputs.texts[i], "is_contact_info": True, "method": "regex" }) else: is_contact = model_results[model_idx] results.append({ "text": inputs.texts[i], "is_contact_info": bool(is_contact), # Convert numpy bool "method": "model" }) model_idx += 1 return results except Exception as e: raise HTTPException(status_code=500, detail=str(e))