Spaces:
Sleeping
Sleeping
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""TODO: Add a description here.""" | |
import copy | |
import re | |
from typing import List, Dict, Union, Callable | |
import numpy as np | |
import datasets | |
import evaluate | |
from rouge_chinese import Rouge | |
from scipy.optimize import linear_sum_assignment | |
# TODO: Add BibTeX citation | |
_CITATION = """\ | |
@InProceedings{huggingface:module, | |
title = {quad match score}, | |
authors={huggingface, Inc.}, | |
year={2020} | |
} | |
""" | |
# TODO: Add description of the module here | |
_DESCRIPTION = """\ | |
evaluate sentiment quadruples. | |
评估生成模型的情感四元组 | |
""" | |
# TODO: Add description of the arguments of the module here | |
_KWARGS_DESCRIPTION = """ | |
Calculates how good are predictions given some references, using certain scores | |
Args: | |
predictions: list of predictions to score. Each predictions | |
should be a string with tokens separated by spaces. | |
references: list of reference for each prediction. Each | |
reference should be a string with tokens separated by spaces. | |
Returns: | |
score: sentiment quadruple match score | |
Examples: | |
Examples should be written in doctest format, and should illustrate how | |
to use the function. | |
>>> import evaluate | |
>>> module = evaluate.load("yuyijiong/quad_match_score") | |
>>> predictions=["food | good | food#taste | pos"] | |
>>> references=["food | good | food#taste | pos & service | bad | service#general | neg"] | |
>>> result=module.compute(predictions=predictions, references=references) | |
>>> print(result) | |
result={'ave match score of weight (1, 1, 1, 1)': 0.375, | |
'f1 score of exact match': 0.0, | |
'f1 score of optimal match of weight (1, 1, 1, 1)': 0.5} | |
""" | |
# 计算rougel的f1值 | |
def get_rougel_f1(text_pred_list: List[str], text_true_list: List[str]) -> float: | |
assert len(text_pred_list) == len(text_true_list), "文本数量不一致" | |
# 如果text_pred_list[0]为空字符串或空格,则返回0 | |
if not text_pred_list[0].strip(): | |
return 0 | |
rouge = Rouge() | |
# 判断text_true[0]是否有中文,有中文则要用空格分割 | |
if re.search(u"[\u4e00-\u9fa5]+", text_pred_list[0]): | |
text_pred_list = [' '.join(list(text_pred)) for text_pred in text_pred_list] | |
text_true_list = [' '.join(list(text_true)) for text_true in text_true_list] | |
rouge_l_f1 = rouge.get_scores(text_pred_list, text_true_list, avg=True)['rouge-l']['f'] | |
return rouge_l_f1 | |
# 记录四元组的函数 | |
class CommentUnitsSim: | |
def __init__(self, data: List[Dict[str, str]], data_source: any = None, abnormal=False, language=None): | |
self.data_source = data_source | |
self.abnormal = abnormal | |
data = copy.deepcopy(data) | |
# 如果字典有target,则改名为target_text | |
for quad_dict in data: | |
if 'target' in quad_dict: | |
quad_dict['target_text'] = quad_dict['target'] | |
del quad_dict['target'] | |
if 'opinion' in quad_dict: | |
quad_dict['opinion_text'] = quad_dict['opinion'] | |
del quad_dict['opinion'] | |
self.data = data | |
self.polarity_en2zh = {'positive': '积极', 'negative': '消极', 'neutral': '中性', 'pos': '积极', 'neg': '消极', | |
'neu': '中性', '积极': '积极', '消极': '消极', '中性': '中性'} | |
self.polarity_zh2en = {'积极': 'pos', '消极': 'neg', '中性': 'neu', 'pos': 'pos', 'neg': 'neg', 'neu': 'neu', | |
'positive': 'pos', 'negative': 'neg', 'neutral': 'neu'} | |
self.language = language if language is not None else 'zh' if self.check_zh() else 'en' | |
self.none_sign = 'null' | |
def num(self): | |
return len(self.data) | |
# 检查四元组中是否有中文 | |
def check_zh(self): | |
for quad_dict in self.data: | |
if re.search('[\u4e00-\u9fa5]', quad_dict['target_text']) or re.search('[\u4e00-\u9fa5]', | |
quad_dict['opinion_text']): | |
return True | |
return False | |
# 检测极性是否正确 | |
def check_polarity(self): | |
# 若有某个四元组的极性不是positive、negative、neutral,则返回False | |
for quad_dict in self.data: | |
if quad_dict['polarity'] not in ['positive', 'negative', 'neutral', 'pos', 'neg', 'neu', '积极', '消极', | |
'中性']: | |
self.abnormal = True | |
return False | |
# 将极性由英文转为中文 | |
def convert_polarity_en2zh(self): | |
for quad_dict in self.data: | |
quad_dict['polarity'] = self.polarity_en2zh[quad_dict['polarity']] | |
return self | |
# 将极性由中文转为英文 | |
def convert_polarity_zh2en(self): | |
for quad_dict in self.data: | |
quad_dict['polarity'] = self.polarity_zh2en[quad_dict['polarity']] | |
return self | |
# 检查是否有重复的四元组,若有则删除重复的 | |
def del_duplicate(self): | |
new_data = [] | |
for quad_dict in self.data: | |
if quad_dict not in new_data: | |
new_data.append(quad_dict) | |
self.data = new_data | |
return self | |
# 检查是否有target和opinion都为null的四元组,若有则返回True | |
def check_target_opinion_null(self): | |
for quad_dict in self.data: | |
if quad_dict['target_text'] == 'null' and quad_dict['opinion_text'] == 'null': | |
return True | |
return False | |
# 检查是否有target或opinion为null的四元组,若有则返回True | |
def check_any_null(self): | |
for quad_dict in self.data: | |
if quad_dict['target_text'] == 'null' or quad_dict['opinion_text'] == 'null': | |
return True | |
return False | |
def from_str(cls, quadruple_str: str, tuple_len: Union[int, list, str] = 4, format_code=0, sep_token1=' & ', | |
sep_token2=' | '): | |
data = [] | |
abnormal = False | |
# 确保分隔符后面一定是空格 | |
for i in range(len(quadruple_str) - 1): | |
if (quadruple_str[i] == sep_token1.strip() or quadruple_str[i] == sep_token2.strip()) and quadruple_str[ | |
i + 1] != ' ': | |
quadruple_str = quadruple_str[:i + 1] + ' ' + quadruple_str[i + 1:] | |
# 选择几元组,即创建列表索引,从四元组中抽出n元 | |
if isinstance(tuple_len, int): | |
tuple_index = list(range(tuple_len)) | |
elif isinstance(tuple_len, list): | |
tuple_index = tuple_len | |
elif isinstance(tuple_len, str): | |
# 例如将‘012’转换为[0,1,2] | |
tuple_index = [int(i) for i in tuple_len] | |
else: | |
raise Exception('tuple_len参数错误') | |
for quadruple in quadruple_str.split(sep_token1): | |
if format_code == 0: | |
# quadruple可能是target|opinion|aspect|polarity,也可能是target|opinion|aspect,也可能是target|opinion,若没有则为“None” | |
quadruple_split = [unit.strip() for unit in quadruple.split(sep_token2)] | |
if len(quadruple_split) > len(tuple_index): | |
print('quadruple格式错误,过多元素', quadruple_str) | |
abnormal = True | |
quadruple_split = quadruple_split[0:len(tuple_index)] # 过长则截断 | |
elif len(quadruple_split) < len(tuple_index): | |
print('quadruple格式错误,过少元素', quadruple_str) | |
abnormal = True | |
quadruple_split = ["None"] * ( | |
len(tuple_index) - len(quadruple_split)) + quadruple_split # 过短则补'None' | |
quadruple_keys = [["target_text", "opinion_text", "aspect", "polarity"][i] for i in tuple_index] | |
quadruple_dict = dict(zip(quadruple_keys, quadruple_split)) | |
q = {"target_text": 'None', "opinion_text": 'None', "aspect": 'None', "polarity": 'None'} | |
q.update(quadruple_dict) | |
# 检查极性是否合法 | |
if q['polarity'] not in ['pos', 'neg', 'neu', 'None', '积极', '消极', '中性']: | |
print('quadruple格式错误,极性格式不对', quadruple_str) | |
else: | |
raise Exception('answer_format参数错误') | |
data.append(q) | |
return CommentUnitsSim(data, quadruple_str, abnormal) | |
def from_list(cls, quadruple_list: List[List[str]], **kwargs): | |
data = [] | |
for quadruple in quadruple_list: | |
# #format_code='013'代表list只有四元组的第0、1、3个元素,需要扩充为4元组,空缺位置补上None | |
# if format_code=='013': | |
# quadruple.insert(2,None) | |
data.append( | |
{"target_text": quadruple[0], "opinion_text": quadruple[1], "aspect": quadruple[2], | |
"polarity": quadruple[3]}) | |
return CommentUnitsSim(data, quadruple_list, **kwargs) | |
def from_list_dict(cls, quadruple_list: List[dict], **kwargs): | |
for quad_dict in quadruple_list: | |
if 'target' in quad_dict: | |
quad_dict['target_text'] = quad_dict['target'] | |
del quad_dict['target'] | |
if 'opinion' in quad_dict: | |
quad_dict['opinion_text'] = quad_dict['opinion'] | |
del quad_dict['opinion'] | |
data = [] | |
for quadruple in quadruple_list: | |
# 如果quadruple缺少某个key,则补上None | |
q = {"target_text": 'None', "opinion_text": 'None', "aspect": 'None', "polarity": 'None'} | |
q.update(quadruple) | |
data.append(q) | |
return CommentUnitsSim(data, quadruple_list, **kwargs) | |
# 转化为list,即只保留字典的value | |
def to_list(self): | |
data = [] | |
for quad_dict in self.data: | |
data.append( | |
[quad_dict['target_text'], quad_dict['opinion_text'], quad_dict['aspect'], quad_dict['polarity']]) | |
return data | |
# 将data转换为n元组字符串 | |
def get_quadruple_str(self, format_code=0, tuple_len: Union[int, list, str] = 4, sep_token1=' & ', | |
sep_token2=' | '): | |
new_text_list = [] | |
# 选择几元组,即创建列表索引,从四元组中抽出n元 | |
if isinstance(tuple_len, int): | |
tuple_index = list(range(tuple_len)) | |
elif isinstance(tuple_len, list): | |
tuple_index = tuple_len | |
elif isinstance(tuple_len, str): | |
# 例如将‘012’转换为[0,1,2] | |
tuple_index = [int(i) for i in tuple_len] | |
else: | |
raise Exception('tuple_len参数错误') | |
try: | |
# 若语言为中文,则使用中文极性 | |
if self.language == 'zh': | |
self.convert_polarity_en2zh() | |
else: | |
self.convert_polarity_zh2en() | |
except: | |
print('语言参数错误', self.data) | |
print(self.language) | |
raise Exception('语言参数错误') | |
# 若tuple_index==[3],则返回综合情感极性 | |
if tuple_index == [3]: | |
return self.merge_polarity() | |
for quad_dict in self.data: | |
# 提取target_text,如果空列表则为'',如果列表长度大于1则为','.join(list) | |
target_text = quad_dict['target_text'] | |
# 提取opinion_text,如果空列表则为'',如果列表长度大于1则为','.join(list) | |
opinion_text = quad_dict['opinion_text'] | |
# 提取aspect | |
aspect = quad_dict['aspect'] | |
# 提取polarity | |
polarity = quad_dict['polarity'] | |
# 拼接,‘|’分割 | |
if format_code == 0: | |
# 根据tuple_len拼接 | |
new_text = sep_token2.join([[target_text, opinion_text, aspect, polarity][i] for i in tuple_index]) | |
else: | |
raise Exception('answer_format参数错误') | |
new_text_list.append(new_text) | |
# 如果tuple_index为[2,3],则需要去除new_text_list中重复的元素,不要改变顺序。因为可能有重复的方面 | |
if tuple_index == [2, 3]: | |
res = [] | |
for t in new_text_list: | |
if t not in res: | |
res.append(t) | |
new_text_list = res | |
# 如果tuple_index为[3],则只保留new_text_list的第一个元素。因为只有一个情感极性 | |
elif tuple_index == [3]: | |
new_text_list = new_text_list[:1] | |
if format_code == 0: | |
# 根据tuple_len拼接 | |
return sep_token1.join(new_text_list) | |
# 与另一个CommentUnits对象对比,检测有几个相同的四元组 | |
def compare_same(self, other) -> int: | |
count = 0 | |
for quad_dict in self.data: | |
if quad_dict in other.data: | |
count += 1 | |
return count | |
# 检查自身数据的四元组中target是否有重复 | |
def check_target_repeat(self): | |
target_list = [] | |
for quad_dict in self.data: | |
target_list.append(quad_dict['target_text']) | |
return len(target_list) != len(set(target_list)) | |
# 检查自身数据的四元组中opinion是否有重复 | |
def check_opinion_repeat(self): | |
opinion_list = [] | |
for quad_dict in self.data: | |
opinion_list.append(quad_dict['opinion_text']) | |
return len(opinion_list) != len(set(opinion_list)) | |
# 检查自身数据的四元组中aspect是否有重复 | |
def check_aspect_repeat(self): | |
aspect_list = [] | |
for quad_dict in self.data: | |
aspect_list.append(quad_dict['aspect']) | |
return len(aspect_list) != len(set(aspect_list)) | |
# 输出所有aspect的列表 | |
def get_aspect_list(self): | |
aspect_list = [] | |
for quad_dict in self.data: | |
aspect_list.append(quad_dict['aspect']) | |
return aspect_list | |
# 输出所有target的列表 | |
def get_target_list(self): | |
target_list = [] | |
for quad_dict in self.data: | |
target_list.append(quad_dict['target_text']) | |
return target_list | |
# 输出所有opinion的列表 | |
def get_opinion_list(self): | |
opinion_list = [] | |
for quad_dict in self.data: | |
opinion_list.append(quad_dict['opinion_text']) | |
return opinion_list | |
# 输出所有polarity的列表 | |
def get_polarity_list(self): | |
polarity_list = [] | |
for quad_dict in self.data: | |
polarity_list.append(quad_dict['polarity']) | |
return polarity_list | |
# 对所有polarity进行综合 | |
def merge_polarity(self): | |
polarity_list = self.get_polarity_list() | |
# 判断是英文还是中文 | |
if self.language == 'en': | |
if 'pos' in polarity_list and 'neg' in polarity_list: | |
return 'neu' | |
elif 'pos' in polarity_list: | |
return 'pos' | |
elif 'neg' in polarity_list: | |
return 'neg' | |
else: | |
return 'neu' | |
else: | |
if '积极' in polarity_list and '消极' in polarity_list: | |
return '中性' | |
elif '积极' in polarity_list: | |
return '积极' | |
elif '消极' in polarity_list: | |
return '消极' | |
else: | |
return '中性' | |
# 检测是否有不合法opinion | |
def check_opinion_in_comment(self, comment_text): | |
for quad_dict in self.data: | |
if quad_dict['opinion_text'] != '*' and (not quad_dict['opinion_text'] in comment_text): | |
return False | |
return True | |
# 检测是否有不合法target | |
def check_target_in_comment(self, comment_text): | |
for quad_dict in self.data: | |
if quad_dict['target_text'] != '*' and (not quad_dict['target_text'] in comment_text): | |
return False | |
return True | |
# 计算两个四元组的相似度 | |
def get_similarity(units1, units2: 'CommentUnitsSim'): | |
pass | |
# 对自身数据进行操作 | |
def apply(self, func: Callable, field: str): | |
for quad_dict in self.data: | |
quad_dict[field] = func(quad_dict[field]) | |
return self | |
# 四元组匹配函数 | |
class CommentUnitsMatch: | |
def __init__(self, target_weight=0.5, opinion_weight=0.5, aspect_weight=0.5, polarity_weight=0.5, one_match=True): | |
# 归一化权重 | |
weight_sum = target_weight + opinion_weight + aspect_weight + polarity_weight | |
self.target_weight = target_weight / weight_sum | |
self.opinion_weight = opinion_weight / weight_sum | |
self.aspect_weight = aspect_weight / weight_sum | |
self.polarity_weight = polarity_weight / weight_sum | |
# 是否一对一匹配 | |
self.one_match = one_match | |
# 特定feature置零 | |
def set_zero(self, feature: str = 'polarity'): | |
if feature == 'polarity': | |
self.polarity_weight = 0 | |
elif feature == 'aspect': | |
self.aspect_weight = 0 | |
elif 'opinion' in feature: | |
self.opinion_weight = 0 | |
elif 'target' in feature: | |
self.target_weight = 0 | |
else: | |
raise Exception('feature参数错误') | |
def re_normalize(self): | |
weight_sum = self.target_weight + self.opinion_weight + self.aspect_weight + self.polarity_weight | |
self.target_weight = self.target_weight / weight_sum | |
self.opinion_weight = self.opinion_weight / weight_sum | |
self.aspect_weight = self.aspect_weight / weight_sum | |
self.polarity_weight = self.polarity_weight / weight_sum | |
# 计算cost矩阵,完全匹配为0,不匹配为1 | |
def get_cost_matrix(self, units1: 'CommentUnitsSim', units2: 'CommentUnitsSim', feature: str = 'polarity'): | |
pass | |
# 检查此feature是否存在,不存在则返回全0矩阵 | |
if units1.data[0].get(feature) is None or units2.data[0].get(feature) is None \ | |
or units1.data[0].get(feature) == 'None' or units2.data[0].get(feature) == 'None': | |
cost_matrix = np.zeros((len(units1.data), len(units2.data))) | |
# 对应feature的weight也为0 | |
self.set_zero(feature) | |
# 并再次归一化 | |
self.re_normalize() | |
return cost_matrix | |
# 检查两个四元组的极性是否相同,生成cost矩阵,用于匈牙利算法。不相同则cost为1,相同则cost为0 | |
cost_matrix = [] | |
for quad_dict1 in units1.data: | |
cost_list = [] | |
for quad_dict2 in units2.data: | |
if quad_dict1[feature] == quad_dict2[feature]: | |
cost_list.append(0) | |
else: | |
cost_list.append(1) | |
cost_matrix.append(cost_list) | |
# cost矩阵转换为numpy数组,大小为(len(units1.data),len(units2.data)) | |
cost_matrix = np.array(cost_matrix) | |
return cost_matrix | |
# 计算cost矩阵,使用rougel指标 | |
def get_cost_matrix_rouge(self, units1: 'CommentUnitsSim', units2: 'CommentUnitsSim', feature: str = 'target_text'): | |
# 检查此feature是否存在,不存在则返回全0矩阵 | |
if units1.data[0].get(feature) is None or units2.data[0].get(feature) is None \ | |
or units1.data[0].get(feature) == 'None' or units2.data[0].get(feature) == 'None': | |
cost_matrix = np.zeros((len(units1.data), len(units2.data))) | |
# 对应feature的weight也为0 | |
self.set_zero(feature) | |
# 并再次归一化 | |
self.re_normalize() | |
return cost_matrix | |
# 检查两个四元组的极性是否相同,生成cost矩阵,用于匈牙利算法。相同则cost为0,不相同则cost为1-rougel | |
cost_matrix = [] | |
for quad_dict1 in units1.data: | |
cost_list = [] | |
for quad_dict2 in units2.data: | |
if quad_dict1[feature] == quad_dict2[feature]: | |
cost_list.append(0) | |
else: | |
cost_list.append(1 - get_rougel_f1([quad_dict1[feature]], [quad_dict2[feature]])) | |
cost_matrix.append(cost_list) | |
# cost矩阵转换为numpy数组,大小为(len(units1.data),len(units2.data)) | |
cost_matrix = np.array(cost_matrix) | |
return cost_matrix | |
# 匹配四元组并计算cost | |
def match_units(self, units1: 'CommentUnitsSim', units2: 'CommentUnitsSim') -> tuple: | |
# 计算极性的cost矩阵,矩阵元素在0-1之间 | |
cost_matrix_polarity = self.get_cost_matrix(units1, units2, feature='polarity') | |
# 计算aspect的cost矩阵 | |
cost_matrix_aspect = self.get_cost_matrix(units1, units2, feature='aspect') | |
# 计算target的cost矩阵 | |
cost_matrix_target = self.get_cost_matrix_rouge(units1, units2, feature='target_text') | |
# 计算opinion的cost矩阵 | |
cost_matrix_opinion = self.get_cost_matrix_rouge(units1, units2, feature='opinion_text') | |
# 计算总的cost矩阵,矩阵元素在0-1之间。矩阵的行数为units1即pred的数量,列数为units2即true的数量 | |
cost_matrix = self.target_weight * cost_matrix_target + self.opinion_weight * cost_matrix_opinion + \ | |
self.aspect_weight * cost_matrix_aspect + self.polarity_weight * cost_matrix_polarity | |
score_matrix = 1 - cost_matrix | |
cost = 0 | |
# 使用匈牙利算法进行匹配 | |
if self.one_match: | |
# 只允许一对一的匹配,这种情况下row_ind和col_ind的长度一定相等且等于units1和units2的数量中的较小值 | |
row_ind, col_ind = linear_sum_assignment(cost_matrix) | |
else: | |
# 允许一对多的匹配。这种情况下每个四元组都一定匹配上,这种情况下row_ind和col_ind的长度一定相等且等于units1和units2的数量中的较大值 | |
if units1.num > units2.num: | |
row_ind = np.arange(units1.num) | |
col_ind = np.argmin(cost_matrix, axis=1) | |
else: | |
row_ind = np.argmin(cost_matrix, axis=0) | |
col_ind = np.arange(units2.num) | |
# 计算这种匹配的cost | |
for i in range(len(row_ind)): | |
cost += cost_matrix[row_ind[i]][col_ind[i]] | |
# 计算这种匹配下的TP\FP\FN | |
TP = 0 | |
for i in range(len(row_ind)): | |
TP += score_matrix[row_ind[i]][col_ind[i]] | |
# len(row_ind)为pred的数量,TP为匹配上的数量 | |
FP = units1.num - TP | |
FN = units2.num - TP | |
# 如果一对一匹配,会有匹配不上的四元组,这些四元组cost为1 | |
max_units_num = max(units1.num, units2.num) | |
if self.one_match: | |
cost += (max_units_num - len(row_ind)) | |
# 对cost进行归一化,使其在0-1之间 | |
cost_per_quadruple = cost / max_units_num | |
if cost_per_quadruple > 1 or cost_per_quadruple < 0: | |
print('cost错误', cost_per_quadruple, 'pred:', units1.data, 'true:', units2.data) | |
print(self.target_weight, self.opinion_weight, self.aspect_weight, self.polarity_weight) | |
# 返回的cost在0-1之间 | |
return cost_per_quadruple, TP, FP, FN | |
class QuadMatch(evaluate.Metric): | |
"""TODO: Short description of my evaluation module.""" | |
def _info(self): | |
# TODO: Specifies the evaluate.EvaluationModuleInfo object | |
return evaluate.MetricInfo( | |
# This is the description that will appear on the modules page. | |
module_type="metric", | |
description=_DESCRIPTION, | |
citation=_CITATION, | |
inputs_description=_KWARGS_DESCRIPTION, | |
# This defines the format of each prediction and reference | |
features=[ | |
datasets.Features( | |
{ | |
"predictions": datasets.Value("string", id="sequence"), | |
"references": datasets.Sequence(datasets.Value("string", id="sequence")), | |
} | |
), | |
datasets.Features( | |
{ | |
"predictions": datasets.Value("string", id="sequence"), | |
"references": datasets.Value("string", id="sequence"), | |
} | |
), | |
], | |
# Homepage of the module for documentation | |
homepage="http://module.homepage", | |
# Additional links to the codebase or references | |
codebase_urls=["http://github.com/path/to/codebase/of/new_module"], | |
reference_urls=["http://path.to.reference.url/new_module"] | |
) | |
def _download_and_prepare(self, dl_manager): | |
"""Optional: download external resources useful to compute the scores""" | |
# TODO: Download external resources if needed | |
pass | |
def _compute(self, | |
predictions: List[str], | |
references: Union[List[str], List[List[str]]], | |
quad_weights: tuple = (1, 1, 1, 1), | |
**kwargs) -> dict: | |
''' | |
:param predictions: list of predictions of sentiment quads | |
:param references: list of references of sentiment quads | |
:param quad_weights: weight of target,opinion,aspect,polarity for cost compute | |
:param kwargs: | |
:param tuple_len: indicate the format of the quad, see the following mapping | |
:param sep_token1: the token to seperate quads | |
:param sep_token2: the token to seperate units of one quad | |
:return:average matching score | |
#mapping | |
id2prompt={'0123':"quadruples (target | opinion | aspect | polarity)", | |
'':"quadruples (target | opinion | aspect | polarity)", | |
'01':'pairs (target | opinion)', | |
'012':'triples (target | opinion | aspect)', | |
'013':'triples (target | opinion | polarity)', | |
'023':'triples (target | aspect | polarity)', | |
'23':'pairs (aspect | polarity)', | |
'03':'pairs (target | polarity)', | |
'13':'pairs (opinion | polarity)', | |
'3':'single (polarity)'} | |
#中文版映射 | |
id2prompt_zh={'0123': "四元组(对象 | 观点 | 方面 | 极性)", | |
'':"四元组(对象 | 观点 | 方面 | 极性)", | |
'01':'二元组(对象 | 观点)', | |
'012':'三元组(对象 | 观点 | 方面)', | |
'013':'三元组(对象 | 观点 | 极性)', | |
'023':'三元组(对象 | 方面 | 极性)', | |
'23':'二元组(方面 | 极性)', | |
'03':'二元组(对象 | 极性)', | |
'13':'二元组(观点 | 极性)', | |
'3':'单元素(极性)'} | |
''' | |
f1_of_optimal_match, score_of_optimal_match = self.quad_f1_of_optimal_match(predictions, references, | |
quad_weights, **kwargs) | |
f1 = self.quad_f1_of_exact_match(predictions=predictions, references=references, **kwargs) | |
# 取1-cost为得分 | |
return {'score of optimal match of weight ' + str(quad_weights): score_of_optimal_match, | |
'f1 of optimal match of weight ' + str(quad_weights): f1_of_optimal_match, | |
'f1 of exact match': f1} | |
def quad_f1_of_exact_match(predictions: List[str], references: Union[List[str], List[List[str]]], | |
return_dict=False, **kwargs) -> Union[Dict[str, float], float]: | |
assert len(predictions) == len(references), "文本数量不一致" | |
correct, pred_num, true_num = 0, 0, 0 | |
for pred, refer in zip(predictions, references): | |
pred = CommentUnitsSim.from_str(pred, **kwargs) | |
# refer转换为list | |
if isinstance(refer, str): | |
refer =[refer] | |
# refer转换为CommentUnitsSim | |
refer = [CommentUnitsSim.from_str(t, **kwargs) for t in refer] | |
# 如果refer是list,说明有多个正确答案,取最高分的那个 | |
#计算每个refer的TP的个数 | |
correct_list = [pred.compare_same(t) for t in refer] | |
#计算每个refer的f1 | |
f1_list=[2 * correct_list[i] / (pred.num + refer[i].num) for i in range(len(refer))] | |
# 获取f1得分最高的索引 | |
best_index = f1_list.index(max(f1_list)) | |
pred_num += pred.num | |
true_num += refer[best_index].num | |
correct += correct_list[best_index] | |
# 以下结果保留4位小数 | |
precision = round(correct / pred_num, 4) + 1e-8 | |
recall = round(correct / true_num, 4) + 1e-8 | |
f1 = round(2 * precision * recall / (precision + recall), 4) | |
if return_dict: | |
return {"precision": precision, "recall": recall, "f1": f1} | |
else: | |
return f1 | |
# 计算最优匹配f1 | |
def quad_f1_of_optimal_match( | |
predictions: List[str], | |
references: Union[List[str], List[List[str]]], | |
quad_weights: tuple = (1, 1, 1, 1), | |
one_match=True, | |
**kwargs): | |
assert len(predictions) == len(references) | |
if isinstance(predictions, str): | |
predictions = [predictions] | |
references = [references] | |
cost = 0 | |
TP, FP, FN = 0, 0, 0 | |
matcher = CommentUnitsMatch(*quad_weights, one_match=one_match) | |
for pred, refer in zip(predictions, references): | |
pred = CommentUnitsSim.from_str(pred, **kwargs) | |
# 将refer转换为list形式 | |
if isinstance(refer, str): | |
refer = [refer] | |
# 将refer中的每个元素转换为CommentUnitsSim | |
refer = [CommentUnitsSim.from_str(t, **kwargs) for t in refer] | |
# 如果true是多个正确答案,取最高分 | |
cost_list = [matcher.match_units(pred, t) for t in refer] | |
# 获取cost最小的值的索引,按元组中第一个元素大小排序 | |
# 计算每一对样本的cost,TP,FP,FN | |
cost_, TP_, FP_, FN_ = cost_list[np.argmin([c[0] for c in cost_list])] | |
cost += cost_ | |
TP += TP_ | |
FP += FP_ | |
FN += FN_ | |
# 平均cost | |
cost = cost / len(predictions) | |
# 由TP\FP\FN计算最优匹配F1 | |
precision_match = TP / (TP + FP) | |
recall_match = TP / (TP + FN) | |
f1_match = 2 * precision_match * recall_match / (precision_match + recall_match) | |
return f1_match, 1 - cost | |