File size: 4,589 Bytes
8fcd344 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import streamlit as st
import pandas as pd
from transformers import GenerationConfig, BartModel, BartTokenizer, AutoTokenizer, AutoModelForSeq2SeqLM, TextStreamer
import torch
import time
import sys, os
path = os.path.abspath(os.path.dirname(__file__))
sys.path.insert(0, path)
from gen_summary import generate_summary
st.title("Dialogue Text Summarization")
st.caption("Natural Language Processing Project 20232")
st.write("---")
class StreamlitTextStreamer(TextStreamer):
def __init__(self, tokenizer, st_container, st_info_container, skip_prompt=False, **decode_kwargs):
super().__init__(tokenizer, skip_prompt, **decode_kwargs)
self.st_container = st_container
self.st_info_container = st_info_container
self.text = ""
self.start_time = None
self.first_token_time = None
self.total_tokens = 0
def on_finalized_text(self, text: str, stream_end: bool=False):
if self.start_time is None:
self.start_time = time.time()
if self.first_token_time is None and len(text.strip()) > 0:
self.first_token_time = time.time()
self.text += text
self.total_tokens += len(text.split())
self.st_container.markdown("###### " + self.text)
time.sleep(0.03)
if stream_end:
total_time = time.time() - self.start_time
first_token_wait_time = self.first_token_time - self.start_time if self.first_token_time else None
tokens_per_second = self.total_tokens / total_time if total_time > 0 else None
df = pd.DataFrame(data={
"First token": [first_token_wait_time],
"Total tokens": [self.total_tokens],
"Time taken": [total_time],
"Token per second": [tokens_per_second]
})
self.st_info_container.table(df)
def generate_summary(model, input_text, generation_config, tokenizer, st_container, st_info_container) -> str:
try:
prefix = "Summarize the following conversation: \n###\n"
suffix = "\n### Summary:"
target_length = max(1, int(0.15 * len(input_text.split())))
input_ids = tokenizer.encode(prefix + input_text + f"The generated summary should be around {target_length} words." + suffix, return_tensors="pt")
# Initialize the Streamlit container and streamer
streamer = StreamlitTextStreamer(tokenizer, st_container, st_info_container, skip_special_tokens=True, decoder_start_token_id=3)
model.generate(input_ids, streamer=streamer, do_sample=True, generation_config=generation_config)
except Exception as e:
raise e
with st.sidebar:
checkpoint = st.selectbox("Model", options=[
"Choose model",
"dtruong46me/train-bart-base",
"dtruong46me/flant5-small",
"dtruong46me/flant5-base",
"dtruong46me/flan-t5-s",
"ntluongg/bart-base-luong"
])
st.button("Model detail", use_container_width=True)
st.write("-----")
st.write("**Generate Options:**")
min_new_tokens = st.number_input("Min new tokens", min_value=1, max_value=64, value=10)
max_new_tokens = st.number_input("Max new tokens", min_value=64, max_value=128, value=64)
temperature = st.number_input("Temperature", min_value=0.0, max_value=1.0, value=0.9, step=0.05)
top_k = st.number_input("Top_k", min_value=1, max_value=50, step=1, value=20)
top_p = st.number_input("Top_p", min_value=0.01, max_value=1.00, step=0.01, value=1.0)
height = 200
input_text = st.text_area("Dialogue", height=height)
generation_config = GenerationConfig(
min_new_tokens=min_new_tokens,
max_new_tokens=320,
temperature=temperature,
top_p=top_p,
top_k=top_k
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if checkpoint=="Choose model":
tokenizer = None
model = None
if checkpoint!="Choose model":
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
if st.button("Submit"):
st.write("---")
st.write("## Summary")
if checkpoint=="Choose model":
st.error("Please selece a model!")
else:
if input_text=="":
st.error("Please enter a dialogue!")
# generate_summary(model, " ".join(input_text.split()), generation_config, tokenizer)
st_container = st.empty()
st_info_container = st.empty()
generate_summary(model, " ".join(input_text.split()), generation_config, tokenizer, st_container, st_info_container) |