Spaces:
Sleeping
Sleeping
修复多个refer时f1不正常的bug
Browse files- quad_match_score.py +24 -23
quad_match_score.py
CHANGED
@@ -660,7 +660,7 @@ class QuadMatch(evaluate.Metric):
|
|
660 |
'''
|
661 |
f1_of_optimal_match, score_of_optimal_match = self.quad_f1_of_optimal_match(predictions, references,
|
662 |
quad_weights, **kwargs)
|
663 |
-
f1 = self.quad_f1_of_exact_match(
|
664 |
|
665 |
# 取1-cost为得分
|
666 |
return {'score of optimal match of weight ' + str(quad_weights): score_of_optimal_match,
|
@@ -668,30 +668,31 @@ class QuadMatch(evaluate.Metric):
|
|
668 |
'f1 of exact match': f1}
|
669 |
|
670 |
@staticmethod
|
671 |
-
def quad_f1_of_exact_match(
|
672 |
return_dict=False, **kwargs) -> Union[Dict[str, float], float]:
|
673 |
-
assert len(
|
674 |
correct, pred_num, true_num = 0, 0, 0
|
675 |
|
676 |
-
for pred,
|
677 |
pred = CommentUnitsSim.from_str(pred, **kwargs)
|
678 |
-
#
|
679 |
-
if isinstance(
|
680 |
-
|
681 |
-
|
682 |
-
|
683 |
-
|
684 |
-
|
685 |
-
|
686 |
-
|
687 |
-
|
688 |
-
|
689 |
-
|
690 |
-
|
691 |
-
|
692 |
-
|
693 |
-
|
694 |
-
|
|
|
695 |
|
696 |
# 以下结果保留4位小数
|
697 |
precision = round(correct / pred_num, 4) + 1e-8
|
@@ -733,9 +734,9 @@ class QuadMatch(evaluate.Metric):
|
|
733 |
|
734 |
# 如果true是多个正确答案,取最高分
|
735 |
cost_list = [matcher.match_units(pred, t) for t in refer]
|
736 |
-
#
|
737 |
# 计算每一对样本的cost,TP,FP,FN
|
738 |
-
cost_, TP_, FP_, FN_ = cost_list[np.
|
739 |
cost += cost_
|
740 |
TP += TP_
|
741 |
FP += FP_
|
|
|
660 |
'''
|
661 |
f1_of_optimal_match, score_of_optimal_match = self.quad_f1_of_optimal_match(predictions, references,
|
662 |
quad_weights, **kwargs)
|
663 |
+
f1 = self.quad_f1_of_exact_match(predictions=predictions, references=references, **kwargs)
|
664 |
|
665 |
# 取1-cost为得分
|
666 |
return {'score of optimal match of weight ' + str(quad_weights): score_of_optimal_match,
|
|
|
668 |
'f1 of exact match': f1}
|
669 |
|
670 |
@staticmethod
|
671 |
+
def quad_f1_of_exact_match(predictions: List[str], references: Union[List[str], List[List[str]]],
|
672 |
return_dict=False, **kwargs) -> Union[Dict[str, float], float]:
|
673 |
+
assert len(predictions) == len(references), "文本数量不一致"
|
674 |
correct, pred_num, true_num = 0, 0, 0
|
675 |
|
676 |
+
for pred, refer in zip(predictions, references):
|
677 |
pred = CommentUnitsSim.from_str(pred, **kwargs)
|
678 |
+
# refer转换为list
|
679 |
+
if isinstance(refer, str):
|
680 |
+
refer =[refer]
|
681 |
+
|
682 |
+
# refer转换为CommentUnitsSim
|
683 |
+
refer = [CommentUnitsSim.from_str(t, **kwargs) for t in refer]
|
684 |
+
|
685 |
+
# 如果refer是list,说明有多个正确答案,取最高分的那个
|
686 |
+
#计算每个refer的TP的个数
|
687 |
+
correct_list = [pred.compare_same(t) for t in refer]
|
688 |
+
#计算每个refer的f1
|
689 |
+
f1_list=[2 * correct_list[i] / (pred.num + refer[i].num) for i in range(len(refer))]
|
690 |
+
# 获取f1得分最高的索引
|
691 |
+
best_index = f1_list.index(max(f1_list))
|
692 |
+
pred_num += pred.num
|
693 |
+
true_num += refer[best_index].num
|
694 |
+
correct += correct_list[best_index]
|
695 |
+
|
696 |
|
697 |
# 以下结果保留4位小数
|
698 |
precision = round(correct / pred_num, 4) + 1e-8
|
|
|
734 |
|
735 |
# 如果true是多个正确答案,取最高分
|
736 |
cost_list = [matcher.match_units(pred, t) for t in refer]
|
737 |
+
# 获取cost最小的值的索引,按元组中第一个元素大小排序
|
738 |
# 计算每一对样本的cost,TP,FP,FN
|
739 |
+
cost_, TP_, FP_, FN_ = cost_list[np.argmin([c[0] for c in cost_list])]
|
740 |
cost += cost_
|
741 |
TP += TP_
|
742 |
FP += FP_
|