alberti / app.py
alvp's picture
Typo fix
fcc4493
raw
history blame
6.44 kB
import random
import re
from poems import SAMPLE_POEMS
import langid
import numpy as np
import streamlit as st
import torch
from icu_tokenizer import Tokenizer
from transformers import pipeline
MODELS = {
"ALBERTI": "flax-community/alberti-bert-base-multilingual-cased",
"mBERT": "bert-base-multilingual-cased"
}
TOPK = 50
st.set_page_config(layout="wide")
def mask_line(line, language="es", restrictive=True):
tokenizer = Tokenizer(lang=language)
token_list = tokenizer.tokenize(line)
if lang != "zh":
restrictive = not all([len(token) <= 3 for token in token_list])
random_num = random.randint(0, len(token_list) - 1)
random_word = token_list[random_num]
if not restrictive:
token_list[random_num] = "[MASK]"
masked_l = " ".join(token_list)
return masked_l
elif len(random_word) > 3 or (lang == "zh" and random_word.isalpha()):
token_list[random_num] = "[MASK]"
masked_l = " ".join(token_list)
return masked_l
else:
return mask_line(line, language)
def filter_candidates(candidates, get_any_candidate=False):
cand_list = []
score_list = []
for candidate in candidates:
if not get_any_candidate and candidate["token_str"][:2] != "##" and candidate["token_str"].isalpha():
cand = candidate["sequence"]
score = candidate["score"]
cand_list.append(cand)
score_list.append('{0:.5f}'.format(score))
elif get_any_candidate:
cand = candidate["sequence"]
score = candidate["score"]
cand_list.append(cand)
score_list.append('{0:.5f}'.format(score))
if len(score_list) == TOPK:
break
if len(cand_list) < 1:
return filter_candidates(candidates, get_any_candidate=True)
else:
return cand_list[0]
def infer_candidates(nlp, line):
line = re.sub("’", "'", line)
line = re.sub("…", "...", line)
inputs = nlp._parse_and_tokenize(line)
outputs = nlp._forward(inputs, return_tensors=True)
input_ids = inputs["input_ids"][0]
masked_index = torch.nonzero(input_ids == nlp.tokenizer.mask_token_id,
as_tuple=False)
logits = outputs[0, masked_index.item(), :]
probs = logits.softmax(dim=0)
values, predictions = probs.topk(TOPK)
result = []
for v, p in zip(values.tolist(), predictions.tolist()):
tokens = input_ids.numpy()
tokens[masked_index] = p
# Filter padding out:
tokens = tokens[np.where(tokens != nlp.tokenizer.pad_token_id)]
l = []
token_list = [nlp.tokenizer.decode([token], skip_special_tokens=True) for token in tokens]
for idx, token in enumerate(token_list):
if token.startswith('##'):
l[-1] += token[2:]
elif idx == masked_index.item():
l += ['<b style="color: #ff0000;">', token, "</b>"]
else:
l += [token]
sequence = " ".join(l).strip()
result.append(
{
"sequence": sequence,
"score": v,
"token": p,
"token_str": nlp.tokenizer.decode(p),
"masked_index": masked_index.item()
}
)
return result
def rewrite_poem(poem, ml_model=MODELS["ALBERTI"], masking=True, language="es"):
nlp = pipeline("fill-mask", model=ml_model)
unmasked_lines = []
masked_lines = []
for line in poem:
if line == "":
unmasked_lines.append("")
masked_lines.append("")
continue
if masking:
masked_line = mask_line(line, language)
else:
masked_line = line
masked_lines.append(masked_line)
unmasked_line_candidates = infer_candidates(nlp, masked_line)
unmasked_line = filter_candidates(unmasked_line_candidates)
unmasked_lines.append(unmasked_line)
unmasked_poem = "<br>".join(unmasked_lines)
return unmasked_poem, masked_lines
instructions_text_0 = st.sidebar.markdown(
"""# ALBERTI vs BERT πŸ₯Š
We present ALBERTI, our BERT-based multilingual model for poetry.""")
instructions_text_1 = st.sidebar.markdown(
"""We have trained bert on a huge (for poetry, that is) corpus of
multilingual poetry to try to get a more 'poetic' model. This is the result
of our work.
You can find more information on the [project's site](https://huggingface.co/flax-community/alberti-bert-base-multilingual-cased)""")
sample_chooser = st.sidebar.selectbox(
"Choose a poem",
list(SAMPLE_POEMS.keys())
)
instructions_text_2 = st.sidebar.markdown("""# How to use
You can choose from a list of example poems in Spanish, English, French, German,
Chinese and Arabic, but you can also paste a poem, or write it yourself!
Then click on 'Rewrite!' to do the masking and the fill-mask task on the chosen
poem and get the two new versions for each of the models.
The list of languages used on the training of ALBERTI are:
* Arabic
* Chinese
* Czech
* English
* Finnish
* French
* German
* Hungarian
* Italian
* Portuguese
* Russian
* Spanish""")
col1, col2, col3 = st.beta_columns(3)
st.markdown(
"""
<style>
label {
font-size: 1rem !important;
font-weight: bold !important;
}
.block-container {
padding-left: 1rem !important;
padding-right: 1rem !important;
}
</style>
""", unsafe_allow_html=True)
if sample_chooser:
model_list = set(MODELS.values())
user_input = col1.text_area("Input poem",
"\n".join(SAMPLE_POEMS[sample_chooser]),
height=600)
poem = user_input.split("\n")
rewrite_button = col1.button("Rewrite!")
if "[MASK]" in user_input or "<mask>" in user_input:
col1.error("You don't have to mask the poem, we'll do it for you!")
if rewrite_button:
lang = langid.classify(user_input)[0]
unmasked_poem, masked_poem = rewrite_poem(poem, language=lang)
user_input_2 = col2.write(f"""<b>Output poem from ALBERTI</b>
{unmasked_poem}""", unsafe_allow_html=True)
unmasked_poem_2, _ = rewrite_poem(masked_poem, ml_model=MODELS["mBERT"],
masking=False)
user_input_3 = col3.write(f"""<b>Output poem from mBERT</b>
{unmasked_poem_2}""", unsafe_allow_html=True)