future-xy
commited on
Commit
•
b9f0099
1
Parent(s):
9ceb74b
add system performance metrics
Browse files
src/backend/hflm_with_measurement.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import copy
|
2 |
import os
|
3 |
from datetime import timedelta
|
|
|
4 |
from pathlib import Path
|
5 |
from typing import List, Literal, Optional, Tuple, Union
|
6 |
|
@@ -195,7 +196,7 @@ class HFLMWithMeasurement(HFLM):
|
|
195 |
# for seq2seq case where self.tok_decode(self.eot_token_id) = ''
|
196 |
s = s.split(term)[0]
|
197 |
|
198 |
-
res.append(s)
|
199 |
|
200 |
self.cache_hook.add_partial("generate_until", (context, gen_kwargs), s)
|
201 |
pbar.update(1)
|
|
|
1 |
import copy
|
2 |
import os
|
3 |
from datetime import timedelta
|
4 |
+
import random
|
5 |
from pathlib import Path
|
6 |
from typing import List, Literal, Optional, Tuple, Union
|
7 |
|
|
|
196 |
# for seq2seq case where self.tok_decode(self.eot_token_id) = ''
|
197 |
s = s.split(term)[0]
|
198 |
|
199 |
+
res.append((s, random.random()))
|
200 |
|
201 |
self.cache_hook.add_partial("generate_until", (context, gen_kwargs), s)
|
202 |
pbar.update(1)
|
src/backend/tasks/measurement_task_utils.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
from lm_eval.api.metrics import mean
|
3 |
+
|
4 |
+
|
5 |
+
def process_results_decorator(func):
|
6 |
+
# This decorator processes the results of a task before passing them to the original process_results function
|
7 |
+
@functools.wraps(func)
|
8 |
+
def wrapper(self, doc, results, *args, **kwargs):
|
9 |
+
# We process the results here
|
10 |
+
processed_results = [r[0] for r in results]
|
11 |
+
|
12 |
+
latency = sum([r[1] for r in results]) / len(results)
|
13 |
+
print(f"Average latency: {latency}")
|
14 |
+
|
15 |
+
# Now call the original process_results with the processed results
|
16 |
+
result_dict = func(self, doc, processed_results, *args, **kwargs)
|
17 |
+
result_dict["latency"] = latency
|
18 |
+
return result_dict
|
19 |
+
return wrapper
|
20 |
+
|
21 |
+
|
22 |
+
def aggregation_decorator(func):
|
23 |
+
@functools.wraps(func)
|
24 |
+
def wrapper(self, *args, **kwargs):
|
25 |
+
aggregation_list = func(self, *args, **kwargs)
|
26 |
+
aggregation_list["latency"] = mean
|
27 |
+
return aggregation_list
|
28 |
+
return wrapper
|
29 |
+
|
30 |
+
|
31 |
+
def higher_is_better_decorator(func):
|
32 |
+
@functools.wraps(func)
|
33 |
+
def wrapper(self, *args, **kwargs):
|
34 |
+
higher_is_better_dict = func(self, *args, **kwargs)
|
35 |
+
higher_is_better_dict["latency"] = False
|
36 |
+
return higher_is_better_dict
|
37 |
+
return wrapper
|
38 |
+
|
39 |
+
|
40 |
+
def measure_system_metrics(cls):
|
41 |
+
method_decorators = {
|
42 |
+
'process_results': [process_results_decorator],
|
43 |
+
'aggregation': [aggregation_decorator],
|
44 |
+
'higher_is_better': [higher_is_better_decorator],
|
45 |
+
}
|
46 |
+
for method_name, decorators in method_decorators.items():
|
47 |
+
if callable(getattr(cls, method_name, None)):
|
48 |
+
original_method = getattr(cls, method_name)
|
49 |
+
for decorator in reversed(decorators):
|
50 |
+
original_method = decorator(original_method)
|
51 |
+
setattr(cls, method_name, original_method)
|
52 |
+
return cls
|
src/backend/tasks/selfcheckgpt/task.py
CHANGED
@@ -12,8 +12,11 @@ from src.backend.envs import DEVICE
|
|
12 |
import spacy
|
13 |
from selfcheckgpt.modeling_selfcheck import SelfCheckMQAG, SelfCheckNLI, SelfCheckBERTScore, SelfCheckNgram
|
14 |
|
|
|
|
|
15 |
|
16 |
# @register_task("selfcheckgpt")
|
|
|
17 |
class SelfCheckGPT(ConfigurableTask):
|
18 |
VERSION = 0.0
|
19 |
DATASET_PATH = "potsawee/wiki_bio_gpt3_hallucination"
|
|
|
12 |
import spacy
|
13 |
from selfcheckgpt.modeling_selfcheck import SelfCheckMQAG, SelfCheckNLI, SelfCheckBERTScore, SelfCheckNgram
|
14 |
|
15 |
+
from src.backend.tasks.measurement_task_utils import measure_system_metrics
|
16 |
+
|
17 |
|
18 |
# @register_task("selfcheckgpt")
|
19 |
+
@measure_system_metrics
|
20 |
class SelfCheckGPT(ConfigurableTask):
|
21 |
VERSION = 0.0
|
22 |
DATASET_PATH = "potsawee/wiki_bio_gpt3_hallucination"
|