ACMCMC commited on
Commit
1b1c01c
1 Parent(s): 3276764
MATLAB/main.m CHANGED
@@ -15,7 +15,7 @@ function display_elementsForKey(connectionsMap, key)
15
  end
16
  end
17
 
18
- data = readtable('MGREL.RRF', Delimiter='|', FileType='text', NumHeaderLines=0, VariableNamingRule='preserve');
19
  data = renamevars(data,"#CUI1","CUI1");
20
  data = data(1:1000,:);
21
 
 
15
  end
16
  end
17
 
18
+ data = readtable('../MGREL.RRF', Delimiter='|', FileType='text', NumHeaderLines=0, VariableNamingRule='preserve');
19
  data = renamevars(data,"#CUI1","CUI1");
20
  data = data(1:1000,:);
21
 
MATLAB/visualize_app.mlapp CHANGED
Binary files a/MATLAB/visualize_app.mlapp and b/MATLAB/visualize_app.mlapp differ
 
MATLAB/visualize_connectedNodes_continuous.m CHANGED
@@ -1,6 +1,6 @@
1
  function visualize_connectedNodes_continuous()
2
  % Read the data and create the connections map
3
- data = readtable('MGREL.RRF', 'Delimiter', '|', 'FileType', 'text', 'NumHeaderLines', 0, 'VariableNamingRule', 'preserve');
4
  data = renamevars(data, '#CUI1', 'CUI1');
5
  data = data(1:10000,:);
6
 
 
1
  function visualize_connectedNodes_continuous()
2
  % Read the data and create the connections map
3
+ data = readtable('../MGREL.RRF', 'Delimiter', '|', 'FileType', 'text', 'NumHeaderLines', 0, 'VariableNamingRule', 'preserve');
4
  data = renamevars(data, '#CUI1', 'CUI1');
5
  data = data(1:10000,:);
6
 
app.py CHANGED
@@ -105,15 +105,17 @@ with st.container():
105
  status.divider()
106
  # 7. Use an LLM to get a summary of the clinical trials, in plain text format.
107
  status.write("Getting a summary of the clinical trials...")
108
- response, stats_dict = get_short_summary_out_of_json_files(json_of_clinical_trials)
109
  disease_overview = response
110
- status.write(f'Response from LLM summarization: {response}')
111
  # 8. Use an LLM to extract numerical data from the clinical trials (e.g. number of patients, number of deaths, etc.). Get summary statistics out of that.
112
- status.write("Getting summary statistics of the clinical trials...")
113
- #response = tagging_insights_from_json(json_of_clinical_trials)
114
- response = ""
115
- print(f'Response from LLM tagging: {response}')
116
- status.write(f'Response from LLM tagging: {response}')
 
 
117
  # 9. Show the results to the user: graph of the diseases chosen, summary of the clinical trials, summary statistics of the clinical trials, and list of the details of the clinical trials considered
118
  status.update(label="Done!", state="complete")
119
  status.balloons()
@@ -187,52 +189,57 @@ with st.container():
187
  with tabs[i]:
188
  render_trial_details(trials[i])
189
 
 
 
 
 
 
 
 
 
 
190
 
191
- chosen_disease_name = st.selectbox(
192
- "Choose a disease",
193
- get_all_diseases_name(engine))
194
-
195
- st.write("You selected:", chosen_disease_name)
196
- chosen_disease_uri = get_uri_from_name(engine, chosen_disease_name)
197
-
198
- nodes = []
199
- edges = []
200
 
 
 
201
 
202
- nodes.append( Node(id=chosen_disease_uri,
203
- label=chosen_disease_name,
204
- size=25,
205
- shape="circular")
206
- )
207
 
208
- similar_diseases = get_most_similar_diseases_from_uri(engine, chosen_disease_uri, threshold=0.6)
209
- print(similar_diseases)
210
- for uri, name, weight in similar_diseases:
211
- nodes.append( Node(id=uri,
212
- label=name,
213
  size=25,
214
  shape="circular")
215
  )
216
 
217
- print(True if float(weight) > 0.7 else False)
218
- edges.append( Edge(source=chosen_disease_uri,
219
- target=uri,
220
- color="red" if float(weight) > 0.7 else "blue",
221
- weight=float(weight)**10,
222
- type="CURVE_SMOOTH"
223
- # type="STRAIGHT"
224
- )
225
- )
226
-
227
- config = Config(width=750,
228
- height=950,
229
- directed=False,
230
- physics=True,
231
- hierarchical=False,
232
- collapsible=False,
233
- # **kwargs
234
- )
235
-
236
- return_value = agraph(nodes=nodes,
237
- edges=edges,
238
- config=config)
 
 
 
 
 
 
 
 
 
 
105
  status.divider()
