DeDeckerThomas commited on
Commit
8339421
β€’
1 Parent(s): 55b038b

Fix last bugs with annotation system

Browse files
app.py CHANGED
@@ -7,30 +7,11 @@ import orjson
7
  from annotated_text.util import get_annotated_html
8
  from st_aggrid import AgGrid, GridOptionsBuilder, GridUpdateMode
9
  import re
 
10
  import numpy as np
11
 
12
 
13
- if "config" not in st.session_state:
14
- with open("config.json", "r") as f:
15
- content = f.read()
16
- st.session_state.config = orjson.loads(content)
17
- st.session_state.data_frame = pd.DataFrame(columns=["model"])
18
- st.session_state.keyphrases = []
19
-
20
- st.set_page_config(
21
- page_icon="πŸ”‘",
22
- page_title="Keyphrase extraction/generation with Transformers",
23
- layout="wide",
24
- )
25
-
26
- if "select_rows" not in st.session_state:
27
- st.session_state.selected_rows = []
28
-
29
- st.header("πŸ”‘ Keyphrase extraction/generation with Transformers")
30
- col1, col2 = st.empty().columns(2)
31
-
32
-
33
- @st.cache(allow_output_mutation=True)
34
  def load_pipeline(chosen_model):
35
  if "keyphrase-extraction" in chosen_model:
36
  return KeyphraseExtractionPipeline(chosen_model)
@@ -67,18 +48,38 @@ def extract_keyphrases():
67
  def get_annotated_text(text, keyphrases):
68
  for keyphrase in keyphrases:
69
  text = re.sub(
70
- f"({keyphrase})",
71
- keyphrase.replace(" ", "$K"),
72
  text,
73
  flags=re.I,
 
74
  )
75
 
76
  result = []
77
  for i, word in enumerate(text.split(" ")):
78
- if re.sub(r"[^\w\s]", "", word) in keyphrases:
79
- result.append((word, "KEY", "#21c354"))
80
- elif "$K" in word:
81
- result.append((" ".join(word.split("$K")), "KEY", "#21c354"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  else:
83
  if i == len(st.session_state.input_text.split(" ")) - 1:
84
  result.append(f" {word}")
@@ -113,12 +114,39 @@ def rerender_output(layout):
113
  ],
114
  )
115
 
116
- result = get_annotated_text(text, keyphrases)
117
 
118
  layout.markdown(
119
  get_annotated_html(*result),
120
  unsafe_allow_html=True,
121
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
 
124
  chosen_model = col1.selectbox(
@@ -127,14 +155,17 @@ chosen_model = col1.selectbox(
127
  )
128
  st.session_state.chosen_model = chosen_model
129
 
130
- pipe = load_pipeline(
131
- f"{st.session_state.config.get('model_author')}/{st.session_state.chosen_model}"
132
- )
 
 
133
 
134
  st.session_state.input_text = col1.text_area(
135
  "Input", st.session_state.config.get("example_text"), height=300
136
- )
137
- pressed = col1.button("Extract", on_click=extract_keyphrases)
 
138
 
139
 
140
  if len(st.session_state.data_frame.columns) > 0:
 
7
  from annotated_text.util import get_annotated_html
8
  from st_aggrid import AgGrid, GridOptionsBuilder, GridUpdateMode
9
  import re
10
+ import string
11
  import numpy as np
12
 
13
 
14
+ @st.cache(allow_output_mutation=True, show_spinner=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  def load_pipeline(chosen_model):
16
  if "keyphrase-extraction" in chosen_model:
17
  return KeyphraseExtractionPipeline(chosen_model)
 
48
  def get_annotated_text(text, keyphrases):
49
  for keyphrase in keyphrases:
50
  text = re.sub(
51
+ rf"({keyphrase})([^A-Za-z])",
52
+ rf"$K:{keyphrases.index(keyphrase)}\2",
53
  text,
54
  flags=re.I,
55
+ count=1
56
  )
57
 
58
  result = []
59
  for i, word in enumerate(text.split(" ")):
60
+ if "$K" in word and re.search(
61
+ "(\d+)$", word.translate(str.maketrans("", "", string.punctuation))
62
+ ):
63
+ result.append(
64
+ (
65
+ re.sub(
66
+ r"\$K:\d+",
67
+ keyphrases[
68
+ int(
69
+ re.search(
70
+ "(\d+)$",
71
+ word.translate(
72
+ str.maketrans("", "", string.punctuation)
73
+ ),
74
+ ).group(1)
75
+ )
76
+ ],
77
+ word,
78
+ ),
79
+ "KEY",
80
+ "#21c354",
81
+ )
82
+ )
83
  else:
84
  if i == len(st.session_state.input_text.split(" ")) - 1:
85
  result.append(f" {word}")
 
114
  ],
115
  )
116
 
117
+ result = get_annotated_text(text, list(keyphrases))
118
 
119
  layout.markdown(
120
  get_annotated_html(*result),
121
  unsafe_allow_html=True,
122
  )
123
+ if "generation" in st.session_state.chosen_model:
124
+ abstractive_keyphrases = [
125
+ keyphrase
126
+ for keyphrase in keyphrases
127
+ if keyphrase.lower() not in text.lower()
128
+ ]
129
+ layout.write(", ".join(abstractive_keyphrases))
130
+
131
+
132
+ if "config" not in st.session_state:
133
+ with open("config.json", "r") as f:
134
+ content = f.read()
135
+ st.session_state.config = orjson.loads(content)
136
+ st.session_state.data_frame = pd.DataFrame(columns=["model"])
137
+ st.session_state.keyphrases = []
138
+
139
+ if "select_rows" not in st.session_state:
140
+ st.session_state.selected_rows = []
141
+
142
+ st.set_page_config(
143
+ page_icon="πŸ”‘",
144
+ page_title="Keyphrase extraction/generation with Transformers",
145
+ layout="wide",
146
+ )
147
+
148
+ st.header("πŸ”‘ Keyphrase extraction/generation with Transformers")
149
+ col1, col2 = st.columns(2)
150
 
151
 
152
  chosen_model = col1.selectbox(
 
155
  )
156
  st.session_state.chosen_model = chosen_model
157
 
158
+ with st.spinner("Loading pipeline..."):
159
+ pipe = load_pipeline(
160
+ f"{st.session_state.config.get('model_author')}/{st.session_state.chosen_model}"
161
+ )
162
+
163
 
164
  st.session_state.input_text = col1.text_area(
165
  "Input", st.session_state.config.get("example_text"), height=300
166
+ ).replace("\n", " ")
167
+ with st.spinner("Extracting keyphrases..."):
168
+ pressed = col1.button("Extract", on_click=extract_keyphrases)
169
 
170
 
171
  if len(st.session_state.data_frame.columns) > 0:
pipelines/__pycache__/keyphrase_generation_pipeline.cpython-39.pyc CHANGED
Binary files a/pipelines/__pycache__/keyphrase_generation_pipeline.cpython-39.pyc and b/pipelines/__pycache__/keyphrase_generation_pipeline.cpython-39.pyc differ