File size: 4,206 Bytes
577164e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import os
import json
from collections import defaultdict
from nltk.translate.bleu_score import corpus_bleu
import statistics
import argparse
import json
import os
import re
from collections import Counter
def is_duplicated(text, top_k=10, min_word_len=0):
words = re.findall(r'\b\w+\b', text)
word_freq = Counter(words)
# ๋จ์ด ์ต์ ๊ธ์์ ์ ํ
if min_word_len > 0:
for word, count in list(word_freq.items()):
if len(word) <= min_word_len:
del word_freq[word]
if len(word_freq) == 0:
return False
if len(word_freq) == 1 and word_freq.most_common(1)[0][1] > 5:
return word_freq.most_common(1)
top_items = word_freq.most_common(top_k)
frequencies = [frequency for item, frequency in top_items]
mean_frequency = sum(frequencies) / len(frequencies)
prev_frequency = 0
index = 0
if mean_frequency < 5:
return False
for item, frequency in top_items:
if (prev_frequency - frequency) > mean_frequency:
if index <= 1:
return False
# print(prev_frequency, frequency, mean_frequency, item)
return top_items
prev_frequency = frequency
index += 1
return False
def is_length_exceed(reference, generation, min_ratio=0.2, max_ratio=2):
return not min_ratio <= (len(generation) / len(reference)) <= max_ratio
def get_average(a):
if isinstance(a, list):
return round(sum(a) / len(a), 2)
return a
def main():
parser = argparse.ArgumentParser("argument")
parser.add_argument(
"directory",
type=str,
help="input_file",
)
parser.add_argument('--detail', action='store_true', help='detail')
args = parser.parse_args()
# ๊ฐ ํ์ผ๋ณ๋ก src์ ๋ํ bleu ์ ์๋ฅผ ์ ์ฅํ ๋์
๋๋ฆฌ
file_src_bleu_scores = defaultdict(list)
file_length_ratio = defaultdict(list)
file_duplicated = defaultdict(list)
file_duplicated_detail = defaultdict(list)
# ๋๋ ํ ๋ฆฌ ๋ด์ ๋ชจ๋ ํ์ผ์ ๋ํด ๋ฐ๋ณต
for filename in os.listdir(args.directory):
if filename.endswith('.jsonl'): # JSONL ํ์ผ์ธ ๊ฒฝ์ฐ์๋ง ์ฒ๋ฆฌ
file_path = os.path.join(args.directory, filename)
with open(file_path, 'r', encoding='utf-8') as file:
for index, line in enumerate(file):
data = json.loads(line)
src = data['src']
bleu_score = data['bleu']
file_src_bleu_scores[filename].append(bleu_score)
# check_length
reference_length = len(data['reference'])
generation_length = len(data['generation'])
file_length_ratio[filename].append(round(generation_length / reference_length, 1))
# check duplication
word_count = is_duplicated(data['generation'])
file_duplicated[filename].append(0 if word_count is False else 1)
if word_count != False:
file_duplicated_detail[filename].append({'index':index, 'count':word_count,'generation':data['generation']})
sorted_items = sorted(file_src_bleu_scores.items(), key=lambda x: statistics.mean(x[1]))
# ๊ฐ ํ์ผ๋ณ๋ก src์ ๋ํ bleu ํ๊ท ๊ณ์ฐ
print('bleu scores')
for filename, src_bleu_scores in sorted_items:
avg_bleu = sum(src_bleu_scores) / len(src_bleu_scores)
length_raio=[]
cur_length_ratio = file_length_ratio[filename]
ratio_mean = round(statistics.mean(cur_length_ratio), 1)
for index, ratio in enumerate(cur_length_ratio):
if ratio < 0.2 or ratio > 2.0:
length_raio.append((index,ratio))
print(f"{filename}: {avg_bleu:.2f}, out_of_range_count={len(length_raio)}, duplicate={sum(file_duplicated[filename])}")
if args.detail:
print(f'\t error length:{length_raio}')
if args.detail:
print(f"\t duplication")
for info in file_duplicated_detail[filename]:
print('\t\t', info)
if __name__ == "__main__":
main()
|