106
  # 7. Use an LLM to get a summary of the clinical trials, in plain text format.
107
  status.write("Getting a summary of the clinical trials...")
108
+ response = get_short_summary_out_of_json_files(json_of_clinical_trials)
109
  disease_overview = response
110
+ try:
111
  # 8. Use an LLM to extract numerical data from the clinical trials (e.g. number of patients, number of deaths, etc.). Get summary statistics out of that.
112
+ status.write("Getting summary statistics of the clinical trials...")
113
+ response = tagging_insights_from_json(json_of_clinical_trials)
114
+ print(f'Response from LLM tagging: {response}')
115
+ status.write(f'Response from LLM tagging: {response}')
116
+ except Exception as e:
117
+ print(f'Error while extracting numerical data from the clinical trials: {e}')
118
+ status.warning(f'Error while extracting numerical data from the clinical trials. This information will not be shown.')
119
  # 9. Show the results to the user: graph of the diseases chosen, summary of the clinical trials, summary statistics of the clinical trials, and list of the details of the clinical trials considered
120
  status.update(label="Done!", state="complete")
121
  status.balloons()
 
189
  with tabs[i]:
190
  render_trial_details(trials[i])
191
 
192
+ show_graph_of_all_diseases = False
193
+ if show_graph_of_all_diseases:
194
+ # If disease_names is not defined, define it
195
+ if "disease_names" not in st.session_state:
196
+ st.session_state.disease_names = get_all_diseases_name(engine)
197
+ chosen_disease_name = st.selectbox(
198
+ "Choose a disease",
199
+ st.session_state.disease_names,
200
+ )
201
 
202
+ st.write("You selected:", chosen_disease_name)
203
+ chosen_disease_uri = get_uri_from_name(engine, chosen_disease_name)
 
 
 
 
 
 
 
204
 
205
+ nodes = []
206
+ edges = []
207
 
 
 
 
 
 
208
 
209
+ nodes.append( Node(id=chosen_disease_uri,
210
+ label=chosen_disease_name,
 
 
 
211
  size=25,
212
  shape="circular")
213
  )
214
 
215
+ similar_diseases = get_most_similar_diseases_from_uri(engine, chosen_disease_uri, threshold=0.6)
216
+ print(similar_diseases)
217
+ for uri, name, weight in similar_diseases:
218
+ nodes.append( Node(id=uri,
219
+ label=name,
220
+ size=25,
221
+ shape="circular")
222
+ )
223
+
224
+ print(True if float(weight) > 0.7 else False)
225
+ edges.append( Edge(source=chosen_disease_uri,
226
+ target=uri,
227
+ color="red" if float(weight) > 0.7 else "blue",
228
+ weight=float(weight)**10,
229
+ type="CURVE_SMOOTH"
230
+ # type="STRAIGHT"
231
+ )
232
+ )
233
+
234
+ config = Config(width=750,
235
+ height=950,
236
+ directed=False,
237
+ physics=True,
238
+ hierarchical=False,
239
+ collapsible=False,
240
+ # **kwargs
241
+ )
242
+
243
+ return_value = agraph(nodes=nodes,
244
+ edges=edges,
245
+ config=config)
llm_res.py CHANGED
@@ -301,7 +301,7 @@ def tagging_insights_from_json(data_json):
301
 
302
  Extract the desired information from the following JSON data.
303
 
304
- Only extract the properties mentioned in the 'Classification' function.
305
 
306
  JSON data:
307
  {input}
@@ -317,20 +317,20 @@ def tagging_insights_from_json(data_json):
317
  # status: list = Field(
318
  # description="Extract the status of all the clinical trials"
319
  # )
320
- keywords: list = Field(
321
- description="Extract the most relevant keywords for each clinical trials"
322
- )
323
  # interventions: list = Field(
324
  # description="describe the interventions for each clinical trial using title, name and description"
325
  # )
326
- primary_outcomes: list = Field(
327
- description="get the timeframe of each clinical trial"
328
- )
329
- secondary_outcomes: list= Field(description= "get the secondary outcomes of each clinical trial")
330
- eligibility: list = Field(
331
- description="get the timeframe of each clinical trial"
332
- )
333
- healthy_volunteers: list= Field(description= "determine whether the clinical trial requires healthy volunteers")
334
  minimum_age: list = Field(
335
  description="get the minimum age from each experiment"
336
  )
