TwT-6's picture
Upload 2667 files
256a159 verified
raw
history blame
9.42 kB
import json
import os
import random
import re
from pathlib import Path
import tiktoken
from datasets import Dataset
from opencompass.datasets.base import BaseDataset
from opencompass.openicl import BaseEvaluator
from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS
def get_random_line_by_language(file_path, language):
with open(file_path, 'r', encoding='utf-8') as file:
lines = [
json.loads(line.strip()) for line in file
if json.loads(line.strip())['language'] == language
]
if lines:
random_line = random.choice(lines)
return {
'needle': random_line['needle'],
'retrieval_question': random_line['retrieval_question'],
'keyword': random_line['arg2']
}
else:
return None
@LOAD_DATASET.register_module()
class NeedleBenchOriginDataset(BaseDataset):
@staticmethod
def load(
path: str,
length: int,
depth: int,
tokenizer_model: str,
file_list: list[str],
num_repeats_per_file: int,
length_buffer: int,
guide: bool,
language: str,
needle_file_name: str,
):
data = {'prompt': [], 'answer': []}
tokenizer = tiktoken.encoding_for_model(tokenizer_model)
def _generate_context(tokens_context, depth_percent, needle):
tokens_needle = _get_tokens_from_context(needle)
insertion_point = int(len(tokens_context) * (depth_percent / 100))
tokens_context = (tokens_context[:insertion_point] +
tokens_needle + tokens_context[insertion_point:])
new_context = _decode_tokens(tokens_context)
return new_context
def _get_tokens_from_context(context):
return tokenizer.encode(context)
def _decode_tokens(tokens):
return tokenizer.decode(tokens)
def _modify_retrieval_question(retrieval_question):
if language == 'Chinese':
parts = retrieval_question.split('请按照')
guide_retrieval_question = (parts[0] + '在回答之前,请思考文档中与此问题'
'最相关的内容是什么。请按照' + parts[1])
return guide_retrieval_question
elif language == 'English':
parts = retrieval_question.split('Please answer in the format')
guide_retrieval_question = (
parts[0] + 'Before answering, please consider'
' what in the document is most relevant to this question.'
' Please answer in the format' + parts[1])
return guide_retrieval_question
else:
raise ValueError(f"Language '{language}' is not supported.")
def _generate_prompt(context, retrieval_question):
if guide:
retrieval_question = _modify_retrieval_question(
retrieval_question)
if language == 'Chinese':
prompt = ('你是一个善于回答用户问题的智能AI助手\n'
'请保持你的回答简洁清楚。不要说和下面文档中的无关的话'
',或重复你的回答\n'
f'用户现在给你的文档是{context}\n\n'
f'现在请问:{retrieval_question}')
elif language == 'English':
prompt = ('You are an intelligent AI assistant skilled in '
'answering user questions.\n'
'Please keep your answers concise and clear. Do not'
' talk about irrelevant topics or repeat your '
'answers.\n'
f'The document given to you by the user is {context}'
f'\n\nNow, the question is: {retrieval_question}')
else:
raise ValueError(f"Language '{language}' is not supported.")
return prompt
files = Path(path).glob('*.jsonl')
for file in files:
if file.name not in file_list:
continue
with open(file, 'r', encoding='utf-8') as f:
lines_bak = [json.loads(line.strip()) for line in f]
lines = lines_bak.copy()
for counter in range(num_repeats_per_file):
random.seed(counter)
random.shuffle(lines)
needle_file_path = os.path.join(path, needle_file_name)
random_needle = get_random_line_by_language(
needle_file_path, language)
needle = '\n' + random_needle['needle'] + '\n'
retrieval_question = random_needle['retrieval_question']
keyword = random_needle['keyword']
context_length = length - length_buffer
target_length_per_record = context_length - len(
_get_tokens_from_context(needle))
target_length_per_record = max(target_length_per_record, 0)
accumulated_tokens = []
for line in lines:
tokens_current_line = _get_tokens_from_context(
line['text'])
accumulated_tokens.extend(tokens_current_line)
if len(accumulated_tokens) >= target_length_per_record:
break
processed_text = _generate_context(
accumulated_tokens[:target_length_per_record], depth,
needle)
processed_prompt = _generate_prompt(processed_text,
retrieval_question)
data['prompt'].append(processed_prompt)
data['answer'].append(needle + '*' + keyword)
dataset = Dataset.from_dict({
'prompt': data['prompt'],
'answer': data['answer'],
})
return dataset
class NeedleBenchOriginEvaluator(BaseEvaluator):
def __init__(self, use_trim=False):
self.use_trim = use_trim
@staticmethod
def _trim_prediction(prediction, reference):
"""Trims the prediction string based on the length of the reference
string.
Args:
prediction (str): The prediction string.
reference (str): The reference string.
Returns:
str: The trimmed prediction string.
"""
l08 = int(0.8 * len(reference))
l12 = int(1.2 * len(reference))
trimmed_prediction = prediction[:l12]
if len(trimmed_prediction) > l08 and \
reference[-1] in trimmed_prediction[l08:]:
end_pos = l08 + trimmed_prediction[l08:].index(reference[-1]) + 1
trimmed_prediction = trimmed_prediction[:end_pos]
return trimmed_prediction
def levenshtein_distance(self, s1, s2):
if len(s1) < len(s2):
return self.levenshtein_distance(s2, s1)
if len(s2) == 0:
return len(s1)
previous_row = range(len(s2) + 1)
for i, c1 in enumerate(s1):
current_row = [i + 1]
for j, c2 in enumerate(s2):
insertions = previous_row[j + 1] + 1
deletions = current_row[j] + 1
substitutions = previous_row[j] + (c1 != c2)
current_row.append(min(insertions, deletions, substitutions))
previous_row = current_row
return previous_row[-1]
def score(self, predictions, gold):
if len(predictions) != len(gold):
return {'error': 'predictions and gold have different lengths'}
total_score = 0
details = []
for prediction, reference in zip(predictions, gold):
keyword = reference.split('*')[1]
reference = reference.split('*')[0]
raw_prediction = prediction
prediction = re.sub(r'\s+', '', prediction)
reference = re.sub(r'\s+', '', reference)
if self.use_trim:
prediction = NeedleBenchOriginEvaluator._trim_prediction(
prediction, reference)
edit_distance = self.levenshtein_distance(prediction, reference)
max_len = max(len(prediction), len(reference))
score = 100 * (1 -
edit_distance / max_len) if max_len != 0 else 100
if keyword in raw_prediction:
print(f'{keyword} is in {prediction}')
score = 100
else:
print(f'{keyword} is not in {prediction}')
score = 0.2 * score
detail = {
'pred': prediction,
'answer': reference,
'edit_distance': edit_distance,
'score': score
}
total_score += score
details.append(detail)
average_score = total_score / len(predictions) if predictions else 0
result = {'score': average_score, 'details': details}
return result
@TEXT_POSTPROCESSORS.register_module('needlebench')
def needlebench_postprocess(text: str) -> str:
return text
@TEXT_POSTPROCESSORS.register_module('needlebench_dataset')
def needlebench_dataset_postprocess(text: str) -> str:
return text