Spaces:
Sleeping
Sleeping
1-ARIjitS
commited on
Commit
•
86f6253
1
Parent(s):
52ee7a9
tagging included
Browse files- llm_res.py +149 -54
llm_res.py
CHANGED
@@ -44,23 +44,106 @@ def get_clinical_records_by_ids(clinical_record_ids: List[str]) -> List[Dict[str
|
|
44 |
return clinical_records
|
45 |
|
46 |
|
47 |
-
def process_json_data_for_llm(data):
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
# Iterate through the dictionary and keep only the desired fields
|
62 |
filtered_data = []
|
63 |
-
for item in
|
64 |
try:
|
65 |
organization_name = item["protocolSection"]["identificationModule"][
|
66 |
"organization"
|
@@ -132,22 +215,24 @@ def process_json_data_for_llm(data):
|
|
132 |
"eligibility": eligibility,
|
133 |
}
|
134 |
filtered_data.append(filtered_item)
|
135 |
-
|
136 |
-
|
137 |
-
# print(ele)
|
138 |
-
|
139 |
|
140 |
def get_short_summary_out_of_json_files(data_json):
|
141 |
-
|
|
|
|
|
|
|
142 |
|
143 |
-
#
|
144 |
-
You will be given a set of descriptions of clinical trials. Your job is to come up with a short summary (100-200 words) of the descriptions of the clinical trials. Your users are clinical researchers who are experts in medicine, so you should be technical and specific, including scientific terms. Always be faithful to the original information written in the reports.
|
145 |
|
146 |
-
|
147 |
|
148 |
-
|
149 |
|
150 |
-
|
|
|
|
|
151 |
|
152 |
prompt = PromptTemplate.from_template(prompt_template)
|
153 |
|
@@ -178,18 +263,31 @@ General summary:"""
|
|
178 |
print(f"Combined descriptions: {combined_descriptions}")
|
179 |
|
180 |
result = stuff_chain.run(combined_descriptions)
|
181 |
-
print(f"
|
182 |
|
183 |
return result
|
184 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
185 |
|
186 |
-
|
|
|
|
|
|
|
|
|
|
|
187 |
class Classification(BaseModel):
|
188 |
-
description: str = Field(
|
189 |
-
|
190 |
-
)
|
191 |
project_title: list = Field(
|
192 |
-
description="Extract the project
|
193 |
)
|
194 |
status: list = Field(
|
195 |
description="Extract the status of all the clinical trials"
|
@@ -207,43 +305,45 @@ def taggingTemplate():
|
|
207 |
# eligibility: list = Field(
|
208 |
# description="get the eligibilityCriteria grouping all the clinical trials"
|
209 |
# )
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
|
219 |
def get_dict(self):
|
220 |
return {
|
221 |
-
"summary": self.description,
|
222 |
"project_title": self.project_title,
|
223 |
"status": self.status,
|
224 |
-
"keywords": self.keywords,
|
225 |
"interventions": self.interventions,
|
226 |
"primary_outcomes": self.primary_outcomes,
|
227 |
# "secondary_outcomes": self.secondary_outcomes,
|
228 |
-
"eligibility": self.eligibility,
|
229 |
-
|
230 |
"minimum_age": self.minimum_age,
|
231 |
"maximum_age": self.maximum_age,
|
232 |
-
"gender": self.gender
|
233 |
}
|
234 |
|
235 |
# LLM
|
236 |
llm = ChatOpenAI(
|
237 |
temperature=0.6,
|
238 |
-
model="gpt-4",
|
239 |
openai_api_key=os.environ["OPENAI_API_KEY"],
|
240 |
).with_structured_output(Classification)
|
241 |
|
242 |
-
stuff_chain = StuffDocumentsChain(llm_chain=llm, document_variable_name="text")
|
243 |
|
244 |
-
|
245 |
|
246 |
-
|
|
|
|
|
|
|
247 |
|
248 |
|
249 |
# clinical_record_info = get_clinical_records_by_ids(['NCT00841061', 'NCT03035123', 'NCT02272751', 'NCT03035123', 'NCT03055377'])
|
@@ -252,10 +352,5 @@ def taggingTemplate():
|
|
252 |
# with open('data.json', 'w') as f:
|
253 |
# json.dump(clinical_record_info, f, indent=4)
|
254 |
|
255 |
-
# tagging_chain = llm_config()
|
256 |
-
|
257 |
|
258 |
-
|
259 |
-
processed_data = process_json_data_for_llm(json_contents)
|
260 |
-
# res = tagging_chain.invoke({"input": processed_data})
|
261 |
-
# return res
|
|
|
44 |
return clinical_records
|
45 |
|
46 |
|
47 |
+
# # def process_json_data_for_llm(data):
|
48 |
+
|
49 |
+
# # Define the fields you want to keep
|
50 |
+
# fields_to_keep = [
|
51 |
+
# "class_of_organization",
|
52 |
+
# "title",
|
53 |
+
# "overallStatus",
|
54 |
+
# "descriptionModule",
|
55 |
+
# "conditions",
|
56 |
+
# "interventions",
|
57 |
+
# "outcomesModule",
|
58 |
+
# "eligibilityModule",
|
59 |
+
# ]
|
60 |
+
|
61 |
+
# # Iterate through the dictionary and keep only the desired fields
|
62 |
+
# filtered_data = []
|
63 |
+
# for item in data:
|
64 |
+
# try:
|
65 |
+
# organization_name = item["protocolSection"]["identificationModule"][
|
66 |
+
# "organization"
|
67 |
+
# ]["fullName"]
|
68 |
+
# except:
|
69 |
+
# organization_name = ""
|
70 |
+
# try:
|
71 |
+
# project_title = item["protocolSection"]["identificationModule"][
|
72 |
+
# "officialTitle"
|
73 |
+
# ]
|
74 |
+
# except:
|
75 |
+
# project_title = ""
|
76 |
+
# try:
|
77 |
+
# status = item["protocolSection"]["statusModule"]["overallStatus"]
|
78 |
+
# except:
|
79 |
+
# status = ""
|
80 |
+
# try:
|
81 |
+
# briefDescription = item["protocolSection"]["descriptionModule"][
|
82 |
+
# "briefSummary"
|
83 |
+
# ]
|
84 |
+
# except:
|
85 |
+
# briefDescription = ""
|
86 |
+
# try:
|
87 |
+
# detailedDescription = item["protocolSection"]["descriptionModule"][
|
88 |
+
# "detailedDescription"
|
89 |
+
# ]
|
90 |
+
# except:
|
91 |
+
# detailedDescription = ""
|
92 |
+
# try:
|
93 |
+
# conditions = item["protocolSection"]["conditionsModule"]["conditions"]
|
94 |
+
# except:
|
95 |
+
# conditions = []
|
96 |
+
# try:
|
97 |
+
# keywords = item["protocolSection"]["conditionsModule"]["keywords"]
|
98 |
+
# except:
|
99 |
+
# keywords = []
|
100 |
+
# try:
|
101 |
+
# interventions = item["protocolSection"]["armsInterventionsModule"][
|
102 |
+
# "interventions"
|
103 |
+
# ]
|
104 |
+
# except:
|
105 |
+
# interventions = []
|
106 |
+
# try:
|
107 |
+
# primary_outcomes = item["protocolSection"]["outcomesModule"][
|
108 |
+
# "primaryOutcomes"
|
109 |
+
# ]
|
110 |
+
# except:
|
111 |
+
# primary_outcomes = []
|
112 |
+
# try:
|
113 |
+
# secondary_outcomes = item["protocolSection"]["outcomesModule"][
|
114 |
+
# "secondaryOutcomes"
|
115 |
+
# ]
|
116 |
+
# except:
|
117 |
+
# secondary_outcomes = []
|
118 |
+
# try:
|
119 |
+
# eligibility = item["protocolSection"]["eligibilityModule"]
|
120 |
+
# except:
|
121 |
+
# eligibility = {}
|
122 |
+
# filtered_item = {
|
123 |
+
# "organization_name": organization_name,
|
124 |
+
# "project_title": project_title,
|
125 |
+
# "status": status,
|
126 |
+
# "briefDescription": briefDescription,
|
127 |
+
# "detailedDescription": detailedDescription,
|
128 |
+
# "keywords": keywords,
|
129 |
+
# "interventions": interventions,
|
130 |
+
# "primary_outcomes": primary_outcomes,
|
131 |
+
# "secondary_outcomes": secondary_outcomes,
|
132 |
+
# "eligibility": eligibility,
|
133 |
+
# }
|
134 |
+
# filtered_data.append(filtered_item)
|
135 |
+
|
136 |
+
# return filtered_data
|
137 |
+
# # for ele in filtered_data:
|
138 |
+
# # print(ele)
|
139 |
+
|
140 |
+
def process_dictionaty_with_llm_to_generate_response(json_data):
|
141 |
+
# processed_data = process_json_data_for_llm(json_data)
|
142 |
+
# res = tagging_chain.invoke({"input": processed_data})
|
143 |
+
# return res
|
144 |
# Iterate through the dictionary and keep only the desired fields
|
145 |
filtered_data = []
|
146 |
+
for item in json_data:
|
147 |
try:
|
148 |
organization_name = item["protocolSection"]["identificationModule"][
|
149 |
"organization"
|
|
|
215 |
"eligibility": eligibility,
|
216 |
}
|
217 |
filtered_data.append(filtered_item)
|
218 |
+
|
219 |
+
return filtered_data
|
|
|
|
|
220 |
|
221 |
def get_short_summary_out_of_json_files(data_json):
|
222 |
+
# prompt_template = """ You are an expert clinician working on the analysis of reports of clinical trials.
|
223 |
+
|
224 |
+
# # Task
|
225 |
+
# You will be given a set of descriptions of clinical trials. Your job is to come up with a short summary (100-200 words) of the descriptions of the clinical trials. Your users are clinical researchers who are experts in medicine, so you should be technical and specific, including scientific terms. Always be faithful to the original information written in the reports.
|
226 |
|
227 |
+
# To write your summary, you will need to read the following examples, labeled as "Report 1", "Report 2", and so on. Your answer should be a single paragraph (100-200 words) that summarizes the general content of all the reports.
|
|
|
228 |
|
229 |
+
# {text}
|
230 |
|
231 |
+
# General summary:"""
|
232 |
|
233 |
+
prompt_template = """ You are an expert on clinicial trials and their analysis of their reports.
|
234 |
+
# Task
|
235 |
+
You will be given a text of descriptions of multiple clinical trials realed to similar diseases. Your job is to come up with a short and detailed summary of the descriptions of the clinical trials. Your users are clinical researchers, so you should be technical and specific, including scientific terms in the summary."""
|
236 |
|
237 |
prompt = PromptTemplate.from_template(prompt_template)
|
238 |
|
|
|
263 |
print(f"Combined descriptions: {combined_descriptions}")
|
264 |
|
265 |
result = stuff_chain.run(combined_descriptions)
|
266 |
+
print(f"Result_summarization: {result}")
|
267 |
|
268 |
return result
|
269 |
|
270 |
+
def tagging_insights_from_json(data_json):
|
271 |
+
processed_json= process_dictionaty_with_llm_to_generate_response(data_json)
|
272 |
+
|
273 |
+
tagging_prompt = ChatPromptTemplate.from_template(
|
274 |
+
"""
|
275 |
+
You are an expert on clinicial trials and analysis of their reports.
|
276 |
+
|
277 |
+
Extract the desired information from the following JSON data.
|
278 |
|
279 |
+
Only extract the properties mentioned in the 'Classification' function.
|
280 |
+
|
281 |
+
JSON data:
|
282 |
+
{input}
|
283 |
+
"""
|
284 |
+
)
|
285 |
class Classification(BaseModel):
|
286 |
+
# description: str = Field(
|
287 |
+
# description="text description grouping all the clinical trials using briefDescription and detailedDescription keys"
|
288 |
+
# )
|
289 |
project_title: list = Field(
|
290 |
+
description="Extract the project titles of all the clinical trials"
|
291 |
)
|
292 |
status: list = Field(
|
293 |
description="Extract the status of all the clinical trials"
|
|
|
305 |
# eligibility: list = Field(
|
306 |
# description="get the eligibilityCriteria grouping all the clinical trials"
|
307 |
# )
|
308 |
+
healthy_volunteers: list= Field(description= "determine whether the clinical trial requires healthy volunteers")
|
309 |
+
minimum_age: list = Field(
|
310 |
+
description="get the minimum age from each experiment"
|
311 |
+
)
|
312 |
+
maximum_age: list = Field(
|
313 |
+
description="get the maximum age from each experiment"
|
314 |
+
)
|
315 |
+
gender: list = Field(description="get the gender from each experiment")
|
316 |
|
317 |
def get_dict(self):
|
318 |
return {
|
|
|
319 |
"project_title": self.project_title,
|
320 |
"status": self.status,
|
321 |
+
# "keywords": self.keywords,
|
322 |
"interventions": self.interventions,
|
323 |
"primary_outcomes": self.primary_outcomes,
|
324 |
# "secondary_outcomes": self.secondary_outcomes,
|
325 |
+
# "eligibility": self.eligibility,
|
326 |
+
"healthy_volunteers": self.healthy_volunteers,
|
327 |
"minimum_age": self.minimum_age,
|
328 |
"maximum_age": self.maximum_age,
|
329 |
+
"gender": self.gender
|
330 |
}
|
331 |
|
332 |
# LLM
|
333 |
llm = ChatOpenAI(
|
334 |
temperature=0.6,
|
335 |
+
model="gpt-4-turbo",
|
336 |
openai_api_key=os.environ["OPENAI_API_KEY"],
|
337 |
).with_structured_output(Classification)
|
338 |
|
339 |
+
# stuff_chain = StuffDocumentsChain(llm_chain=llm, document_variable_name="text")
|
340 |
|
341 |
+
tagging_chain = tagging_prompt | llm
|
342 |
|
343 |
+
res= tagging_chain.invoke({"input": processed_json})
|
344 |
+
result_dict= res.get_dict()
|
345 |
+
print(f"Result_tagging: {result_dict}")
|
346 |
+
return result_dict
|
347 |
|
348 |
|
349 |
# clinical_record_info = get_clinical_records_by_ids(['NCT00841061', 'NCT03035123', 'NCT02272751', 'NCT03035123', 'NCT03055377'])
|
|
|
352 |
# with open('data.json', 'w') as f:
|
353 |
# json.dump(clinical_record_info, f, indent=4)
|
354 |
|
|
|
|
|
355 |
|
356 |
+
# tagging_chain = tagging_insights_from_json(json_data)
|
|
|
|
|
|