Spaces:
Sleeping
Sleeping
GautamGaur
commited on
Commit
β’
19d4944
1
Parent(s):
7e497bf
Rename main.py to app.py
Browse files- main.py β app.py +55 -59
main.py β app.py
RENAMED
@@ -1,59 +1,55 @@
|
|
1 |
-
from fastapi import FastAPI,Header,HTTPException,Depends,WebSocket,WebSocketDisconnect
|
2 |
-
from fastapi.middleware.cors import CORSMiddleware
|
3 |
-
|
4 |
-
app = FastAPI()
|
5 |
-
|
6 |
-
app.add_middleware(
|
7 |
-
CORSMiddleware,
|
8 |
-
allow_origins=["*"], # Allow all origins
|
9 |
-
allow_methods=["GET", "POST"], # Allow only GET and POST methods
|
10 |
-
allow_headers=["*"], # Allow all headers
|
11 |
-
)
|
12 |
-
|
13 |
-
|
14 |
-
from fastapi import FastAPI, HTTPException
|
15 |
-
from pydantic import BaseModel
|
16 |
-
import torch
|
17 |
-
from transformers import RobertaTokenizer, RobertaForSequenceClassification
|
18 |
-
|
19 |
-
app = FastAPI()
|
20 |
-
|
21 |
-
# Load the tokenizer
|
22 |
-
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
|
23 |
-
|
24 |
-
# Load the model
|
25 |
-
model_path="model_ai_detection"
|
26 |
-
model = RobertaForSequenceClassification.from_pretrained(model_path)
|
27 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
28 |
-
model.to(device)
|
29 |
-
model.eval()
|
30 |
-
|
31 |
-
class TextData(BaseModel):
|
32 |
-
text: str
|
33 |
-
|
34 |
-
@app.post("/predict")
|
35 |
-
async def predict(data: TextData):
|
36 |
-
inputs = tokenizer(data.text, return_tensors="pt", padding=True, truncation=True)
|
37 |
-
inputs = {k: v.to(device) for k, v in inputs.items()}
|
38 |
-
|
39 |
-
with torch.no_grad():
|
40 |
-
outputs = model(**inputs)
|
41 |
-
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
|
42 |
-
ai_prob = probs[0][1].item() * 100 # Probability of the text being AI-generated
|
43 |
-
|
44 |
-
message = "The text is likely generated by AI." if ai_prob > 50 else "The text is likely generated by a human."
|
45 |
-
|
46 |
-
return {
|
47 |
-
"score": ai_prob,
|
48 |
-
"message": message
|
49 |
-
}
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
@app.get("/")
|
54 |
-
async def read_root():
|
55 |
-
return {"message": "Ready to go"}
|
56 |
-
|
57 |
-
if __name__ == "__main__":
|
58 |
-
import uvicorn
|
59 |
-
uvicorn.run(app, host="0.0.0.0", port=8000)
|
|
|
1 |
+
from fastapi import FastAPI,Header,HTTPException,Depends,WebSocket,WebSocketDisconnect
|
2 |
+
from fastapi.middleware.cors import CORSMiddleware
|
3 |
+
|
4 |
+
app = FastAPI()
|
5 |
+
|
6 |
+
app.add_middleware(
|
7 |
+
CORSMiddleware,
|
8 |
+
allow_origins=["*"], # Allow all origins
|
9 |
+
allow_methods=["GET", "POST"], # Allow only GET and POST methods
|
10 |
+
allow_headers=["*"], # Allow all headers
|
11 |
+
)
|
12 |
+
|
13 |
+
|
14 |
+
from fastapi import FastAPI, HTTPException
|
15 |
+
from pydantic import BaseModel
|
16 |
+
import torch
|
17 |
+
from transformers import RobertaTokenizer, RobertaForSequenceClassification
|
18 |
+
|
19 |
+
app = FastAPI()
|
20 |
+
|
21 |
+
# Load the tokenizer
|
22 |
+
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
|
23 |
+
|
24 |
+
# Load the model
|
25 |
+
model_path="model_ai_detection"
|
26 |
+
model = RobertaForSequenceClassification.from_pretrained(model_path)
|
27 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
28 |
+
model.to(device)
|
29 |
+
model.eval()
|
30 |
+
|
31 |
+
class TextData(BaseModel):
|
32 |
+
text: str
|
33 |
+
|
34 |
+
@app.post("/predict")
|
35 |
+
async def predict(data: TextData):
|
36 |
+
inputs = tokenizer(data.text, return_tensors="pt", padding=True, truncation=True)
|
37 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
38 |
+
|
39 |
+
with torch.no_grad():
|
40 |
+
outputs = model(**inputs)
|
41 |
+
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
|
42 |
+
ai_prob = probs[0][1].item() * 100 # Probability of the text being AI-generated
|
43 |
+
|
44 |
+
message = "The text is likely generated by AI." if ai_prob > 50 else "The text is likely generated by a human."
|
45 |
+
|
46 |
+
return {
|
47 |
+
"score": ai_prob,
|
48 |
+
"message": message
|
49 |
+
}
|
50 |
+
|
51 |
+
|
52 |
+
|
53 |
+
@app.get("/")
|
54 |
+
async def read_root():
|
55 |
+
return {"message": "Ready to go"}
|
|
|
|
|
|
|
|