apidetect / app.py
GautamGaur's picture
Update app.py
597cdbb verified
from fastapi import FastAPI,Header,HTTPException,Depends,WebSocket,WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allow all origins
allow_methods=["GET", "POST"], # Allow only GET and POST methods
allow_headers=["*"], # Allow all headers
)
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
from transformers import RobertaTokenizer, RobertaForSequenceClassification
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
from transformers import RobertaTokenizer, RobertaForSequenceClassification
from datetime import datetime
import logging
app = FastAPI()
# Load the tokenizer
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
# Load the first model
model_path1 = "model_ai_detection"
model1 = RobertaForSequenceClassification.from_pretrained(model_path1)
device1 = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model1.to(device1)
model1.eval()
# Load the second model
model_path2 = "best-ai-model" # Change this to your second model's path
model2 = RobertaForSequenceClassification.from_pretrained(model_path2)
device2 = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model2.to(device2)
model2.eval()
class TextData(BaseModel):
text: str
# Set up logging
def log_text(text: str):
timestamp = datetime.now().isoformat()
log_entry = f"{timestamp} - {text}\n"
with open("/tmp/text_logs.txt", "a") as file:
file.write(log_entry)
@app.post("/predict")
async def predict(data: TextData):
log_text(data.text)
inputs = tokenizer(data.text, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(device1) for k, v in inputs.items()}
with torch.no_grad():
outputs = model1(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
ai_prob = probs[0][1].item() * 100 # Probability of the text being AI-generated
message = "The text is likely generated by AI." if ai_prob > 50 else "The text is likely generated by a human."
return {
"score": ai_prob,
"message": message
}
@app.post("/predict_v2")
async def predict_v2(data: TextData):
log_text(data.text)
inputs = tokenizer(data.text, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(device2) for k, v in inputs.items()}
with torch.no_grad():
outputs = model2(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
ai_prob = probs[0][1].item() * 100 # Probability of the text being AI-generated
message = "The text is likely generated by AI." if ai_prob > 50 else "The text is likely generated by a human."
return {
"score": ai_prob,
"message": message
}
@app.get("/")
async def read_root():
return {"message": "Ready to go"}