mohdelgaar commited on
Commit
54ba470
β€’
1 Parent(s): 01ac9fe

implement imputation

Browse files
Files changed (1) hide show
  1. app.py +27 -7
app.py CHANGED
@@ -13,6 +13,9 @@ from model import get_model
13
  from options import parse_args
14
  from transformers import T5Tokenizer
15
  from compute_lng import compute_lng
 
 
 
16
 
17
 
18
  def process_examples(samples, full_names):
@@ -35,10 +38,12 @@ examples = process_examples(examples, lng_names)
35
 
36
  stats = json.load(open('assets/stats.json'))
37
 
38
- ling_collection = np.load('assets/ling_collection.npy')
39
  scaler = joblib.load('assets/scaler.bin')
40
  scale_ratio = np.load('assets/ratios.npy')
41
 
 
 
 
42
  model, ling_disc, sem_emb = get_model(args, tokenizer, device)
43
 
44
  state = torch.load(args.ckpt, map_location=torch.device('cpu'))
@@ -201,6 +206,21 @@ def sub(ling):
201
  ling['Target'] = x
202
  return ling
203
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  title = """
205
  <h1 style="text-align: center;">Controlled Paraphrase Generation with Linguistic Feature Control</h1>
206
 
@@ -255,6 +275,8 @@ css = """
255
  #mode {border: 0px; box-shadow: none}
256
  #mode .block {padding: 0px}
257
 
 
 
258
  div.gradio-container {color: black}
259
  div.form {background: inherit}
260
 
