tlem / tlem.py
Cookize's picture
ADD: BoolQ, TurthfulQA (#5)
7e92c24
import logging
from typing import Any, Optional, Protocol, Iterable, Callable
from tqdm.auto import tqdm
from evaluate.evaluation_suite import EvaluationSuite
import evaluate
import numpy as np
import datasets
import pandas as pd
from .tasks import *
from .utils import *
from itertools import chain
from copy import deepcopy
from . import utils
class ReasoningMetric(evaluate.Metric):
"""TODO: Short description of my evaluation module."""
def _info(self):
# if self.config_name in ["cmmlu"]:
features = datasets.Features(
{
"responses": datasets.Value("string"),
# "responses": datasets.Sequence(datasets.Value("float")),
"references": datasets.Value("string"),
}
)
# TODO: Specifies the evaluate.EvaluationModuleInfo object
return evaluate.EvaluationModuleInfo(
# This is the description that will appear on the modules page.
# module_type="measurement",
description="",
citation="",
inputs_description="",
# This defines the format of each prediction and reference
features=features,
# Homepage of the module for documentation
homepage="http://module.homepage",
# Additional links to the codebase or references
codebase_urls=["http://github.com/path/to/codebase/of/new_module"],
reference_urls=["http://path.to.reference.url/new_module"],
)
def _compute(self, responses, references):
return_value = getattr(Metrics, self.config_name)(responses, references)
match return_value:
case extract_responses, extract_references:
results = {
self.config_name: np.mean(
sync_pipe(lambda x, y: x == y)(
zip(extract_responses, extract_references)
)
)
}
case dict():
results = return_value
case list():
results = {self.config_name: np.mean(return_value)}
case _:
raise NotImplementedError
return results
class Suite(EvaluationSuite):
task_class = Task
utils = utils
supported_datasets = [
"arc",
"hellaswag",
"mmlu-chat",
"winogrande",
"gsm8k",
"cmmlu-chat",
"ceval-chat",
"bbh",
"drop",
"MATH",
]
def __getitem__(self, key) -> Task:
match key:
case str():
return self.suite[key]
case slice() | int():
return self.tasks[key]
def agg(self, suite):
for cate, tasks in suite.items():
if isinstance(tasks, dict):
suite[cate] = self.agg(tasks)
else:
suite[cate] = np.mean([pd.Series(task.result).mean() for task in tasks])
return suite
def run(
self,
model_or_pipeline: Any,
) -> dict[str, float]:
self.assert_suite_nonempty()
self.suite: dict[str, list[Task]]
for task in (bar := tqdm(self.tasks)):
bar.desc = f"complete {task.name}."
_ = task.run(model_or_pipeline)
logging.info(f"{task.name} {task.result=}")
return self.agg(deepcopy(self.suite))
def arun(self, model_or_pipeline):
async def sync_function():
return await tqdm.gather(
*[task.arun(model_or_pipeline) for task in self.tasks], leave=False
)
asyncio.run(sync_function())
return self.agg(deepcopy(self.suite))
def get_suite(self, name) -> dict[str, Task]:
chat = False
suite={}
match name:
case _ if "chat" in name:
chat = True
match name:
case _ if name.startswith("mmlu"):
suite = MMLU.suite(chat=chat)
case _ if name.startswith("cmmlu"):
suite = CMMLU.suite(chat=chat)
case _ if name.startswith("ceval"):
suite = CEVAL.suite(chat=chat)
case "gsm8k":
suite = Task(
dataset_name=("gsm8k", "main"),
metric_name=("sustech/tlem", "gsm8k"),
input_column="question",
label_column="answer",
)
case "bbh":
suite = BBH.suite()
case "arc":
suite = ARC.suite()
case "hellaswag":
suite = HellaSwag.suite()
case "drop":
suite = DROP.suite()
case "winogrande":
suite = Winogrande.suite()
case "truthfulqa_mc1":
suite = TruthfulQAMC1.suite()
case _ if name.startswith("boolq"):
suite = BoolQ.suite(chat=chat)
case "mt_bench":
suite = Task(
dataset_name="SUSTech/mt_bench_judge",
split="train",
prompt=mt_bench_prompt
# metric_name=("sustech/tlem", "gsm8k"),
)
case "MATH" | "competition_math":
suite = Task(
dataset_name="hendrycks/competition_math",
prompt="This is a math problem, please think step by step and slove it: {input_column}. Simplify your final answer as much as possible and surround them with '$' in TeX form.",
metric_name=("sustech/tlem", "MATH"),
input_column="problem",
label_column="solution",
)
case "open-leaderboard":
for name in [
"arc",
"hellaswag",
"mmlu-chat",
"winogrande",
"gsm8k",
# "truthful_qa",
"drop",
]:
suite.update(self.get_suite(name))
case "tlem":
for name in [
"arc",
"hellaswag",
"mmlu-chat",
"winogrande",
"gsm8k",
# "truthful_qa",
"cmmlu-chat",
"ceval-chat",
"bbh",
]:
suite.update(self.get_suite(name))
case "all":
for name in self.supported_datasets:
suite.update(self.get_suite(name))
case _:
raise NotImplementedError(
f"{name} is not supported in {self.supported_datasets}"
)
if isinstance(suite, Task):
suite = [suite]
suite = {name: suite}
return suite
def singleton(self, task):
try:
return self.tasks[self.tasks.index(task)]
except ValueError:
logging.debug(f"add {task.name} to suite.")
self.tasks.append(task)
logging.debug(self.tasks)
return self.tasks[-1]
def drop_duplicates(self, suite):
for category, tasks in suite.items():
match tasks:
case list():
suite[category] = [self.singleton(task) for task in tasks]
case dict():
suite[category] = self.drop_duplicates(tasks)
case _:
raise NotImplementedError
return suite
def load(self, name):
sub_suite = self.get_suite(name)
self.suite.update(sub_suite)
self.suite = self.drop_duplicates(self.suite)
# return self
def __init__(self, name="tlem"):
super().__init__(name)
self.tasks = []
self.suite = {}