File size: 2,758 Bytes
a0a6e98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from os import path
import streamlit as st

import nltk, subprocess, sys

stwfilename = "Vstopword_new.txt"
punfilename = "punctuation.txt"
STW_PATH = path.join(path.dirname(__file__), stwfilename)
PUNCT_PATH = path.join(path.dirname(__file__), punfilename)


from pyvi import ViTokenizer
@st.cache_resource
def open2list_vn(path):
        if path:
            with open(path) as f:
                line = list(f.read().splitlines())
            return line
def pre_progress(input):
    stw = open2list_vn(STW_PATH)
    punctuations = open2list_vn(PUNCT_PATH)
    textU = ViTokenizer.tokenize(input)
    text = textU.lower()
    tokens = []
    all_tokens = []
    raw = nltk.wordpunct_tokenize(text)
    for token in raw:
        if token not in punctuations:
          tokens.append(token)
    for i in range(len(tokens)):
      if tokens[i] not in stw:
        all_tokens.append(tokens[i])
    return " ".join(all_tokens)


# from tensorflow import keras
import tensorflow as tf
from transformers import ElectraTokenizer, TFElectraForSequenceClassification

MODEL_NAME = "google/electra-small-discriminator"
MODEL_PATH = 'nguyennghia0902/textming_proj01_electra'

tokenizer = ElectraTokenizer.from_pretrained(MODEL_NAME)

id2label = {0: "FALSE", 1: "TRUE"}
label2id = {"FALSE": 0, "TRUE": 1}
loaded_model = TFElectraForSequenceClassification.from_pretrained(MODEL_PATH, id2label=id2label, label2id=label2id)

def predict(question, text):
    combined = pre_progress(question + ' ' + text)

    inputs = tokenizer(combined, truncation=True, padding=True, return_tensors='tf')
    logits = loaded_model(**inputs).logits
    predicted_class_id = int(tf.math.argmax(logits, axis=-1)[0])
    
    return loaded_model.config.id2label[predicted_class_id]


def main():
    st.set_page_config(page_title="Information Retrieval", page_icon="📝")

    # giving a title to our page
    st.title("Information Retrieval")
    text = st.text_area(
        "Please enter a text:",
        placeholder="Enter your text here",
        height=200,
    )
    question = st.text_area(
        "Please enter a question:",
        placeholder="Enter your question here",
        height=200,
    )

    prediction = ""

    # Create a prediction button
    if st.button("Predict"):
        stripped = text.strip()
        if not stripped:
            st.error("Please enter some text.")
            return
        stripped = question.strip()
        if not stripped:
            st.error("Please enter a question.")
            return
        text = text.replace("\n", "")
        prediction = predict(question, text)
        if prediction == "TRUE":
            st.success("TRUE 😄")
        else:
            st.warning("FALSE 😟")

if __name__ == "__main__":
    main()