|
import streamlit as st |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
import tempfile |
|
|
|
|
|
model_name = "potsawee/t5-large-generation-squad-QuestionAnswer" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
|
|
uploaded_file = st.file_uploader("Upload Document or Paragraph") |
|
|
|
if uploaded_file is not None: |
|
with tempfile.NamedTemporaryFile(delete=False) as temp_file: |
|
temp_file.write(uploaded_file.read()) |
|
|
|
temp_file.close() |
|
with open(temp_file.name, 'r', encoding='utf-8') as file: |
|
document_text = file.read() |
|
st.success("Document uploaded successfully!") |
|
else: |
|
document_text = st.text_area("Enter Text (Optional)", height=200) |
|
|
|
question = st.text_input("Ask a Question") |
|
bouton_ok = st.button("Answer") |
|
|
|
if bouton_ok: |
|
|
|
context = document_text if document_text else "Empty document." |
|
inputs = tokenizer.encode(f"Question: {question} Context: {context}", return_tensors='pt', max_length=512, truncation=True) |
|
outputs = model.generate(inputs, max_length=150, min_length=80, length_penalty=5, num_beams=2) |
|
summary = tokenizer.decode(outputs[0]) |
|
st.text("Answer:") |
|
st.text(summary) |
|
|