@@ -343,12 +343,12 @@ def tagging_insights_from_json(data_json):
343
  return {
344
  # "project_title": self.project_title,
345
  # "status": self.status,
346
- "keywords": self.keywords,
347
  # "interventions": self.interventions,
348
- "primary_outcomes": self.primary_outcomes,
349
- "secondary_outcomes": self.secondary_outcomes,
350
  # "eligibility": self.eligibility,
351
- "healthy_volunteers": self.healthy_volunteers,
352
  "minimum_age": self.minimum_age,
353
  "maximum_age": self.maximum_age,
354
  "gender": self.gender
@@ -370,13 +370,13 @@ def tagging_insights_from_json(data_json):
370
 
371
  avg_min_age, avg_max_age, most_common_gender, common_keywords= analyze_data(result_dict)
372
 
373
- stats_dict= {'Average Minimum age': avg_min_age,
374
- 'Average Maximum age': avg_max_age,
375
- 'Most common gender undergoing the trials': most_common_gender,
376
- 'common keywords found in the trials': common_keywords}
377
 
378
  print(f"Result_tagging: {result_dict}")
379
- return result_dict, stats_dict
380
 
381
 
382
  # clinical_record_info = get_clinical_records_by_ids(['NCT00841061', 'NCT03035123', 'NCT02272751', 'NCT03035123', 'NCT03055377'])
 
301
 
302
  Extract the desired information from the following JSON data.
303
 
304
+ Only extract the properties mentioned in the 'Classification' function. Output a list of the extracted properties, starting with [ and ending with ].
305
 
306
  JSON data:
307
  {input}
 
317
  # status: list = Field(
318
  # description="Extract the status of all the clinical trials"
319
  # )
320
+ #keywords: list = Field(
321
+ # description="Extract the most relevant keywords for each clinical trials"
322
+ #)
323
  # interventions: list = Field(
324
  # description="describe the interventions for each clinical trial using title, name and description"
325
  # )
326
+ #primary_outcomes: list = Field(
327
+ # description="get the timeframe of each clinical trial"
328
+ #)
329
+ #secondary_outcomes: list= Field(description= "get the secondary outcomes of each clinical trial")
330
+ #eligibility: list = Field(
331
+ # description="get the timeframe of each clinical trial"
332
+ #)
333
+ # healthy_volunteers: list= Field(description= "determine whether the clinical trial requires healthy volunteers")
334
  minimum_age: list = Field(
335
  description="get the minimum age from each experiment"
336
  )
 
343
  return {
344
  # "project_title": self.project_title,
345
  # "status": self.status,
346
+ #"keywords": self.keywords,
347
  # "interventions": self.interventions,
348
+ #"primary_outcomes": self.primary_outcomes,
349
+ #"secondary_outcomes": self.secondary_outcomes,
350
  # "eligibility": self.eligibility,
351
+ # "healthy_volunteers": self.healthy_volunteers,
352
  "minimum_age": self.minimum_age,
353
  "maximum_age": self.maximum_age,
354
  "gender": self.gender
 
370
 
371
  avg_min_age, avg_max_age, most_common_gender, common_keywords= analyze_data(result_dict)
372
 
373
+ #stats_dict= {'Average Minimum age': avg_min_age,
374
+ # 'Average Maximum age': avg_max_age,
375
+ # 'Most common gender undergoing the trials': most_common_gender,
376
+ # 'common keywords found in the trials': common_keywords}
377
 
378
  print(f"Result_tagging: {result_dict}")
379
+ return result_dict#, stats_dict
380
 
381
 
382
  # clinical_record_info = get_clinical_records_by_ids(['NCT00841061', 'NCT03035123', 'NCT02272751', 'NCT03035123', 'NCT03055377'])
utils.py CHANGED
@@ -18,15 +18,16 @@ engine = create_engine(CONNECTION_STRING)
18
 
19
 
20
  def get_all_diseases_name(engine) -> List[List[str]]:
 
21
  with engine.connect() as conn:
22
  with conn.begin():
23
  sql = f"""
24
- SELECT * FROM Test.EntityEmbeddings
25
  """
26
  result = conn.execute(text(sql))
27
  data = result.fetchall()
28
 
29
- all_diseases = [row[1] for row in data if row[1] != "nan"]
30
  return all_diseases
31
 
32
 
 
18
 
19
 
20
  def get_all_diseases_name(engine) -> List[List[str]]:
21
+ print("Fetching all disease names...")
22
  with engine.connect() as conn:
23
  with conn.begin():
24
  sql = f"""
25
+ SELECT label FROM Test.EntityEmbeddings
26
  """
27
  result = conn.execute(text(sql))
28
  data = result.fetchall()
29
 
30
+ all_diseases = [row[0] for row in data if row[0] != "nan"]
31
  return all_diseases
32
 
33