File size: 3,364 Bytes
707d24a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
from tokenizers import Tokenizer
from torch.utils.data import DataLoader
import uvicorn
from fastapi import FastAPI
from pydantic import BaseModel, Field
from fastapi.responses import RedirectResponse
from model import CustomDataset, TransformerEncoder, load_model_to_cpu

app = FastAPI()



tag2id = {"O": 0, "olumsuz": 1, "nötr": 2, "olumlu": 3, "org": 4}
id2tag = {value: key for key, value in tag2id.items()}

device = torch.device('cpu')
def predict_fonk(model, device, example, tokenizer):
    model.to(device)
    model.eval()
    predictions = []

    encodings_prdict = tokenizer.encode(example)

    predict_texts = [encodings_prdict.tokens]
    predict_input_ids = [encodings_prdict.ids]
    predict_attention_masks = [encodings_prdict.attention_mask]
    predict_token_type_ids = [encodings_prdict.type_ids]
    prediction_labels = [encodings_prdict.type_ids]

    predict_data = CustomDataset(predict_texts, predict_input_ids, predict_attention_masks, predict_token_type_ids,
                                 prediction_labels)

    predict_loader = DataLoader(predict_data, batch_size=1, shuffle=False)

    with torch.no_grad():
        for dataset in predict_loader:
            batch_input_ids = dataset['input_ids'].to(device)
            batch_att_mask = dataset['attention_mask'].to(device)



            outputs = model(batch_input_ids, batch_att_mask)
            logits = outputs.view(-1, outputs.size(-1))  # Flatten the outputs
            _, predicted = torch.max(logits, 1)

            # Ignore padding tokens for predictions
            predictions.append(predicted)

    results_list = []
    entity_list = []
    results_dict = {}
    trio = zip(predict_loader.dataset[0]["text"], predictions[0].tolist(), predict_attention_masks[0])
    
    for i, (token, label, attention) in enumerate(trio):
        if attention != 0 and label != 0 and label !=4:
            for next_ones in predictions[0].tolist()[i+1:]:
                i+=1
                if next_ones == 4:
                    token = token +" "+ predict_loader.dataset[0]["text"][i]
                else:break
            if token not in entity_list:
                entity_list.append(token)
                results_list.append({"entity":token,"sentiment":id2tag.get(label)})


    results_dict["entity_list"] = entity_list
    results_dict["results"] = results_list


    return results_dict

model = TransformerEncoder()
model = load_model_to_cpu(model, "model.pth")
tokenizer = Tokenizer.from_file("tokenizer.json")

class Item(BaseModel):
    text: str = Field(..., example="""Fiber 100mb SuperOnline kullanıcısıyım yaklaşık 2 haftadır @Twitch @Kick_Turkey gibi canlı yayın platformlarında 360p yayın izlerken donmalar yaşıyoruz.  Başka hiç bir operatörler bu sorunu yaşamazken ben parasını verip alamadığım hizmeti neden ödeyeyim ? @Turkcell """)
    
@app.get("/")
async def root():
    return RedirectResponse(url="/docs#/default/predict_predict__post")

@app.post("/predict/", response_model=dict)
async def predict(item: Item):


    predict_list = predict_fonk(model=model, device=device, example=item.text, tokenizer=tokenizer)

    #Buraya model'in çıktısı gelecek
    #Çıktı formatı aşağıdaki örnek gibi olacak
    return predict_list


if __name__=="__main__":
    uvicorn.run(app,host="0.0.0.0",port=8000)