|
import os |
|
import json |
|
from collections import defaultdict |
|
from nltk.translate.bleu_score import corpus_bleu |
|
import statistics |
|
import argparse |
|
import json |
|
import os |
|
import re |
|
from collections import Counter |
|
|
|
def is_duplicated(text, top_k=10, min_word_len=0): |
|
words = re.findall(r'\b\w+\b', text) |
|
word_freq = Counter(words) |
|
|
|
|
|
if min_word_len > 0: |
|
for word, count in list(word_freq.items()): |
|
if len(word) <= min_word_len: |
|
del word_freq[word] |
|
|
|
if len(word_freq) == 0: |
|
return False |
|
|
|
if len(word_freq) == 1 and word_freq.most_common(1)[0][1] > 5: |
|
return word_freq.most_common(1) |
|
|
|
top_items = word_freq.most_common(top_k) |
|
frequencies = [frequency for item, frequency in top_items] |
|
mean_frequency = sum(frequencies) / len(frequencies) |
|
|
|
prev_frequency = 0 |
|
index = 0 |
|
|
|
if mean_frequency < 5: |
|
return False |
|
|
|
for item, frequency in top_items: |
|
if (prev_frequency - frequency) > mean_frequency: |
|
if index <= 1: |
|
return False |
|
|
|
return top_items |
|
|
|
prev_frequency = frequency |
|
index += 1 |
|
|
|
return False |
|
|
|
def is_length_exceed(reference, generation, min_ratio=0.2, max_ratio=2): |
|
return not min_ratio <= (len(generation) / len(reference)) <= max_ratio |
|
|
|
def get_average(a): |
|
if isinstance(a, list): |
|
return round(sum(a) / len(a), 2) |
|
return a |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser("argument") |
|
parser.add_argument( |
|
"directory", |
|
type=str, |
|
help="input_file", |
|
) |
|
parser.add_argument('--detail', action='store_true', help='detail') |
|
args = parser.parse_args() |
|
|
|
|
|
file_src_bleu_scores = defaultdict(list) |
|
file_length_ratio = defaultdict(list) |
|
file_duplicated = defaultdict(list) |
|
file_duplicated_detail = defaultdict(list) |
|
|
|
for filename in os.listdir(args.directory): |
|
if filename.endswith('.jsonl'): |
|
file_path = os.path.join(args.directory, filename) |
|
with open(file_path, 'r', encoding='utf-8') as file: |
|
for index, line in enumerate(file): |
|
data = json.loads(line) |
|
src = data['src'] |
|
bleu_score = data['bleu'] |
|
file_src_bleu_scores[filename].append(bleu_score) |
|
|
|
|
|
reference_length = len(data['reference']) |
|
generation_length = len(data['generation']) |
|
file_length_ratio[filename].append(round(generation_length / reference_length, 1)) |
|
|
|
|
|
word_count = is_duplicated(data['generation']) |
|
file_duplicated[filename].append(0 if word_count is False else 1) |
|
if word_count != False: |
|
file_duplicated_detail[filename].append({'index':index, 'count':word_count,'generation':data['generation']}) |
|
|
|
sorted_items = sorted(file_src_bleu_scores.items(), key=lambda x: statistics.mean(x[1])) |
|
|
|
print('bleu scores') |
|
for filename, src_bleu_scores in sorted_items: |
|
avg_bleu = sum(src_bleu_scores) / len(src_bleu_scores) |
|
length_raio=[] |
|
cur_length_ratio = file_length_ratio[filename] |
|
ratio_mean = round(statistics.mean(cur_length_ratio), 1) |
|
for index, ratio in enumerate(cur_length_ratio): |
|
if ratio < 0.2 or ratio > 2.0: |
|
length_raio.append((index,ratio)) |
|
print(f"{filename}: {avg_bleu:.2f}, out_of_range_count={len(length_raio)}, duplicate={sum(file_duplicated[filename])}") |
|
if args.detail: |
|
print(f'\t error length:{length_raio}') |
|
if args.detail: |
|
print(f"\t duplication") |
|
for info in file_duplicated_detail[filename]: |
|
print('\t\t', info) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|