File size: 3,460 Bytes
a0a6e98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
731e929
a0a6e98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
731e929
 
 
 
 
a0a6e98
731e929
a0a6e98
731e929
 
 
 
 
 
a0a6e98
 
 
 
731e929
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a0a6e98
731e929
 
a0a6e98
 
731e929
a0a6e98
 
 
 
 
731e929
a0a6e98
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from os import path
import streamlit as st

import nltk, subprocess, sys

stwfilename = "Vstopword_new.txt"
punfilename = "punctuation.txt"
STW_PATH = path.join(path.dirname(__file__), stwfilename)
PUNCT_PATH = path.join(path.dirname(__file__), punfilename)


from pyvi import ViTokenizer
@st.cache_resource
def open2list_vn(path):
        if path:
            with open(path, encoding="utf8") as f:
                line = list(f.read().splitlines())
            return line
def pre_progress(input):
    stw = open2list_vn(STW_PATH)
    punctuations = open2list_vn(PUNCT_PATH)
    textU = ViTokenizer.tokenize(input)
    text = textU.lower()
    tokens = []
    all_tokens = []
    raw = nltk.wordpunct_tokenize(text)
    for token in raw:
        if token not in punctuations:
          tokens.append(token)
    for i in range(len(tokens)):
      if tokens[i] not in stw:
        all_tokens.append(tokens[i])
    return " ".join(all_tokens)


# from tensorflow import keras
import tensorflow as tf
from transformers import ElectraTokenizer, TFElectraForSequenceClassification

MODEL_NAME = "google/electra-small-discriminator"
MODEL_PATH = 'nguyennghia0902/textming_proj01_electra'

tokenizer = ElectraTokenizer.from_pretrained(MODEL_NAME)

id2label = {0: "FALSE", 1: "TRUE"}
label2id = {"FALSE": 0, "TRUE": 1}
loaded_model = TFElectraForSequenceClassification.from_pretrained(MODEL_PATH, id2label=id2label, label2id=label2id)

def predict(question, text):
    combined = pre_progress(question + ' ' + text)

    inputs = tokenizer(combined, truncation=True, padding=True, return_tensors='tf')
    logits = loaded_model(**inputs).logits
    predicted_class_id = int(tf.math.argmax(logits, axis=-1)[0])
    
    return loaded_model.config.id2label[predicted_class_id]


def main():
    st.set_page_config(page_title="Information Retrieval", page_icon="📝")

    # giving a title to our page
    col1, col2 = st.columns([2, 1])
    col1.title("Information Retrieval")

    col2.link_button("Explore my model", "https://huggingface.co/nguyennghia0902/textming_proj01_electra")

    question = st.text_area(
        "QUESTION: Please enter a question:",
        placeholder="Enter your question here",
        height=15,
    )
    text = st.text_area(
        "CONTEXT: Please enter a context:",
        placeholder="Enter your context here",
        height=100,
    )

    prediction = ""

    upload_file = st.file_uploader("CONTEXT: Or upload a file with some contexts", type=["txt"])
    if upload_file is not None:
        text = upload_file.read().decode("utf-8")

        for line in text.splitlines():
            line = line.strip()
            if not line:
                continue

            prediction = predict(question, line)
            if prediction == "TRUE":
                st.success(line + "\n\nTRUE 😄")
            else:
                st.warning(line + "\n\nFALSE 😟")
    
    # Create a prediction button
    elif st.button("Predict"):
        prediction = ""
        stripped = text.strip()
        if not stripped:
            st.error("Please enter a contextext.")
            return
        stripped = question.strip()
        if not stripped:
            st.error("Please enter a question.")
            return
        
        prediction = predict(question, text)
        if prediction == "TRUE":
            st.success("TRUE 😄")
        else:
            st.warning("FALSE 😟")

if __name__ == "__main__":
    main()