File size: 31,525 Bytes
83cb3c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13f64cc
83cb3c8
 
 
 
 
 
 
 
 
 
13f64cc
83cb3c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13f64cc
 
 
 
 
 
 
 
 
83cb3c8
 
 
 
 
 
13f64cc
83cb3c8
 
 
 
 
 
 
 
 
 
 
 
 
13f64cc
83cb3c8
 
13f64cc
 
 
 
83cb3c8
 
 
 
 
 
 
 
 
 
13f64cc
 
 
 
83cb3c8
13f64cc
 
83cb3c8
 
 
 
 
13f64cc
83cb3c8
 
13f64cc
 
83cb3c8
 
 
 
 
13f64cc
83cb3c8
13f64cc
 
 
83cb3c8
 
13f64cc
83cb3c8
 
13f64cc
83cb3c8
 
13f64cc
83cb3c8
 
13f64cc
83cb3c8
 
13f64cc
83cb3c8
13f64cc
83cb3c8
 
 
13f64cc
83cb3c8
 
13f64cc
83cb3c8
 
13f64cc
83cb3c8
 
 
13f64cc
83cb3c8
 
13f64cc
83cb3c8
 
 
 
13f64cc
 
83cb3c8
13f64cc
 
 
 
 
83cb3c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13f64cc
 
83cb3c8
13f64cc
 
 
83cb3c8
13f64cc
 
 
83cb3c8
13f64cc
 
83cb3c8
 
 
13f64cc
 
83cb3c8
 
 
 
 
 
 
13f64cc
83cb3c8
 
13f64cc
83cb3c8
 
 
 
 
 
 
 
 
 
13f64cc
83cb3c8
 
13f64cc
83cb3c8
 
 
 
 
 
 
 
 
 
13f64cc
 
83cb3c8
 
 
13f64cc
83cb3c8
13f64cc
83cb3c8
 
 
13f64cc
 
83cb3c8
 
 
13f64cc
 
83cb3c8
 
 
 
 
 
 
 
 
 
 
 
 
13f64cc
 
83cb3c8
 
 
 
13f64cc
83cb3c8
 
 
13f64cc
 
83cb3c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13f64cc
 
83cb3c8
 
 
 
13f64cc
83cb3c8
13f64cc
 
 
83cb3c8
 
 
 
 
 
13f64cc
83cb3c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13f64cc
83cb3c8
 
13f64cc
83cb3c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13f64cc
83cb3c8
 
13f64cc
83cb3c8
 
 
13f64cc
 
83cb3c8
13f64cc
83cb3c8
 
 
13f64cc
83cb3c8
 
 
 
13f64cc
 
83cb3c8
 
 
 
 
13f64cc
83cb3c8
13f64cc
 
 
 
 
 
 
 
 
 
 
 
 
83cb3c8
 
 
 
 
 
 
 
 
 
 
 
13f64cc
 
 
 
 
 
 
 
83cb3c8
13f64cc
 
 
 
 
83cb3c8
 
 
 
 
 
 
13f64cc
83cb3c8
 
 
 
 
 
 
 
 
 
13f64cc
83cb3c8
 
 
13f64cc
 
 
 
 
 
 
83cb3c8
 
 
 
 
13f64cc
83cb3c8
 
 
 
 
 
 
13f64cc
83cb3c8
 
13f64cc
83cb3c8
 
 
13f64cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83cb3c8
13f64cc
83cb3c8
13f64cc
 
 
 
83cb3c8
13f64cc
 
 
83cb3c8
13f64cc
83cb3c8
 
 
13f64cc
83cb3c8
 
 
 
13f64cc
 
 
83cb3c8
13f64cc
 
 
 
83cb3c8
13f64cc
 
 
 
 
83cb3c8
13f64cc
 
83cb3c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13f64cc
 
 
83cb3c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13f64cc
 
120df80
83cb3c8
13f64cc
 
 
 
