proj01_textmining / Information_Retrieval.py
nguyennghia0902's picture
Upload 3 files
a0a6e98 verified
raw
history blame
No virus
2.76 kB
from os import path
import streamlit as st
import nltk, subprocess, sys
stwfilename = "Vstopword_new.txt"
punfilename = "punctuation.txt"
STW_PATH = path.join(path.dirname(__file__), stwfilename)
PUNCT_PATH = path.join(path.dirname(__file__), punfilename)
from pyvi import ViTokenizer
@st.cache_resource
def open2list_vn(path):
if path:
with open(path) as f:
line = list(f.read().splitlines())
return line
def pre_progress(input):
stw = open2list_vn(STW_PATH)
punctuations = open2list_vn(PUNCT_PATH)
textU = ViTokenizer.tokenize(input)
text = textU.lower()
tokens = []
all_tokens = []
raw = nltk.wordpunct_tokenize(text)
for token in raw:
if token not in punctuations:
tokens.append(token)
for i in range(len(tokens)):
if tokens[i] not in stw:
all_tokens.append(tokens[i])
return " ".join(all_tokens)
# from tensorflow import keras
import tensorflow as tf
from transformers import ElectraTokenizer, TFElectraForSequenceClassification
MODEL_NAME = "google/electra-small-discriminator"
MODEL_PATH = 'nguyennghia0902/textming_proj01_electra'
tokenizer = ElectraTokenizer.from_pretrained(MODEL_NAME)
id2label = {0: "FALSE", 1: "TRUE"}
label2id = {"FALSE": 0, "TRUE": 1}
loaded_model = TFElectraForSequenceClassification.from_pretrained(MODEL_PATH, id2label=id2label, label2id=label2id)
def predict(question, text):
combined = pre_progress(question + ' ' + text)
inputs = tokenizer(combined, truncation=True, padding=True, return_tensors='tf')
logits = loaded_model(**inputs).logits
predicted_class_id = int(tf.math.argmax(logits, axis=-1)[0])
return loaded_model.config.id2label[predicted_class_id]
def main():
st.set_page_config(page_title="Information Retrieval", page_icon="📝")
# giving a title to our page
st.title("Information Retrieval")
text = st.text_area(
"Please enter a text:",
placeholder="Enter your text here",
height=200,
)
question = st.text_area(
"Please enter a question:",
placeholder="Enter your question here",
height=200,
)
prediction = ""
# Create a prediction button
if st.button("Predict"):
stripped = text.strip()
if not stripped:
st.error("Please enter some text.")
return
stripped = question.strip()
if not stripped:
st.error("Please enter a question.")
return
text = text.replace("\n", "")
prediction = predict(question, text)
if prediction == "TRUE":
st.success("TRUE 😄")
else:
st.warning("FALSE 😟")
if __name__ == "__main__":
main()