GautamGaur commited on
Commit
f2b5652
1 Parent(s): 63e83a7

Upload main.py

Browse files
Files changed (1) hide show
  1. main.py +59 -0
main.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI,Header,HTTPException,Depends,WebSocket,WebSocketDisconnect
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+
4
+ app = FastAPI()
5
+
6
+ app.add_middleware(
7
+ CORSMiddleware,
8
+ allow_origins=["*"], # Allow all origins
9
+ allow_methods=["GET", "POST"], # Allow only GET and POST methods
10
+ allow_headers=["*"], # Allow all headers
11
+ )
12
+
13
+
14
+ from fastapi import FastAPI, HTTPException
15
+ from pydantic import BaseModel
16
+ import torch
17
+ from transformers import RobertaTokenizer, RobertaForSequenceClassification
18
+
19
+ app = FastAPI()
20
+
21
+ # Load the tokenizer
22
+ tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
23
+
24
+ # Load the model
25
+ model_path="model_ai_detection"
26
+ model = RobertaForSequenceClassification.from_pretrained(model_path)
27
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+ model.to(device)
29
+ model.eval()
30
+
31
+ class TextData(BaseModel):
32
+ text: str
33
+
34
+ @app.post("/predict")
35
+ async def predict(data: TextData):
36
+ inputs = tokenizer(data.text, return_tensors="pt", padding=True, truncation=True)
37
+ inputs = {k: v.to(device) for k, v in inputs.items()}
38
+
39
+ with torch.no_grad():
40
+ outputs = model(**inputs)
41
+ probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
42
+ ai_prob = probs[0][1].item() * 100 # Probability of the text being AI-generated
43
+
44
+ message = "The text is likely generated by AI." if ai_prob > 50 else "The text is likely generated by a human."
45
+
46
+ return {
47
+ "score": ai_prob,
48
+ "message": message
49
+ }
50
+
51
+
52
+
53
+ @app.get("/")
54
+ async def read_root():
55
+ return {"message": "Ready to go"}
56
+
57
+ if __name__ == "__main__":
58
+ import uvicorn
59
+ uvicorn.run(app, host="0.0.0.0", port=8000)