@@ -336,6 +358,7 @@ with gr.Blocks(
336
  generate_btn = gr.Button("Generate", variant='primary', visible=False)
337
  with gr.Accordion("Tools to assist in the setting of linguistic indices...", open=False, visible=False) as ling_tools:
338
  rand_ex_btn = gr.Button("Random target", size='lg', visible=False)
 
339
  with gr.Row():
340
  estimate_src_btn = gr.Button("Estimate linguistic indices of source sentence", visible=False)
341
  copy_btn = gr.Button("Copy linguistic indices of source to target", size='lg', visible=False)
@@ -344,28 +367,25 @@ with gr.Blocks(
344
  add_btn = gr.Button('Add \u03B5 to target linguistic indices', visible=False)
345
  with gr.Row():
346
  estimate_tgt_btn = gr.Button("Estimate linguistic indices of this sentence β†’", visible=False)
347
- sent_ling_est = gr.Textbox(label='Text to estimate linguistic indices', scale=2, visible=False, container=False)
348
  ling.render()
349
  #####################
350
 
351
  estimate_src_btn.click(estimate_src, inputs=[sent1, ling, approx], outputs=[ling])
352
  estimate_tgt_btn.click(estimate_tgt, inputs=[sent_ling_est, ling, approx], outputs=[ling])
353
- # estimate_tgt_btn.click(estimate_tgt, inputs=[sent_ling, ling], outputs=[ling])
354
  estimate_gen_btn.click(estimate_gen, inputs=[sent1, sent_ling_gen, ling, approx], outputs=[sent2, interpolation, ling])
355
- # rand_btn.click(rand_target, inputs=[ling], outputs=[ling])
356
  rand_ex_btn.click(rand_ex_target, inputs=[ling], outputs=[ling])
 
357
  copy_btn.click(copy, inputs=[ling], outputs=[ling])
358
  generate_btn.click(generate_with_feedback, inputs=[sent1, ling, approx], outputs=[sent2, interpolation])
359
  generate_random_btn.click(generate_random, inputs=[sent1, ling, count, approx],
360
  outputs=[sent2, interpolation, ling])
361
- # generate_fb_btn.click(generate_with_feedback, inputs=[sent1, ling], outputs=sent2s)
362
- # generate_fb_s_btn.click(generate_with_feedbacks, inputs=[sent1, ling], outputs=sent2s)
363
  add_btn.click(add, inputs=[ling], outputs=[ling])
364
  sub_btn.click(sub, inputs=[ling], outputs=[ling])
365
 
366
  group1 = [generate_random_btn, count]
367
  group2 = [estimate_gen_btn, sent_ling_gen]
368
- group3 = [generate_btn, estimate_src_btn, estimate_tgt_btn, sent_ling_est, rand_ex_btn, copy_btn, add_btn, sub_btn, ling, ling_tools]
369
  components = group1 + group2 + group3
370
  mode.change(visibility, inputs=[mode], outputs=[sent2, interpolation] + components)
371
  control_interpolation.change(lambda v: gr.update(visible=v), inputs=[control_interpolation],
 
13
  from options import parse_args
14
  from transformers import T5Tokenizer
15
  from compute_lng import compute_lng
16
+ from sklearn.experimental import enable_iterative_imputer
17
+ from sklearn.impute import IterativeImputer
18
+ from sklearn.linear_model import Ridge
19
 
20
 
21
  def process_examples(samples, full_names):
 
38
 
39
  stats = json.load(open('assets/stats.json'))
40
 
 
41
  scaler = joblib.load('assets/scaler.bin')
42
  scale_ratio = np.load('assets/ratios.npy')
43
 
44
+ ling_collection = np.load('assets/ling_collection.npy')
45
+ ling_collection_scaled = scaler.transform(ling_collection)
46
+
47
  model, ling_disc, sem_emb = get_model(args, tokenizer, device)
48
 
49
  state = torch.load(args.ckpt, map_location=torch.device('cpu'))
 
206
  ling['Target'] = x
207
  return ling
208
 
209
+ def impute(ling):
210
+ ling['Target'] = ling['Target'].replace("", np.nan)
211
+ ling['Target'] = scaler.transform([ling['Target']])[0]
212
+ estimator = Ridge(alpha=1e3, fit_intercept=False)
213
+ imputer = IterativeImputer(estimator=estimator, imputation_order='random', max_iter=100)
214
+
215
+ combined_matrix = np.vstack([ling_collection, ling['Target']])
216
+ interpolated_matrix = imputer.fit_transform(combined_matrix)
217
+ interpolated_vector = interpolated_matrix[-1]
218
+
219
+ interp_raw = scaler.inverse_transform([interpolated_vector])[0]
220
+
221
+ ling['Target'] = round_ling(interp_raw)
222
+ return ling
223
+
224
  title = """
225
  <h1 style="text-align: center;">Controlled Paraphrase Generation with Linguistic Feature Control</h1>
226
 
 
275
  #mode {border: 0px; box-shadow: none}
276
  #mode .block {padding: 0px}
277
 
278
+ #estimate textarea {border: 1px solid; border-radius: 7px}
279
+
280
  div.gradio-container {color: black}
281
  div.form {background: inherit}
282
 
 
358
  generate_btn = gr.Button("Generate", variant='primary', visible=False)
359
  with gr.Accordion("Tools to assist in the setting of linguistic indices...", open=False, visible=False) as ling_tools:
360
  rand_ex_btn = gr.Button("Random target", size='lg', visible=False)
361
+ impute_btn = gr.Button("Impute Missing Values", size='lg', visible=False)
362
  with gr.Row():
363
  estimate_src_btn = gr.Button("Estimate linguistic indices of source sentence", visible=False)
364
  copy_btn = gr.Button("Copy linguistic indices of source to target", size='lg', visible=False)
 
367
  add_btn = gr.Button('Add \u03B5 to target linguistic indices', visible=False)
368
  with gr.Row():
369
  estimate_tgt_btn = gr.Button("Estimate linguistic indices of this sentence β†’", visible=False)
370
+ sent_ling_est = gr.Textbox(label='Text to estimate linguistic indices', scale=2, visible=False, container=False, elem_id='estimate')
371
  ling.render()
372
  #####################
373
 
374
  estimate_src_btn.click(estimate_src, inputs=[sent1, ling, approx], outputs=[ling])
375
  estimate_tgt_btn.click(estimate_tgt, inputs=[sent_ling_est, ling, approx], outputs=[ling])
 
376
  estimate_gen_btn.click(estimate_gen, inputs=[sent1, sent_ling_gen, ling, approx], outputs=[sent2, interpolation, ling])
 
377
  rand_ex_btn.click(rand_ex_target, inputs=[ling], outputs=[ling])
378
+ impute_btn.click(impute, inputs=[ling], outputs=[ling])
379
  copy_btn.click(copy, inputs=[ling], outputs=[ling])
380
  generate_btn.click(generate_with_feedback, inputs=[sent1, ling, approx], outputs=[sent2, interpolation])
381
  generate_random_btn.click(generate_random, inputs=[sent1, ling, count, approx],
382
  outputs=[sent2, interpolation, ling])
 
 
383
  add_btn.click(add, inputs=[ling], outputs=[ling])
384
  sub_btn.click(sub, inputs=[ling], outputs=[ling])
385
 
386
  group1 = [generate_random_btn, count]
387
  group2 = [estimate_gen_btn, sent_ling_gen]
388
+ group3 = [generate_btn, estimate_src_btn, impute_btn, estimate_tgt_btn, sent_ling_est, rand_ex_btn, copy_btn, add_btn, sub_btn, ling, ling_tools]
389
  components = group1 + group2 + group3
390
  mode.change(visibility, inputs=[mode], outputs=[sent2, interpolation] + components)
391
  control_interpolation.change(lambda v: gr.update(visible=v), inputs=[control_interpolation],