File size: 21,958 Bytes
256a159 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 |
import json
import os.path as osp
import sys
from datasets import Dataset
from sklearn.metrics import classification_report
from opencompass.openicl.icl_evaluator import BaseEvaluator
from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET
from ..base import BaseDataset
from .math_equivalence import is_equiv
from .post_process import parse_math_answer, parse_qa_multiple_answer
import evaluate
from nltk.translate.bleu_score import sentence_bleu
# # from bert_score import score
import re
from transformers import BasicTokenizer
from rouge_chinese import Rouge
basic_tokenizer = BasicTokenizer(tokenize_chinese_chars=True)
@LOAD_DATASET.register_module()
class MedBenchDataset(BaseDataset):
@staticmethod
def load(path: str, name: str, setting_name: str):
from .dataset_loader import load_dataset, load_dataset_as_result_schema
assert setting_name in 'zero-shot', 'only support zero-shot setting'
dataset_wo_label = load_dataset(name, setting_name, path)
dataset_with_label = load_dataset_as_result_schema(name, path)
dataset = []
for d1, d2 in zip(dataset_wo_label, dataset_with_label):
dataset.append({
'id': d2.index,
'problem_input': d1['context'],
'label': d2.label,
})
dataset = Dataset.from_list(dataset)
return dataset
@ICL_EVALUATORS.register_module()
class MedBenchEvaluator(BaseEvaluator):
def score(self, predictions, references):
# predictions: [[]]
# references: [[]]
predictions = [parse_qa_multiple_answer(pred) for pred in predictions]
details = []
cnt = 0
for pred, ref in zip(predictions, references):
detail = {'pred': pred, 'answer': ref, 'correct': False}
if is_equiv(pred, ref):
cnt += 1
detail['correct'] = True
details.append(detail)
score = cnt / len(predictions) * 100
return {'Accuracy': score, 'details': details}
def process_generated_results_CMeEE(pred_file):
# 实体每类占一行,每行格式为 "[类型名称]实体:实体名称1,实体名称2,实体名称3\n"
# 多个实体,用 ,符号分割
structured_output = []
answer_choices = ['药物', '设备', '医院科室', '微生物类', '身体部位', '医疗操作', '医学检验项目', '症状', '疾病']
for pred in pred_file:
list_entities = []
for choice in answer_choices:
for piece in re.split('\n', pred):
if piece.startswith(f"{choice}"):
mentions = re.split(r"[,,]", piece.replace(f"{choice}:", "").replace(f"{choice}:", ""))
for ment in mentions:
list_entities.append({'type':choice, 'entity':ment})
structured_output.append(list_entities)
return structured_output
def process_generated_results_EMR(pred_file):
structured_output = []
regex = r"^(主诉|现病史|既往史|个人史|婚育史|家族史)[::]([\s\S]+)$"
for prediction in pred_file:
entities: dict = {}
if "\n\n" in prediction:
blocks = prediction.split("\n\n")
else:
blocks = prediction.splitlines()
for line in blocks:
if match := re.match(regex, line.strip()):
type_ = match[1]
mention = match[2].strip()
entities[type_] = mention
structured_output.append(entities)
return structured_output
def process_generated_results_CMeIE(pred_file):
structured_output = []
for line in pred_file:
gen_output = line
answer_choices = "相关(导致)、鉴别诊断、遗传因素、发病性别倾向、相关(症状)、手术治疗、预防、辅助检查、筛查、阶段、临床表现、风险评估因素、同义词、发病年龄、预后生存率、病史、传播途径、治疗后症状、药物治疗、辅助治疗、化疗、死亡率、放射治疗、病因、组织学检查、内窥镜检查、多发群体、并发症、实验室检查、就诊科室、病理生理、高危因素、发病率、多发地区、病理分型、影像学检查、转移部位、发病部位、相关(转化)、外侵部位、预后状况、发病机制、多发季节"
re_choices = "|".join(re.escape(choice) for choice in answer_choices.split('、'))
regex = (
rf'关系[::]["“]({re_choices})["”][,,]'
r'头实体[::]["“]([^"”]+)["”][,,]尾实体[::]["“]([^"”]+)["”]'
)
list_spos = []
list_answer_strs = gen_output.split("\n")
for line in list_answer_strs:
for item in re.finditer(regex, line):
print(item)
for match in re.finditer(regex, line):
list_spos.append({"predicate": match[1], "subject": match[2], "object": match[3]})
structured_output.append(list_spos)
return structured_output
def process_generated_results_CDN(pred_file):
structured_output = []
answer_choices = json.load(open('./opencompass/datasets/medbench/entity_list.jsonl', 'r'))
for line in pred_file:
gen_output = line
answer_str = gen_output.split("\n")[-1]
answers = answer_str.split(",")
answers = [w.strip() for w in answers if len(w.strip()) > 0]
answers = [w for w in answers if w in answer_choices]
answers = list(set(answers))
answers = [
{
"entity": w,
"type": "normalization",
}
for w in answers
]
structured_output.append(answers)
return structured_output
def process_generated_results_CDEE(pred_file):
structured_output = []
for prediction in pred_file:
events: list[dict] = []
for line in prediction.splitlines():
if "主体词" in line:
line = line.rstrip("。")
kvs = line.split(";")
kv_dict = dict(kv.split(":", maxsplit=1) for kv in kvs if ":" in kv)
events.append({
"主体词": kv_dict.get("主体词", ""),
"发生状态": (
v
if (v := kv_dict.get("发生状态", "不确定")) in ("不确定", "否定")
else ""
),
"描述词": (
v.split(",") if (v := kv_dict.get("描述词", "空")) != "空" else []
),
"解剖部位": (
v.split(",")
if (v := kv_dict.get("解剖部位", "空")) != "空"
else []
),
})
structured_output.append(events)
return structured_output
def process_generated_results_CTC(pred_file):
structured_output = []
for line in pred_file:
gen_output = line
# 答案格式:直接回答分类标签
answer_str = gen_output.strip()
structured_output.append(answer_str)
return structured_output
def process_generated_results_doc_parsing(pred_file):
float_field_regex = r"(体温|脉搏|心率|收缩压|舒张压|呼吸)[^\d]*(\d+(?:\.\d+)?)"
output = []
for prediction in pred_file:
entities = {
"体温": "未扪及",
"脉搏": "未扪及",
"心率": "未扪及",
"收缩压": "未扪及",
"舒张压": "未扪及",
"呼吸": "未扪及",
"是否上腹部深压痛": None,
"是否腹部反跳痛": None,
"上腹部肿块": None,
}
for sentence in re.split("[,|。|\n]", prediction):
for match in re.finditer(float_field_regex, prediction):
entities[match[1]] = match[2]
if "上腹部深压痛" in sentence:
if re.search("是(?!否)|(?:^|[^不])存在|有", sentence):
entities["是否上腹部深压痛"] = "是"
else:
entities["是否上腹部深压痛"] = "否"
elif "腹部反跳痛" in sentence:
if re.search("是(?!否)|(?:^|[^不])存在|有", sentence):
entities["是否腹部反跳痛"] = "是"
else:
entities["是否腹部反跳痛"] = "否"
elif "上腹部肿块" in sentence:
if re.search("是(?!否)|(?:^|[^不])存在|有", sentence):
entities["上腹部肿块"] = "扪及"
else:
entities["上腹部肿块"] = "未扪及"
result = [
{
"type": "体温(℃)",
"entity": entities["体温"],
},
{
"type": "脉搏(次/分)",
"entity": entities["脉搏"],
},
{
"type": "心率(次/分)",
"entity": entities["心率"],
},
{
"type": "收缩压(mmHg)",
"entity": entities["收缩压"],
},
{
"type": "舒张压(mmHg)",
"entity": entities["舒张压"],
},
{
"type": "呼吸(次/分)",
"entity": entities["呼吸"],
},
]
if entities["是否上腹部深压痛"]:
result.append({
"type": "是否上腹部深压痛",
"entity": entities["是否上腹部深压痛"],
})
if entities["是否腹部反跳痛"]:
result.append({
"type": "是否腹部反跳痛",
"entity": entities["是否腹部反跳痛"],
})
if entities["上腹部肿块"]:
result.append({
"type": "上腹部肿块",
"entity": entities["上腹部肿块"],
})
output.append(result)
return output
def process_generated_results_mrg(pred_file):
structured_output = []
regex = r"^(主诉|现病史|辅助检查|既往史|诊断|建议)[::]([\s\S]+)$"
for prediction in pred_file:
entities = {}
if "\n\n" in prediction:
blocks = prediction.split("\n\n")
else:
blocks = prediction.splitlines()
for line in blocks:
if match := re.match(regex, line.strip()):
type_ = match[1]
mention = match[2].strip()
entities[type_] = mention
structured_output.append(entities)
return structured_output
def calc_info_extract_task_scores(list_structured_predict, list_structured_golden):
assert len(list_structured_golden) == len(list_structured_predict)
tp = 0
fp = 0
fn = 0
for samp_golden, samp_predict in zip(list_structured_golden, list_structured_predict):
# samp_golden: [[{}]]
answer_golden = samp_golden
answer_predict = samp_predict
# assert isinstance(answer_golden, list)
# assert isinstance(answer_predict, list), "sample format is wrong!"
set_golden = set()
for inst in answer_golden:
assert isinstance(inst, dict)
keys = sorted(list(inst.keys()))
inst = tuple([json.dumps(inst[w], ensure_ascii=False) for w in keys ])
# inst = list(inst.items())
# inst.sort()
# inst = tuple(inst)
set_golden.add(inst)
set_predict = set()
for inst in answer_predict:
assert isinstance(inst, dict)
keys = sorted(list(inst.keys()))
inst = tuple([json.dumps(inst[w], ensure_ascii=False) for w in keys])
set_predict.add(inst)
tp += len(set_golden.intersection(set_predict))
fp += len(set_predict.difference(set_golden))
fn += len(set_golden.difference(set_predict))
if tp:
precision = tp / (tp + fp)
recall = tp / (tp + fn)
f1 = 2 * precision * recall / (precision + recall)
else:
precision, recall, f1 = 0, 0, 0
return precision, recall, f1
def calc_cls_task_scores(list_structured_golden,
list_structured_predict,
list_labels=None,
return_macro=False,
):
# types = list_labels
# scores = {c: {"tp": 0, "fp": 0, "fn": 0, "tn": 0} for c in list_labels + ["ALL"]}
predictions = []
ground_truths = []
# Count GT relations and Predicted relations
assert len(list_structured_golden) == len(list_structured_predict)
n_sents = len(list_structured_golden)
# Count TP, FP and FN per type
for pred_samp, gt_samp in zip(list_structured_predict, list_structured_golden):
pred_label = pred_samp
gt_label = gt_samp
# assert gt_label != ""
if gt_label == "":
get_label = list_labels[0]
if pred_label == "":
pred_label = list_labels[0]
predictions.append(pred_label)
ground_truths.append(gt_label)
# metric
cls_report = classification_report(
ground_truths, predictions,
output_dict=True,
zero_division=0,
)
if return_macro:
return cls_report["macro avg"]["precision"], \
cls_report["macro avg"]["recall"], \
cls_report["macro avg"]["f1-score"]
else:
return cls_report["weighted avg"]["precision"], \
cls_report["weighted avg"]["recall"], \
cls_report["weighted avg"]["f1-score"]
def calc_nlg_task_scores(list_structured_golden, list_structured_predict):
assert len(list_structured_golden) == len(list_structured_predict)
scores = []
predictions = []
references = []
details = []
for samp_golden, samp_predict in zip(list_structured_golden, list_structured_predict):
answer_golden = samp_golden
answer_predict = samp_predict
if not (answer_predict and answer_golden):
continue
# basic tokenizer: 拆分中文字,保留英文单词
answer_predict = basic_tokenizer.tokenize(answer_predict)
answer_golden = basic_tokenizer.tokenize(answer_golden)
answer_predict = " ".join(answer_predict).strip()
answer_golden = " ".join(answer_golden).strip()
if answer_golden.strip() == "":
answer_golden = "无 。"
if answer_predict.strip() == "":
answer_predict = "无 。"
predictions.append(answer_predict)
references.append(answer_golden)
details.append({'pred':answer_predict, 'answer':answer_golden, 'correct':False})
rouge = Rouge()
# bleu = evaluate.load('sacrebleu')
scores = rouge.get_scores(predictions, references, avg=True)
# scores_bleu = bleu.compute(predictions=predictions, references=references)
rouge1 = scores["rouge-1"]["f"]
rouge2 = scores["rouge-2"]["f"]
rougeL = scores["rouge-l"]["f"]
# bleu = sentence_bleu(references, predictions)
# bert_score = []
# for id in range(len(predictions)):
# P, R, F1 = score([predictions[i]], [references[i]], model_type='bert-base-chinese', lang="zh", verbose=True)
# bert_score.append(F1)
# bert_score = float(sum(bert_score)) / float(len(bert_score))
# return rougeL, bleu, bert_score
return {'RougeL': rougeL, 'details':details}
def calc_scores_f1(dict_gt, dict_pred):
details = []
for gt, pred in zip(dict_gt, dict_pred):
details.append({'pred':pred, 'answer':gt, 'correct':None})
precision, recall, f1 = calc_info_extract_task_scores(dict_gt, dict_pred)
return {'F1':f1, 'details':details}
def calc_scores_ctc(dict_gt, dict_pred):
details = []
for gt, pred in zip(dict_gt, dict_pred):
details.append({'pred':pred, 'answer':gt, 'correct':None})
gts = dict_gt
preds = dict_pred
precision, recall, f1 = calc_cls_task_scores(
gts,
preds,
list_labels=['非上述类型', '疾病', '症状(患者感受)',
'体征(医生检测)', '怀孕相关', '肿瘤进展',
'疾病分期', '过敏耐受', '器官组织状态',
'预期寿命', '口腔相关', '药物',
'治疗或手术', '设备', '护理',
'诊断', '实验室检查', '风险评估',
'受体状态', '年龄', '特殊病人特征',
'读写能力', '性别', '教育情况',
'居住情况', '种族', '知情同意',
'参与其它试验', '研究者决定', '能力',
'伦理审查', '依存性', '成瘾行为',
'睡眠', '锻炼', '饮食', '酒精使用',
'性取向', '吸烟状况', '献血',
'病例来源', '残疾群体', '健康群体',
'数据可及性', "含有多个类别"],
return_macro=True,
)
return {'Macro-F1':f1, 'details':details}
def calc_scores_nlg(dict_gt, dict_pred):
# scores = {}
scores = {'score':0, 'details':[]}
success_flag = 1
gts = dict_gt
preds = dict_pred
# if not len(gts) == len(preds):
# success_flag = 0
# try:
return calc_nlg_task_scores(gts, preds)
@ICL_EVALUATORS.register_module()
class MedBenchEvaluator_CMeEE(BaseEvaluator):
def score(self, predictions, references):
predictions = process_generated_results_CMeEE(predictions)
return calc_scores_f1(predictions, references)
@ICL_EVALUATORS.register_module()
class MedBenchEvaluator_DBMHG(BaseEvaluator):
def score(self, predictions, references):
predictions = process_generated_results_EMR(predictions)
return calc_scores_f1(predictions, references)
@ICL_EVALUATORS.register_module()
class MedBenchEvaluator_IMCS_V2_MRG(BaseEvaluator):
def score(self, predictions, references):
# predictions = process_generated_results_mrg(predictions)
references_revise = []
for item in references:
temp_ref = ''
for sub_item in item:
temp_ref += sub_item['type'] + ':' + sub_item['entity'] + '\n'
references_revise.append(temp_ref)
return calc_nlg_task_scores(references_revise, predictions)
@ICL_EVALUATORS.register_module()
class MedBenchEvaluator_CMeIE(BaseEvaluator):
def score(self, predictions, references):
predictions = process_generated_results_CMeIE(predictions)
return calc_scores_f1(predictions, references)
@ICL_EVALUATORS.register_module()
class MedBenchEvaluator_CHIP_CDEE(BaseEvaluator):
def score(self, predictions, references):
predictions = process_generated_results_CDEE(predictions)
return calc_scores_f1(predictions, references)
@ICL_EVALUATORS.register_module()
class MedBenchEvaluator_CHIP_CDN(BaseEvaluator):
def score(self, predictions, references):
predictions = process_generated_results_CDN(predictions)
return calc_scores_f1(predictions, references)
@ICL_EVALUATORS.register_module()
class MedBenchEvaluator_CHIP_CTC(BaseEvaluator):
def score(self, predictions, references):
predictions = process_generated_results_CTC(predictions)
return calc_scores_ctc(predictions, references)
@ICL_EVALUATORS.register_module()
class MedBenchEvaluator_Doc_parsing(BaseEvaluator):
def score(self, predictions, references):
# predictions = process_generated_results_doc_parsing(predictions)
references_revise = []
for item in references:
temp_ref = ''
for sub_item in item:
temp_ref += sub_item['type'] + ':' + sub_item['entity'] + '\n'
references_revise.append(temp_ref)
return calc_nlg_task_scores(references_revise, predictions)
@ICL_EVALUATORS.register_module()
class MedBenchEvaluator_NLG(BaseEvaluator):
def score(self, predictions, references):
# predictions = process_generated_results_med(predictions)
return calc_scores_nlg(predictions, references)
@ICL_EVALUATORS.register_module()
class MedBenchEvaluator_Cloze(BaseEvaluator):
def score(self, predictions, references):
# predictions: [[]]
# references: [[]]
# predictions = [parse_qa_multiple_answer(pred) for pred in predictions]
details = []
cnt = 0
for pred, ref in zip(predictions, references):
detail = {'pred':pred, 'answer':ref, 'correct':False}
if sum([item in pred for item in ref]) == len(ref):
cnt += 1
detail['correct'] = True
details.append(detail)
score = cnt / len(predictions) * 100
return {'Accuracy': score, 'details': details}
@ICL_EVALUATORS.register_module()
class MedBenchEvaluator_TF(BaseEvaluator):
def score(self, predictions, references):
# predictions: [[]]
# references: [[]]
# predictions = [parse_qa_multiple_answer(pred) for pred in predictions]
details = []
cnt = 0
for pred, ref in zip(predictions, references):
if '不' in pred or '否' in pred:
cur_pred = '不可以'
else:
cur_pred = '可以'
detail = {'pred':cur_pred, 'answer':ref, 'correct':False}
if cur_pred == ref:
cnt += 1
detail['correct'] = True
details.append(detail)
score = cnt / len(predictions) * 100
return {'Accuracy': score, 'details': details} |