from os import path import streamlit as st import nltk, subprocess, sys stwfilename = "Vstopword_new.txt" punfilename = "punctuation.txt" STW_PATH = path.join(path.dirname(__file__), stwfilename) PUNCT_PATH = path.join(path.dirname(__file__), punfilename) from pyvi import ViTokenizer @st.cache_resource def open2list_vn(path): if path: with open(path, encoding="utf8") as f: line = list(f.read().splitlines()) return line def pre_progress(input): stw = open2list_vn(STW_PATH) punctuations = open2list_vn(PUNCT_PATH) textU = ViTokenizer.tokenize(input) text = textU.lower() tokens = [] all_tokens = [] raw = nltk.wordpunct_tokenize(text) for token in raw: if token not in punctuations: tokens.append(token) for i in range(len(tokens)): if tokens[i] not in stw: all_tokens.append(tokens[i]) return " ".join(all_tokens) # from tensorflow import keras import tensorflow as tf from transformers import ElectraTokenizer, TFElectraForSequenceClassification MODEL_NAME = "google/electra-small-discriminator" MODEL_PATH = 'nguyennghia0902/textming_proj01_electra' tokenizer = ElectraTokenizer.from_pretrained(MODEL_NAME) id2label = {0: "FALSE", 1: "TRUE"} label2id = {"FALSE": 0, "TRUE": 1} loaded_model = TFElectraForSequenceClassification.from_pretrained(MODEL_PATH, id2label=id2label, label2id=label2id) def predict(question, text): combined = pre_progress(question + ' ' + text) inputs = tokenizer(combined, truncation=True, padding=True, return_tensors='tf') logits = loaded_model(**inputs).logits predicted_class_id = int(tf.math.argmax(logits, axis=-1)[0]) return loaded_model.config.id2label[predicted_class_id] def main(): st.set_page_config(page_title="Information Retrieval", page_icon="📝") # giving a title to our page col1, col2 = st.columns([2, 1]) col1.title("Information Retrieval") col2.link_button("Explore my model", "https://huggingface.co/nguyennghia0902/textming_proj01_electra") question = st.text_area( "QUESTION: Please enter a question:", placeholder="Enter your question here", height=15, ) text = st.text_area( "CONTEXT: Please enter a context:", placeholder="Enter your context here", height=100, ) prediction = "" upload_file = st.file_uploader("CONTEXT: Or upload a file with some contexts", type=["txt"]) if upload_file is not None: text = upload_file.read().decode("utf-8") for line in text.splitlines(): line = line.strip() if not line: continue prediction = predict(question, line) if prediction == "TRUE": st.success(line + "\n\nTRUE 😄") else: st.warning(line + "\n\nFALSE 😟") # Create a prediction button elif st.button("Predict"): prediction = "" stripped = text.strip() if not stripped: st.error("Please enter a contextext.") return stripped = question.strip() if not stripped: st.error("Please enter a question.") return prediction = predict(question, text) if prediction == "TRUE": st.success("TRUE 😄") else: st.warning("FALSE 😟") if __name__ == "__main__": main()