File size: 4,206 Bytes
577164e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
import json
from collections import defaultdict
from nltk.translate.bleu_score import corpus_bleu
import statistics
import argparse
import json
import os
import re
from collections import Counter

def is_duplicated(text, top_k=10, min_word_len=0):
    words = re.findall(r'\b\w+\b', text)
    word_freq = Counter(words)

    # ๋‹จ์–ด ์ตœ์†Œ ๊ธ€์ž์ˆ˜ ์ œํ•œ
    if min_word_len > 0:
        for word, count in list(word_freq.items()):
            if len(word) <= min_word_len:
                del word_freq[word]

    if len(word_freq) == 0:
        return False

    if len(word_freq) == 1 and word_freq.most_common(1)[0][1] > 5:
        return word_freq.most_common(1)

    top_items = word_freq.most_common(top_k)
    frequencies = [frequency for item, frequency in top_items]
    mean_frequency = sum(frequencies) / len(frequencies)

    prev_frequency = 0
    index = 0

    if mean_frequency < 5:
        return False

    for item, frequency in top_items:
        if (prev_frequency - frequency) > mean_frequency:
            if index <= 1:
                return False
            # print(prev_frequency, frequency, mean_frequency, item)
            return top_items

        prev_frequency = frequency
        index += 1

    return False

def is_length_exceed(reference, generation, min_ratio=0.2, max_ratio=2):
    return not min_ratio <= (len(generation) / len(reference)) <= max_ratio

def get_average(a):
    if isinstance(a, list):
        return round(sum(a) / len(a), 2)
    return a


def main():
    parser = argparse.ArgumentParser("argument")
    parser.add_argument(
        "directory",
        type=str,
        help="input_file",
    )
    parser.add_argument('--detail', action='store_true', help='detail')
    args = parser.parse_args()
    
    # ๊ฐ ํŒŒ์ผ๋ณ„๋กœ src์— ๋Œ€ํ•œ bleu ์ ์ˆ˜๋ฅผ ์ €์žฅํ•  ๋”•์…”๋„ˆ๋ฆฌ
    file_src_bleu_scores = defaultdict(list)
    file_length_ratio = defaultdict(list)
    file_duplicated = defaultdict(list)
    file_duplicated_detail = defaultdict(list)
    # ๋””๋ ‰ํ† ๋ฆฌ ๋‚ด์˜ ๋ชจ๋“  ํŒŒ์ผ์— ๋Œ€ํ•ด ๋ฐ˜๋ณต
    for filename in os.listdir(args.directory):
        if filename.endswith('.jsonl'):  # JSONL ํŒŒ์ผ์ธ ๊ฒฝ์šฐ์—๋งŒ ์ฒ˜๋ฆฌ
            file_path = os.path.join(args.directory, filename)
            with open(file_path, 'r', encoding='utf-8') as file:
                for index, line in enumerate(file):
                    data = json.loads(line)
                    src = data['src']
                    bleu_score = data['bleu']
                    file_src_bleu_scores[filename].append(bleu_score)

                    # check_length
                    reference_length = len(data['reference'])
                    generation_length = len(data['generation'])
                    file_length_ratio[filename].append(round(generation_length / reference_length, 1))

                    # check duplication
                    word_count = is_duplicated(data['generation'])
                    file_duplicated[filename].append(0 if word_count is False else 1)
                    if word_count != False:
                        file_duplicated_detail[filename].append({'index':index, 'count':word_count,'generation':data['generation']})

    sorted_items = sorted(file_src_bleu_scores.items(), key=lambda x: statistics.mean(x[1]))
    # ๊ฐ ํŒŒ์ผ๋ณ„๋กœ src์— ๋Œ€ํ•œ bleu ํ‰๊ท  ๊ณ„์‚ฐ
    print('bleu scores')
    for filename, src_bleu_scores in sorted_items:
        avg_bleu = sum(src_bleu_scores) / len(src_bleu_scores)
        length_raio=[]
        cur_length_ratio = file_length_ratio[filename]
        ratio_mean = round(statistics.mean(cur_length_ratio), 1)
        for index, ratio in enumerate(cur_length_ratio):
            if ratio < 0.2 or ratio > 2.0:
                length_raio.append((index,ratio))
        print(f"{filename}: {avg_bleu:.2f}, out_of_range_count={len(length_raio)}, duplicate={sum(file_duplicated[filename])}")
        if args.detail:
            print(f'\t error length:{length_raio}')
        if args.detail:
            print(f"\t duplication")
            for info in file_duplicated_detail[filename]:
                print('\t\t', info)

if __name__ == "__main__":
    main()