Spaces:
Sleeping
Sleeping
import torch | |
import numpy as np | |
import gradio as gr | |
from nltk import sent_tokenize | |
from transformers import RobertaTokenizer, RobertaForMaskedLM | |
cuda = torch.cuda.is_available() | |
tokenizer = RobertaTokenizer.from_pretrained("roberta-large") | |
model = RobertaForMaskedLM.from_pretrained("roberta-large") | |
if cuda: | |
model = model.cuda() | |
max_len = 20 | |
top_k = 100 | |
temperature = 1 | |
burnin = 250 | |
max_iter = 500 | |
# adapted from https://github.com/nyu-dl/bert-gen | |
def generate_step(out, | |
gen_idx, | |
temperature=None, | |
top_k=0, | |
sample=False, | |
return_list=True): | |
""" Generate a word from from out[gen_idx] | |
args: | |
- out (torch.Tensor): tensor of logits of size batch_size x seq_len x vocab_size | |
- gen_idx (int): location for which to generate for | |
- top_k (int): if >0, only sample from the top k most probable words | |
- sample (Bool): if True, sample from full distribution. Overridden by top_k | |
""" | |
logits = out.logits[:, gen_idx] | |
if temperature is not None: | |
logits = logits / temperature | |
if top_k > 0: | |
kth_vals, kth_idx = logits.topk(top_k, dim=-1) | |
dist = torch.distributions.categorical.Categorical(logits=kth_vals) | |
idx = kth_idx.gather(dim=1, | |
index=dist.sample().unsqueeze(-1)).squeeze(-1) | |
elif sample: | |
dist = torch.distributions.categorical.Categorical(logits=logits) | |
idx = dist.sample() # removed superfluous squeeze(-1) | |
else: | |
idx = torch.argmax(logits, dim=-1) | |
return idx.tolist() if return_list else idx | |
# adapted from https://github.com/nyu-dl/bert-gen | |
def parallel_sequential_generation(seed_text, | |
seed_end_text, | |
max_len=max_len, | |
top_k=top_k, | |
temperature=temperature, | |
max_iter=max_iter, | |
burnin=burnin): | |
""" Generate for one random position at a timestep | |
args: | |
- burnin: during burn-in period, sample from full distribution; afterwards take argmax | |
""" | |
inp = tokenizer(seed_text + tokenizer.mask_token * max_len + seed_end_text, | |
return_tensors='pt') | |
masked_tokens = np.where( | |
inp['input_ids'][0].numpy() == tokenizer.mask_token_id)[0] | |
seed_len = masked_tokens[0] | |
if cuda: | |
inp = inp.to('cuda') | |
for ii in range(max_iter): | |
kk = np.random.randint(0, max_len) | |
out = model(**inp) | |
topk = top_k if (ii >= burnin) else 0 | |
idxs = generate_step(out, | |
gen_idx=seed_len + kk, | |
top_k=topk, | |
temperature=temperature, | |
sample=(ii < burnin)) | |
inp['input_ids'][0][seed_len + kk] = idxs[0] | |
tokens = inp['input_ids'].cpu().numpy()[0][masked_tokens] | |
tokens = tokens[(np.where((tokens != tokenizer.eos_token_id) | |
& (tokens != tokenizer.bos_token_id)))] | |
return tokenizer.decode(tokens) | |
def inbertolate(doc, | |
max_len=15, | |
top_k=0, | |
temperature=None, | |
max_iter=300, | |
burnin=200): | |
new_doc = '' | |
paras = doc.split('\n') | |
for para in paras: | |
para = sent_tokenize(para) | |
if para == '': | |
new_doc += '\n' | |
continue | |
para += [''] | |
for sentence in range(len(para) - 1): | |
new_doc += para[sentence] + ' ' | |
new_doc += parallel_sequential_generation(para[sentence], | |
para[sentence + 1], | |
max_len=max_len, | |
top_k=top_k, | |
temperature=temperature, | |
burnin=burnin, | |
max_iter=max_iter) + ' ' | |
new_doc += '\n' | |
return new_doc | |
if __name__ == '__main__': | |
block = gr.Blocks(css='.container') | |
with block: | |
gr.Markdown("<h1><center>inBERTolate</center></h1>") | |
gr.Markdown( | |
"<center>Hit your word count by using BERT to pad out your essays!</center>" | |
) | |
gr.Interface( | |
fn=inbertolate, | |
inputs=[ | |
gr.Textbox(label="Text", lines=7), | |
gr.Slider(label="Maximum length to insert between sentences", | |
minimum=1, | |
maximum=40, | |
step=1, | |
value=max_len), | |
gr.Slider(label="Top k", minimum=0, maximum=200, value=top_k), | |
gr.Slider(label="Temperature", | |
minimum=0, | |
maximum=2, | |
value=temperature), | |
gr.Slider(label="Maximum iterations", | |
minimum=0, | |
maximum=1000, | |
value=max_iter), | |
gr.Slider(label="Burn-in", | |
minimum=0, | |
maximum=500, | |
value=burnin), | |
], | |
outputs=gr.Textbox(label="Expanded text", lines=24)) | |
block.launch(server_name='0.0.0.0') | |