Lamp Socrates commited on
Commit
7551cdd
1 Parent(s): 5c82e3e
Files changed (1) hide show
  1. app.py +26 -11
app.py CHANGED
@@ -2,6 +2,8 @@ import streamlit as st
2
  from transformers import pipeline
3
  from transformers import AutoTokenizer, AutoModelForTokenClassification
4
  import pandas as pd
 
 
5
 
6
  @st.cache_resource()
7
  def load_trained_model():
@@ -18,6 +20,13 @@ def load_trained_model():
18
  ner_pipeline = pipeline("ner", model=model, tokenizer = tokenizer)
19
  return ner_pipeline
20
 
 
 
 
 
 
 
 
21
  def load_random_examples(dataset_name, num_examples=5):
22
  """
23
  Load random examples from the specified Hugging Face dataset.
@@ -28,11 +37,11 @@ def load_random_examples(dataset_name, num_examples=5):
28
  pd.DataFrame: A DataFrame containing the random examples.
29
  """
30
  # Load the dataset
31
- from datasets import load_dataset
32
- dataset = load_dataset("surrey-nlp/PLOD-CW")
33
 
 
 
34
  # Convert the dataset to a pandas DataFrame
35
- df = pd.DataFrame(dataset['test'])
36
 
37
  # Select random examples
38
  random_examples = df.sample(n=1)
@@ -162,6 +171,8 @@ def prep_page():
162
  if text:
163
  st.write("Entities recognized:")
164
  entities = model(text)
 
 
165
 
166
  # Create a dictionary to map entity labels to colors
167
  label_colors = {
@@ -173,7 +184,7 @@ def prep_page():
173
 
174
  # Prepare the HTML output with styled entities
175
  def get_entity_html(text, entities):
176
- html = ""
177
  last_idx = 0
178
  for entity in entities:
179
  start = entity['start']
@@ -181,17 +192,18 @@ def prep_page():
181
  label = entity['entity']
182
  entity_text = text[start:end]
183
  color = label_colors.get(label, 'lightgray')
184
-
185
  # Append the text before the entity
186
- html += text[last_idx:start]
187
  # Append the entity with styling
188
- html += f'<mark style="background-color: {color}; border-radius: 3px;">{entity_text}</mark>'
189
  last_idx = end
190
-
191
  # Append any remaining text after the last entity
192
- html += text[last_idx:]
 
193
  return html
194
-
195
  # Generate and display the styled HTML
196
  styled_text = get_entity_html(text, entities)
197
 
@@ -209,7 +221,10 @@ if __name__ == '__main__':
209
  if 'api' in query_params:
210
  sentence = query_params.get('sentence')
211
  entities = predict_using_trained(sentence)
212
- st.write({"sentence" : sentence , "entities" : entities})
 
 
 
213
  else:
214
  prep_page()
215
 
 
2
  from transformers import pipeline
3
  from transformers import AutoTokenizer, AutoModelForTokenClassification
4
  import pandas as pd
5
+ from pprint import pprint
6
+
7
 
8
  @st.cache_resource()
9
  def load_trained_model():
 
20
  ner_pipeline = pipeline("ner", model=model, tokenizer = tokenizer)
21
  return ner_pipeline
22
 
23
+
24
+ @st.cache_data()
25
+ def load_plod_cw_dataset():
26
+ from datasets import load_dataset
27
+ dataset = load_dataset("surrey-nlp/PLOD-CW")
28
+ return dataset
29
+
30
  def load_random_examples(dataset_name, num_examples=5):
31
  """
32
  Load random examples from the specified Hugging Face dataset.
 
37
  pd.DataFrame: A DataFrame containing the random examples.
38
  """
39
  # Load the dataset
 
 
40
 
41
+ dat = load_plod_cw_dataset()
42
+
43
  # Convert the dataset to a pandas DataFrame
44
+ df = pd.DataFrame(dat['test'])
45
 
46
  # Select random examples
47
  random_examples = df.sample(n=1)
 
171
  if text:
172
  st.write("Entities recognized:")
173
  entities = model(text)
174
+
175
+ pprint(entities)
176
 
177
  # Create a dictionary to map entity labels to colors
178
  label_colors = {
 
184
 
185
  # Prepare the HTML output with styled entities
186
  def get_entity_html(text, entities):
187
+ html = "<div>"
188
  last_idx = 0
189
  for entity in entities:
190
  start = entity['start']
 
192
  label = entity['entity']
193
  entity_text = text[start:end]
194
  color = label_colors.get(label, 'lightgray')
195
+
196
  # Append the text before the entity
197
+ html += text[last_idx:start].replace(" ", "<br>")
198
  # Append the entity with styling
199
+ html += f'<div style="background-color: {color}; padding: 5px; border-radius: 3px; margin: 5px 0;">{entity_text}</div>'
200
  last_idx = end
201
+
202
  # Append any remaining text after the last entity
203
+ html += text[last_idx:].replace(" ", "<br>")
204
+ html += "</div>"
205
  return html
206
+
207
  # Generate and display the styled HTML
208
  styled_text = get_entity_html(text, entities)
209
 
 
221
  if 'api' in query_params:
222
  sentence = query_params.get('sentence')
223
  entities = predict_using_trained(sentence)
224
+ response = {"sentence" : sentence , "entities" : entities}
225
+ pprint(response)
226
+
227
+ st.write(response)
228
  else:
229
  prep_page()
230