File size: 3,364 Bytes
707d24a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import torch
from tokenizers import Tokenizer
from torch.utils.data import DataLoader
import uvicorn
from fastapi import FastAPI
from pydantic import BaseModel, Field
from fastapi.responses import RedirectResponse
from model import CustomDataset, TransformerEncoder, load_model_to_cpu
app = FastAPI()
tag2id = {"O": 0, "olumsuz": 1, "nötr": 2, "olumlu": 3, "org": 4}
id2tag = {value: key for key, value in tag2id.items()}
device = torch.device('cpu')
def predict_fonk(model, device, example, tokenizer):
model.to(device)
model.eval()
predictions = []
encodings_prdict = tokenizer.encode(example)
predict_texts = [encodings_prdict.tokens]
predict_input_ids = [encodings_prdict.ids]
predict_attention_masks = [encodings_prdict.attention_mask]
predict_token_type_ids = [encodings_prdict.type_ids]
prediction_labels = [encodings_prdict.type_ids]
predict_data = CustomDataset(predict_texts, predict_input_ids, predict_attention_masks, predict_token_type_ids,
prediction_labels)
predict_loader = DataLoader(predict_data, batch_size=1, shuffle=False)
with torch.no_grad():
for dataset in predict_loader:
batch_input_ids = dataset['input_ids'].to(device)
batch_att_mask = dataset['attention_mask'].to(device)
outputs = model(batch_input_ids, batch_att_mask)
logits = outputs.view(-1, outputs.size(-1)) # Flatten the outputs
_, predicted = torch.max(logits, 1)
# Ignore padding tokens for predictions
predictions.append(predicted)
results_list = []
entity_list = []
results_dict = {}
trio = zip(predict_loader.dataset[0]["text"], predictions[0].tolist(), predict_attention_masks[0])
for i, (token, label, attention) in enumerate(trio):
if attention != 0 and label != 0 and label !=4:
for next_ones in predictions[0].tolist()[i+1:]:
i+=1
if next_ones == 4:
token = token +" "+ predict_loader.dataset[0]["text"][i]
else:break
if token not in entity_list:
entity_list.append(token)
results_list.append({"entity":token,"sentiment":id2tag.get(label)})
results_dict["entity_list"] = entity_list
results_dict["results"] = results_list
return results_dict
model = TransformerEncoder()
model = load_model_to_cpu(model, "model.pth")
tokenizer = Tokenizer.from_file("tokenizer.json")
class Item(BaseModel):
text: str = Field(..., example="""Fiber 100mb SuperOnline kullanıcısıyım yaklaşık 2 haftadır @Twitch @Kick_Turkey gibi canlı yayın platformlarında 360p yayın izlerken donmalar yaşıyoruz. Başka hiç bir operatörler bu sorunu yaşamazken ben parasını verip alamadığım hizmeti neden ödeyeyim ? @Turkcell """)
@app.get("/")
async def root():
return RedirectResponse(url="/docs#/default/predict_predict__post")
@app.post("/predict/", response_model=dict)
async def predict(item: Item):
predict_list = predict_fonk(model=model, device=device, example=item.text, tokenizer=tokenizer)
#Buraya model'in çıktısı gelecek
#Çıktı formatı aşağıdaki örnek gibi olacak
return predict_list
if __name__=="__main__":
uvicorn.run(app,host="0.0.0.0",port=8000) |