83cb3c8
13f64cc
120df80
13f64cc
120df80
13f64cc
83cb3c8
120df80
13f64cc
120df80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13f64cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120df80
13f64cc
120df80
13f64cc
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TODO: Add a description here."""

import copy
import re
from typing import List, Dict, Union, Callable
import numpy as np

import datasets
import evaluate
from rouge_chinese import Rouge
from scipy.optimize import linear_sum_assignment

# TODO: Add BibTeX citation
_CITATION = """\
@InProceedings{huggingface:module,
title = {quad match score},
authors={huggingface, Inc.},
year={2020}
}
"""

# TODO: Add description of the module here
_DESCRIPTION = """\
evaluate sentiment quadruples.
评估生成模型的情感四元组
"""

# TODO: Add description of the arguments of the module here
_KWARGS_DESCRIPTION = """
Calculates how good are predictions given some references, using certain scores
Args:
    predictions: list of predictions to score. Each predictions
        should be a string with tokens separated by spaces.
    references: list of reference for each prediction. Each
        reference should be a string with tokens separated by spaces.
Returns:
    score: sentiment quadruple match score
    
Examples:
    Examples should be written in doctest format, and should illustrate how
    to use the function.

    >>> import evaluate
    >>> module = evaluate.load("yuyijiong/quad_match_score")
    >>> predictions=["food | good | food#taste | pos"]
    >>> references=["food | good | food#taste | pos & service | bad | service#general | neg"]
    >>> result=module.compute(predictions=predictions, references=references)
    >>> print(result)
    result={'ave match score of weight (1, 1, 1, 1)': 0.375, 
    'f1 score of exact match': 0.0, 
    'f1 score of optimal match of weight (1, 1, 1, 1)': 0.5}
