Spaces:
Sleeping
Sleeping
import nltk | |
import spacy | |
nltk.download('wordnet') | |
spacy.cli.download('en_core_web_sm') | |
import torch | |
import joblib, json | |
import numpy as np | |
import pandas as pd | |
import gradio as gr | |
from const import used_indices, name_map | |
from model import get_model | |
from options import parse_args | |
from transformers import T5Tokenizer | |
from compute_lng import compute_lng | |
def process_examples(samples, full_names): | |
processed = [] | |
for sample in samples: | |
processed.append([ | |
sample['sentence1'], | |
pd.DataFrame({'Index': full_names, 'Source': sample['sentence1_ling'], 'Target': sample['sentence2_ling']}) | |
]) | |
return processed | |
args, args_list, lng_names = parse_args(ckpt='./ckpt/model.pt') | |
tokenizer = T5Tokenizer.from_pretrained(args.model_name) | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
lng_names = [name_map[x] for x in lng_names] | |
examples = json.load(open('assets/examples.json')) | |
examples = process_examples(examples, lng_names) | |
stats = json.load(open('assets/stats.json')) | |
ling_collection = np.load('assets/ling_collection.npy') | |
scaler = joblib.load('assets/scaler.bin') | |
scale_ratio = np.load('assets/ratios.npy') | |
model, ling_disc, sem_emb = get_model(args, tokenizer, device) | |
state = torch.load(args.ckpt, map_location=torch.device('cpu')) | |
model.load_state_dict(state['model'], strict=True) | |
model.eval() | |
ling_disc.eval() | |
state = torch.load(args.sem_ckpt) | |
sem_emb.load_state_dict(state['model'], strict=True) | |
sem_emb.eval() | |
device = model.backbone.device | |
############# Start demo code | |
def round_ling(x): | |
is_int = stats['is_int'] | |
mins = stats['min'] | |
maxs = stats['max'] | |
for i in range(len(x)): | |
# if is_int[i]: | |
# x[i] = round(x[i]) | |
# else: | |
# x[i] = round(x[i], 3) | |
x[i] = round(x[i], 3) | |
return np.clip(x, mins, maxs) | |
def visibility(mode): | |
if mode == 0: | |
vis_group = group1 | |
elif mode == 1: | |
vis_group = group2 | |
elif mode == 2: | |
vis_group = group3 | |
output = [gr.update(value=''), gr.update(value='')] | |
for component in components: | |
if component in vis_group: | |
output.append(gr.update(visible=True)) | |
else: | |
output.append(gr.update(visible=False)) | |
return output | |
def generate(sent1, ling): | |
input_ids = tokenizer.encode(sent1, return_tensors='pt').to(device) | |
ling1 = scaler.transform([ling['Source']]) | |
ling2 = scaler.transform([ling['Target']]) | |
inputs = {'sentence1_input_ids': input_ids, | |
'sentence1_ling': torch.tensor(ling1).float().to(device), | |
'sentence2_ling': torch.tensor(ling2).float().to(device), | |
'sentence1_attention_mask': torch.ones_like(input_ids)} | |
preds = [] | |
with torch.no_grad(): | |
pred = model.infer(inputs).cpu().numpy() | |
pred = tokenizer.batch_decode(pred, | |
skip_special_tokens=True)[0] | |
return pred | |
def generate_with_feedback(sent1, ling, approx): | |
if sent1 == '': | |
return ['Please input a source text.', ''] | |
input_ids = tokenizer.encode(sent1, return_tensors='pt').to(device) | |
ling2 = torch.tensor(scaler.transform([ling['Target']])).float().to(device) | |
inputs = { | |
'sentence1_input_ids': input_ids, | |
'sentence2_ling': ling2, | |
'sentence1_attention_mask': torch.ones_like(input_ids) | |
} | |
pred, (pred_text, interpolations) = model.infer_with_feedback_BP(ling_disc, sem_emb, inputs, tokenizer) | |
interpolation = '-- ' + '\n-- '.join(interpolations) | |
return [pred_text, interpolation] | |
def generate_random(sent1, ling, count, approx): | |
preds, interpolations = [], [] | |
for c in range(count): | |
idx = np.random.randint(0, len(ling_collection)) | |
ling_ex = ling_collection[idx] | |
ling['Target'] = ling_ex | |
pred, interpolation = generate_with_feedback(sent1, ling, approx) | |
preds.append(pred) | |
interpolations.append(interpolation) | |
return '\n***\n'.join(preds), '\n***\n'.join(interpolations), ling | |
def estimate_gen(sent1, sent2, ling, approx): | |
if 'approximate' in approx: | |
input_ids = tokenizer.encode(sent2, return_tensors='pt').to(device) | |
with torch.no_grad(): | |
ling_pred = ling_disc(input_ids=input_ids).cpu().numpy() | |
ling_pred = scaler.inverse_transform(ling_pred)[0] | |
elif 'exact' in approx: | |
ling_pred = np.array(compute_lng(sent2))[used_indices] | |
else: | |
raise ValueError() | |
ling_pred = round_ling(ling_pred) | |
ling['Target'] = ling_pred | |
gen = generate_with_feedback(sent1, ling, approx) | |
results = gen + [ling] | |
return results | |
def estimate_tgt(sent2, ling, approx): | |
if 'approximate' in approx: | |
input_ids = tokenizer.encode(sent2, return_tensors='pt').to(device) | |
with torch.no_grad(): | |
ling_pred = ling_disc(input_ids=input_ids).cpu().numpy() | |
ling_pred = scaler.inverse_transform(ling_pred)[0] | |
elif 'exact' in approx: | |
ling_pred = np.array(compute_lng(sent2))[used_indices] | |
else: | |
raise ValueError() | |
ling_pred = round_ling(ling_pred) | |
ling['Target'] = ling_pred | |
return ling | |
def estimate_src(sent1, ling, approx): | |
if 'approximate' in approx: | |
input_ids = tokenizer.encode(sent1, return_tensors='pt').to(device) | |
with torch.no_grad(): | |
ling_pred = ling_disc(input_ids=input_ids).cpu().numpy() | |
ling_pred = scaler.inverse_transform(ling_pred)[0] | |
elif 'exact' in approx: | |
ling_pred = np.array(compute_lng(sent1))[used_indices] | |
else: | |
raise ValueError() | |
ling['Source'] = ling_pred | |
return ling | |
def rand_target(ling): | |
ling['Target'] = scaler.inverse_transform([np.random.randn(*ling['Target'].shape)])[0] | |
return ling | |
def rand_ex_target(ling): | |
idx = np.random.randint(0, len(ling_collection)) | |
ling_ex = ling_collection[idx] | |
ling['Target'] = ling_ex | |
return ling | |
def copy(ling): | |
ling['Target'] = ling['Source'] | |
return ling | |
def add(ling): | |
scale_stepsize = np.random.uniform(1.0, 5.0) | |
x = ling['Target'] + scale_stepsize * scale_ratio | |
x = round_ling(x) | |
ling['Target'] = x | |
return ling | |
def sub(ling): | |
scale_stepsize = np.random.uniform(1.0, 5.0) | |
x = ling['Target'] - scale_stepsize * scale_ratio | |
x = round_ling(x) | |
ling['Target'] = x | |
return ling | |
title = """ | |
<h1 style="text-align: center;">Controlled Paraphrase Generation with Linguistic Feature Control</h1> | |
<p style="font-size:1.2em;">This system utilizes an encoder-decoder model to generate text with controlled complexity, guided by 40 linguistic complexity indices. | |
The model can generate diverse paraphrases of a given sentence, each adjusted to maintain consistent meaning while varying | |
in linguistic complexity according to the desired level.</p> | |
<p style="font-size:1.2em;">It is important to note that not all index combinations are feasible (e.g., a sentence of "length" 5 with 10 "unique words"). | |
To ensure high-quality outputs, our approach interpolates the embeddings of linguistic indices to identify the closest, | |
achievable set of indices for the given target.</p> | |
""" | |
guide = """ | |
You may use the system in on of the following ways: | |
**Randomized Paraphrase Generation**: Select this option to produce multiple paraphrases with a range | |
of linguistic complexity. You need to provide a source text, specify the number of paraphrases you want, | |
and click "Generate." The linguistic complexity of the paraphrases will be determined randomly. | |
**Complexity-Matched Paraphrasing**: Select this option to generate a paraphrase of the given source | |
sentence that closely mirrors the linguistic complexity of another given sentence. Input your source | |
sentence along with another sentence (which will serve only to extract linguistic indices for the | |
paraphrase generation). Then, click "Generate." | |
**Manual Linguistic Control**: Select this option to manually control the linguistic complexity of the | |
generated text. We provided a set of tools for manual adjustments of the desired linguistic complexity of | |
the target sentence. These tools enable the user to extract linguistic indices from a given sentence, | |
generate a random (yet coherent) set of linguistic indices, and add or remove noise from the indices. | |
These tools are designed for experimental use and require the user to possess linguistic expertise for | |
effective input of linguistic indices. To use these tools, select "Tools to assist in setting linguistic | |
indices." Once indices are entered, click "Generate." | |
Second, you may select to use exact or approximate computation of linguistic indices (used in mode (2) and | |
in quality control of the genration). Approximate computation is significantly faster. | |
Third, you may view the intermediate sentences of the quality control process by selecting the checkbox. | |
Fourth, you may try out some examples by clicking on "Examples...". Examples consist of a source sentences, | |
the indices of the source sentences, and a sample set of target linguistic indices. | |
Please make your choice below. | |
""" | |
sent1 = gr.Textbox(label='Source text') | |
ling = gr.Dataframe(value = [[x, 0, 0] for x in lng_names], | |
headers=['Index', 'Source', 'Target'], | |
datatype=['str', 'number', 'number'], visible=False) | |
css = """ | |
#guide span.svelte-1w6vloh {font-size: 22px !important; font-weight: 600 !important} | |
#mode span.svelte-1gfkn6j {font-size: 18px !important; font-weight: 600 !important} | |
#mode {border: 0px; box-shadow: none} | |
#mode .block {padding: 0px} | |
div.gradio-container {color: black} | |
div.form {background: inherit} | |
body { | |
--text-sm: 12px; | |
--text-md: 16px; | |
--text-lg: 18px; | |
--input-text-size: 16px; | |
--section-text-size: 16px; | |
--input-background: --neutral-50; | |
} | |
.separator { | |
width: 100%; | |
height: 3px; /* Adjust the height for boldness */ | |
background-color: #000; /* Adjust the color as needed */ | |
margin: 20px 0; /* Adjust the margin as needed */ | |
} | |
""" | |
with gr.Blocks( | |
theme=gr.themes.Default( | |
spacing_size=gr.themes.sizes.spacing_md, | |
text_size=gr.themes.sizes.text_md, | |
), | |
css=css) as demo: | |
gr.Image('assets/logo.png', height=100, container=False, show_download_button=False) | |
gr.Markdown(title) | |
with gr.Accordion("π Quick Start Guide", open=False, elem_id='guide'): | |
gr.Markdown(guide) | |
with gr.Group(elem_classes='separator'): | |
pass | |
with gr.Group(elem_id='mode'): | |
mode = gr.Radio( | |
value='Randomized Paraphrase Generation', | |
label='How would you like to use this system?', | |
type="index", | |
choices=['π Randomized Paraphrase Generation', | |
'βοΈ Complexity-Matched Paraphrasing', | |
'ποΈ Manual Linguistic Control'], | |
) | |
with gr.Accordion("βοΈ Advanced Options", open=False): | |
approx = gr.Radio(value='Use approximate computation of linguistic indices (faster)', | |
choices=['Use approximate computation of linguistic indices (faster)', | |
'Use exact computation of linguistic indices'], container=False, show_label=False) | |
control_interpolation = gr.Checkbox(label='View the intermediate sentences in the interpolation of linguistic indices') | |
with gr.Accordion("π Examples...", open=False): | |
gr.Examples(examples, [sent1, ling], examples_per_page=4, label=None) | |
with gr.Row(): | |
sent1.render() | |
with gr.Column(): | |
sent2 = gr.Textbox(label='Generated text') | |
interpolation = gr.Textbox(label='Quality control interpolation', visible=False, lines=5) | |
with gr.Group(elem_classes='separator'): | |
pass | |
##################### | |
with gr.Row(): | |
generate_random_btn = gr.Button("Generate", | |
variant='primary', scale=1, visible=True) | |
count = gr.Number(label='Number of generated sentences', value=3, precision=0, scale=1, visible=True) | |
# generate_fb_btn = gr.Button("Generate with auto-adjust (towards pred)") | |
# generate_fb_s_btn = gr.Button("Generate with auto-adjust (moving s)") | |
##################### | |
with gr.Row(): | |
estimate_gen_btn = gr.Button("Generate", | |
variant='primary', | |
scale=1, visible=False) | |
sent_ling_gen = gr.Textbox(label='Text to estimate linguistic indices', scale=1, visible=False) | |
##################### | |
generate_btn = gr.Button("Generate", variant='primary', visible=False) | |
with gr.Accordion("Tools to assist in the setting of linguistic indices...", open=False, visible=False) as ling_tools: | |
with gr.Row(): | |
estimate_tgt_btn = gr.Button("Estimate linguistic indices of this sentence", visible=False) | |
sent_ling_est = gr.Textbox(label='Text to estimate linguistic indices', scale=2, visible=False) | |
estimate_src_btn = gr.Button("Estimate linguistic indices of source sentence", visible=False) | |
# rand_btn = gr.Button("Random target") | |
rand_ex_btn = gr.Button("Random target", size='lg', visible=False) | |
copy_btn = gr.Button("Copy linguistic indices of source to target", size='sm', visible=False) | |
with gr.Row(): | |
sub_btn = gr.Button('Subtract \u03B5 from target linguistic indices', visible=False) | |
add_btn = gr.Button('Add \u03B5 to target linguistic indices', visible=False) | |
ling.render() | |
##################### | |
estimate_src_btn.click(estimate_src, inputs=[sent1, ling, approx], outputs=[ling]) | |
estimate_tgt_btn.click(estimate_tgt, inputs=[sent_ling_est, ling, approx], outputs=[ling]) | |
# estimate_tgt_btn.click(estimate_tgt, inputs=[sent_ling, ling], outputs=[ling]) | |
estimate_gen_btn.click(estimate_gen, inputs=[sent1, sent_ling_gen, ling, approx], outputs=[sent2, interpolation, ling]) | |
# rand_btn.click(rand_target, inputs=[ling], outputs=[ling]) | |
rand_ex_btn.click(rand_ex_target, inputs=[ling], outputs=[ling]) | |
copy_btn.click(copy, inputs=[ling], outputs=[ling]) | |
generate_btn.click(generate_with_feedback, inputs=[sent1, ling, approx], outputs=[sent2, interpolation]) | |
generate_random_btn.click(generate_random, inputs=[sent1, ling, count, approx], | |
outputs=[sent2, interpolation, ling]) | |
# generate_fb_btn.click(generate_with_feedback, inputs=[sent1, ling], outputs=sent2s) | |
# generate_fb_s_btn.click(generate_with_feedbacks, inputs=[sent1, ling], outputs=sent2s) | |
add_btn.click(add, inputs=[ling], outputs=[ling]) | |
sub_btn.click(sub, inputs=[ling], outputs=[ling]) | |
group1 = [generate_random_btn, count] | |
group2 = [estimate_gen_btn, sent_ling_gen] | |
group3 = [generate_btn, estimate_src_btn, estimate_tgt_btn, sent_ling_est, rand_ex_btn, copy_btn, add_btn, sub_btn, ling, ling_tools] | |
components = group1 + group2 + group3 | |
mode.change(visibility, inputs=[mode], outputs=[sent2, interpolation] + components) | |
control_interpolation.change(lambda v: gr.update(visible=v), inputs=[control_interpolation], | |
outputs=[interpolation]) | |
print('Finished loading') | |
demo.launch(share=True) | |