Arabic-NLP / backend /aragpt.py
wissamantoun's picture
re-build
c0f4b1b
raw
history blame
5.9 kB
import streamlit as st
from .services import TextGeneration
from tokenizers import Tokenizer
from functools import lru_cache
# @st.cache(allow_output_mutation=False, hash_funcs={Tokenizer: str})
@lru_cache(maxsize=1)
def load_text_generator():
generator = TextGeneration()
generator.load()
return generator
generator = load_text_generator()
qa_prompt = """
أجب عن السؤال التالي:
"""
qa_prompt_post = """ الجواب هو """
qa_prompt_post_year = """ في سنة: """
def write():
st.markdown(
"""
<h1 style="text-align:left;">Arabic Language Generation</h1>
""",
unsafe_allow_html=True,
)
# Sidebar
# Taken from https://huggingface.co/spaces/flax-community/spanish-gpt2/blob/main/app.py
st.sidebar.subheader("Configurable parameters")
model_name = st.sidebar.selectbox(
"Model Selector",
options=[
"AraGPT2-Base",
# "AraGPT2-Medium",
# "Aragpt2-Large",
"AraGPT2-Mega",
],
index=0,
)
max_new_tokens = st.sidebar.number_input(
"Maximum length",
min_value=0,
max_value=1024,
value=100,
help="The maximum length of the sequence to be generated.",
)
temp = st.sidebar.slider(
"Temperature",
value=1.0,
min_value=0.1,
max_value=100.0,
help="The value used to module the next token probabilities.",
)
top_k = st.sidebar.number_input(
"Top k",
value=10,
help="The number of highest probability vocabulary tokens to keep for top-k-filtering.",
)
top_p = st.sidebar.number_input(
"Top p",
value=0.95,
help=" If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.",
)
do_sample = st.sidebar.selectbox(
"Sampling?",
(True, False),
help="Whether or not to use sampling; use greedy decoding otherwise.",
)
num_beams = st.sidebar.number_input(
"Number of beams",
min_value=1,
max_value=10,
value=3,
help="The number of beams to use for beam search.",
)
repetition_penalty = st.sidebar.number_input(
"Repetition Penalty",
min_value=0.0,
value=3.0,
step=0.1,
help="The parameter for repetition penalty. 1.0 means no penalty",
)
no_repeat_ngram_size = st.sidebar.number_input(
"No Repeat N-Gram Size",
min_value=0,
value=3,
help="If set to int > 0, all ngrams of that size can only occur once.",
)
st.write("#")
col = st.columns(2)
col[0].image("images/AraGPT2.png", width=200)
st.markdown(
"""
<h3 style="text-align:left;">AraGPT2 is GPT2 model trained from scratch on 77GB of Arabic text.</h3>
<h4 style="text-align:left;"> More details in our <a href="https://github.com/aub-mind/arabert/tree/master/aragpt2">repo</a>.</h4>
<p style="text-align:left;"><p>
<p style="text-align:left;">Use the generation paramters on the sidebar to adjust generation quality.</p>
<p style="text-align:right;"><p>
""",
unsafe_allow_html=True,
)
# col[0].write(
# "AraGPT2 is trained from screatch on 77GB of Arabic text. More details in our [repo](https://github.com/aub-mind/arabert/tree/master/aragpt2)."
# )
# st.write("## Generate Arabic Text")
st.markdown(
"""
<style>
p, div, input, label, textarea{
text-align: right;
}
</style>
""",
unsafe_allow_html=True,
)
prompt = st.text_area(
"Prompt",
"يحكى أن مزارعا مخادعا قام ببيع بئر الماء الموجود في أرضه لجاره مقابل مبلغ كبير من المال",
)
if st.button("Generate"):
with st.spinner("Generating..."):
generated_text = generator.generate(
prompt=prompt,
model_name=model_name,
max_new_tokens=max_new_tokens,
temperature=temp,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=do_sample,
num_beams=num_beams,
no_repeat_ngram_size=no_repeat_ngram_size,
)
st.write(generated_text)
st.markdown("---")
st.subheader("")
st.markdown(
"""
<p style="text-align:left;"><p>
<h2 style="text-align:left;">Zero-Shot Question Answering</h2>
<p style="text-align:left;">Adjust the maximum length to closely match the expected output length. Setting the Sampling paramter to False is recommended</p>
<p style="text-align:left;"><p>
""",
unsafe_allow_html=True,
)
question = st.text_input(
"Question", "من كان رئيس ألمانيا النازية في الحرب العالمية الثانية ؟"
)
is_date = st.checkbox("Help the model: Is the answer a date?")
if st.button("Answer"):
prompt2 = qa_prompt + question + qa_prompt_post
if is_date:
prompt2 += qa_prompt_post_year
else:
prompt2 += " : "
with st.spinner("Thinking..."):
answer = generator.generate(
prompt=prompt2,
model_name=model_name,
max_new_tokens=max_new_tokens,
temperature=temp,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=do_sample,
num_beams=num_beams,
no_repeat_ngram_size=no_repeat_ngram_size,
)
st.write(answer)