Gabriela Nicole Gonzalez Saez commited on
Commit
8d6b878
1 Parent(s): bb65bbe
Files changed (3) hide show
  1. app.py +758 -0
  2. plotsjs.js +744 -0
  3. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,758 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from time import time
3
+
4
+ import torch
5
+ import os
6
+ # import nltk
7
+ import argparse
8
+ import random
9
+ import numpy as np
10
+ import faiss
11
+ from argparse import Namespace
12
+ from tqdm.notebook import tqdm
13
+ from torch.utils.data import DataLoader
14
+ from functools import partial
15
+ from sklearn.manifold import TSNE
16
+
17
+ from transformers import AutoTokenizer, MarianTokenizer, AutoModel, AutoModelForSeq2SeqLM, MarianMTModel
18
+ import os
19
+ dir_path = os.path.dirname(os.path.realpath(__file__))
20
+ print(dir_path)
21
+
22
+ metadata_all = {}
23
+ model_es = "Helsinki-NLP/opus-mt-en-es"
24
+ model_fr = "Helsinki-NLP/opus-mt-en-fr"
25
+ model_zh = "Helsinki-NLP/opus-mt-en-zh"
26
+
27
+ tokenizer_es = AutoTokenizer.from_pretrained(model_es)
28
+ tokenizer_fr = AutoTokenizer.from_pretrained(model_fr)
29
+ tokenizer_zh = AutoTokenizer.from_pretrained(model_zh)
30
+
31
+ model_tr_es = MarianMTModel.from_pretrained(model_es)
32
+ model_tr_fr = MarianMTModel.from_pretrained(model_fr)
33
+ model_tr_zh = MarianMTModel.from_pretrained(model_zh)
34
+
35
+ dict_models = {
36
+ 'en-es': model_es,
37
+ 'en-fr': model_fr,
38
+ 'en-zh': model_zh,
39
+ }
40
+
41
+ dict_models_tr = {
42
+ 'en-es': model_tr_es,
43
+ 'en-fr': model_tr_fr,
44
+ 'en-zh': model_tr_zh,
45
+ }
46
+
47
+ dict_tokenizer_tr = {
48
+ 'en-es': tokenizer_es,
49
+ 'en-fr': tokenizer_fr,
50
+ 'en-zh': tokenizer_zh,
51
+ }
52
+
53
+ from faiss import write_index, read_index
54
+ import pickle
55
+
56
+
57
+
58
+ def translation_model(w1,model ):
59
+ inputs = dict_tokenizer_tr[model](w1, return_tensors="pt")
60
+ # embeddings = get_tokens_embeddings(inputs, model)
61
+ input_embeddings = dict_models_tr[model].get_encoder().embed_tokens(inputs.input_ids)
62
+ # model_tr_es.get_input_embeddings()
63
+ print(inputs)
64
+ num_ret_seq = 1
65
+ translated = dict_models_tr[model].generate(**inputs,
66
+ num_beams=5,
67
+ num_return_sequences=num_ret_seq,
68
+ return_dict_in_generate=True,
69
+ output_attentions =False,
70
+ output_hidden_states = True,
71
+ output_scores=True,)
72
+
73
+ tgt_text = dict_tokenizer_tr[model].decode(translated.sequences[0], skip_special_tokens=True)
74
+
75
+ target_embeddings = dict_models_tr[model].get_decoder().embed_tokens(translated.sequences)
76
+
77
+ return tgt_text, translated, inputs.input_ids, input_embeddings, target_embeddings
78
+
79
+ def create_vocab_multiple(embeddings_list, model):
80
+ """_summary_
81
+
82
+ Args:
83
+ embeddings_list (list): embedding array
84
+
85
+ Returns:
86
+ Dict: vocabulary of tokens' embeddings
87
+ """
88
+ print("START VOCAB CREATION MULTIPLE \n \n ")
89
+ vocab = {} ## add embedds.
90
+ sentence_tokens_text_list = []
91
+ for embeddings in embeddings_list:
92
+ tokens_id = embeddings['tokens'] # [[tokens_id]x n_sentences ]
93
+ for sent_i, sentence in enumerate(tokens_id):
94
+ sentence_tokens = []
95
+ for tok_i, token in enumerate(sentence):
96
+ sentence_tokens.append(token)
97
+ if not (token in vocab):
98
+ vocab[token] = {
99
+ 'token' : token,
100
+ 'count': 1,
101
+ # 'text': embeddings['texts'][sent_i][tok_i],
102
+ 'text': dict_tokenizer_tr[model].decode([token]),
103
+ # 'text': src_token_lists[sent_i][tok_i],
104
+ 'embed': embeddings['embeddings'][sent_i][tok_i]}
105
+ else:
106
+ vocab[token]['count'] = vocab[token]['count'] + 1
107
+ # print(vocab)
108
+ sentence_tokens_text_list.append(sentence_tokens)
109
+ print("END VOCAB CREATION MULTIPLE \n \n ")
110
+ return vocab, sentence_tokens_text_list
111
+
112
+ def vocab_words_all_prefix(token_embeddings, model, sufix="@@",prefix = '▁' ):
113
+ vocab = {}
114
+ # inf_model = dict_models_tr[model]
115
+ sentence_words_text_list = []
116
+ if prefix :
117
+ n_prefix = len(prefix)
118
+ for input_sentences in token_embeddings:
119
+ # n_tokens_in_word
120
+ for sent_i, sentence in enumerate(input_sentences['tokens']):
121
+ words_text_list = []
122
+ # embedding = input_sentences['embed'][sent_i]
123
+ word = ''
124
+ tokens_ids = []
125
+ embeddings = []
126
+ ids_to_tokens = dict_tokenizer_tr[model].convert_ids_to_tokens(sentence)
127
+ # print("validate same len", len(sentence) == len(ids_to_tokens), len(sentence), len(ids_to_tokens), ids_to_tokens)
128
+
129
+ to_save= False
130
+ for tok_i, token_text in enumerate(ids_to_tokens):
131
+ token_id = sentence[tok_i]
132
+ if token_text[:n_prefix] == prefix :
133
+ #first we save the previous word
134
+ if to_save:
135
+ vocab[word] = {
136
+ 'word' : word,
137
+ 'text': word,
138
+ 'count': 1,
139
+ 'tokens_ids' : tokens_ids,
140
+ 'embed': np.mean(np.array(embeddings), 0).tolist()
141
+ }
142
+ words_text_list.append(word)
143
+ #word is starting if prefix
144
+ tokens_ids = [token_id]
145
+ embeddings = [input_sentences['embeddings'][sent_i][tok_i]]
146
+ word = token_text[n_prefix:]
147
+ ## if word
148
+ to_save = True
149
+
150
+ else :
151
+ if (token_text in dict_tokenizer_tr[model].special_tokens_map.values()):
152
+ # print('final or save', token_text, token_id, to_save, word)
153
+ if to_save:
154
+ # vocab[word] = ids
155
+ vocab[word] = {
156
+ 'word' : word,
157
+ 'text': word,
158
+ 'count': 1,
159
+ 'tokens_ids' : tokens_ids,
160
+ 'embed': np.mean(np.array(embeddings), 0).tolist()
161
+ }
162
+ words_text_list.append(word)
163
+ #special token is one token element, no continuation
164
+ # vocab[token_text] = [token_id]
165
+ tokens_ids = [token_id]
166
+ embeddings = [input_sentences['embeddings'][sent_i][tok_i]]
167
+ vocab[token_text] = {
168
+ 'word' : token_text,
169
+ 'count': 1,
170
+ 'text': word,
171
+ 'tokens_ids' : tokens_ids,
172
+ 'embed': np.mean(np.array(embeddings), 0).tolist()
173
+ }
174
+ words_text_list.append(token_text)
175
+ to_save = False
176
+ else:
177
+ # is a continuation; we do not know if it is final; we don't save here.
178
+ to_save = True
179
+ word += token_text
180
+ tokens_ids.append(token_id)
181
+ embeddings.append(input_sentences['embeddings'][sent_i][tok_i])
182
+ if to_save:
183
+ # print('final save', token_text, token_id, to_save, word)
184
+ vocab[word] = tokens_ids
185
+ if not (word in vocab):
186
+ vocab[word] = {
187
+ 'word' : word,
188
+ 'count': 1,
189
+ 'text': word,
190
+ 'tokens_ids' : tokens_ids,
191
+ 'embed': np.mean(np.array(embeddings), 0).tolist()
192
+ }
193
+ words_text_list.append(word)
194
+ else:
195
+ vocab[word]['count'] = vocab[word]['count'] + 1
196
+ sentence_words_text_list.append(words_text_list)
197
+
198
+ return vocab, sentence_words_text_list
199
+
200
+ # nb_ids.append(token_values['token']) # for x in vocab_tokens]
201
+ # nb_embds.append(token_values['embed']) # for x in vocab_tokens]
202
+
203
+ def create_index_voronoi(vocab):
204
+ """
205
+ it returns an index of words and a metadata of ids.
206
+ """
207
+ d = 1024
208
+ nb_embds = [] ##ordered embeddings list
209
+ metadata = {}
210
+ i_pos = 0
211
+ for key_token, token_values in vocab.items():
212
+ nb_embds.append(token_values['embed']) # for x in vocab_tokens]
213
+ metadata[i_pos] = {'token': token_values['token'], 'text': token_values['text']}
214
+ i_pos += 1
215
+ # nb_embds = [x['embed'] for x in vocab_tokens]
216
+
217
+ # print(len(nb_embds),len(nb_embds[0]) )
218
+ xb = np.array(nb_embds).astype('float32') #elements to index
219
+ # ids = np.array(nb_ids)
220
+ d = len(xb[0]) # dimension of each element
221
+
222
+ nlist = 5 # Nb of Voronois
223
+ quantizer = faiss.IndexFlatL2(d)
224
+ index = faiss.IndexIVFFlat(quantizer, d, nlist)
225
+ index.train(xb)
226
+ index.add(xb)
227
+ # index.add(xb)
228
+
229
+ return index, metadata## , nb_embds, nb_ids
230
+
231
+ def create_index_voronoi_words(vocab):
232
+ """
233
+ it returns an index of words and a metadata of ids.
234
+ """
235
+ d = 1024
236
+ nb_embds = [] ##ordered embeddings list
237
+ metadata = {}
238
+ i_pos = 0
239
+ for key_token, token_values in vocab.items():
240
+ nb_embds.append(token_values['embed']) # for x in vocab_tokens]
241
+ metadata[i_pos] = {'word': token_values['word'], 'tokens': token_values['tokens_ids'],'text': token_values['text']}
242
+ i_pos += 1
243
+ # nb_embds = [x['embed'] for x in vocab_tokens]
244
+
245
+ # print(len(nb_embds),len(nb_embds[0]) )
246
+ xb = np.array(nb_embds).astype('float32') #elements to index
247
+ # ids = np.array(nb_ids)
248
+ d = len(xb[0]) # dimension of each element
249
+
250
+ nlist = 5 # Nb of Voronois
251
+ quantizer = faiss.IndexFlatL2(d)
252
+ index = faiss.IndexIVFFlat(quantizer, d, nlist)
253
+ index.train(xb)
254
+ index.add(xb)
255
+ # index.add(xb)
256
+
257
+ return index, metadata## , nb_embds, nb_ids
258
+
259
+ def search_query_vocab(index, vocab_queries, topk = 10, limited_search = []):
260
+ """ the embed queries are a vocabulary of words : embds_input_voc
261
+
262
+ Args:
263
+ index (_type_): faiss index
264
+ embed_queries (_type_): vocab format.
265
+ { 'token' : token,
266
+ 'count': 1,
267
+ 'text': src_token_lists[sent_i][tok_i],
268
+ 'embed': embeddings[0]['embeddings'][sent_i][tok_i] }
269
+ nb_ids (_type_): hash to find the token_id w.r.t the faiss index id.
270
+ topk (int, optional): nb of similar tokens. Defaults to 10.
271
+
272
+ Returns:
273
+ _type_: Distance matrix D, indices matrix I and tokens ids (using nb_ids)
274
+ """
275
+ # nb_qi_ids = [] ##ordered ids list
276
+ nb_q_embds = [] ##ordered embeddings list
277
+ metadata = {}
278
+ qi_pos = 0
279
+ for key , token_values in vocab_queries.items():
280
+ # nb_qi_ids.append(token_values['token']) # for x in vocab_tokens]
281
+ metadata[qi_pos] = {'word': token_values['word'], 'tokens': token_values['tokens_ids'], 'text': token_values['text']}
282
+ qi_pos += 1
283
+ nb_q_embds.append(token_values['embed']) # for x in vocab_tokens]
284
+
285
+ xq = np.array(nb_q_embds).astype('float32') #elements to query
286
+
287
+ D,I = index.search(xq, topk)
288
+
289
+ return D,I, metadata
290
+
291
+ def search_query_vocab_token(index, vocab_queries, topk = 10, limited_search = []):
292
+ """ the embed queries are a vocabulary of words : embds_input_vov
293
+ Returns:
294
+ _type_: Distance matrix D, indices matrix I and tokens ids (using nb_ids)
295
+ """
296
+ # nb_qi_ids = [] ##ordered ids list
297
+ nb_q_embds = [] ##ordered embeddings list
298
+ metadata = {}
299
+ qi_pos = 0
300
+ for key , token_values in vocab_queries.items():
301
+ # nb_qi_ids.append(token_values['token']) # for x in vocab_tokens]
302
+ metadata[qi_pos] = {'token': token_values['token'], 'text': token_values['text']}
303
+ qi_pos += 1
304
+ nb_q_embds.append(token_values['embed']) # for x in vocab_tokens]
305
+
306
+ xq = np.array(nb_q_embds).astype('float32') #elements to query
307
+
308
+ D,I = index.search(xq, topk)
309
+
310
+ return D,I, metadata
311
+
312
+ def build_search(query_embeddings, model,type="input"):
313
+ global metadata_all
314
+
315
+ # ## biuld vocab for index
316
+ vocab_queries, sentence_tokens_list = create_vocab_multiple(query_embeddings, model)
317
+ words_vocab_queries, sentence_words_list = vocab_words_all_prefix(query_embeddings, model, sufix="@@",prefix="▁")
318
+
319
+ index_vor_tokens = metadata_all[type]['tokens'][1]
320
+ md_tokens = metadata_all[type]['tokens'][2]
321
+ D, I, meta = search_query_vocab_token(index_vor_tokens, vocab_queries)
322
+
323
+ qi_pos = 0
324
+ similar_tokens = {}
325
+ # similar_tokens = []
326
+ for dist, ind in zip(D,I):
327
+ try:
328
+ # similar_tokens.append({
329
+ similar_tokens[str(meta[qi_pos]['token'])] = {
330
+ 'token': meta[qi_pos]['token'],
331
+ 'text': meta[qi_pos]['text'],
332
+ # 'text': dict_tokenizer_tr[model].decode(meta[qi_pos]['token'])
333
+ # 'text': meta[qi_pos]['text'],
334
+ "similar_topk": [md_tokens[i_index]['token'] for i_index in ind if (i_index != -1) ],
335
+ "distance": [dist[i] for (i, i_index) in enumerate(ind) if (i_index != -1)],
336
+ }
337
+ # )
338
+ except:
339
+ print("\n ERROR ", qi_pos, dist, ind)
340
+ qi_pos += 1
341
+
342
+
343
+ index_vor_words = metadata_all[type]['words'][1]
344
+ md_words = metadata_all[type]['words'][2]
345
+
346
+ Dw, Iw, metaw = search_query_vocab(index_vor_words, words_vocab_queries)
347
+ # D, I, meta, vocab_words, sentence_words_list = result_input['words']# [2] # D ; I ; meta
348
+ qi_pos = 0
349
+ # similar_words = []
350
+ similar_words = {}
351
+ for dist, ind in zip(Dw,Iw):
352
+ try:
353
+ # similar_words.append({
354
+ similar_words[str(metaw[qi_pos]['word']) ] = {
355
+ 'word': metaw[qi_pos]['word'],
356
+ 'text': metaw[qi_pos]['word'],
357
+ "similar_topk": [md_words[i_index]['word'] for i_index in ind if (i_index != -1) ],
358
+ "distance": [dist[i] for (i, i_index) in enumerate(ind) if (i_index != -1)],
359
+ }
360
+ # )
361
+ except:
362
+ print("\n ERROR ", qi_pos, dist, ind)
363
+ qi_pos += 1
364
+
365
+
366
+ return {'tokens': {'D': D, 'I': I, 'meta': meta, 'vocab_queries': vocab_queries, 'similar':similar_tokens, 'sentence_key_list': sentence_tokens_list},
367
+ 'words': {'D':Dw,'I': Iw, 'meta': metaw, 'vocab_queries':words_vocab_queries, 'sentence_key_list': sentence_words_list, 'similar': similar_words}
368
+ }
369
+
370
+ def build_reference(all_embeddings, model):
371
+
372
+ # ## biuld vocab for index
373
+ vocab, sentence_tokens = create_vocab_multiple(all_embeddings,model)
374
+ words_vocab, sentences = vocab_words_all_prefix(all_embeddings, model, sufix="@@",prefix="▁")
375
+
376
+ index_tokens, meta_tokens = create_index_voronoi(vocab)
377
+ index_words, meta_words = create_index_voronoi_words(words_vocab)
378
+
379
+
380
+
381
+ return {'tokens': [vocab, index_tokens, meta_tokens],
382
+ 'words': [words_vocab, index_words, meta_words]
383
+ } # , index, meta
384
+
385
+
386
+ def embds_input_projection_vocab(vocab, key="token"):
387
+ t0 = time()
388
+
389
+ nb_ids = [] ##ordered ids list
390
+ nb_embds = [] ##ordered embeddings list
391
+ nb_text = [] ##ordered embeddings list
392
+ tnse_error = []
393
+ for _ , token_values in vocab.items():
394
+ tnse_error.append([0,0])
395
+ nb_ids.append(token_values[key]) # for x in vocab_tokens]
396
+ nb_text.append(token_values['text']) # for x in vocab_tokens]
397
+ nb_embds.append(token_values['embed']) # for x in vocab_tokens]
398
+
399
+ X = np.array(nb_embds).astype('float32') #elements to project
400
+ try:
401
+ tsne = TSNE(random_state=0, n_iter=1000)
402
+ tsne_results = tsne.fit_transform(X)
403
+
404
+ tsne_results = np.c_[tsne_results, nb_ids, nb_text, range(len(nb_ids))] ## creates a zip array : [[TNSE[X,Y], tokenid, token_text], ...]
405
+ except:
406
+ tsne_results = np.c_[tnse_error, nb_ids, nb_text, range(len(nb_ids))] ## creates a zip array : [[TNSE[X,Y], tokenid, token_text], ...]
407
+
408
+ t1 = time()
409
+ print("t-SNE: %.2g sec" % (t1 - t0))
410
+ print(tsne_results)
411
+
412
+ return tsne_results.tolist()
413
+
414
+ def filtered_projection(similar_key, vocab, type="input", key="word"):
415
+ global metadata_all
416
+ vocab_proj = vocab.copy()
417
+ ## tnse projection Input words
418
+ source_words_voc_similar = set()
419
+ # for words_set in similar_key:
420
+ for key_i in similar_key:
421
+ words_set = similar_key[key_i]
422
+ source_words_voc_similar.update(words_set['similar_topk'])
423
+
424
+ print(len(source_words_voc_similar))
425
+ # source_embeddings_filtered = {key: metadata_all['input']['words'][0][key] for key in source_words_voc_similar}
426
+ source_embeddings_filtered = {key_value: metadata_all[type][key][0][key_value] for key_value in source_words_voc_similar}
427
+ vocab_proj.update(source_embeddings_filtered)
428
+ ## vocab_proj add
429
+ try:
430
+ result_TSNE = embds_input_projection_vocab(vocab_proj, key=key[:-1]) ## singular => without 's'
431
+ dict_projected_embds_all = {str(embds[2]): [embds[0], embds[1], embds[2], embds[3], embds[4]] for embds in result_TSNE}
432
+ except:
433
+ print('TSNE error', type, key)
434
+ dict_projected_embds_all = {}
435
+
436
+
437
+
438
+ # print(result_TSNE)
439
+ return dict_projected_embds_all
440
+
441
+ def first_function(w1, model):
442
+ global metadata_all
443
+ #translate and get internal values
444
+ # print(w1)
445
+ sentences = w1.split("\n")
446
+ all_sentences = []
447
+ translated_text = ''
448
+ input_embeddings = []
449
+ output_embeddings = []
450
+ for sentence in sentences :
451
+ # print(sentence, end=";")
452
+ params = translation_model(sentence, model)
453
+ all_sentences.append(params)
454
+ # print(len(params))
455
+ translated_text += params[0] + ' \n'
456
+ input_embeddings.append({
457
+ 'embeddings': params[3].detach(), ## create a vocabulary with the set of embeddings
458
+ 'tokens': params[2].tolist(), # one translation = one sentence
459
+ # 'texts' : dict_tokenizer_tr[model].decode(params[2].tolist())
460
+
461
+ })
462
+ output_embeddings.append({
463
+ 'embeddings' : params[4].detach(),
464
+ 'tokens': params[1].sequences.tolist(),
465
+ # 'texts' : dict_tokenizer_tr[model].decode(params[1].sequences.tolist())
466
+ })
467
+ # print(input_embeddings)
468
+ # print(output_embeddings)
469
+
470
+ ## Build FAISS index
471
+ # ---> preload faiss using the respective model with a initial dataset.
472
+ result_input = build_reference(input_embeddings,model)
473
+ result_output = build_reference(output_embeddings,model)
474
+ # print(result_input, result_output)
475
+
476
+ metadata_all = {'input': result_input, 'output': result_output}
477
+
478
+ ### get translation
479
+
480
+ return [translated_text, params]
481
+
482
+ def first_function_tr(w1, model, var2={}):
483
+ global metadata_all
484
+ #Translate and find similar tokens in token
485
+ print("SEARCH -- ")
486
+ sentences = w1.split("\n")
487
+ all_sentences = []
488
+ translated_text = ''
489
+ input_embeddings = []
490
+ output_embeddings = []
491
+ for sentence in sentences :
492
+ # print(sentence, end=";")
493
+ params = translation_model(sentence, model)
494
+ all_sentences.append(params)
495
+ # print(len(params))
496
+ translated_text += params[0] + ' \n'
497
+ input_embeddings.append({
498
+ 'embeddings': params[3].detach(), ## create a vocabulary with the set of embeddings
499
+ 'tokens': params[2].tolist(), # one translation = one sentence
500
+ # 'texts' : dict_tokenizer_tr[model].decode(params[2].tolist()[0])
501
+ })
502
+ output_embeddings.append({
503
+ 'embeddings' : params[4].detach(),
504
+ 'tokens': params[1].sequences.tolist(),
505
+ # 'texts' : dict_tokenizer_tr[model].decode(params[1].sequences.tolist())
506
+ })
507
+
508
+ ## Build FAISS index
509
+ # ---> preload faiss using the respective model with a initial dataset.
510
+ result_search = {}
511
+ result_search['input'] = build_search(input_embeddings, model, type='input')
512
+ result_search['output'] = build_search(output_embeddings, model, type='output')
513
+
514
+ # D, I, meta, vocab_words, sentence_words_list = result_input['words']# [2] # D ; I ; meta
515
+ # md = metadata_all['input']['words'][2]
516
+ # qi_pos = 0
517
+ # similar_words = []
518
+ # for dist, ind in zip(D,I):
519
+ # try:
520
+ # similar_words.append({
521
+ # 'word': meta[qi_pos]['word'],
522
+ # "similar_topk": [md[i_index]['word'] for i_index in ind if (i_index != -1) ],
523
+ # "distance": [D[qi_pos][i] for (i, i_index) in enumerate(ind) if (i_index != -1)],
524
+ # })
525
+ # except:
526
+ # print("\n ERROR ", qi_pos, dist, ind)
527
+ # qi_pos += 1
528
+ # similar_vocab_queries = similar_vocab_queries[3]
529
+
530
+ # result_output = build_search(output_embeddings, model, type="output")
531
+ ## {'tokens': {'D': D, 'I': I, 'meta': meta, 'vocab_queries': vocab_queries, 'similar':similar_tokens},
532
+ ## 'words': {'D':Dw,'I': Iw, 'meta': metaw, 'vocab_queries':words_vocab_queries, 'sentence_key_list': sentence_words_list, 'similar': similar_words}
533
+ ## }
534
+
535
+ # print(result_input, result_output)
536
+
537
+
538
+ # json_out['input']['tokens'] = { 'similar_queries' : result_input['token'][5], # similarity and distance dict.
539
+ # 'tnse': dict_projected_embds_all, #projected points (all)
540
+ # 'key_text_list': result_input['token'][4], # current sentences keys
541
+ # }
542
+
543
+ json_out = {'input': {'tokens': {}, 'words': {}}, 'output': {'tokens': {}, 'words': {}}}
544
+ dict_projected = {}
545
+ for type in ['input', 'output']:
546
+ dict_projected[type] = {}
547
+ for key in ['tokens', 'words']:
548
+ similar_key = result_search[type][key]['similar']
549
+ vocab = result_search[type][key]['vocab_queries']
550
+ dict_projected[type][key] = filtered_projection(similar_key, vocab, type=type, key=key)
551
+ json_out[type][key]['similar_queries'] = similar_key
552
+ json_out[type][key]['tnse'] = dict_projected[type][key]
553
+ json_out[type][key]['key_text_list'] = result_search[type][key]['sentence_key_list']
554
+
555
+ return [translated_text, [ json_out, json_out['output']['words'], json_out['output']['tokens']] ]
556
+
557
+
558
+ from pathlib import Path
559
+ ## First create html and divs
560
+ html = """
561
+ <html>
562
+ <script async src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.6/require.min.js"></script>
563
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/2.0.0/jquery.min"></script>
564
+ <script async data-require="[email protected]" data-semver="3.5.3"
565
+ src="//cdnjs.cloudflare.com/ajax/libs/d3/3.5.3/d3.js"></script>
566
+ <body>
567
+ <div id="select_div">
568
+ <select id="select_type" class="form-select" aria-label="select example" hidden>
569
+ <option selected value="words">Words</option>
570
+ <option value="tokens">Tokens</option>
571
+ </select>
572
+ </div>
573
+ <div id="d3_embed_div">
574
+ <div class="row">
575
+ <div class="col-6">
576
+ <div id="d3_embeds_input_words" class="d3_embed words"></div>
577
+ </div>
578
+ <div class="col-6">
579
+ <div id="d3_embeds_output_words" class="d3_embed words"></div>
580
+
581
+ </div>
582
+ <div class="col-6">
583
+ <div id="d3_embeds_input_tokens" class="d3_embed tokens"></div>
584
+ </div>
585
+ <div class="col-6">
586
+ <div id="d3_embeds_output_tokens" class="d3_embed tokens"></div>
587
+ </div>
588
+ </div>
589
+ </div>
590
+ <div id="d3_graph_div">
591
+ <div class="row">
592
+ <div class="col-4">
593
+ <div id="d3_graph_input_words" class="d3_graph words"></div>
594
+
595
+ </div>
596
+ <div class="col-4">
597
+ <div id="similar_input_words" class=""></div>
598
+ </div>
599
+ <div class="col-4">
600
+ <div id="d3_graph_output_words" class="d3_graph words"></div>
601
+ <div id="similar_output_words" class="d3_graph words"></div>
602
+ </div>
603
+ </div>
604
+ <div class="row">
605
+ <div class="col-6">
606
+ <div id="d3_graph_input_tokens" class="d3_graph tokens"></div>
607
+ <div id="similar_input_tokens" class="d3_graph tokens"></div>
608
+ </div>
609
+ <div class="col-6">
610
+ <div id="d3_graph_output_tokens" class="d3_graph tokens"></div>
611
+ <div id="similar_output_tokens" class="d3_graph tokens"></div>
612
+ </div>
613
+ </div>
614
+ </div>
615
+ </body>
616
+
617
+ </html>
618
+ """
619
+ html0 = """
620
+ <html>
621
+ <script async src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.6/require.min.js"></script>
622
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/2.0.0/jquery.min"></script>
623
+ <script async data-require="[email protected]" data-semver="3.5.3"
624
+ src="//cdnjs.cloudflare.com/ajax/libs/d3/3.5.3/d3.js"></script>
625
+ <body>
626
+ <div id="select_div">
627
+ <select id="select_type" class="form-select" aria-label="select example" hidden>
628
+ <option selected value="words">Words</option>
629
+ <option value="tokens">Tokens</option>
630
+ </select>
631
+ </div>
632
+ </body>
633
+
634
+ </html>
635
+ """
636
+
637
+ html_col1 = """
638
+ <div id="d3_graph_input_words" class="d3_graph words"></div>
639
+ <div id="d3_graph_input_tokens" class="d3_graph tokens"></div>
640
+ """
641
+
642
+ html_col2 = """
643
+ <div id="similar_input_words" class=""></div>
644
+ <div id="similar_output_words" class=""></div>
645
+ <div id="similar_input_tokens" class=" "></div>
646
+ <div id="similar_output_tokens" class=" "></div>
647
+
648
+ """
649
+
650
+
651
+ html_col3 = """
652
+ <div id="d3_graph_output_words" class="d3_graph words"></div>
653
+ <div id="d3_graph_output_tokens" class="d3_graph tokens"></div>
654
+ """
655
+
656
+
657
+ # # <div class="row">
658
+ # <div class="col-6" id="d3_legend_data_source"> </div>
659
+ # <div class="col-6" id="d3_legend_similar_source"> </div>
660
+ # </div>
661
+ def second_function(w1,j2):
662
+ # json_value = {'one':1}# return f"{w1['two']} in sentence22..."
663
+ # to transfer the data to json.
664
+ print("second_function -- after the js", w1,j2)
665
+ return "transition to second js function finished."
666
+
667
+ paths = []
668
+ def save_index(model) :
669
+ names = []
670
+ with open(model + '_metadata_ref.pkl', 'wb') as f:
671
+ pickle.dump(metadata_all, f)
672
+ names.append(model + '_metadata_ref.pkl')
673
+ for type in ['tokens','words']:
674
+ for kind in ['input', 'output']:
675
+ ## save index file
676
+ name = model + "_" + kind + "_"+ type + ".index"
677
+ write_index(metadata_all[kind][type][1], name)
678
+ names.append(name)
679
+ print("in save index done")
680
+ return gr.File(names)
681
+
682
+
683
+ with gr.Blocks(js="plotsjs.js") as demo:
684
+ gr.Markdown(
685
+ """
686
+ # MAKE NMT Workshop \t `Embeddings representation`
687
+ """)
688
+ with gr.Row():
689
+ with gr.Column(scale=1):
690
+ model_radio_c = gr.Radio(choices=['en-es', 'en-zh', 'en-fr'], value="en-es", label= '', container=False)
691
+
692
+ with gr.Column(scale=2):
693
+ gr.Markdown(
694
+ """
695
+ ### Reference Translation Sentences
696
+ Enter at least 50 sentences to be used as comparison.
697
+ This is submitted just once.
698
+ """)
699
+ in_text = gr.Textbox(lines=2, label="reference source text")
700
+ out_text = gr.Textbox(label="reference target text", interactive=False)
701
+ out_text2 = gr.Textbox(visible=False)
702
+ var2 = gr.JSON(visible=False)
703
+ btn = gr.Button("Reference Translation")
704
+ # save_index_btn = gr.Button("Download reference index")
705
+ # file_obj = gr.File(label="Input File")
706
+ # input = file_obj
707
+ save_index_btn = gr.Button("Generate index files to download ",)
708
+ tab2_outputs = gr.File()
709
+ input = tab2_outputs
710
+
711
+ # save_output = gr.Button("Download", link="/file=en-es_input_tokens.index")
712
+
713
+
714
+ with gr.Column(scale=3):
715
+
716
+ gr.Markdown(
717
+ """
718
+ ### Translation Sentences
719
+ Sentences to be analysed.
720
+ """)
721
+ in_text_tr = gr.Textbox(lines=2, label="source text")
722
+ out_text_tr = gr.Textbox(label="target text", interactive=False)
723
+ out_text2_tr = gr.Textbox(visible=False)
724
+ var2_tr = gr.JSON(visible=False)
725
+ btn_faiss= gr.Button("Translation ")
726
+ gr.Button("Download", link="/file=en-es_input_tokens.index")
727
+
728
+ with gr.Row():
729
+ # input_mic = gr.HTML(html)
730
+ with gr.Column(scale=1):
731
+ input_mic = gr.HTML(html0)
732
+ input_html2 = gr.HTML(html_col2)
733
+
734
+ with gr.Column(scale=2):
735
+ input_html1 = gr.HTML(html_col1)
736
+ # with gr.Column(scale=2):
737
+
738
+ with gr.Column(scale=2):
739
+ input_html3 = gr.HTML(html_col3)
740
+
741
+ ## first function input w1, model ; return out_text, var2; it does first function and js;
742
+ btn.click(first_function, [in_text, model_radio_c], [out_text,var2], js="(in_text,model_radio_c) => testFn_out(in_text,model_radio_c)") #should return an output comp.
743
+ btn_faiss.click(first_function_tr, [in_text_tr, model_radio_c], [out_text_tr,var2_tr], js="(in_text_tr,model_radio_c) => testFn_out(in_text_tr,model_radio_c)") #should return an output comp.
744
+ ## second function input out_text(returned in first_function), [json]var2(returned in first_function) ;
745
+ ## second function returns out_text2, var2; it does second function and js(with the input params);
746
+ out_text.change(second_function, [out_text, var2], out_text2, js="(out_text,var2) => testFn_out_json(var2)") #
747
+ out_text_tr.change(second_function, [out_text_tr, var2_tr], out_text2_tr, js="(out_text_tr,var2_tr) => testFn_out_json_tr(var2_tr)") #
748
+ save_index_btn.click(save_index, [model_radio_c], [tab2_outputs])
749
+
750
+ # tab2_submit_button.click(func2,
751
+ # inputs=tab2_inputs,
752
+ # outputs=tab2_outputs)
753
+
754
+ # run script function on load,
755
+ # demo.load(None,None,None,js="plotsjs.js")
756
+ # allowed_paths
757
+ if __name__ == "__main__":
758
+ demo.launch(allowed_paths=["./", ".", "/"])
plotsjs.js ADDED
@@ -0,0 +1,744 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ async () => {
3
+ // set testFn() function on globalThis, so you html onlclick can access it
4
+
5
+
6
+ globalThis.testFn = () => {
7
+ document.getElementById('demo').innerHTML = "Hello?"
8
+ };
9
+
10
+ const d3 = await import("https://cdn.jsdelivr.net/npm/d3@7/+esm");
11
+ // const d3 = await import("https://cdn.jsdelivr.net/npm/d3@5/+esm");
12
+ const $ = await import("https://cdn.jsdelivr.net/npm/[email protected]/dist/jquery.min.js");
13
+
14
+ globalThis.$ = $;
15
+ globalThis.d3 = d3;
16
+
17
+ globalThis.d3Fn = () => {
18
+ d3.select('#viz').append('svg')
19
+ .append('rect')
20
+ .attr('width', 50)
21
+ .attr('height', 50)
22
+ .attr('fill', 'black')
23
+ .on('mouseover', function(){d3.select(this).attr('fill', 'red')})
24
+ .on('mouseout', function(){d3.select(this).attr('fill', 'black')});
25
+
26
+ };
27
+
28
+ globalThis.testFn_out = (val,model_radio_c) => {
29
+ // document.getElementById('demo').innerHTML = val
30
+ console.log(val, "testFn_out");
31
+ // globalThis.d3Fn();
32
+ return([val,model_radio_c]);
33
+ };
34
+
35
+
36
+ globalThis.testFn_out_json = (data) => {
37
+ console.log(data, "testFn_out_json --");
38
+ // var $ = jQuery;
39
+ // console.log( d3.select('#d3_embeddings'));
40
+ return(['string', {}])
41
+ }
42
+
43
+ globalThis.testFn_out_json_tr = (data) => {
44
+ // data['input|output']['words|tokens']
45
+
46
+ console.log(data, "testFn_out_json_tr new");
47
+ var $ = jQuery;
48
+ console.log("$('#d3_embeddings')");
49
+ console.log($('#d3_embeddings'));
50
+ // d3.select('#d3_embeddings').html("");
51
+
52
+
53
+ d3.select("#d3_embeds_source").html("here");
54
+
55
+ // words or token visualization ?
56
+ console.log(d3.select("#select_type").node().value);
57
+ d3.select("#select_type").attr("hidden", null);
58
+ d3.select("#select_type").on("change", change);
59
+ change();
60
+ // tokens
61
+ // network plots;
62
+ ['input', 'output'].forEach(text_type => {
63
+ ['tokens', 'words'].forEach(text_key => {
64
+ // console.log(type, key, data[0][text_type]);
65
+ data_i = data[0][text_type][text_key];
66
+ embeddings_network([], data_i['tnse'], data_i['similar_queries'], type=text_type +"_"+text_key, )
67
+ });
68
+ });
69
+
70
+
71
+
72
+
73
+
74
+ // data_proj = data['tsne']; // it is not a dict.
75
+ // d3.select("#d3_embeds_" + type).html(scatterPlot(data_proj, data_sentences, dict_token_sentence_id, similar_vocab_queries, 'd3_embeds_'+type, type ));
76
+ // d3.select('#d3_embeddings').append(function(){return Tree(root);});
77
+ // embeddings_network(data['source_tokens'], data['dict_projected_embds_all']['source'], data['similar_vocab_queries']['source'], "source")
78
+
79
+ // source
80
+ // embeddings_graph(data['dict_projected_embds_all'],source_tks_list, data['source_tokens'], data['similar_vocab_queries'], "source"); //, data['similar_text'], data['similar_embds']);
81
+ // target decision: all tokens ? or separeted by language ? hint: do not assume they share the same dict.
82
+ // embeddings_graph(data['dict_projected_embds_all'], translated_tks_text, translated_tks_ids_by_sent, data['similar_vocab_queries'], "target"); //, data['similar_text'], data['similar_embds']);
83
+
84
+ return(['string', {}])
85
+
86
+ }
87
+
88
+ function change() {
89
+ show_type = d3.select("#select_type").node().value;
90
+ // hide all
91
+ d3.selectAll(".d3_embed").attr("hidden",'');
92
+ d3.selectAll(".d3_graph").attr("hidden", '');
93
+ // show current type;
94
+ d3.select("#d3_embeds_input_" + show_type).attr("hidden", null);
95
+ d3.select("#d3_embeds_output_" + show_type).attr("hidden", null);
96
+ d3.select("#d3_graph_input_" + show_type).attr("hidden", null);
97
+ d3.select("#d3_graph_output_" + show_type).attr("hidden", null);
98
+ }
99
+
100
+
101
+
102
+ function embeddings_network(tokens_text, dict_projected_embds, similar_vocab_queries, type="source", ){
103
+ // tokens_text : not used;
104
+ // dict_projected_embds = tnse
105
+ console.log("Each token is a node; distance if in similar list", type );
106
+ console.log(tokens_text, dict_projected_embds, similar_vocab_queries);
107
+ // similar_vocab_queries_target[key]['similar_topk']
108
+
109
+ var nodes_tokens = {}
110
+ var nodeHash = {};
111
+ var nodes = []; // [{id: , label: }]
112
+ var edges = []; // [{source: , target: weight: }]
113
+ var edges_ids = []; // [{source: , target: weight: }]
114
+
115
+ // similar_vocab_queries {key: {similar_topk : [], distance : []}}
116
+ console.log('similar_vocab_queries', similar_vocab_queries);
117
+ prev_node = '';
118
+ for ([sent_token, value] of Object.entries(similar_vocab_queries)) {
119
+ // console.log('dict_projected_embds',sent_token, parseInt(sent_token), value, dict_projected_embds);
120
+ // sent_token = parseInt(sent_token); // Object.entries assumes key:string;
121
+ token_text = dict_projected_embds[sent_token][3]
122
+ if (!nodeHash[sent_token]) {
123
+ nodeHash[sent_token] = {id: sent_token, label: token_text, type: 'sentence', type_i: 0};
124
+ nodes.push(nodeHash[sent_token]);
125
+ }
126
+ sim_tokens = value['similar_topk']
127
+ dist_tokens = value['distance']
128
+
129
+ for (let index = 0; index < sim_tokens.length; index++) {
130
+ const sim = sim_tokens[index];
131
+ const dist = dist_tokens[index];
132
+
133
+ token_text_sim = dict_projected_embds[sim][3]
134
+ if (!nodeHash[sim]) {
135
+ nodeHash[sim] = {id: sim, label: token_text_sim, type:'similar', type_i: 1};
136
+ nodes.push(nodeHash[sim]);
137
+ }
138
+ edges.push({source: nodeHash[sent_token], target: nodeHash[sim], weight: dist});
139
+ edges_ids.push({source: sent_token, target: sim, weight: dist});
140
+ }
141
+
142
+ if (prev_node != '' ) {
143
+ edges.push({source: nodeHash[prev_node], target:nodeHash[sent_token], weight: 1});
144
+ edges_ids.push({source: prev_node, target: sent_token, weight: 1});
145
+ }
146
+ prev_node = sent_token;
147
+
148
+ }
149
+ console.log("TYPE", type, edges, nodes, edges_ids, similar_vocab_queries)
150
+ // d3.select('#d3_graph_input_tokens').html(networkPlot({nodes: nodes, links:edges}, similar_vocab_queries, div_type=type) );
151
+ // type +"_"+key
152
+ d3.select('#d3_graph_'+type).html("");
153
+ d3.select('#d3_graph_'+type).append(function(){return networkPlot({nodes: nodes, links:edges}, similar_vocab_queries, dict_projected_embds,div_type=type);});
154
+
155
+ // $('#d3_embeds_network_target').html(networkPlot({nodes: nodes, links:edges}));
156
+ // $('#d3_embeds_network_'+type).html(etworkPlot({nodes: nodes, link:edges}));
157
+ }
158
+
159
+ function embeddings_graph(data, source_tokens_text_list, source_tokens, similar_vocab_queries, type="source") {
160
+ /*
161
+ ### source
162
+ data: dict_projected_embds_all = { token_id: [tns1, tns2, token_id, token_text] ...}
163
+ ### target
164
+ */
165
+ console.log("embeddings_graph");
166
+ active_sentences = get_sentences();
167
+ console.log("active_sentences", active_sentences, type); // working
168
+
169
+ active_sentences_tokens_text = active_sentences.map((x) => source_tokens_text_list[x]);
170
+ active_sentences_tokens = active_sentences.map((x) => source_tokens[x]);
171
+
172
+ console.log(active_sentences_tokens);
173
+
174
+ data_sentences = []
175
+ dict_token_sentence_id = {}
176
+ // active_sentences_tokens.forEach((sentence, i) => {
177
+ source_tokens_text_list.forEach((sentence, i) => {
178
+ /// opt1
179
+ proj = []
180
+ sentence.forEach((tok, tok_j) => {
181
+ console.log("tok,tok_j", tok, tok_j);
182
+ token_text = source_tokens_text_list[i][tok_j];
183
+ proj.push([data[tok][0], data[tok][1], token_text, i, tok_j, tok])
184
+ if (token_text in dict_token_sentence_id){
185
+ dict_token_sentence_id[token_text].push(i);
186
+ }
187
+ else{
188
+ dict_token_sentence_id[token_text] = [i];
189
+ }
190
+ });
191
+ data_sentences.push(proj);
192
+ });
193
+ console.log("data_sentences error here in target", data_sentences);
194
+
195
+ console.log(data);
196
+
197
+ $('#d3_embeds_' + type).html(scatterPlot(data, data_sentences, dict_token_sentence_id, similar_vocab_queries, 'd3_embeds_'+type, type ));
198
+ }
199
+
200
+
201
+ /*
202
+ data: dict_projected_embds_all = { token_id: [tns1, tns2, token_id, token_text] ...}
203
+ */
204
+ function scatterPlot(data, data_sentences, dict_token_sentence_id, similar_vocab_queries, div_name, div_type="source", {
205
+ width = 400, // outer width, in pixels
206
+ height , // outer height, in pixels
207
+ r = 3, // radius of nodes
208
+ padding = 1, // horizontal padding for first and last column
209
+ // text = d => d[2],
210
+ } = {}){
211
+ // data_dict = data[div_type];
212
+ var data_dict = { ...data[div_type] };
213
+ data = Object.values(data[div_type]);
214
+ // similar_vocab_queries = similar_vocab_queries[div_type];
215
+ var similar_vocab_queries = { ...similar_vocab_queries[div_type] };
216
+ console.log("div_type, data, data_dict, data_sentences, dict_token_sentence_id, similar_vocab_queries");
217
+ console.log(div_type, data, data_dict, data_sentences, dict_token_sentence_id, similar_vocab_queries);
218
+
219
+ // Create the SVG container.
220
+ var margin = {top: 10, right: 10, bottom: 30, left: 50 },
221
+ width = width - margin.left - margin.right,
222
+ height = 400 - margin.top - margin.bottom;
223
+
224
+ // append the svg object to the body of the page
225
+ var svg = d3.create("svg")
226
+ // .attr("style", "max-width: 100%; height: auto; height: intrinsic;")
227
+ .attr("width", width + margin.left + margin.right)
228
+ .attr("height", height + margin.top + margin.bottom)
229
+
230
+ svg.append("g")
231
+ .attr("transform",
232
+ "translate(" + margin.left + "," + margin.top + ")");
233
+
234
+ // const svg = d3.create("svg")
235
+ // .attr("width", width)
236
+ // .attr("height", height);
237
+
238
+ // Add X axis
239
+ min_value_x = d3.min(data, d => d[0])
240
+ max_value_x = d3.max(data, d => d[0])
241
+
242
+
243
+ var x = d3.scaleLinear()
244
+ .domain([min_value_x, max_value_x])
245
+ .range([ margin.left , width ]);
246
+
247
+ svg.append("g")
248
+ // .attr("transform", "translate("+ margin.left +"," + height + ")")
249
+ .attr("transform", "translate(0," + height + ")")
250
+ .call(d3.axisBottom(x));
251
+
252
+ // Add Y axis
253
+ min_value_y = d3.min(data, d => d[1])
254
+ max_value_y = d3.max(data, d => d[1])
255
+
256
+ var y = d3.scaleLinear()
257
+ .domain([min_value_y, max_value_y])
258
+ .range([ height, margin.top]);
259
+
260
+ svg.append("g")
261
+ .attr("transform", "translate("+ margin.left +", 0)")
262
+ .call(d3.axisLeft(y));
263
+
264
+ svg.selectAll()
265
+ .data(data)
266
+ .enter()
267
+ .append('circle')
268
+ .attr("class", function (d) { return "dot-" + d[2] } )
269
+ // .attr("cx", function (d) { return x(d[0] + margin.left); } )
270
+ .attr("cx", function (d) { return x(d[0]); } )
271
+ .attr("cy", function (d) { return y(d[1] - margin.bottom); } )
272
+ .attr("r", 5)
273
+ .style("fill", "#e85252")
274
+ .style("fillOpacity",0.2)
275
+ .style("stroke", "#000000ff")
276
+ .style("strokeWidth", 1)
277
+ .style("opacity", 0.7);
278
+
279
+ // svg.selectAll()
280
+ // .data(data)
281
+ // .enter()
282
+ // .append('text')
283
+ // .text(d => d[3])
284
+ // .attr("class", function (d) { return "text-" + d[2] } )
285
+ // // .attr("cx", function (d) { return x(d[0] + margin.left); } )
286
+ // .attr("x", function (d) { return x(d[0]); } )
287
+ // .attr("y", function (d) { return y(d[1] - margin.bottom); } )
288
+ // .attr("dy", "0.35em");
289
+
290
+ // colors = ['#cb1dd1',"#e0ac2b", "#e85252", "#6689c6", "#9a6fb0", "#a53253"];
291
+ colors = ['#6689c6',"#e0ac2b", "#e0ac2b", "#cb1dd1", "#cb1dd1", "#cb1dd1"];
292
+
293
+ // create a tooltip
294
+ var Tooltip = d3.select("#"+div_name)
295
+ .append("div")
296
+ .style("opacity", 0)
297
+ .attr("class", "tooltip")
298
+ .style("background-color", "white")
299
+ .style("border", "solid")
300
+ .style("border-width", "2px")
301
+ .style("border-radius", "5px")
302
+ .style("padding", "5px")
303
+ .text("I'm a circle!");
304
+
305
+ // const colorScale = d3.scaleOrdinal()
306
+ // .domain(domain_values)
307
+ // .range(["#e0ac2b", "#e85252", "#6689c6", "#9a6fb0", "#a53253"]);
308
+ // colorScale(d.group)
309
+
310
+ for (let i_snt = 0; i_snt < data_sentences.length; i_snt++) {
311
+ const sentence = data_sentences[i_snt];
312
+ // similar_tokens;
313
+ console.log("sentence: ", sentence);
314
+
315
+ svg.selectAll()
316
+ .data(sentence)
317
+ .enter()
318
+ .append('text')
319
+ .text(d => d[2])
320
+ .attr("class", function (d) { return "text-" + d[2] + " sent-" + i_snt } )
321
+ // .attr("cx", function (d) { return x(d[0] + margin.left); } )
322
+ .attr("x", function (d) { return x(d[0]); } )
323
+ .attr("y", function (d) { return y(d[1] - margin.bottom); } )
324
+ .attr("dy", "0.35em")
325
+ .attr("sentence_i", i_snt );
326
+
327
+ svg.selectAll()
328
+ .data(sentence)
329
+ .enter()
330
+ .append('circle')
331
+ .attr("class", function (d) { return "dot " + d[2] + " " + i_snt } )
332
+ // .attr("cx", function (d) { return x(d[0] + margin.left); } )
333
+ .attr("cx", function (d) { return x(d[0]); } )
334
+ .attr("cy", function (d) { return y(d[1] - margin.bottom); } )
335
+ .attr("sentence_i", i_snt )
336
+ .attr("r", 6)
337
+ .style("fill", colors[0])
338
+ .style("fillOpacity",0.2)
339
+ .style("stroke", "#000000")
340
+ .style("strokeWidth", 1)
341
+ .style("opacity", 1)
342
+ .on('click', change_legend )
343
+ .on('mouseover', highlight_mouseover )
344
+ .on('mouseout', highlight_mouseout )
345
+ // .on("mousemove", mousemove);
346
+
347
+
348
+ }
349
+
350
+
351
+ function change_legend(d,i) {
352
+ console.log(d,i);
353
+ if (i[2] in dict_token_sentence_id){
354
+ show_sentences(dict_token_sentence_id[i[2]], i[2]);
355
+
356
+ show_similar_tokens(i[5], '#d3_legend_similar_'+type);
357
+
358
+ console.log(dict_token_sentence_id[i[2]]);
359
+ }
360
+ else{console.log("no sentence")};
361
+ }
362
+
363
+ function highlight_mouseover(d,i) {
364
+ console.log("highlight_mouseover", d,i);
365
+ // token_id = parseInt(i[5])
366
+ similar_ids = similar_vocab_queries[token_id]['similar_topk'];
367
+ d3.select(this).transition()
368
+ .duration('50')
369
+ .style('opacity', '1')
370
+ .attr("r", 12)
371
+
372
+ similar_ids.forEach(similar_token => {
373
+ d3.selectAll('.dot-' + similar_token).attr("r",12 ).style('opacity', '1')//.raise()
374
+ });
375
+
376
+ Tooltip
377
+ .style("opacity", 1)
378
+ .style("visibility", "visible")
379
+ // .style("top", (event.pageY-height)+"px").style("left",(event.pageX-width)+"px")
380
+ d3.select(this)
381
+ .style("stroke", "red")
382
+ .attr("strokeWidth", 2)
383
+ .style("opacity", 0.7)
384
+
385
+ // .html("The exact value of<br>this cell is: ")
386
+ // .style("left", (d3.mouse(this)[0]+70) + "px")
387
+ // .style("top", (d3.mouse(this)[1]) + "px")
388
+
389
+ }
390
+ function highlight_mouseout(d,i) {
391
+ // token_id = parseInt(i[5])
392
+ console.log("similar_vocab_queries", similar_vocab_queries);
393
+ similar_ids = similar_vocab_queries[token_id]['similar_topk'];
394
+ // clean_sentences();
395
+ d3.select(this).transition()
396
+ .duration('50')
397
+ .style('opacity', '.7')
398
+ .attr("r", 6)
399
+
400
+ similar_ids.forEach(similar_token => {
401
+ d3.selectAll('.dot-' + similar_token).attr("r",6 ).style('opacity', '.7')
402
+ });
403
+
404
+ Tooltip
405
+ .style("opacity", 0)
406
+ d3.select(this)
407
+ .style("stroke", "none")
408
+ .style("opacity", 0.8)
409
+ }
410
+
411
+ function mousemove(d,i) {
412
+ console.log("mousemove", d, i)
413
+ pointer = d3.pointer(d);
414
+ Tooltip
415
+ .html("The exact value of<br> ")
416
+ // .style("top", ((e.pageY ) - (height*2)) +"px")
417
+ // .attr("transform", `translate(${pointer[0]},0)`)
418
+ .style("top", height - pointer[1] +"px")
419
+ .style("left", pointer[0]+"px")
420
+ }
421
+
422
+
423
+ function show_sentences(sentences_id, token) {
424
+
425
+ // Show sentences with token "token"
426
+ d3.select('#d3_legend_data_'+div_type).html("");
427
+ console.log("show_sentences", data_sentences, sentences_id);
428
+ sentences_id.forEach(sent_id => {
429
+ console.log(data_sentences[sent_id])
430
+ // console.log(data_sentences[sent_id].map( x => x[2] ));
431
+ // p = d3.select('#d3_legend_data').append("p").enter();
432
+ d3.select('#d3_legend_data_'+div_type)
433
+ .selectAll().append("p")
434
+ .data(data_sentences[sent_id])
435
+ .enter()
436
+ .append('text')
437
+ .attr('class_data', sent_id)
438
+ .attr('class_id', d => d[5])
439
+ .style("background", d=> {if (d[2]== token) return "yellow"} )
440
+ .text( d => d[2] + " ");
441
+ d3.select('#d3_legend_data_'+div_type).append("p").enter();
442
+ });
443
+ // $("#d3_legend_data")
444
+ // data_sentences
445
+ }
446
+
447
+ function clean_sentences() {
448
+ d3.select('#d3_legend_data_'+div_type).html("");
449
+ }
450
+
451
+ function show_similar_tokens(token, div_name_similar= '#d3_legend_similar_') {
452
+ d3.select(div_name_similar).html("");
453
+ console.log("token", token);
454
+ console.log("similar_vocab_queries[token]", similar_vocab_queries[token]);
455
+ token_data = similar_vocab_queries[token];
456
+ console.log(token, token_data);
457
+ var decForm = d3.format(".3f");
458
+
459
+ d3.select(div_name_similar)
460
+ .selectAll().append("p")
461
+ .data(token_data['similar_topk'])
462
+ .enter()
463
+ .append("p").append('text')
464
+ // .attr('class_data', sent_id)
465
+ .attr('class_id', d => d)
466
+ .style("background", d=> {if (d == token) return "yellow"} )
467
+ // .text( d => d + " \n ");
468
+ .text((d,i) => do_text(d,i) );
469
+
470
+ function do_text(d,i){
471
+ console.log("do_text d,i" );
472
+ console.log(d,i);
473
+ console.log("data_dict[d], data_dict");
474
+ // console.log(data_dict[d], data_dict);
475
+ // return data_dict[d][3] + " " + decForm(token_data['distance'][i]) + " ";
476
+ return " " + decForm(token_data['distance'][i]) + " ";
477
+ }
478
+
479
+
480
+ }
481
+ // data_sentences
482
+
483
+ // .attr('x', (d) => x_scale(d[0]) + margin.left)
484
+ // .attr('y', (d) => y_scale(d[1]) + margin_top_extra)
485
+ // .attr("rx", 4)
486
+ // .attr("ry", 4)
487
+ // .attr("stroke", "#F7F7F7")
488
+ // .attr("stroke-width","2px")
489
+ // .attr('width', x_scale.bandwidth())
490
+ // .attr('height', (d) => height_text);
491
+ // // .attr('fill', (d) => color_scale(d.value));
492
+
493
+ // Add dots
494
+ // svg.append('g')
495
+ // // .selectAll("dot")
496
+ // .data(data)
497
+ // .enter()
498
+ // .append("circle")
499
+ // .attr("class", function (d) { return "dot " + d[2] } )
500
+ // .attr("cx", function (d) { return x(d[0]); } )
501
+ // .attr("cy", function (d) { return y(d[1]); } )
502
+ // .attr("r", 5)
503
+ // .style("fill", function (d) { return color(d.Species) } )
504
+ // .on("mouseover", highlight)
505
+ // .on("mouseleave", doNotHighlight )
506
+
507
+
508
+
509
+ return svg.node();
510
+ }
511
+
512
+
513
+
514
+ function networkPlot(data, similar_vocab_queries,dict_proj, div_type="source", {
515
+ width = 400, // outer width, in pixels
516
+ height , // outer height, in pixels
517
+ r = 3, // radius of nodes
518
+ padding = 1, // horizontal padding for first and last column
519
+ // text = d => d[2],
520
+ } = {}){
521
+ // data_dict = data;
522
+ data = data// [div_type];
523
+ similar_vocab_queries = similar_vocab_queries// [div_type];
524
+ console.log("data, similar_vocab_queries, div_type");
525
+ console.log(data, similar_vocab_queries, div_type);
526
+
527
+ // Create the SVG container.
528
+ var margin = {top: 10, right: 10, bottom: 30, left: 50 },
529
+ width = width //- margin.left - margin.right,
530
+ height = 400 //- margin.top - margin.bottom;
531
+
532
+ width_box = width + margin.left + margin.right;
533
+ height_box = height + margin.top + margin.bottom
534
+ totalWidth = width*2;
535
+ // append the svg object to the body of the page
536
+ // const parent = d3.create("div");
537
+ // const body = parent.append("div")
538
+ // .style("overflow-x", "scroll")
539
+ // .style("-webkit-overflow-scrolling", "touch");
540
+
541
+
542
+ var svg = d3.create("svg")
543
+ // var svg = body.create("svg")
544
+ // .style("display", "block")
545
+ // .attr("style", "max-width: 100%; height: auto; height: intrinsic;")
546
+ .attr("width", width + margin.left + margin.right)
547
+ .attr("height", height + margin.top + margin.bottom)
548
+ // .attr("viewBox", [-width_box / 2, -height_box / 2, width_box, height_box])
549
+ // .attr("viewBox", [0, 0, width, height]);
550
+ // .attr("style", "max-width: 100%; height: auto;");
551
+
552
+ // svg.append("g")
553
+ // .attr("transform",
554
+ // "translate(" + margin.left + "," + margin.top + ")");
555
+
556
+
557
+
558
+ // Initialize the links
559
+ var link = svg
560
+ .selectAll("line")
561
+ .data(data.links)
562
+ .enter()
563
+ .append("line")
564
+ .style("fill", d => d.weight == 1 ? "#dfd5d5" : "#000000") // , "#69b3a2" : "#69b3a2")
565
+ .style("stroke", "#aaa")
566
+
567
+
568
+
569
+ var text = svg
570
+ .selectAll("text")
571
+ .data(data.nodes)
572
+ .enter()
573
+ .append("text")
574
+ .style("text-anchor", "middle")
575
+ .attr("y", 15)
576
+ .attr("class", d => 'text_token-'+ dict_proj[d.id][4] + div_type)
577
+ .attr("div-type", div_type)
578
+ // .attr("class", d => 'text_token-'+ d.index)
579
+ .text(function (d) {return d.label} )
580
+ // .on('mouseover', function(d) { (d.type_i == 0) ? highlight_mouseover_text : console.log(0)} )
581
+ // .on('mouseover', function(d) { (d.type_i == 0) ? highlight_mouseout_text : '' } )
582
+ // .on('mouseout', highlight_mouseout_text )
583
+ // .join('text')
584
+ // .text(function(d) {
585
+ // return d.id
586
+ // })
587
+
588
+ // Initialize the nodes
589
+ var node = svg
590
+ .selectAll("circle")
591
+ .data(data.nodes)
592
+ .enter()
593
+ .append("circle")
594
+ .attr("r", 6)
595
+ // .attr("class", d => 'node_token-'+ d.id)
596
+ .attr("class", d => 'node_token-'+ dict_proj[d.id][4] + div_type)
597
+ .attr("div-type", div_type)
598
+ .style("fill", d => d.type_i ? "#e85252" : "#6689c6") // , "#69b3a2" : "#69b3a2")
599
+ .on('mouseover', highlight_mouseover )
600
+ // .on('mouseover', function(d) { return (d.type_i == 0) ? highlight_mouseover : console.log(0)} )
601
+ .on('mouseout',highlight_mouseout )
602
+ .on('click', change_legend )
603
+ // .on('click', show_similar_tokens )
604
+
605
+
606
+
607
+ // Let's list the force we wanna apply on the network
608
+ var simulation = d3.forceSimulation(data.nodes) // Force algorithm is applied to data.nodes
609
+ .force("link", d3.forceLink() // This force provides links between nodes
610
+ .id(function(d) { return d.id; }) // This provide the id of a node
611
+ .links(data.links) // and this the list of links
612
+ )
613
+ .force("charge", d3.forceManyBody(-400)) // This adds repulsion between nodes. Play with the -400 for the repulsion strength
614
+ .force("center", d3.forceCenter(width / 2, height / 2)) // This force attracts nodes to the center of the svg area
615
+ // .force("collision", d3.forceCollide())
616
+ .on("end", ticked);
617
+
618
+ // This function is run at each iteration of the force algorithm, updating the nodes position.
619
+ function ticked() {
620
+ link
621
+ .attr("x1", function(d) { return d.source.x; })
622
+ .attr("y1", function(d) { return d.source.y; })
623
+ .attr("x2", function(d) { return d.target.x; })
624
+ .attr("y2", function(d) { return d.target.y; });
625
+
626
+ node
627
+ .attr("cx", function (d) { return d.x+3; })
628
+ .attr("cy", function(d) { return d.y-3; });
629
+
630
+ text
631
+ .attr("transform", function(d) { return "translate(" + d.x + "," + d.y + ")"; })
632
+ }
633
+
634
+ function highlight_mouseover(d,i) {
635
+ console.log("highlight_mouseover", d,i, d3.select(this).attr("div-type"));
636
+ if (i.type_i == 0 ){
637
+ token_id = i.id
638
+ similar_ids = similar_vocab_queries[token_id]['similar_topk'];
639
+ d3.select(this).transition()
640
+ .duration('50')
641
+ .style('opacity', '1')
642
+ .attr("r", 12)
643
+ type = d3.select(this).attr("div-type")
644
+ similar_ids.forEach(similar_token => {
645
+ node_id_name = dict_proj[similar_token][4]
646
+ d3.selectAll('.node_token-'+ node_id_name + type).attr("r",12 ).style('opacity', '1')//.raise()
647
+ // d3.selectAll('.text_token-'+ node_id_name).raise()
648
+ });
649
+ }
650
+ }
651
+
652
+
653
+ function highlight_mouseout(d,i) {
654
+ if (i.type_i == 0 ){
655
+ token_id = i.id
656
+ console.log("similar_vocab_queries", similar_vocab_queries, "this type:", d3.select(this).attr("div-type"));
657
+ similar_ids = similar_vocab_queries[token_id]['similar_topk'];
658
+ // clean_sentences();
659
+ d3.select(this).transition()
660
+ .duration('50')
661
+ .style('opacity', '.7')
662
+ .attr("r", 6)
663
+ type = d3.select(this).attr("div-type")
664
+ similar_ids.forEach(similar_token => {
665
+ node_id_name = dict_proj[similar_token][4]
666
+ d3.selectAll('.node_token-' + node_id_name + type).attr("r",6 ).style('opacity', '.7')
667
+ d3.selectAll("circle").raise()
668
+ });
669
+ }
670
+ }
671
+
672
+ function change_legend(d,i,j) {
673
+ console.log(d,i,dict_proj);
674
+ if (i['id'] in dict_proj){
675
+ // show_sentences(dict_proj[i[2]], i[2]);
676
+
677
+ show_similar_tokens(i['id'], '#similar_'+type);
678
+
679
+ console.log(dict_proj[i['id']]);
680
+ }
681
+ else{console.log("no sentence")};
682
+ }
683
+
684
+ function show_similar_tokens(token, div_name_similar='#similar_input_tokens') {
685
+ d3.select(div_name_similar).html("");
686
+ console.log("token", token);
687
+ console.log("similar_vocab_queries[token]", similar_vocab_queries[token]);
688
+ token_data = similar_vocab_queries[token];
689
+ console.log(token, token_data);
690
+ var decForm = d3.format(".3f");
691
+
692
+ d3.select(div_name_similar)
693
+ .selectAll().append("p")
694
+ .data(token_data['similar_topk'])
695
+ .enter()
696
+ .append("p").append('text')
697
+ // .attr('class_data', sent_id)
698
+ .attr('class_id', d => d)
699
+ .style("background", d=> {if (d == token) return "yellow"} )
700
+ // .text( d => d + " \n ");
701
+ .text((d,i) => do_text(d,i) );
702
+
703
+ function do_text(d,i){
704
+ console.log("do_text d,i" );
705
+ console.log(d,i);
706
+ console.log("data_dict[d], data_dict");
707
+ console.log(dict_proj[d], dict_proj);
708
+ return dict_proj[d][3] + " " + decForm(token_data['distance'][i]) + " ";
709
+ }
710
+
711
+
712
+ }
713
+
714
+ // svg.call(d3.zoom()
715
+ // .extent([[0, 0], [width, height]])
716
+ // .scaleExtent([1, 8])
717
+ // .on("zoom", zoomed));
718
+
719
+ // function zoomed({transform}) {
720
+ // circle.attr("transform", d => `translate(${transform.apply(d)})`);
721
+ // }
722
+
723
+ // svg.call(
724
+ // d3.zoom().on("zoom", (event) => {
725
+ // g.attr("transform", event.transform);
726
+ // })
727
+ // );
728
+ // body.node().scrollBy(totalWidth, 0);
729
+
730
+
731
+ return svg.node();
732
+ // return parent.node();
733
+
734
+ };
735
+
736
+
737
+
738
+
739
+
740
+
741
+
742
+
743
+
744
+ }
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ inseq
2
+ bertviz
3
+ jupyter
4
+ scikit-learn
5
+ faiss-cpu