File size: 8,322 Bytes
6eff5e7
 
580aef7
6eff5e7
 
2fad322
6eff5e7
 
 
 
963bf46
6eff5e7
 
 
 
 
963bf46
6eff5e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
963bf46
 
 
 
6eff5e7
 
 
 
580aef7
 
6eff5e7
 
 
 
580aef7
6eff5e7
 
 
 
 
 
a6756ef
963bf46
 
 
 
580aef7
 
a6756ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
091bb76
a6756ef
091bb76
 
580aef7
 
963bf46
 
 
6eff5e7
 
 
 
 
 
 
 
 
580aef7
6eff5e7
 
580aef7
6eff5e7
580aef7
6eff5e7
 
 
 
 
 
091bb76
580aef7
6eff5e7
580aef7
 
091bb76
580aef7
 
 
 
6eff5e7
 
 
 
 
 
 
 
 
580aef7
963bf46
 
 
6eff5e7
 
580aef7
 
6eff5e7
 
 
580aef7
 
6eff5e7
091bb76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6eff5e7
963bf46
 
6eff5e7
963bf46
6eff5e7
 
 
 
 
 
 
 
 
 
 
963bf46
2fad322
963bf46
6eff5e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
963bf46
6eff5e7
 
 
 
0532283
 
 
6eff5e7
0532283
6eff5e7
 
 
 
 
 
091bb76
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
from sentence_transformers import util
from nltk.tokenize import sent_tokenize
from nltk import word_tokenize, pos_tag
import torch
import numpy as np
import tqdm

def compute_sentencewise_scores(model, query_sents, candidate_sents):
    # list of sentences from query and candidate
    q_v, c_v = get_embedding(model, query_sents, candidate_sents)
    
    return util.cos_sim(q_v, c_v)
    
def get_embedding(model, query_sents, candidate_sents):
    q_v = model.encode(query_sents)
    c_v = model.encode(candidate_sents)
   
    return q_v, c_v

def get_top_k(score_mat, K=3):
    """
    Pick top K sentences to show
    """
    idx = torch.argsort(-score_mat)
    picked_sent = idx[:,:K]
    picked_scores = torch.vstack(
        [score_mat[i,picked_sent[i]] for i in range(picked_sent.shape[0])]
    )
    
    return picked_sent, picked_scores

def get_words(sent):
    """
    Input: list of sentences
    Output: list of list of words per sentence, all words in, index of starting words for each sentence
    """
    words = []
    sent_start_id = [] # keep track of the word index where the new sentence starts
    counter = 0
    for x in sent:
        #w = x.split()
        w = word_tokenize(x)
        nw = len(w)
        counter += nw
        words.append(w)
        sent_start_id.append(counter)
    words = [word_tokenize(x) for x in sent]
    all_words = [item for sublist in words for item in sublist]
    sent_start_id.pop()
    sent_start_id = [0] + sent_start_id
    assert(len(sent_start_id) == len(sent))
    return words, all_words, sent_start_id

def get_match_phrase(w1, w2, method='pos'):
    """
    Input: list of words for query and candidate text
    Output: word list and binary mask of matching phrases between the inputs
    """
    mask1 = np.zeros(len(w1))
    mask2 = np.zeros(len(w2))
    if method == 'pos':
        # POS tags that should be considered for matching phrase
        include = [
            'NN',
            'NNS',
            'NNP',
            'NNPS',
            'LS',
            'SYM',
            'FW'
        ]
        pos1 = pos_tag(w1)
        pos2 = pos_tag(w2)
        for i, (w, p) in enumerate(pos2):
            if w.lower() in w1 and p in include:
                j = w1.index(w.lower())
                mask2[i] = 1
                mask1[j] = 1
    return mask1, mask2

def mark_words(query_sents, words, all_words, sent_start_id, sent_ids, sent_scores):
    """
    Mark the words that are highlighted, both by in terms of sentence and phrase
    """
    num_query_sent = sent_ids.shape[0]
    num_words = len(all_words)
    
    output = dict()
    output['all_words'] = all_words
    output['words_by_sentence'] = words
    
    # for each query sentence, mark the highlight information
    for i in range(num_query_sent):
        query_words = word_tokenize(query_sents[i])
        is_selected_sent = np.zeros(num_words)
        is_selected_phrase = np.zeros(num_words)
        word_scores = np.zeros(num_words)
        
        # for each selected sentences from the candidate, compile information
        for sid, sscore in zip(sent_ids[i], sent_scores[i]):
            #print(len(sent_start_id), sid, sid+1)
            if sid+1 < len(sent_start_id):
                sent_range = (sent_start_id[sid], sent_start_id[sid+1])
                is_selected_sent[sent_range[0]:sent_range[1]] = 1
                word_scores[sent_range[0]:sent_range[1]] = sscore
                _, is_selected_phrase[sent_range[0]:sent_range[1]] = \
                    get_match_phrase(query_words, all_words[sent_range[0]:sent_range[1]])
            else:
                is_selected_sent[sent_start_id[sid]:] = 1
                word_scores[sent_start_id[sid]:] = sscore
                _, is_selected_phrase[sent_start_id[sid]:] = \
                    get_match_phrase(query_words, all_words[sent_start_id[sid]:])
                    
        # update selected phrase scores (-1 meaning a different color in gradio)
        word_scores[is_selected_sent+is_selected_phrase==2] = -1
            
        output[i] = {
            'is_selected_sent': is_selected_sent,
            'is_selected_phrase': is_selected_phrase,
            'scores': word_scores
        }

    return output

