Spaces:
Sleeping
Sleeping
def run_gradio(model, tokenizer, scaler, ling_collection, examples=None, lng_names=None, M=None): | |
import numpy as np | |
import torch | |
from datetime import datetime | |
from compute_lng import compute_lng | |
import gradio as gr | |
m = np.load('assets/m.npy') | |
m = -1/m | |
m[m == -np.inf] = 0 | |
m /= 100 | |
device = model.backbone.device | |
def visibility(mode): | |
if mode == 0: | |
vis_group = group1 | |
elif mode == 1: | |
vis_group = group2 | |
elif mode == 2: | |
vis_group = group3 | |
output = [gr.update(value=''), gr.update(value='')] | |
for component in components: | |
if component in vis_group: | |
output.append(gr.update(visible=True)) | |
else: | |
output.append(gr.update(visible=False)) | |
return output | |
def generate(sent1, ling): | |
input_ids = tokenizer.encode(sent1, return_tensors='pt').to(device) | |
ling1 = scaler.transform([ling['Source']]) | |
ling2 = scaler.transform([ling['Target']]) | |
inputs = {'sentence1_input_ids': input_ids, | |
'sentence1_ling': torch.tensor(ling1).float().to(device), | |
'sentence2_ling': torch.tensor(ling2).float().to(device), | |
'sentence1_attention_mask': torch.ones_like(input_ids)} | |
preds = [] | |
with torch.no_grad(): | |
pred = model.infer(inputs).cpu().numpy() | |
pred = tokenizer.batch_decode(pred, | |
skip_special_tokens=True)[0] | |
return pred | |
def generate_with_feedbacks(sent1, ling): | |
preds = [] | |
eta = 0.1 | |
input_ids = tokenizer.encode(sent1, return_tensors='pt').to(device) | |
ling1 = torch.tensor(scaler.transform([ling['Source']])).float().to(device) | |
ling2 = torch.tensor(scaler.transform([ling['Target']])).float().to(device) | |
ling1_embed = model.ling_embed(ling1) | |
ling2_embed = model.ling_embed(ling2) | |
cur_ling = ling1_embed + eta * (ling2_embed - ling1_embed) | |
inputs = {'sentence1_input_ids': input_ids, | |
'sent1_ling_embed': ling1_embed, | |
'sent2_ling_embed': ling2_embed, | |
'sentence1_attention_mask': torch.ones_like(input_ids)} | |
converged = False | |
c = 0 | |
while not converged: | |
with torch.no_grad(): | |
pred = model.infer(inputs) | |
inputs_pred = inputs.copy() | |
inputs_pred.update({'input_ids': pred, | |
'attention_mask': torch.ones_like(pred)}) | |
ling_pred = model.ling_disc(**inputs_pred) | |
ling_pred_embed = model.ling_embed(ling_pred) | |
if len(interpolations) == 0 or pred != interpolations[-1]: | |
interpolations.append(pred) | |
diff = torch.mean((ling2_embed - ling_pred_embed)**2) | |
scale = torch.norm(cur_ling)/torch.norm(ling2) | |
# print(f'Diff: {diff.item():.3f} / Scale: ({scale.item():.3f})>> {tokenizer.batch_decode(pred.cpu().numpy(), skip_special_tokens=True)[0]}') | |
if diff < 1e-5 or c >= 50: | |
converged = True | |
else: | |
# cur_ling = cur_ling + eta * (ling2_embed - ling_pred_embed) | |
inputs.update({ | |
'sentence1_input_ids': pred, | |
# 'sent2_ling_embed': ling2_embed, | |
'sentence1_attention_mask': torch.ones_like(pred) | |
}) | |
c += 1 | |
pred = tokenizer.batch_decode(pred.cpu().numpy(), | |
skip_special_tokens=True)[0] | |
return pred | |
def generate_with_feedback(sent1, ling, approx): | |
if sent1 == '': | |
return ['Please input a source text.', ''] | |
preds = [] | |
interpolations = [] | |
input_ids = tokenizer.encode(sent1, return_tensors='pt').to(device) | |
ling1 = torch.tensor(scaler.transform([ling['Source']])).float().to(device) | |
ling2 = torch.tensor(scaler.transform([ling['Target']])).float().to(device) | |
ling1_embed = model.ling_embed(ling1) | |
ling2_embed = model.ling_embed(ling2) | |
inputs = {'sentence1_input_ids': input_ids, | |
'sent1_ling_embed': ling1_embed, | |
'sent2_ling_embed': ling2_embed, | |
'sentence1_attention_mask': torch.ones_like(input_ids)} | |
converged = False | |
c = 0 | |
eta = 0.3 | |
while not converged: | |
with torch.no_grad(): | |
pred = model.infer(inputs) | |
inputs_pred = inputs.copy() | |
inputs_pred.update({'input_ids': pred, | |
'attention_mask': torch.ones_like(pred)}) | |
pred_text = tokenizer.batch_decode(pred.cpu().numpy(), | |
skip_special_tokens=True)[0] | |
if 'approximate' in approx: | |
ling_pred = model.ling_disc(**inputs_pred) | |
elif 'exact' in approx: | |
ling_pred = compute_lng(pred_text) | |
ling_pred = scaler.transform([ling_pred])[0] | |
ling_pred = torch.tensor(ling_pred).to(pred.device).float() | |
else: | |
raise ValueError() | |
ling_pred_embed = model.ling_embed(ling_pred) | |
if len(interpolations) == 0 or pred_text != interpolations[-1]: | |
interpolations.append(pred_text) | |
diff = torch.mean((ling2_embed - ling_pred_embed)**2) | |
# print(f'Diff {diff.item():.3f}>> {tokenizer.batch_decode(pred.cpu().numpy(), skip_special_tokens=True)[0]}') | |
if diff < 10 or c >= 50: | |
converged = True | |
else: | |
ling2_embed = ling2_embed + eta * (ling_pred_embed - ling2_embed) | |
inputs.update({'sent2_ling_embed': ling2_embed}) | |
c += 1 | |
interpolation = '-- ' + '\n-- '.join(interpolations) | |
return [pred_text, interpolation] | |
def generate_random(sent1, ling, count, approx): | |
preds, interpolations = [], [] | |
for c in range(count): | |
idx = np.random.randint(0, len(ling_collection)) | |
ling_ex = ling_collection[idx] | |
ling['Target'] = ling_ex | |
pred, interpolation = generate_with_feedback(sent1, ling, approx) | |
preds.append(pred) | |
interpolations.append(interpolation) | |
return '\n***\n'.join(preds), '\n***\n'.join(interpolations), ling | |
def estimate_gen(sent1, sent2, ling, approx): | |
if 'approximate' in approx: | |
input_ids = tokenizer.encode(sent2, return_tensors='pt').to(device) | |
with torch.no_grad(): | |
ling_pred = model.ling_disc(input_ids=input_ids).cpu().numpy() | |
ling_pred = scaler.inverse_transform(ling_pred)[0] | |
elif 'exact' in approx: | |
ling_pred = compute_lng(sent2) | |
else: | |
raise ValueError() | |
ling['Target'] = ling_pred | |
gen = generate_with_feedback(sent1, ling, approx) | |
results = gen + [ling] | |
return results | |
def estimate_tgt(sent2, ling, approx): | |
if 'approximate' in approx: | |
input_ids = tokenizer.encode(sent2, return_tensors='pt').to(device) | |
with torch.no_grad(): | |
ling_pred = model.ling_disc(input_ids=input_ids).cpu().numpy() | |
ling_pred = scaler.inverse_transform(ling_pred)[0] | |
elif 'exact' in approx: | |
ling_pred = compute_lng(sent2) | |
else: | |
raise ValueError() | |
ling['Target'] = ling_pred | |
return ling | |
def estimate_src(sent1, ling, approx): | |
if 'approximate' in approx: | |
input_ids = tokenizer.encode(sent1, return_tensors='pt').to(device) | |
with torch.no_grad(): | |
ling_pred = model.ling_disc(input_ids=input_ids).cpu().numpy() | |
ling_pred = scaler.inverse_transform(ling_pred)[0] | |
elif 'exact' in approx: | |
ling_pred = compute_lng(sent1) | |
else: | |
raise ValueError() | |
ling['Source'] = ling_pred | |
return ling | |
def rand_target(ling): | |
ling['Target'] = scaler.inverse_transform([np.random.randn(*ling['Target'].shape)])[0] | |
return ling | |
def rand_ex_target(ling): | |
idx = np.random.randint(0, len(examples)) | |
ling_ex = examples[idx][1] | |
ling['Target'] = ling_ex['Target'] | |
return ling | |
def copy(ling): | |
ling['Target'] = ling['Source'] | |
return ling | |
def add_noise(ling): | |
x = scaler.transform([ling['Target']]) | |
x += np.random.randn(*ling['Target'].shape) | |
x = scaler.inverse_transform(x)[0] | |
ling['Target'] = x | |
return ling | |
def add(ling): | |
x = scaler.transform([ling['Target']]) | |
x += m | |
x = scaler.inverse_transform(x)[0] | |
ling['Target'] = x | |
return ling | |
def sub(ling): | |
x = scaler.transform([ling['Target']]) | |
x -= m | |
x = scaler.inverse_transform(x)[0] | |
ling['Target'] = x | |
return ling | |
# title = '' | |
# for i, model in enumerate(models): | |
# if i > 0: | |
# title += '\n' | |
# title += f"model ({i})\n\tUsing VAE = {model.args.ling_vae}\n\tUsing ICA = {model.args.use_ica}\n\tNumber of features = {model.args.lng_dim if not model.args.use_ica else model.args.n_ica}" | |
title = """ | |
# LingConv: A System for Controlled Linguistic Conversion | |
## Description | |
This system is an encoder-decoder model for complexity controlled text generation, guided by 241 | |
linguistic complexity indices as key attributes. Given a sentence and a desired level of linguistic | |
complexity, the model can generate diverse paraphrases that maintain consistent meaning, adjusted for | |
different linguistic complexity levels. However, it's important to note that not all index combinations are | |
feasible (such as requesting a sentence of "length" 5 with 10 "unique words"). To ensure high quality | |
outputs, our approach interpolates the embedding of linguistic indices to locate the most closely matched, | |
achievable set of indices for the given target. | |
""" | |
guide = """ | |
You may use the system in on of the following ways: | |
**Randomized Paraphrase Generation**: Select this option to produce multiple paraphrases with a range | |
of linguistic complexity. You need to provide a source text, specify the number of paraphrases you want, | |
and click "Generate." The linguistic complexity of the paraphrases will be determined randomly. | |
**Complexity-Matched Paraphrasing**: Select this option to generate a paraphrase of the given source | |
sentence that closely mirrors the linguistic complexity of another given sentence. Input your source | |
sentence along with another sentence (which will serve only to extract linguistic indices for the | |
paraphrase generation). Then, click "Generate." | |
**Manual Linguistic Control**: Select this option to manually control the linguistic complexity of the | |
generated text. We provided a set of tools for manual adjustments of the desired linguistic complexity of | |
the target sentence. These tools enable the user to extract linguistic indices from a given sentence, | |
generate a random (yet coherent) set of linguistic indices, and add or remove noise from the indices. | |
These tools are designed for experimental use and require the user to possess linguistic expertise for | |
effective input of linguistic indices. To use these tools, select "Tools to assist in setting linguistic | |
indices." Once indices are entered, click "Generate." | |
Second, you may select to use exact or approximate computation of linguistic indices (used in mode (2) and | |
in quality control of the genration). Approximate computation is significantly faster. | |
Third, you may view the intermediate sentences of the quality control process by selecting the checkbox. | |
Fourth, you may try out some examples by clicking on "Examples...". Examples consist of a source sentences, | |
the indices of the source sentences, and a sample set of target linguistic indices. | |
Please make your choice below. | |
""" | |
sent1 = gr.Textbox(label='Source text') | |
ling = gr.Dataframe(value = [[x, 0, 0] for x in lng_names], | |
headers=['Index', 'Source', 'Target'], | |
datatype=['str', 'number', 'number'], visible=False) | |
css = """ | |
#guide span.svelte-s1r2yt {font-size: 22px !important; | |
font-weight: 600 !important} | |
""" | |
with gr.Blocks(css=css) as demo: | |
gr.Markdown(title) | |
with gr.Accordion("Quick Start Guide", open=False, elem_id='guide'): | |
gr.Markdown(guide) | |
mode = gr.Radio(value='Randomized Paraphrase Generation', | |
label='How would you like to use this system?', | |
type="index", | |
choices=['Randomized Paraphrase Generation', | |
'Complexity-Matched Paraphrasing', 'Manual Linguistic Control']) | |
approx = gr.Radio(value='Use approximate computation of linguistic indices (faster)', | |
choices=['Use approximate computation of linguistic indices (faster)', | |
'Use exact computation of linguistic indices'], container=False, show_label=False) | |
control_interpolation = gr.Checkbox(label='View the intermediate sentences in the interpolation of linguistic indices') | |
with gr.Accordion("Examples...", open=False): | |
gr.Examples(examples, [sent1, ling], examples_per_page=4, label=None) | |
with gr.Row(): | |
sent1.render() | |
with gr.Column(): | |
sent2 = gr.Textbox(label='Generated text') | |
interpolation = gr.Textbox(label='Quality control interpolation', visible=False, lines=5) | |
##################### | |
with gr.Row(): | |
generate_random_btn = gr.Button("Generate", | |
variant='primary', scale=1, visible=True) | |
count = gr.Number(label='Number of generated sentences', value=3, precision=0, scale=1, visible=True) | |
# generate_fb_btn = gr.Button("Generate with auto-adjust (towards pred)") | |
# generate_fb_s_btn = gr.Button("Generate with auto-adjust (moving s)") | |
# add_noise_btn = gr.Button('Add noise to target linguistic indices') | |
##################### | |
with gr.Row(): | |
estimate_gen_btn = gr.Button("Generate", | |
variant='primary', | |
scale=1, visible=False) | |
sent_ling_gen = gr.Textbox(label='Text to estimate linguistic indices', scale=1, visible=False) | |
##################### | |
generate_btn = gr.Button("Generate", variant='primary', visible=False) | |
with gr.Accordion("Tools to assist in the setting of linguistic indices...", open=False, visible=False) as ling_tools: | |
with gr.Row(): | |
estimate_tgt_btn = gr.Button("Estimate linguistic indices of this sentence", visible=False) | |
sent_ling_est = gr.Textbox(label='Text to estimate linguistic indices', scale=2, visible=False) | |
estimate_src_btn = gr.Button("Estimate linguistic indices of source sentence", visible=False) | |
# rand_btn = gr.Button("Random target") | |
rand_ex_btn = gr.Button("Random target", size='lg', visible=False) | |
copy_btn = gr.Button("Copy linguistic indices of source to target", size='sm', visible=False) | |
with gr.Row(): | |
add_btn = gr.Button('Add \u03B5 to target linguistic indices', visible=False) | |
sub_btn = gr.Button('Subtract \u03B5 from target linguistic indices', visible=False) | |
ling.render() | |
##################### | |
estimate_src_btn.click(estimate_src, inputs=[sent1, ling, approx], outputs=[ling]) | |
estimate_tgt_btn.click(estimate_tgt, inputs=[sent_ling_est, ling, approx], outputs=[ling]) | |
# estimate_tgt_btn.click(estimate_tgt, inputs=[sent_ling, ling], outputs=[ling]) | |
estimate_gen_btn.click(estimate_gen, inputs=[sent1, sent_ling_gen, ling, approx], outputs=[sent2, interpolation, ling]) | |
# rand_btn.click(rand_target, inputs=[ling], outputs=[ling]) | |
rand_ex_btn.click(rand_ex_target, inputs=[ling], outputs=[ling]) | |
copy_btn.click(copy, inputs=[ling], outputs=[ling]) | |
generate_btn.click(generate_with_feedback, inputs=[sent1, ling, approx], outputs=[sent2, interpolation]) | |
generate_random_btn.click(generate_random, inputs=[sent1, ling, count, approx], | |
outputs=[sent2, interpolation, ling]) | |
# generate_fb_btn.click(generate_with_feedback, inputs=[sent1, ling], outputs=sent2s) | |
# generate_fb_s_btn.click(generate_with_feedbacks, inputs=[sent1, ling], outputs=sent2s) | |
add_btn.click(add, inputs=[ling], outputs=[ling]) | |
sub_btn.click(sub, inputs=[ling], outputs=[ling]) | |
# add_noise_btn.click(add_noise, inputs=[ling], outputs=[ling]) | |
group1 = [generate_random_btn, count] | |
group2 = [estimate_gen_btn, sent_ling_gen] | |
group3 = [generate_btn, estimate_src_btn, estimate_tgt_btn, sent_ling_est, rand_ex_btn, copy_btn, add_btn, sub_btn, ling, ling_tools] | |
components = group1 + group2 + group3 | |
mode.change(visibility, inputs=[mode], outputs=[sent2, interpolation] + components) | |
control_interpolation.change(lambda v: gr.update(visible=v), inputs=[control_interpolation], | |
outputs=[interpolation]) | |
demo.launch(share=True) | |