Spaces:
Runtime error
Runtime error
from os import path | |
import streamlit as st | |
import nltk, subprocess, sys | |
stwfilename = "Vstopword_new.txt" | |
punfilename = "punctuation.txt" | |
STW_PATH = path.join(path.dirname(__file__), stwfilename) | |
PUNCT_PATH = path.join(path.dirname(__file__), punfilename) | |
from pyvi import ViTokenizer | |
def open2list_vn(path): | |
if path: | |
with open(path) as f: | |
line = list(f.read().splitlines()) | |
return line | |
def pre_progress(input): | |
stw = open2list_vn(STW_PATH) | |
punctuations = open2list_vn(PUNCT_PATH) | |
textU = ViTokenizer.tokenize(input) | |
text = textU.lower() | |
tokens = [] | |
all_tokens = [] | |
raw = nltk.wordpunct_tokenize(text) | |
for token in raw: | |
if token not in punctuations: | |
tokens.append(token) | |
for i in range(len(tokens)): | |
if tokens[i] not in stw: | |
all_tokens.append(tokens[i]) | |
return " ".join(all_tokens) | |
# from tensorflow import keras | |
import tensorflow as tf | |
from transformers import ElectraTokenizer, TFElectraForSequenceClassification | |
MODEL_NAME = "google/electra-small-discriminator" | |
MODEL_PATH = 'nguyennghia0902/textming_proj01_electra' | |
tokenizer = ElectraTokenizer.from_pretrained(MODEL_NAME) | |
id2label = {0: "FALSE", 1: "TRUE"} | |
label2id = {"FALSE": 0, "TRUE": 1} | |
loaded_model = TFElectraForSequenceClassification.from_pretrained(MODEL_PATH, id2label=id2label, label2id=label2id) | |
def predict(question, text): | |
combined = pre_progress(question + ' ' + text) | |
inputs = tokenizer(combined, truncation=True, padding=True, return_tensors='tf') | |
logits = loaded_model(**inputs).logits | |
predicted_class_id = int(tf.math.argmax(logits, axis=-1)[0]) | |
return loaded_model.config.id2label[predicted_class_id] | |
def main(): | |
st.set_page_config(page_title="Information Retrieval", page_icon="📝") | |
# giving a title to our page | |
st.title("Information Retrieval") | |
text = st.text_area( | |
"Please enter a text:", | |
placeholder="Enter your text here", | |
height=200, | |
) | |
question = st.text_area( | |
"Please enter a question:", | |
placeholder="Enter your question here", | |
height=200, | |
) | |
prediction = "" | |
# Create a prediction button | |
if st.button("Predict"): | |
stripped = text.strip() | |
if not stripped: | |
st.error("Please enter some text.") | |
return | |
stripped = question.strip() | |
if not stripped: | |
st.error("Please enter a question.") | |
return | |
text = text.replace("\n", "") | |
prediction = predict(question, text) | |
if prediction == "TRUE": | |
st.success("TRUE 😄") | |
else: | |
st.warning("FALSE 😟") | |
if __name__ == "__main__": | |
main() |