def get_highlight_info(model, text1, text2, K=None):
    """
    Get highlight information from two texts
    """
    sent1 = sent_tokenize(text1) # query
    sent2 = sent_tokenize(text2) # candidate
    if K is None: # if K is not set, select based on the length of the candidate
        K = int(len(sent2) / 3)
    score_mat = compute_sentencewise_scores(model, sent1, sent2)

    sent_ids, sent_scores = get_top_k(score_mat, K=K)
    words2, all_words2, sent_start_id2 = get_words(sent2)
    info = mark_words(sent1, words2, all_words2, sent_start_id2, sent_ids, sent_scores)
    
    # get top sentence pairs from the query and candidate (score, index_pair)
    top_pair_num = 5
    top_pairs = []
    ii = np.unravel_index(np.argsort(np.array(sent_scores).ravel())[-top_pair_num:], sent_scores.shape)
    for i, j in zip(ii[0][::-1], ii[1][::-1]):
        score = sent_scores[i,j]
        index_pair = (i, sent_ids[i,j].item())
        top_pairs.append((score, index_pair)) # list of (score, (sent_id_query, sent_id_candidate))
        
    # convert top_pairs to corresponding highlights format for GRadio Interpretation component
    top_pairs_info = dict()
    count = 0
    for s, (sidq, sidc) in top_pairs:
        q_sent = sent1[sidq]
        c_sent = sent2[sidc]
        q_words = word_tokenize(q_sent)
        c_words = word_tokenize(c_sent)
        mask1, mask2 = get_match_phrase(q_words, c_words)
        mask1 *= -1 # mark matching phrases as blue
        mask2 *= -1
        assert(len(mask1) == len(q_words) and len(mask2) == len(c_words))
        top_pairs_info[count] = {
            'query': {
                'original': q_sent,
                'interpretation': list(zip(q_words, mask1))
            },
            'candidate': {
                'original': c_sent,
                'interpretation': list(zip(c_words, mask2))
            },
            'score': s,
            'sent_idx': (sidq, sidc) 
        }
        count += 1
    
    return sent_ids, sent_scores, info, top_pairs_info

### Document-level operations

def predict_docscore(doc_model, tokenizer, query, titles, abstracts, batch=20):
    # compute document scores for each papers
    
    # concatenate title and abstract
    title_abs = []
    for t, a in zip(titles, abstracts):
        if t is not None and a is not None:
            title_abs.append(t + ' [SEP] ' + a)    
            
    num_docs = len(title_abs) 
    no_iter = int(np.ceil(num_docs / batch))
    scores = []
    with torch.no_grad():
        # batch
        for i in tqdm.tqdm(range(no_iter)):
            # preprocess the input
            inputs = tokenizer(
                [query] + title_abs[i*batch:(i+1)*batch], 
                padding=True, 
                truncation=True, 
                return_tensors="pt", 
                max_length=512
            )
            inputs.to(doc_model.device)
            result = doc_model(**inputs)
        
            # take the first token in the batch as the embedding
            embeddings = result.last_hidden_state[:, 0, :].detach().cpu().numpy()
        
            # compute cosine similarity
            q_emb = embeddings[0,:]
            p_emb = embeddings[1:,:]
            nn = np.linalg.norm(q_emb) * np.linalg.norm(p_emb, axis=1)
            scores += list(np.dot(p_emb, q_emb) / nn)

    assert(len(scores) == num_docs)
    
    return scores

def compute_document_score(doc_model, tokenizer, query, papers, batch=5):
    scores = []
    titles = []
    abstracts = []
    for p in papers:
        if p['title'] is not None and p['abstract'] is not None:
            titles.append(p['title'])
            abstracts.append(p['abstract'])
    scores = predict_docscore(doc_model, tokenizer, query, titles, abstracts, batch=batch)
    assert(len(scores) == len(abstracts))
    idx_sorted = np.argsort(scores)[::-1]
    
    titles_sorted = [titles[x] for x in idx_sorted]
    abstracts_sorted = [abstracts[x] for x in idx_sorted]
    scores_sorted = [scores[x] for x in idx_sorted]
    
    return titles_sorted, abstracts_sorted, scores_sorted