|
import torch |
|
from tokenizers import Tokenizer |
|
from torch.utils.data import DataLoader |
|
import uvicorn |
|
from fastapi import FastAPI |
|
from pydantic import BaseModel, Field |
|
from fastapi.responses import RedirectResponse |
|
from model import CustomDataset, TransformerEncoder, load_model_to_cpu |
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
tag2id = {"O": 0, "olumsuz": 1, "nötr": 2, "olumlu": 3, "org": 4} |
|
id2tag = {value: key for key, value in tag2id.items()} |
|
|
|
device = torch.device('cpu') |
|
def predict_fonk(model, device, example, tokenizer): |
|
model.to(device) |
|
model.eval() |
|
predictions = [] |
|
|
|
encodings_prdict = tokenizer.encode(example) |
|
|
|
predict_texts = [encodings_prdict.tokens] |
|
predict_input_ids = [encodings_prdict.ids] |
|
predict_attention_masks = [encodings_prdict.attention_mask] |
|
predict_token_type_ids = [encodings_prdict.type_ids] |
|
prediction_labels = [encodings_prdict.type_ids] |
|
|
|
predict_data = CustomDataset(predict_texts, predict_input_ids, predict_attention_masks, predict_token_type_ids, |
|
prediction_labels) |
|
|
|
predict_loader = DataLoader(predict_data, batch_size=1, shuffle=False) |
|
|
|
with torch.no_grad(): |
|
for dataset in predict_loader: |
|
batch_input_ids = dataset['input_ids'].to(device) |
|
batch_att_mask = dataset['attention_mask'].to(device) |
|
|
|
|
|
|
|
outputs = model(batch_input_ids, batch_att_mask) |
|
logits = outputs.view(-1, outputs.size(-1)) |
|
_, predicted = torch.max(logits, 1) |
|
|
|
|
|
predictions.append(predicted) |
|
|
|
results_list = [] |
|
entity_list = [] |
|
results_dict = {} |
|
trio = zip(predict_loader.dataset[0]["text"], predictions[0].tolist(), predict_attention_masks[0]) |
|
|
|
for i, (token, label, attention) in enumerate(trio): |
|
if attention != 0 and label != 0 and label !=4: |
|
for next_ones in predictions[0].tolist()[i+1:]: |
|
i+=1 |
|
if next_ones == 4: |
|
token = token +" "+ predict_loader.dataset[0]["text"][i] |
|
else:break |
|
if token not in entity_list: |
|
entity_list.append(token) |
|
results_list.append({"entity":token,"sentiment":id2tag.get(label)}) |
|
|
|
|
|
results_dict["entity_list"] = entity_list |
|
results_dict["results"] = results_list |
|
|
|
|
|
return results_dict |
|
|
|
model = TransformerEncoder() |
|
model = load_model_to_cpu(model, "model.pth") |
|
tokenizer = Tokenizer.from_file("tokenizer.json") |
|
|
|
class Item(BaseModel): |
|
text: str = Field(..., example="""Fiber 100mb SuperOnline kullanıcısıyım yaklaşık 2 haftadır @Twitch @Kick_Turkey gibi canlı yayın platformlarında 360p yayın izlerken donmalar yaşıyoruz. Başka hiç bir operatörler bu sorunu yaşamazken ben parasını verip alamadığım hizmeti neden ödeyeyim ? @Turkcell """) |
|
|
|
@app.get("/") |
|
async def root(): |
|
return RedirectResponse(url="/docs#/default/predict_predict__post") |
|
|
|
@app.post("/predict/", response_model=dict) |
|
async def predict(item: Item): |
|
|
|
|
|
predict_list = predict_fonk(model=model, device=device, example=item.text, tokenizer=tokenizer) |
|
|
|
|
|
|
|
return predict_list |
|
|
|
|
|
if __name__=="__main__": |
|
uvicorn.run(app,host="0.0.0.0",port=8000) |