"""


# 计算rougel的f1值
def get_rougel_f1(text_pred_list: List[str], text_true_list: List[str]) -> float:
    assert len(text_pred_list) == len(text_true_list), "文本数量不一致"
    # 如果text_pred_list[0]为空字符串或空格,则返回0
    if not text_pred_list[0].strip():
        return 0

    rouge = Rouge()
    # 判断text_true[0]是否有中文,有中文则要用空格分割
    if re.search(u"[\u4e00-\u9fa5]+", text_pred_list[0]):
        text_pred_list = [' '.join(list(text_pred)) for text_pred in text_pred_list]
        text_true_list = [' '.join(list(text_true)) for text_true in text_true_list]

    rouge_l_f1 = rouge.get_scores(text_pred_list, text_true_list, avg=True)['rouge-l']['f']

    return rouge_l_f1


# 记录四元组的函数
class CommentUnitsSim:
    def __init__(self, data: List[Dict[str, str]], data_source: any = None, abnormal=False, language=None):
        self.data_source = data_source
        self.abnormal = abnormal
        data = copy.deepcopy(data)
        # 如果字典有target,则改名为target_text
        for quad_dict in data:
            if 'target' in quad_dict:
                quad_dict['target_text'] = quad_dict['target']
                del quad_dict['target']
            if 'opinion' in quad_dict:
                quad_dict['opinion_text'] = quad_dict['opinion']
                del quad_dict['opinion']

        self.data = data
        self.polarity_en2zh = {'positive': '积极', 'negative': '消极', 'neutral': '中性', 'pos': '积极', 'neg': '消极',
                               'neu': '中性', '积极': '积极', '消极': '消极', '中性': '中性'}
        self.polarity_zh2en = {'积极': 'pos', '消极': 'neg', '中性': 'neu', 'pos': 'pos', 'neg': 'neg', 'neu': 'neu',
                               'positive': 'pos', 'negative': 'neg', 'neutral': 'neu'}

        self.language = language if language is not None else 'zh' if self.check_zh() else 'en'
        self.none_sign = 'null'

    @property
    def num(self):
        return len(self.data)

    # 检查四元组中是否有中文
    def check_zh(self):
        for quad_dict in self.data:
            if re.search('[\u4e00-\u9fa5]', quad_dict['target_text']) or re.search('[\u4e00-\u9fa5]',
                                                                                   quad_dict['opinion_text']):
                return True
        return False

    # 检测极性是否正确
    def check_polarity(self):
        # 若有某个四元组的极性不是positive、negative、neutral,则返回False
        for quad_dict in self.data:
            if quad_dict['polarity'] not in ['positive', 'negative', 'neutral', 'pos', 'neg', 'neu', '积极', '消极',
                                             '中性']:
                self.abnormal = True
                return False

    # 将极性由英文转为中文
    def convert_polarity_en2zh(self):
        for quad_dict in self.data:
            quad_dict['polarity'] = self.polarity_en2zh[quad_dict['polarity']]
        return self

    # 将极性由中文转为英文
    def convert_polarity_zh2en(self):
        for quad_dict in self.data:
            quad_dict['polarity'] = self.polarity_zh2en[quad_dict['polarity']]
        return self

    # 检查是否有重复的四元组,若有则删除重复的
    def del_duplicate(self):
        new_data = []
        for quad_dict in self.data:
            if quad_dict not in new_data:
                new_data.append(quad_dict)
        self.data = new_data
        return self

    # 检查是否有target和opinion都为null的四元组,若有则返回True
    def check_target_opinion_null(self):
        for quad_dict in self.data:
            if quad_dict['target_text'] == 'null' and quad_dict['opinion_text'] == 'null':
                return True
        return False

    # 检查是否有target或opinion为null的四元组,若有则返回True
    def check_any_null(self):
        for quad_dict in self.data:
            if quad_dict['target_text'] == 'null' or quad_dict['opinion_text'] == 'null':
                return True
        return False

    @classmethod
    def from_str(cls, quadruple_str: str, tuple_len: Union[int, list, str] = 4, format_code=0, sep_token1=' & ',
                 sep_token2=' | '):
        data = []
        abnormal = False
        # 确保分隔符后面一定是空格
        for i in range(len(quadruple_str) - 1):
            if (quadruple_str[i] == sep_token1.strip() or quadruple_str[i] == sep_token2.strip()) and quadruple_str[
                i + 1] != ' ':
                quadruple_str = quadruple_str[:i + 1] + ' ' + quadruple_str[i + 1:]

        # 选择几元组,即创建列表索引,从四元组中抽出n元
        if isinstance(tuple_len, int):
            tuple_index = list(range(tuple_len))
        elif isinstance(tuple_len, list):
            tuple_index = tuple_len
        elif isinstance(tuple_len, str):
            # 例如将‘012’转换为[0,1,2]
            tuple_index = [int(i) for i in tuple_len]
        else:
            raise Exception('tuple_len参数错误')

        for quadruple in quadruple_str.split(sep_token1):
            if format_code == 0:
                # quadruple可能是target|opinion|aspect|polarity,也可能是target|opinion|aspect,也可能是target|opinion,若没有则为“None”
                quadruple_split = [unit.strip() for unit in quadruple.split(sep_token2)]
                if len(quadruple_split) > len(tuple_index):
                    print('quadruple格式错误,过多元素', quadruple_str)
                    abnormal = True
                    quadruple_split = quadruple_split[0:len(tuple_index)]  # 过长则截断
                elif len(quadruple_split) < len(tuple_index):
                    print('quadruple格式错误,过少元素', quadruple_str)
                    abnormal = True
                    quadruple_split = ["None"] * (
                                len(tuple_index) - len(quadruple_split)) + quadruple_split  # 过短则补'None'

                quadruple_keys = [["target_text", "opinion_text", "aspect", "polarity"][i] for i in tuple_index]
                quadruple_dict = dict(zip(quadruple_keys, quadruple_split))

                q = {"target_text": 'None', "opinion_text": 'None', "aspect": 'None', "polarity": 'None'}
                q.update(quadruple_dict)
                # 检查极性是否合法
                if q['polarity'] not in ['pos', 'neg', 'neu', 'None', '积极', '消极', '中性']:
                    print('quadruple格式错误,极性格式不对', quadruple_str)

            else:
                raise Exception('answer_format参数错误')

            data.append(q)

        return CommentUnitsSim(data, quadruple_str, abnormal)

    @classmethod
    def from_list(cls, quadruple_list: List[List[str]], **kwargs):
        data = []
        for quadruple in quadruple_list:
            # #format_code='013'代表list只有四元组的第0、1、3个元素,需要扩充为4元组,空缺位置补上None
            # if format_code=='013':
            #     quadruple.insert(2,None)

            data.append(
                {"target_text": quadruple[0], "opinion_text": quadruple[1], "aspect": quadruple[2],
                 "polarity": quadruple[3]})

        return CommentUnitsSim(data, quadruple_list, **kwargs)

    @classmethod
    def from_list_dict(cls, quadruple_list: List[dict], **kwargs):
        for quad_dict in quadruple_list:
            if 'target' in quad_dict:
                quad_dict['target_text'] = quad_dict['target']
                del quad_dict['target']
            if 'opinion' in quad_dict:
                quad_dict['opinion_text'] = quad_dict['opinion']
                del quad_dict['opinion']

        data = []
        for quadruple in quadruple_list:
            # 如果quadruple缺少某个key,则补上None
            q = {"target_text": 'None', "opinion_text": 'None', "aspect": 'None', "polarity": 'None'}
            q.update(quadruple)
            data.append(q)

        return CommentUnitsSim(data, quadruple_list, **kwargs)

    # 转化为list,即只保留字典的value
    def to_list(self):
        data = []
        for quad_dict in self.data:
            data.append(
                [quad_dict['target_text'], quad_dict['opinion_text'], quad_dict['aspect'], quad_dict['polarity']])
        return data

    # 将data转换为n元组字符串
    def get_quadruple_str(self, format_code=0, tuple_len: Union[int, list, str] = 4, sep_token1=' & ',
                          sep_token2=' | '):
        new_text_list = []
        # 选择几元组,即创建列表索引,从四元组中抽出n元
        if isinstance(tuple_len, int):
            tuple_index = list(range(tuple_len))
        elif isinstance(tuple_len, list):
            tuple_index = tuple_len
        elif isinstance(tuple_len, str):
            # 例如将‘012’转换为[0,1,2]
            tuple_index = [int(i) for i in tuple_len]
        else:
            raise Exception('tuple_len参数错误')

        try:
            # 若语言为中文,则使用中文极性
            if self.language == 'zh':
                self.convert_polarity_en2zh()
            else:
                self.convert_polarity_zh2en()
        except:
            print('语言参数错误', self.data)
            print(self.language)
            raise Exception('语言参数错误')

        # 若tuple_index==[3],则返回综合情感极性
        if tuple_index == [3]:
            return self.merge_polarity()

        for quad_dict in self.data:
            # 提取target_text,如果空列表则为'',如果列表长度大于1则为','.join(list)
            target_text = quad_dict['target_text']
            # 提取opinion_text,如果空列表则为'',如果列表长度大于1则为','.join(list)
            opinion_text = quad_dict['opinion_text']
            # 提取aspect
            aspect = quad_dict['aspect']
            # 提取polarity
            polarity = quad_dict['polarity']

            # 拼接,‘|’分割
            if format_code == 0:
                # 根据tuple_len拼接
                new_text = sep_token2.join([[target_text, opinion_text, aspect, polarity][i] for i in tuple_index])
            else:
                raise Exception('answer_format参数错误')

            new_text_list.append(new_text)

        # 如果tuple_index为[2,3],则需要去除new_text_list中重复的元素,不要改变顺序。因为可能有重复的方面
        if tuple_index == [2, 3]:
            res = []
            for t in new_text_list:
                if t not in res:
                    res.append(t)
            new_text_list = res

        # 如果tuple_index为[3],则只保留new_text_list的第一个元素。因为只有一个情感极性
        elif tuple_index == [3]:
            new_text_list = new_text_list[:1]

        if format_code == 0:
            # 根据tuple_len拼接
            return sep_token1.join(new_text_list)

    # 与另一个CommentUnits对象对比,检测有几个相同的四元组
    def compare_same(self, other) -> int:
        count = 0
        for quad_dict in self.data:
            if quad_dict in other.data:
                count += 1
        return count

    # 检查自身数据的四元组中target是否有重复
    def check_target_repeat(self):
        target_list = []
        for quad_dict in self.data:
            target_list.append(quad_dict['target_text'])
        return len(target_list) != len(set(target_list))

    # 检查自身数据的四元组中opinion是否有重复
    def check_opinion_repeat(self):
        opinion_list = []
        for quad_dict in self.data:
            opinion_list.append(quad_dict['opinion_text'])
        return len(opinion_list) != len(set(opinion_list))

    # 检查自身数据的四元组中aspect是否有重复
    def check_aspect_repeat(self):
        aspect_list = []
        for quad_dict in self.data:
            aspect_list.append(quad_dict['aspect'])
        return len(aspect_list) != len(set(aspect_list))

    # 输出所有aspect的列表
    def get_aspect_list(self):
        aspect_list = []
        for quad_dict in self.data:
            aspect_list.append(quad_dict['aspect'])
        return aspect_list

    # 输出所有target的列表
    def get_target_list(self):
        target_list = []
        for quad_dict in self.data:
            target_list.append(quad_dict['target_text'])
        return target_list

    # 输出所有opinion的列表
    def get_opinion_list(self):
        opinion_list = []
        for quad_dict in self.data:
            opinion_list.append(quad_dict['opinion_text'])
        return opinion_list

    # 输出所有polarity的列表
    def get_polarity_list(self):
        polarity_list = []
        for quad_dict in self.data:
            polarity_list.append(quad_dict['polarity'])
        return polarity_list

    # 对所有polarity进行综合
    def merge_polarity(self):
        polarity_list = self.get_polarity_list()
        # 判断是英文还是中文
        if self.language == 'en':
            if 'pos' in polarity_list and 'neg' in polarity_list:
                return 'neu'
            elif 'pos' in polarity_list:
                return 'pos'
            elif 'neg' in polarity_list:
                return 'neg'
            else:
                return 'neu'
        else:
            if '积极' in polarity_list and '消极' in polarity_list:
                return '中性'
            elif '积极' in polarity_list:
                return '积极'
            elif '消极' in polarity_list:
                return '消极'
            else:
                return '中性'

    # 检测是否有不合法opinion
    def check_opinion_in_comment(self, comment_text):
        for quad_dict in self.data:
            if quad_dict['opinion_text'] != '*' and (not quad_dict['opinion_text'] in comment_text):
                return False
        return True

    # 检测是否有不合法target
    def check_target_in_comment(self, comment_text):
        for quad_dict in self.data:
            if quad_dict['target_text'] != '*' and (not quad_dict['target_text'] in comment_text):
                return False
        return True

    # 计算两个四元组的相似度
    @staticmethod
    def get_similarity(units1, units2: 'CommentUnitsSim'):
        pass

    # 对自身数据进行操作
    def apply(self, func: Callable, field: str):
        for quad_dict in self.data:
            quad_dict[field] = func(quad_dict[field])
        return self


# 四元组匹配函数
class CommentUnitsMatch:
    def __init__(self, target_weight=0.5, opinion_weight=0.5, aspect_weight=0.5, polarity_weight=0.5, one_match=True):
        # 归一化权重
        weight_sum = target_weight + opinion_weight + aspect_weight + polarity_weight
        self.target_weight = target_weight / weight_sum
        self.opinion_weight = opinion_weight / weight_sum
        self.aspect_weight = aspect_weight / weight_sum
        self.polarity_weight = polarity_weight / weight_sum

        # 是否一对一匹配
        self.one_match = one_match

    # 特定feature置零
    def set_zero(self, feature: str = 'polarity'):
        if feature == 'polarity':
            self.polarity_weight = 0
        elif feature == 'aspect':
            self.aspect_weight = 0
        elif 'opinion' in feature:
            self.opinion_weight = 0
        elif 'target' in feature:
            self.target_weight = 0
        else:
            raise Exception('feature参数错误')

    def re_normalize(self):
        weight_sum = self.target_weight + self.opinion_weight + self.aspect_weight + self.polarity_weight
        self.target_weight = self.target_weight / weight_sum
        self.opinion_weight = self.opinion_weight / weight_sum
        self.aspect_weight = self.aspect_weight / weight_sum
        self.polarity_weight = self.polarity_weight / weight_sum

    # 计算cost矩阵,完全匹配为0,不匹配为1
    def get_cost_matrix(self, units1: 'CommentUnitsSim', units2: 'CommentUnitsSim', feature: str = 'polarity'):
        pass
        # 检查此feature是否存在,不存在则返回全0矩阵
        if units1.data[0].get(feature) is None or units2.data[0].get(feature) is None \
                or units1.data[0].get(feature) == 'None' or units2.data[0].get(feature) == 'None':
            cost_matrix = np.zeros((len(units1.data), len(units2.data)))
            # 对应feature的weight也为0
            self.set_zero(feature)

            # 并再次归一化
            self.re_normalize()

            return cost_matrix

        # 检查两个四元组的极性是否相同,生成cost矩阵,用于匈牙利算法。不相同则cost为1,相同则cost为0
        cost_matrix = []
        for quad_dict1 in units1.data:
            cost_list = []
            for quad_dict2 in units2.data:
                if quad_dict1[feature] == quad_dict2[feature]:
                    cost_list.append(0)
                else:
                    cost_list.append(1)
            cost_matrix.append(cost_list)

        # cost矩阵转换为numpy数组,大小为(len(units1.data),len(units2.data))
        cost_matrix = np.array(cost_matrix)
        return cost_matrix

    # 计算cost矩阵,使用rougel指标
    def get_cost_matrix_rouge(self, units1: 'CommentUnitsSim', units2: 'CommentUnitsSim', feature: str = 'target_text'):
        # 检查此feature是否存在,不存在则返回全0矩阵
        if units1.data[0].get(feature) is None or units2.data[0].get(feature) is None \
                or units1.data[0].get(feature) == 'None' or units2.data[0].get(feature) == 'None':
            cost_matrix = np.zeros((len(units1.data), len(units2.data)))
            # 对应feature的weight也为0
            self.set_zero(feature)
            # 并再次归一化
            self.re_normalize()
            return cost_matrix

        # 检查两个四元组的极性是否相同,生成cost矩阵,用于匈牙利算法。相同则cost为0,不相同则cost为1-rougel
        cost_matrix = []
        for quad_dict1 in units1.data:
            cost_list = []
            for quad_dict2 in units2.data:
                if quad_dict1[feature] == quad_dict2[feature]:
                    cost_list.append(0)
                else:
                    cost_list.append(1 - get_rougel_f1([quad_dict1[feature]], [quad_dict2[feature]]))
            cost_matrix.append(cost_list)

        # cost矩阵转换为numpy数组,大小为(len(units1.data),len(units2.data))
        cost_matrix = np.array(cost_matrix)
        return cost_matrix

    # 匹配四元组并计算cost
    def match_units(self, units1: 'CommentUnitsSim', units2: 'CommentUnitsSim') -> tuple:
        # 计算极性的cost矩阵,矩阵元素在0-1之间
        cost_matrix_polarity = self.get_cost_matrix(units1, units2, feature='polarity')
        # 计算aspect的cost矩阵
        cost_matrix_aspect = self.get_cost_matrix(units1, units2, feature='aspect')
        # 计算target的cost矩阵
        cost_matrix_target = self.get_cost_matrix_rouge(units1, units2, feature='target_text')
        # 计算opinion的cost矩阵
        cost_matrix_opinion = self.get_cost_matrix_rouge(units1, units2, feature='opinion_text')

        # 计算总的cost矩阵,矩阵元素在0-1之间。矩阵的行数为units1即pred的数量,列数为units2即true的数量
        cost_matrix = self.target_weight * cost_matrix_target + self.opinion_weight * cost_matrix_opinion + \
                      self.aspect_weight * cost_matrix_aspect + self.polarity_weight * cost_matrix_polarity
        score_matrix = 1 - cost_matrix

        cost = 0
        # 使用匈牙利算法进行匹配
        if self.one_match:
            # 只允许一对一的匹配,这种情况下row_ind和col_ind的长度一定相等且等于units1和units2的数量中的较小值
            row_ind, col_ind = linear_sum_assignment(cost_matrix)

        else:
            # 允许一对多的匹配。这种情况下每个四元组都一定匹配上,这种情况下row_ind和col_ind的长度一定相等且等于units1和units2的数量中的较大值
            if units1.num > units2.num:
                row_ind = np.arange(units1.num)
                col_ind = np.argmin(cost_matrix, axis=1)

            else:
                row_ind = np.argmin(cost_matrix, axis=0)
                col_ind = np.arange(units2.num)

        # 计算这种匹配的cost
        for i in range(len(row_ind)):
            cost += cost_matrix[row_ind[i]][col_ind[i]]

        # 计算这种匹配下的TP\FP\FN
        TP = 0
        for i in range(len(row_ind)):
            TP += score_matrix[row_ind[i]][col_ind[i]]

        # len(row_ind)为pred的数量,TP为匹配上的数量
        FP = units1.num - TP
        FN = units2.num - TP

        # 如果一对一匹配,会有匹配不上的四元组,这些四元组cost为1
        max_units_num = max(units1.num, units2.num)
        if self.one_match:
            cost += (max_units_num - len(row_ind))

        # 对cost进行归一化,使其在0-1之间
        cost_per_quadruple = cost / max_units_num
        if cost_per_quadruple > 1 or cost_per_quadruple < 0:
            print('cost错误', cost_per_quadruple, 'pred:', units1.data, 'true:', units2.data)
            print(self.target_weight, self.opinion_weight, self.aspect_weight, self.polarity_weight)

        # 返回的cost在0-1之间
        return cost_per_quadruple, TP, FP, FN


@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class QuadMatch(evaluate.Metric):
    """TODO: Short description of my evaluation module."""

    def _info(self):
        # TODO: Specifies the evaluate.EvaluationModuleInfo object
        return evaluate.MetricInfo(
            # This is the description that will appear on the modules page.
            module_type="metric",
            description=_DESCRIPTION,
            citation=_CITATION,
            inputs_description=_KWARGS_DESCRIPTION,
            # This defines the format of each prediction and reference
            features=[
                datasets.Features(
                    {
                        "predictions": datasets.Value("string", id="sequence"),
                        "references": datasets.Sequence(datasets.Value("string", id="sequence")),
                    }
                ),
                datasets.Features(
                    {
                        "predictions": datasets.Value("string", id="sequence"),
                        "references": datasets.Value("string", id="sequence"),
                    }
                ),
            ],
            # Homepage of the module for documentation
            homepage="http://module.homepage",
            # Additional links to the codebase or references
            codebase_urls=["http://github.com/path/to/codebase/of/new_module"],
            reference_urls=["http://path.to.reference.url/new_module"]
        )

    def _download_and_prepare(self, dl_manager):
        """Optional: download external resources useful to compute the scores"""
        # TODO: Download external resources if needed
        pass

    def _compute(self,
                 predictions: List[str],
                 references: Union[List[str], List[List[str]]],
                 quad_weights: tuple = (1, 1, 1, 1),
                 **kwargs) -> dict:
        '''

        :param predictions: list of predictions of sentiment quads
        :param references: list of references of sentiment quads
        :param quad_weights: weight of target,opinion,aspect,polarity for cost compute

        :param kwargs:
                :param tuple_len: indicate the format of the quad, see the following mapping
                :param sep_token1: the token to seperate quads
                :param sep_token2: the token to seperate units of one quad

        :return:average matching score

        #mapping
        id2prompt={'0123':"quadruples (target | opinion | aspect | polarity)",
           '':"quadruples (target | opinion | aspect | polarity)",
           '01':'pairs (target | opinion)',
           '012':'triples (target | opinion | aspect)',
           '013':'triples (target | opinion | polarity)',
           '023':'triples (target | aspect | polarity)',
           '23':'pairs (aspect | polarity)',
           '03':'pairs (target | polarity)',
           '13':'pairs (opinion | polarity)',
           '3':'single (polarity)'}

        #中文版映射
        id2prompt_zh={'0123': "四元组(对象 | 观点 | 方面 | 极性)",
                      '':"四元组(对象 | 观点 | 方面 | 极性)",
                      '01':'二元组(对象 | 观点)',
                      '012':'三元组(对象 | 观点 | 方面)',
                      '013':'三元组(对象 | 观点 | 极性)',
                      '023':'三元组(对象 | 方面 | 极性)',
                      '23':'二元组(方面 | 极性)',
                      '03':'二元组(对象 | 极性)',
                      '13':'二元组(观点 | 极性)',
                        '3':'单元素(极性)'}
        '''
        f1_of_optimal_match, score_of_optimal_match = self.quad_f1_of_optimal_match(predictions, references,
                                                                                    quad_weights, **kwargs)
        f1 = self.quad_f1_of_exact_match(predictions=predictions, references=references, **kwargs)

        # 取1-cost为得分
        return {'score of optimal match of weight ' + str(quad_weights): score_of_optimal_match,
                'f1 of optimal match of weight ' + str(quad_weights): f1_of_optimal_match,
                'f1 of exact match': f1}

    @staticmethod
    def quad_f1_of_exact_match(predictions: List[str], references: Union[List[str], List[List[str]]],
                               return_dict=False, **kwargs) -> Union[Dict[str, float], float]:
        assert len(predictions) == len(references), "文本数量不一致"
        correct, pred_num, true_num = 0, 0, 0

        for pred, refer in zip(predictions, references):
            pred = CommentUnitsSim.from_str(pred, **kwargs)
            # refer转换为list
            if isinstance(refer, str):
                refer =[refer]

            # refer转换为CommentUnitsSim
            refer = [CommentUnitsSim.from_str(t, **kwargs) for t in refer]

            # 如果refer是list,说明有多个正确答案,取最高分的那个
            #计算每个refer的TP的个数
            correct_list = [pred.compare_same(t) for t in refer]
            #计算每个refer的f1
            f1_list=[2 * correct_list[i] / (pred.num + refer[i].num) for i in range(len(refer))]
            # 获取f1得分最高的索引
            best_index = f1_list.index(max(f1_list))
            pred_num += pred.num
            true_num += refer[best_index].num
            correct += correct_list[best_index]


        # 以下结果保留4位小数
        precision = round(correct / pred_num, 4) + 1e-8
        recall = round(correct / true_num, 4) + 1e-8
        f1 = round(2 * precision * recall / (precision + recall), 4)

        if return_dict:
            return {"precision": precision, "recall": recall, "f1": f1}
        else:
            return f1

    # 计算最优匹配f1
    @staticmethod
    def quad_f1_of_optimal_match(
            predictions: List[str],
            references: Union[List[str], List[List[str]]],
            quad_weights: tuple = (1, 1, 1, 1),
            one_match=True,
            **kwargs):

        assert len(predictions) == len(references)
        if isinstance(predictions, str):
            predictions = [predictions]
            references = [references]

        cost = 0
        TP, FP, FN = 0, 0, 0
        matcher = CommentUnitsMatch(*quad_weights, one_match=one_match)

        for pred, refer in zip(predictions, references):

            pred = CommentUnitsSim.from_str(pred, **kwargs)
            # 将refer转换为list形式
            if isinstance(refer, str):
                refer = [refer]

            # 将refer中的每个元素转换为CommentUnitsSim
            refer = [CommentUnitsSim.from_str(t, **kwargs) for t in refer]

            # 如果true是多个正确答案,取最高分
            cost_list = [matcher.match_units(pred, t) for t in refer]
            # 获取cost最小的值的索引,按元组中第一个元素大小排序
            # 计算每一对样本的cost,TP,FP,FN
            cost_, TP_, FP_, FN_ = cost_list[np.argmin([c[0] for c in cost_list])]
            cost += cost_
            TP += TP_
            FP += FP_
            FN += FN_

        # 平均cost
        cost = cost / len(predictions)
        # 由TP\FP\FN计算最优匹配F1
        precision_match = TP / (TP + FP)
        recall_match = TP / (TP + FN)
        f1_match = 2 * precision_match * recall_match / (precision_match + recall_match)

        return f1_match, 1 - cost