jskim commited on
Commit
963bf46
1 Parent(s): b1499f3
Files changed (3) hide show
  1. app.py +31 -14
  2. input_format.py +1 -16
  3. score.py +23 -12
app.py CHANGED
@@ -28,7 +28,7 @@ def get_similar_paper(
28
  author_id_input,
29
  num_papers_show=10
30
  ):
31
- print('-- retrieving similar papers')
32
  input_sentences = sent_tokenize(abstract_text_input)
33
 
34
  # TODO handle pdf file input
@@ -41,8 +41,8 @@ def get_similar_paper(
41
  name, papers = get_text_from_author_id(author_id_input)
42
 
43
  # Compute Doc-level affinity scores for the Papers
44
- print('---- computing scores')
45
- titles, abstracts, doc_scores = compute_overall_score(
46
  doc_model,
47
  tokenizer,
48
  abstract_text_input,
@@ -63,9 +63,15 @@ def get_similar_paper(
63
  doc_scores = doc_scores[:num_papers_show]
64
 
65
  display_title = ['[ %0.3f ] %s'%(s, t) for t, s in zip(titles, doc_scores)]
66
- print('----- done')
67
-
68
- return gr.update(choices=display_title, interactive=True, visible=True), gr.update(choices=input_sentences, interactive=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
 
 
 
 
 
 
69
 
70
  def get_highlights(
71
  abstract_text_input,
@@ -73,7 +79,7 @@ def get_highlights(
73
  abstract,
74
  K=2
75
  ):
76
- print('-- obtaining highlights')
77
  # Compute sent-level and phrase-level affinity scores for each papers
78
  sent_ids, sent_scores, info = get_highlight_info(
79
  sent_model,
@@ -86,18 +92,20 @@ def get_highlights(
86
  num_sents = len(input_sentences)
87
 
88
  word_scores = dict()
89
- # different highlights for each input sentences
 
90
  for i in range(num_sents):
91
  word_scores[str(i)] = {
92
  "original": abstract,
93
  "interpretation": list(zip(info['all_words'], info[i]['scores']))
94
- }
95
 
96
  tmp = {
97
  'source_sentences': input_sentences,
98
  'highlight': word_scores
99
  }
100
  pickle.dump(tmp, open('highlight_info.pkl', 'wb'))
 
101
 
102
  # update the visibility of radio choices
103
  return gr.update(visible=True)
@@ -105,11 +113,12 @@ def get_highlights(
105
  def update_name(author_id_input):
106
  # update the name of the author based on the id input
107
  name, _ = get_text_from_author_id(author_id_input)
 
108
  return gr.update(value=name)
109
 
110
  def change_output_highlight(source_sent_choice):
111
- fname = 'highlight_info.pkl'
112
  # change the output highlight based on the sentence selected from the submission
 
113
  if os.path.exists(fname):
114
  tmp = pickle.load(open(fname, 'rb'))
115
  source_sents = tmp['source_sentences']
@@ -122,7 +131,7 @@ def change_output_highlight(source_sent_choice):
122
  return
123
 
124
  def change_paper(selected_papers_radio):
125
- # change the paper to show
126
  fname = 'paper_info.pkl'
127
  if os.path.exists(fname):
128
  tmp = pickle.load(open(fname, 'rb'))
@@ -130,7 +139,7 @@ def change_paper(selected_papers_radio):
130
  display_title = '[ %0.3f ] %s'%(aff_score, title)
131
  if display_title == selected_papers_radio:
132
  #print('changing paper')
133
- return title, abstract, aff_score
134
  else:
135
  return
136
 
@@ -150,7 +159,9 @@ with gr.Blocks() as demo:
150
  author_id_input.change(fn=update_name, inputs=author_id_input, outputs=name)
151
  with gr.Row():
152
  compute_btn = gr.Button('Search Similar Papers from the Reviewer')
153
-
 
 
154
  # show multiple papers in radio check box to select from
155
  with gr.Row():
156
  selected_papers_radio = gr.Radio(
@@ -159,7 +170,7 @@ with gr.Blocks() as demo:
159
  label='Selected Top Papers from the Reviewer'
160
  )
161
 
162
- ### PAPER INFORMATION
163
  with gr.Row(visible=False) as title_row:
164
  with gr.Column(scale=3):
165
  paper_title = gr.Textbox(label='Title', interactive=False)
@@ -183,6 +194,9 @@ with gr.Blocks() as demo:
183
  with gr.Column(scale=3): # highlighted text from paper
184
  highlight = gr.components.Interpretation(paper_abstract)
185
 
 
 
 
186
  compute_btn.click(
187
  fn=get_similar_paper,
188
  inputs=[
@@ -199,6 +213,7 @@ with gr.Blocks() as demo:
199
  ]
200
  )
201
 
 
202
  explain_btn.click(
203
  fn=get_highlights,
204
  inputs=[
@@ -209,12 +224,14 @@ with gr.Blocks() as demo:
209
  outputs=source_sentences
210
  )
211
 
 
212
  source_sentences.change(
213
  fn=change_output_highlight,
214
  inputs=source_sentences,
215
  outputs=highlight
216
  )
217
 
 
218
  selected_papers_radio.change(
219
  fn=change_paper,
220
  inputs=selected_papers_radio,
 
28
  author_id_input,
29
  num_papers_show=10
30
  ):
31
+ print('retrieving similar papers')
32
  input_sentences = sent_tokenize(abstract_text_input)
33
 
34
  # TODO handle pdf file input
 
41
  name, papers = get_text_from_author_id(author_id_input)
42
 
43
  # Compute Doc-level affinity scores for the Papers
44
+ print('computing scores')
45
+ titles, abstracts, doc_scores = compute_document_score(
46
  doc_model,
47
  tokenizer,
48
  abstract_text_input,
 
63
  doc_scores = doc_scores[:num_papers_show]
64
 
65
  display_title = ['[ %0.3f ] %s'%(s, t) for t, s in zip(titles, doc_scores)]
66
+ print('retrieval done')
67
+
68
+ return (
69
+ gr.update(choices=display_title, interactive=True, visible=True), # set of papers
70
+ gr.update(choices=input_sentences, interactive=True), # submission sentences
71
+ gr.update(visible=True), # title row
72
+ gr.update(visible=True), # abstract row
73
+ gr.update(visible=True) # button
74
+ )
75
 
76
  def get_highlights(
77
  abstract_text_input,
 
79
  abstract,
80
  K=2
81
  ):
82
+ print('obtaining highlights')
83
  # Compute sent-level and phrase-level affinity scores for each papers
84
  sent_ids, sent_scores, info = get_highlight_info(
85
  sent_model,
 
92
  num_sents = len(input_sentences)
93
 
94
  word_scores = dict()
95
+
96
+ # different highlights for each input sentence
97
  for i in range(num_sents):
98
  word_scores[str(i)] = {
99
  "original": abstract,
100
  "interpretation": list(zip(info['all_words'], info[i]['scores']))
101
+ } # format to feed to for Gradio Interpretation component
102
 
103
  tmp = {
104
  'source_sentences': input_sentences,
105
  'highlight': word_scores
106
  }
107
  pickle.dump(tmp, open('highlight_info.pkl', 'wb'))
108
+ print('done')
109
 
110
  # update the visibility of radio choices
111
  return gr.update(visible=True)
 
113
  def update_name(author_id_input):
114
  # update the name of the author based on the id input
115
  name, _ = get_text_from_author_id(author_id_input)
116
+
117
  return gr.update(value=name)
118
 
119
  def change_output_highlight(source_sent_choice):
 
120
  # change the output highlight based on the sentence selected from the submission
121
+ fname = 'highlight_info.pkl'
122
  if os.path.exists(fname):
123
  tmp = pickle.load(open(fname, 'rb'))
124
  source_sents = tmp['source_sentences']
 
131
  return
132
 
133
  def change_paper(selected_papers_radio):
134
+ # change the paper to show based on the paper selected
135
  fname = 'paper_info.pkl'
136
  if os.path.exists(fname):
137
  tmp = pickle.load(open(fname, 'rb'))
 
139
  display_title = '[ %0.3f ] %s'%(aff_score, title)
140
  if display_title == selected_papers_radio:
141
  #print('changing paper')
142
+ return title, abstract, aff_score # update title, abstract, and affinity score fields
143
  else:
144
  return
145
 
 
159
  author_id_input.change(fn=update_name, inputs=author_id_input, outputs=name)
160
  with gr.Row():
161
  compute_btn = gr.Button('Search Similar Papers from the Reviewer')
162
+
163
+ ### PAPER INFORMATION
164
+
165
  # show multiple papers in radio check box to select from
166
  with gr.Row():
167
  selected_papers_radio = gr.Radio(
 
170
  label='Selected Top Papers from the Reviewer'
171
  )
172
 
173
+ # selected paper information
174
  with gr.Row(visible=False) as title_row:
175
  with gr.Column(scale=3):
176
  paper_title = gr.Textbox(label='Title', interactive=False)
 
194
  with gr.Column(scale=3): # highlighted text from paper
195
  highlight = gr.components.Interpretation(paper_abstract)
196
 
197
+ ### EVENT LISTENERS
198
+
199
+ # retrieve similar papers
200
  compute_btn.click(
201
  fn=get_similar_paper,
202
  inputs=[
 
213
  ]
214
  )
215
 
216
+ # get highlights
217
  explain_btn.click(
218
  fn=get_highlights,
219
  inputs=[
 
224
  outputs=source_sentences
225
  )
226
 
227
+ # change highlight based on selected sentences from submission
228
  source_sentences.change(
229
  fn=change_output_highlight,
230
  inputs=source_sentences,
231
  outputs=highlight
232
  )
233
 
234
+ # change paper to show based on selected papers
235
  selected_papers_radio.change(
236
  fn=change_paper,
237
  inputs=selected_papers_radio,
input_format.py CHANGED
@@ -94,19 +94,4 @@ def get_introduction(text):
94
  pass
95
 
96
  def get_conclusion(text):
97
- pass
98
-
99
-
100
- if __name__ == '__main__':
101
- def run_sample():
102
- url = 'https://arxiv.org/abs/2105.06506'
103
- text = get_text_from_url(url)
104
- assert(text[0].split('\n')[0] == 'Sanity Simulations for Saliency Methods')
105
-
106
- text2 = get_text_from_url('https://arxiv.org/pdf/2105.06506.pdf')
107
- assert(text2[0].split('\n')[0] == 'Sanity Simulations for Saliency Methods')
108
-
109
- # text = get_text_from_url('https://arxiv.org/paetseths.pdf')
110
-
111
- # test the code
112
- run_sample()
 
94
  pass
95
 
96
  def get_conclusion(text):
97
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
score.py CHANGED
@@ -5,16 +5,16 @@ import torch
5
  import numpy as np
6
 
7
  def compute_sentencewise_scores(model, query_sents, candidate_sents):
 
8
  # list of sentences from query and candidate
9
-
10
  q_v, c_v = get_embedding(model, query_sents, candidate_sents)
 
11
  return util.cos_sim(q_v, c_v)
12
 
13
  def get_embedding(model, query_sents, candidate_sents):
14
-
15
  q_v = model.encode(query_sents)
16
  c_v = model.encode(candidate_sents)
17
-
18
  return q_v, c_v
19
 
20
  def get_top_k(score_mat, K=3):
@@ -30,6 +30,10 @@ def get_top_k(score_mat, K=3):
30
  return picked_sent, picked_scores
31
 
32
  def get_words(sent):
 
 
 
 
33
  words = []
34
  sent_start_id = [] # keep track of the word index where the new sentence starts
35
  counter = 0
@@ -48,8 +52,10 @@ def get_words(sent):
48
  return words, all_words, sent_start_id
49
 
50
  def get_match_phrase(w1, w2):
51
- # list of words for query and candidate as input
52
- # return the word list and binary mask of matching phrases
 
 
53
  # POS tags that should be considered for matching phrase
54
  include = [
55
  'JJ',
@@ -80,6 +86,9 @@ def get_match_phrase(w1, w2):
80
  return mask2
81
 
82
  def mark_words(query_sents, words, all_words, sent_start_id, sent_ids, sent_scores):
 
 
 
83
  num_query_sent = sent_ids.shape[0]
84
  num_words = len(all_words)
85
 
@@ -121,6 +130,9 @@ def mark_words(query_sents, words, all_words, sent_start_id, sent_ids, sent_scor
121
  return output
122
 
123
  def get_highlight_info(model, text1, text2, K=None):
 
 
 
124
  sent1 = sent_tokenize(text1) # query
125
  sent2 = sent_tokenize(text2) # candidate
126
  if K is None: # if K is not set, select based on the length of the candidate
@@ -128,15 +140,15 @@ def get_highlight_info(model, text1, text2, K=None):
128
  score_mat = compute_sentencewise_scores(model, sent1, sent2)
129
 
130
  sent_ids, sent_scores = get_top_k(score_mat, K=K)
131
- #print(sent_ids, sent_scores)
132
  words2, all_words2, sent_start_id2 = get_words(sent2)
133
- #print(all_words1, sent_start_id1)
134
  info = mark_words(sent1, words2, all_words2, sent_start_id2, sent_ids, sent_scores)
135
 
136
  return sent_ids, sent_scores, info
137
 
138
- ## Document-level operations
 
139
  def predict_docscore(doc_model, tokenizer, query, titles, abstracts, batch=20):
 
140
 
141
  # concatenate title and abstract
142
  title_abs = []
@@ -146,12 +158,11 @@ def predict_docscore(doc_model, tokenizer, query, titles, abstracts, batch=20):
146
 
147
  num_docs = len(title_abs)
148
  no_iter = int(np.ceil(num_docs / batch))
149
-
150
- # preprocess the input
151
  scores = []
152
  with torch.no_grad():
153
- # batch
154
  for i in range(no_iter):
 
155
  inputs = tokenizer(
156
  [query] + title_abs[i*batch:(i+1)*batch],
157
  padding=True,
@@ -175,7 +186,7 @@ def predict_docscore(doc_model, tokenizer, query, titles, abstracts, batch=20):
175
 
176
  return scores
177
 
178
- def compute_overall_score(doc_model, tokenizer, query, papers, batch=5):
179
  scores = []
180
  titles = []
181
  abstracts = []
 
5
  import numpy as np
6
 
7
  def compute_sentencewise_scores(model, query_sents, candidate_sents):
8
+ # TODO make this more general for different types of models
9
  # list of sentences from query and candidate
 
10
  q_v, c_v = get_embedding(model, query_sents, candidate_sents)
11
+
12
  return util.cos_sim(q_v, c_v)
13
 
14
  def get_embedding(model, query_sents, candidate_sents):
 
15
  q_v = model.encode(query_sents)
16
  c_v = model.encode(candidate_sents)
17
+
18
  return q_v, c_v
19
 
20
  def get_top_k(score_mat, K=3):
 
30
  return picked_sent, picked_scores
31
 
32
  def get_words(sent):
33
+ """
34
+ Input: list of sentences
35
+ Output: list of list of words per sentence, all words in, index of starting words for each sentence
36
+ """
37
  words = []
38
  sent_start_id = [] # keep track of the word index where the new sentence starts
39
  counter = 0
 
52
  return words, all_words, sent_start_id
53
 
54
  def get_match_phrase(w1, w2):
55
+ """
56
+ Input: list of words for query and candidate text
57
+ Output: word list and binary mask of matching phrases between the inputs
58
+ """
59
  # POS tags that should be considered for matching phrase
60
  include = [
61
  'JJ',
 
86
  return mask2
87
 
88
  def mark_words(query_sents, words, all_words, sent_start_id, sent_ids, sent_scores):
89
+ """
90
+ Mark the words that are highlighted, both by in terms of sentence and phrase
91
+ """
92
  num_query_sent = sent_ids.shape[0]
93
  num_words = len(all_words)
94
 
 
130
  return output
131
 
132
  def get_highlight_info(model, text1, text2, K=None):
133
+ """
134
+ Get highlight information from two texts
135
+ """
136
  sent1 = sent_tokenize(text1) # query
137
  sent2 = sent_tokenize(text2) # candidate
138
  if K is None: # if K is not set, select based on the length of the candidate
 
140
  score_mat = compute_sentencewise_scores(model, sent1, sent2)
141
 
142
  sent_ids, sent_scores = get_top_k(score_mat, K=K)
 
143
  words2, all_words2, sent_start_id2 = get_words(sent2)
 
144
  info = mark_words(sent1, words2, all_words2, sent_start_id2, sent_ids, sent_scores)
145
 
146
  return sent_ids, sent_scores, info
147
 
148
+ ### Document-level operations
149
+
150
  def predict_docscore(doc_model, tokenizer, query, titles, abstracts, batch=20):
151
+ # compute document scores for each papers
152
 
153
  # concatenate title and abstract
154
  title_abs = []
 
158
 
159
  num_docs = len(title_abs)
160
  no_iter = int(np.ceil(num_docs / batch))
 
 
161
  scores = []
162
  with torch.no_grad():
163
+ # batch
164
  for i in range(no_iter):
165
+ # preprocess the input
166
  inputs = tokenizer(
167
  [query] + title_abs[i*batch:(i+1)*batch],
168
  padding=True,
 
186
 
187
  return scores
188
 
189
+ def compute_document_score(doc_model, tokenizer, query, papers, batch=5):
190
  scores = []
191
  titles = []
192
  abstracts = []