Spaces:
Sleeping
Sleeping
Upload 14 files
Browse files- README.md +9 -13
- app.py +97 -0
- code_executor.py +111 -0
- data_loader.py +78 -0
- data_utils.py +358 -0
- evaluate.py +71 -0
- grader.py +305 -0
- main.py +468 -0
- prompt_templates/reward_template.md +44 -0
- prompt_templates/sage_template.md +45 -0
- prompt_templates/swift_template.md +152 -0
- run_eval.sh +6 -0
- test.py +28 -0
- utils.py +260 -0
README.md
CHANGED
@@ -1,13 +1,9 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
license: apache-2.0
|
11 |
-
---
|
12 |
-
|
13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
+
## 🤖 SwiftSage (v2):
|
2 |
+
|
3 |
+
> [!IMPORTANT]
|
4 |
+
> The code of SwiftSage v1 (for the experiments in NeurIPS 2023) is archived in the [`science_world`](https://github.com/SwiftSage/SwiftSage/tree/science_world) branch.
|
5 |
+
|
6 |
+
|
7 |
+
<!-- Github Readme Important Callout box note -->
|
8 |
+
|
9 |
+
|
|
|
|
|
|
|
|
app.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
import json
|
4 |
+
import logging
|
5 |
+
import numpy as np
|
6 |
+
from utils import (PromptTemplate, api_configs, setup_logging)
|
7 |
+
from data_loader import load_data
|
8 |
+
from evaluate import evaluate
|
9 |
+
from main import SwiftSage, run_test, run_benchmark
|
10 |
+
import multiprocessing
|
11 |
+
|
12 |
+
|
13 |
+
|
14 |
+
def solve_problem(problem, max_iterations, reward_threshold, swift_model_id, sage_model_id, reward_model_id, use_retrieval, start_with_sage):
|
15 |
+
# Configuration for each LLM
|
16 |
+
max_iterations = int(max_iterations)
|
17 |
+
reward_threshold = int(reward_threshold)
|
18 |
+
|
19 |
+
swift_config = {
|
20 |
+
"model_id": swift_model_id,
|
21 |
+
"api_config": api_configs['Together']
|
22 |
+
}
|
23 |
+
|
24 |
+
reward_config = {
|
25 |
+
"model_id": reward_model_id,
|
26 |
+
"api_config": api_configs['Together']
|
27 |
+
}
|
28 |
+
|
29 |
+
sage_config = {
|
30 |
+
"model_id": sage_model_id,
|
31 |
+
"api_config": api_configs['Together']
|
32 |
+
}
|
33 |
+
|
34 |
+
# specify the path to the prompt templates
|
35 |
+
prompt_template_dir = './prompt_templates'
|
36 |
+
dataset = []
|
37 |
+
embeddings = [] # TODO: for retrieval augmentation (not implemented yet now)
|
38 |
+
s2 = SwiftSage(
|
39 |
+
dataset,
|
40 |
+
embeddings,
|
41 |
+
prompt_template_dir,
|
42 |
+
swift_config,
|
43 |
+
sage_config,
|
44 |
+
reward_config,
|
45 |
+
use_retrieval=use_retrieval,
|
46 |
+
start_with_sage=start_with_sage,
|
47 |
+
)
|
48 |
+
|
49 |
+
reasoning, solution = s2.solve(problem, max_iterations, reward_threshold)
|
50 |
+
solution = solution.replace("Answer (from running the code):\n ", " ")
|
51 |
+
return reasoning, solution
|
52 |
+
|
53 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
54 |
+
# gr.Markdown("## SwiftSage: A Multi-Agent Framework for Reasoning")
|
55 |
+
# use the html and center the title
|
56 |
+
gr.HTML("<h1 style='text-align: center;'>SwiftSage: A Multi-Agent Framework for Reasoning</h1>")
|
57 |
+
|
58 |
+
with gr.Row():
|
59 |
+
swift_model_id = gr.Textbox(label="😄 Swift Model ID", value="meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo")
|
60 |
+
reward_model_id = gr.Textbox(label="🤔 Feedback Model ID", value="meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo")
|
61 |
+
sage_model_id = gr.Textbox(label="😎 Sage Model ID", value="meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo")
|
62 |
+
# the following two should have a smaller width
|
63 |
+
|
64 |
+
with gr.Accordion(label="⚙️ Advanced Options", open=False):
|
65 |
+
with gr.Row():
|
66 |
+
with gr.Column():
|
67 |
+
max_iterations = gr.Textbox(label="Max Iterations", value="5")
|
68 |
+
reward_threshold = gr.Textbox(label="Reward Threshold", value="8")
|
69 |
+
# TODO: add top-p and temperature for each module for controlling
|
70 |
+
with gr.Column():
|
71 |
+
top_p_swift = gr.Textbox(label="Top-p for Swift", value="0.9")
|
72 |
+
temperature_swift = gr.Textbox(label="Temperature for Swift", value="0.7")
|
73 |
+
with gr.Column():
|
74 |
+
top_p_sage = gr.Textbox(label="Top-p for Sage", value="0.9")
|
75 |
+
temperature_sage = gr.Textbox(label="Temperature for Sage", value="0.7")
|
76 |
+
with gr.Column():
|
77 |
+
top_p_reward = gr.Textbox(label="Top-p for Feedback", value="0.9")
|
78 |
+
temperature_reward = gr.Textbox(label="Temperature for Feedback", value="0.7")
|
79 |
+
|
80 |
+
use_retrieval = gr.Checkbox(label="Use Retrieval Augmentation", value=False, visible=False)
|
81 |
+
start_with_sage = gr.Checkbox(label="Start with Sage", value=False, visible=False)
|
82 |
+
|
83 |
+
problem = gr.Textbox(label="Input your problem", value="How many letter r are there in the sentence 'My strawberry is so ridiculously red.'?", lines=2)
|
84 |
+
|
85 |
+
solve_button = gr.Button("🚀 Solve Problem")
|
86 |
+
reasoning_output = gr.Textbox(label="Reasoning steps with Code", interactive=False)
|
87 |
+
solution_output = gr.Textbox(label="Final answer", interactive=False)
|
88 |
+
|
89 |
+
solve_button.click(
|
90 |
+
solve_problem,
|
91 |
+
inputs=[problem, max_iterations, reward_threshold, swift_model_id, sage_model_id, reward_model_id, use_retrieval, start_with_sage],
|
92 |
+
outputs=[reasoning_output, solution_output]
|
93 |
+
)
|
94 |
+
|
95 |
+
if __name__ == '__main__':
|
96 |
+
multiprocessing.set_start_method('spawn')
|
97 |
+
demo.launch(share=False, show_api=False)
|
code_executor.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Source and credits: https://github.com/ZubinGou/math-evaluation-harness/blob/main/python_executor.py
|
3 |
+
|
4 |
+
We modified it to be more simple.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import io
|
8 |
+
import pickle
|
9 |
+
import traceback
|
10 |
+
from concurrent.futures import ProcessPoolExecutor, TimeoutError
|
11 |
+
from contextlib import redirect_stdout
|
12 |
+
|
13 |
+
|
14 |
+
class GenericRuntime:
|
15 |
+
GLOBAL_DICT = {}
|
16 |
+
LOCAL_DICT = None
|
17 |
+
HEADERS = []
|
18 |
+
|
19 |
+
def __init__(self):
|
20 |
+
self._global_vars = self.GLOBAL_DICT.copy()
|
21 |
+
self._local_vars = self.LOCAL_DICT.copy() if self.LOCAL_DICT else None
|
22 |
+
|
23 |
+
for c in self.HEADERS:
|
24 |
+
self.exec_code(c)
|
25 |
+
|
26 |
+
def exec_code(self, code_piece: str) -> None:
|
27 |
+
exec(code_piece, self._global_vars)
|
28 |
+
|
29 |
+
def eval_code(self, expr: str) -> any:
|
30 |
+
return eval(expr, self._global_vars)
|
31 |
+
|
32 |
+
def inject(self, var_dict):
|
33 |
+
self._global_vars.update(var_dict)
|
34 |
+
|
35 |
+
@property
|
36 |
+
def answer(self):
|
37 |
+
return self._global_vars['answer']
|
38 |
+
|
39 |
+
|
40 |
+
class PythonExecutor:
|
41 |
+
def __init__(
|
42 |
+
self,
|
43 |
+
runtime=None,
|
44 |
+
get_answer_symbol=None,
|
45 |
+
get_answer_expr=None,
|
46 |
+
get_answer_from_stdout=False,
|
47 |
+
timeout_length=5,
|
48 |
+
):
|
49 |
+
self.runtime = runtime if runtime else GenericRuntime()
|
50 |
+
self.answer_symbol = get_answer_symbol
|
51 |
+
self.get_answer_expr = get_answer_expr
|
52 |
+
self.get_answer_from_stdout = get_answer_from_stdout
|
53 |
+
self.timeout_length = timeout_length
|
54 |
+
|
55 |
+
def execute(self, code):
|
56 |
+
try:
|
57 |
+
if self.get_answer_from_stdout:
|
58 |
+
program_io = io.StringIO()
|
59 |
+
with redirect_stdout(program_io):
|
60 |
+
self.runtime.exec_code('\n'.join(code))
|
61 |
+
program_io.seek(0)
|
62 |
+
result = program_io.read()
|
63 |
+
elif self.answer_symbol:
|
64 |
+
self.runtime.exec_code('\n'.join(code))
|
65 |
+
result = self.runtime._global_vars[self.answer_symbol]
|
66 |
+
elif self.get_answer_expr:
|
67 |
+
self.runtime.exec_code('\n'.join(code))
|
68 |
+
result = self.runtime.eval_code(self.get_answer_expr)
|
69 |
+
else:
|
70 |
+
self.runtime.exec_code('\n'.join(code[:-1]))
|
71 |
+
result = self.runtime.eval_code(code[-1])
|
72 |
+
|
73 |
+
report = "Done"
|
74 |
+
pickle.dumps(result) # Serialization check
|
75 |
+
except Exception as e:
|
76 |
+
result = ''
|
77 |
+
report = str(e)
|
78 |
+
|
79 |
+
return result, report
|
80 |
+
|
81 |
+
def apply(self, code):
|
82 |
+
code_snippet = code.split('\n')
|
83 |
+
|
84 |
+
# Use ProcessPoolExecutor to enforce timeout
|
85 |
+
with ProcessPoolExecutor() as executor:
|
86 |
+
future = executor.submit(self.execute, code_snippet)
|
87 |
+
try:
|
88 |
+
result, report = future.result(timeout=self.timeout_length)
|
89 |
+
except TimeoutError:
|
90 |
+
result, report = "", "Timeout Error"
|
91 |
+
|
92 |
+
return result.strip(), report.strip()
|
93 |
+
|
94 |
+
|
95 |
+
# Example usage
|
96 |
+
if __name__ == "__main__":
|
97 |
+
executor = PythonExecutor(get_answer_from_stdout=True)
|
98 |
+
code = """
|
99 |
+
from sympy import Matrix
|
100 |
+
|
101 |
+
def null_space_basis():
|
102 |
+
A = Matrix([[3, 3, -1, -6], [9, -1, -8, -1], [7, 4, -2, -9]])
|
103 |
+
basis = A.nullspace()
|
104 |
+
return [v.evalf(3) for v in basis]
|
105 |
+
|
106 |
+
result = null_space_basis()
|
107 |
+
print(result)
|
108 |
+
"""
|
109 |
+
result, report = executor.apply(code)
|
110 |
+
print("Result:", result)
|
111 |
+
print("Report:", report)
|
data_loader.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
import random
|
5 |
+
from typing import Any, Iterable, Union
|
6 |
+
|
7 |
+
from datasets import Dataset, concatenate_datasets, load_dataset
|
8 |
+
|
9 |
+
from data_utils import (
|
10 |
+
lower_keys,
|
11 |
+
parse_question,
|
12 |
+
parse_ground_truth,
|
13 |
+
)
|
14 |
+
|
15 |
+
|
16 |
+
def load_jsonl(file):
|
17 |
+
with open(file, "r", encoding="utf-8") as f:
|
18 |
+
for line in f:
|
19 |
+
try:
|
20 |
+
yield json.loads(line)
|
21 |
+
except:
|
22 |
+
print("Error in loading:", line)
|
23 |
+
exit()
|
24 |
+
|
25 |
+
|
26 |
+
def load_data(
|
27 |
+
data_name,
|
28 |
+
split='test',
|
29 |
+
data_dir='./data',
|
30 |
+
num_test_sample=-1,
|
31 |
+
):
|
32 |
+
if data_name.lower() == "math":
|
33 |
+
data_name = 'MATH' # we use 500 problem test split in "Let's Verify Step-by-Step"
|
34 |
+
data_file = f"{data_dir}/{data_name}/{split}.jsonl"
|
35 |
+
if os.path.exists(data_file):
|
36 |
+
examples = list(load_jsonl(data_file))
|
37 |
+
else:
|
38 |
+
if data_name == "mmlu_stem":
|
39 |
+
dataset = load_dataset("hails/mmlu_no_train", 'all', split='test')
|
40 |
+
# only keep stem subjects
|
41 |
+
stem_subjects = ['abstract_algebra', 'astronomy', 'college_biology', 'college_chemistry',
|
42 |
+
'college_computer_science', 'college_mathematics', 'college_physics', 'computer_security',
|
43 |
+
'conceptual_physics', 'electrical_engineering', 'elementary_mathematics', 'high_school_biology',
|
44 |
+
'high_school_chemistry', 'high_school_computer_science', 'high_school_mathematics',
|
45 |
+
'high_school_physics', 'high_school_statistics', 'machine_learning']
|
46 |
+
dataset = dataset.rename_column("subject", "type")
|
47 |
+
dataset = dataset.filter(lambda x: x['type'] in stem_subjects)
|
48 |
+
elif data_name == "mathvista":
|
49 |
+
raise NotImplementedError(data_name)
|
50 |
+
elif data_name == "gpqa":
|
51 |
+
dataset = load_dataset("Idavidrein/gpqa", "gpqa_diamond", split="train")
|
52 |
+
elif data_name == "codeforces":
|
53 |
+
raise NotImplementedError(data_name)
|
54 |
+
else:
|
55 |
+
raise NotImplementedError(data_name)
|
56 |
+
|
57 |
+
examples = list(dataset)
|
58 |
+
examples = [lower_keys(example) for example in examples]
|
59 |
+
dataset = Dataset.from_list(examples)
|
60 |
+
os.makedirs(f"{data_dir}/{data_name}", exist_ok=True)
|
61 |
+
dataset.to_json(data_file)
|
62 |
+
|
63 |
+
# add 'idx' in the first column
|
64 |
+
if 'idx' not in examples[0]:
|
65 |
+
examples = [{'idx': i, **example} for i, example in enumerate(examples)]
|
66 |
+
|
67 |
+
# dedepulicate & sort
|
68 |
+
examples = sorted(examples, key=lambda x: x['idx'])
|
69 |
+
|
70 |
+
if num_test_sample > 0:
|
71 |
+
examples = examples[:num_test_sample]
|
72 |
+
|
73 |
+
return examples
|
74 |
+
|
75 |
+
|
76 |
+
if __name__ == "__main__":
|
77 |
+
examples = load_data("gpqa", "test")
|
78 |
+
print('test')
|
data_utils.py
ADDED
@@ -0,0 +1,358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Source and credits: https://github.com/ZubinGou/math-evaluation-harness/blob/main/python_executor.py
|
3 |
+
"""
|
4 |
+
import re
|
5 |
+
import regex
|
6 |
+
import sympy
|
7 |
+
from typing import TypeVar, Iterable, List, Union, Any, Dict
|
8 |
+
from word2number import w2n
|
9 |
+
from utils import *
|
10 |
+
|
11 |
+
|
12 |
+
def lower_keys(example):
|
13 |
+
new_example = {}
|
14 |
+
for key, value in example.items():
|
15 |
+
if key != key.lower():
|
16 |
+
new_key = key.lower()
|
17 |
+
new_example[new_key] = value
|
18 |
+
else:
|
19 |
+
new_example[key] = value
|
20 |
+
return new_example
|
21 |
+
|
22 |
+
|
23 |
+
def _fix_fracs(string):
|
24 |
+
substrs = string.split("\\frac")
|
25 |
+
new_str = substrs[0]
|
26 |
+
if len(substrs) > 1:
|
27 |
+
substrs = substrs[1:]
|
28 |
+
for substr in substrs:
|
29 |
+
new_str += "\\frac"
|
30 |
+
if len(substr) > 0 and substr[0] == "{":
|
31 |
+
new_str += substr
|
32 |
+
else:
|
33 |
+
try:
|
34 |
+
assert len(substr) >= 2
|
35 |
+
except:
|
36 |
+
return string
|
37 |
+
a = substr[0]
|
38 |
+
b = substr[1]
|
39 |
+
if b != "{":
|
40 |
+
if len(substr) > 2:
|
41 |
+
post_substr = substr[2:]
|
42 |
+
new_str += "{" + a + "}{" + b + "}" + post_substr
|
43 |
+
else:
|
44 |
+
new_str += "{" + a + "}{" + b + "}"
|
45 |
+
else:
|
46 |
+
if len(substr) > 2:
|
47 |
+
post_substr = substr[2:]
|
48 |
+
new_str += "{" + a + "}" + b + post_substr
|
49 |
+
else:
|
50 |
+
new_str += "{" + a + "}" + b
|
51 |
+
string = new_str
|
52 |
+
return string
|
53 |
+
|
54 |
+
|
55 |
+
def _fix_a_slash_b(string):
|
56 |
+
if len(string.split("/")) != 2:
|
57 |
+
return string
|
58 |
+
a = string.split("/")[0]
|
59 |
+
b = string.split("/")[1]
|
60 |
+
try:
|
61 |
+
if "sqrt" not in a:
|
62 |
+
a = int(a)
|
63 |
+
if "sqrt" not in b:
|
64 |
+
b = int(b)
|
65 |
+
assert string == "{}/{}".format(a, b)
|
66 |
+
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
|
67 |
+
return new_string
|
68 |
+
except:
|
69 |
+
return string
|
70 |
+
|
71 |
+
|
72 |
+
def _fix_sqrt(string):
|
73 |
+
_string = re.sub(r"\\sqrt(\w+)", r"\\sqrt{\1}", string)
|
74 |
+
return _string
|
75 |
+
|
76 |
+
|
77 |
+
def convert_word_number(text:str) -> str:
|
78 |
+
try:
|
79 |
+
text = str(w2n.word_to_num(text))
|
80 |
+
except:
|
81 |
+
pass
|
82 |
+
return text
|
83 |
+
|
84 |
+
# units mainly from MathQA
|
85 |
+
unit_texts = [
|
86 |
+
"east", "degree", "mph", "kmph", "ft", "m sqaure", " m east", "sq m", "deg", "mile",
|
87 |
+
"q .", "monkey", "prime", "ratio", "profit of rs", "rd", "o", "gm",
|
88 |
+
"p . m", "lb", "tile", "per", "dm", "lt", "gain", "ab", "way", "west",
|
89 |
+
"a .", "b .", "c .", "d .", "e .", "f .", "g .", "h .", "t", "a", "h",
|
90 |
+
"no change", "men", "soldier", "pie", "bc", "excess", "st",
|
91 |
+
"inches", "noon", "percent", "by", "gal", "kmh", "c", "acre", "rise",
|
92 |
+
"a . m", "th", "π r 2", "sq", "mark", "l", "toy", "coin",
|
93 |
+
"sq . m", "gallon", "° f", "profit", "minw", "yr", "women",
|
94 |
+
"feet", "am", "pm", "hr", "cu cm", "square", "v â € ™", "are",
|
95 |
+
"rupee", "rounds", "cubic", "cc", "mtr", "s", "ohm", "number",
|
96 |
+
"kmph", "day", "hour", "minute", "min", "second", "man", "woman",
|
97 |
+
"sec", "cube", "mt", "sq inch", "mp", "∏ cm ³", "hectare", "more",
|
98 |
+
"sec", "unit", "cu . m", "cm 2", "rs .", "rs", "kg", "g", "month",
|
99 |
+
"km", "m", "cm", "mm", "apple", "liter", "loss", "yard",
|
100 |
+
"pure", "year", "increase", "decrease", "d", "less", "Surface",
|
101 |
+
"litre", "pi sq m", "s .", "metre", "meter", "inch",
|
102 |
+
]
|
103 |
+
|
104 |
+
unit_texts.extend([t + "s" for t in unit_texts])
|
105 |
+
|
106 |
+
def strip_string(string):
|
107 |
+
string = str(string).strip()
|
108 |
+
# linebreaks
|
109 |
+
string = string.replace("\n", "")
|
110 |
+
|
111 |
+
# right "."
|
112 |
+
string = string.rstrip(".")
|
113 |
+
|
114 |
+
# remove inverse spaces
|
115 |
+
# replace \\ with \
|
116 |
+
string = string.replace("\\!", "")
|
117 |
+
# string = string.replace("\\ ", "")
|
118 |
+
# string = string.replace("\\\\", "\\")
|
119 |
+
|
120 |
+
# matrix
|
121 |
+
string = re.sub(r'\\begin\{array\}\{.*?\}', r'\\begin{pmatrix}', string)
|
122 |
+
string = re.sub(r'\\end\{array\}', r'\\end{pmatrix}', string)
|
123 |
+
string = string.replace("bmatrix", "pmatrix")
|
124 |
+
|
125 |
+
|
126 |
+
# replace tfrac and dfrac with frac
|
127 |
+
string = string.replace("tfrac", "frac")
|
128 |
+
string = string.replace("dfrac", "frac")
|
129 |
+
|
130 |
+
# remove \left and \right
|
131 |
+
string = string.replace("\\left", "")
|
132 |
+
string = string.replace("\\right", "")
|
133 |
+
string = string.replace("\\{", "{")
|
134 |
+
string = string.replace("\\}", "}")
|
135 |
+
|
136 |
+
# Remove unit: miles, dollars if after is not none
|
137 |
+
_string = re.sub(r"\\text{.*?}$", "", string).strip()
|
138 |
+
if _string != "" and _string != string:
|
139 |
+
# print("Warning: unit not removed: '{}' -> '{}'".format(string, _string))
|
140 |
+
string = _string
|
141 |
+
|
142 |
+
# Remove unit: texts
|
143 |
+
for _ in range(2):
|
144 |
+
for unit_text in unit_texts:
|
145 |
+
# use regex, the prefix should be either the start of the string or a non-alphanumeric character
|
146 |
+
# the suffix should be either the end of the string or a non-alphanumeric character
|
147 |
+
_string = re.sub(r"(^|\W)" + unit_text + r"($|\W)", r"\1\2", string)
|
148 |
+
if _string != "":
|
149 |
+
string = _string
|
150 |
+
|
151 |
+
# Remove circ (degrees)
|
152 |
+
string = string.replace("^{\\circ}", "")
|
153 |
+
string = string.replace("^\\circ", "")
|
154 |
+
|
155 |
+
# remove dollar signs
|
156 |
+
string = string.replace("\\$", "")
|
157 |
+
string = string.replace("$", "")
|
158 |
+
|
159 |
+
# convert word number to digit
|
160 |
+
string = convert_word_number(string)
|
161 |
+
|
162 |
+
# replace "\\text{...}" to "..."
|
163 |
+
string = re.sub(r"\\text\{(.*?)\}", r"\1", string)
|
164 |
+
for key in ['x=', 'y=', 'z=', 'x\\in', 'y\\in', 'z\\in', 'x\\to', 'y\\to', 'z\\to']:
|
165 |
+
string = string.replace(key, "")
|
166 |
+
string = string.replace("\\emptyset", r"{}")
|
167 |
+
string = string.replace("(-\\infty,\\infty)", "\\mathbb{R}")
|
168 |
+
|
169 |
+
# remove percentage
|
170 |
+
string = string.replace("\\%", "")
|
171 |
+
string = string.replace("\%", "")
|
172 |
+
string = string.replace("%", "")
|
173 |
+
|
174 |
+
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
|
175 |
+
string = string.replace(" .", " 0.")
|
176 |
+
string = string.replace("{.", "{0.")
|
177 |
+
|
178 |
+
# cdot
|
179 |
+
# string = string.replace("\\cdot", "")
|
180 |
+
if string.startswith("{") and string.endswith("}") and string.isalnum() or \
|
181 |
+
string.startswith("(") and string.endswith(")") and string.isalnum() or \
|
182 |
+
string.startswith("[") and string.endswith("]") and string.isalnum():
|
183 |
+
string = string[1:-1]
|
184 |
+
|
185 |
+
# inf
|
186 |
+
string = string.replace("infinity", "\\infty")
|
187 |
+
if "\\infty" not in string:
|
188 |
+
string = string.replace("inf", "\\infty")
|
189 |
+
string = string.replace("+\\inity", "\\infty")
|
190 |
+
|
191 |
+
# and
|
192 |
+
string = string.replace("and", "")
|
193 |
+
string = string.replace("\\mathbf", "")
|
194 |
+
|
195 |
+
# use regex to remove \mbox{...}
|
196 |
+
string = re.sub(r"\\mbox{.*?}", "", string)
|
197 |
+
|
198 |
+
# quote
|
199 |
+
string.replace("'", "")
|
200 |
+
string.replace("\"", "")
|
201 |
+
|
202 |
+
# i, j
|
203 |
+
if "j" in string and "i" not in string:
|
204 |
+
string = string.replace("j", "i")
|
205 |
+
|
206 |
+
# replace a.000b where b is not number or b is end, with ab, use regex
|
207 |
+
string = re.sub(r"(\d+)\.0*([^\d])", r"\1\2", string)
|
208 |
+
string = re.sub(r"(\d+)\.0*$", r"\1", string)
|
209 |
+
|
210 |
+
# if empty, return empty string
|
211 |
+
if len(string) == 0:
|
212 |
+
return string
|
213 |
+
if string[0] == ".":
|
214 |
+
string = "0" + string
|
215 |
+
|
216 |
+
# to consider: get rid of e.g. "k = " or "q = " at beginning
|
217 |
+
if len(string.split("=")) == 2:
|
218 |
+
if len(string.split("=")[0]) <= 2:
|
219 |
+
string = string.split("=")[1]
|
220 |
+
|
221 |
+
string = _fix_sqrt(string)
|
222 |
+
string = string.replace(" ", "")
|
223 |
+
|
224 |
+
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
|
225 |
+
string = _fix_fracs(string)
|
226 |
+
|
227 |
+
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
|
228 |
+
string = _fix_a_slash_b(string)
|
229 |
+
|
230 |
+
return string
|
231 |
+
|
232 |
+
|
233 |
+
def extract_multi_choice_answer(pred_str):
|
234 |
+
# TODO: SFT models
|
235 |
+
if 'Problem:' in pred_str:
|
236 |
+
pred_str = pred_str.split("Problem:", 1)[0]
|
237 |
+
pred_str = pred_str.replace("choice is", "answer is")
|
238 |
+
patt = regex.search(r"answer is \(?(?P<ans>[abcde])\)?", pred_str.lower())
|
239 |
+
if patt is not None:
|
240 |
+
return patt.group('ans').upper()
|
241 |
+
return 'placeholder'
|
242 |
+
|
243 |
+
|
244 |
+
def extract_answer(pred_str, data_name):
|
245 |
+
if data_name in ["mmlu_stem", "sat_math", "mathqa"]:
|
246 |
+
return extract_multi_choice_answer(pred_str)
|
247 |
+
|
248 |
+
if 'final answer is $' in pred_str and '$. I hope' in pred_str:
|
249 |
+
# minerva_math
|
250 |
+
tmp = pred_str.split('final answer is $', 1)[1]
|
251 |
+
pred = tmp.split('$. I hope', 1)[0].strip()
|
252 |
+
elif 'boxed' in pred_str:
|
253 |
+
ans = pred_str.split('boxed')[-1]
|
254 |
+
if len(ans) == 0:
|
255 |
+
return ""
|
256 |
+
elif ans[0] == '{':
|
257 |
+
stack = 1
|
258 |
+
a = ''
|
259 |
+
for c in ans[1:]:
|
260 |
+
if (c == '{'):
|
261 |
+
stack += 1
|
262 |
+
a += c
|
263 |
+
elif (c == '}'):
|
264 |
+
stack -= 1
|
265 |
+
if (stack == 0): break
|
266 |
+
a += c
|
267 |
+
else:
|
268 |
+
a += c
|
269 |
+
else:
|
270 |
+
a = ans.split('$')[0].strip()
|
271 |
+
pred = a
|
272 |
+
elif ('he answer is' in pred_str):
|
273 |
+
pred = pred_str.split('he answer is')[-1].strip()
|
274 |
+
elif ('final answer is' in pred_str):
|
275 |
+
pred = pred_str.split('final answer is')[-1].strip()
|
276 |
+
# elif extract_program_output(pred_str) != "":
|
277 |
+
# fall back to program
|
278 |
+
# pred = extract_program_output(pred_str)
|
279 |
+
else: # use the last number
|
280 |
+
pattern = '-?\d*\.?\d+'
|
281 |
+
pred = re.findall(pattern, pred_str.replace(",", ""))
|
282 |
+
if(len(pred) >= 1):
|
283 |
+
pred = pred[-1]
|
284 |
+
else: pred = ''
|
285 |
+
|
286 |
+
# multiple line
|
287 |
+
# pred = pred.split("\n")[0]
|
288 |
+
pred = re.sub(r"\n\s*", "", pred)
|
289 |
+
if pred != "" and pred[0] == ":":
|
290 |
+
pred = pred[1:]
|
291 |
+
if pred != "" and pred[-1] == ".":
|
292 |
+
pred = pred[:-1]
|
293 |
+
if pred != "" and pred[-1] == "/":
|
294 |
+
pred = pred[:-1]
|
295 |
+
pred = strip_string(pred)
|
296 |
+
return pred
|
297 |
+
|
298 |
+
|
299 |
+
def parse_ground_truth(example: Dict[str, Any], data_name):
|
300 |
+
# parse ground truth
|
301 |
+
if data_name in ["MATH", "math", "math_oai", "minerva_math", "ocw", "amps", "hungarian_exam"]:
|
302 |
+
gt_ans = example['answer']
|
303 |
+
elif data_name == "gsm8k":
|
304 |
+
gt_ans = example['answer'].split("####")[-1]
|
305 |
+
elif data_name == "mmlu_stem":
|
306 |
+
abcd = 'ABCD'
|
307 |
+
gt_ans = abcd[example['answer']]
|
308 |
+
elif data_name == "gpqa":
|
309 |
+
gt_ans = example['correct answer']
|
310 |
+
else:
|
311 |
+
raise NotImplementedError(f"`{data_name}`")
|
312 |
+
# post process
|
313 |
+
gt_ans = strip_string(gt_ans)
|
314 |
+
return gt_ans
|
315 |
+
|
316 |
+
|
317 |
+
def parse_question(example, data_name):
|
318 |
+
question = ""
|
319 |
+
if data_name == "mmlu_stem":
|
320 |
+
options = example['choices']
|
321 |
+
assert len(options) == 4
|
322 |
+
for i, (label, option) in enumerate(zip('ABCD', options)):
|
323 |
+
options[i] = f"({label}) {str(option).strip()}"
|
324 |
+
options = ", ".join(options)
|
325 |
+
question = f"{example['question'].strip()}\nWhat of the following is the right choice? Explain your answer.\n{options}"
|
326 |
+
else:
|
327 |
+
for key in ['question', 'problem', 'Question', 'input']:
|
328 |
+
if key in example:
|
329 |
+
question = example[key]
|
330 |
+
break
|
331 |
+
assert question != ""
|
332 |
+
# Yes or No question
|
333 |
+
gt_ans = parse_ground_truth(example, data_name)
|
334 |
+
gt_lower = gt_ans.lower()
|
335 |
+
if gt_lower in ["true", "false"]:
|
336 |
+
question += " (True or False)"
|
337 |
+
if gt_lower in ["yes", "no"]:
|
338 |
+
question += " (Yes or No)"
|
339 |
+
return question.strip()
|
340 |
+
|
341 |
+
|
342 |
+
def _test_extract_answer():
|
343 |
+
text= """
|
344 |
+
The answer is $\\boxed{\left(
|
345 |
+
\\begin{array}{ccc}
|
346 |
+
-13 & 4 & -2 \\\\
|
347 |
+
7 & 8 & -3 \\\\
|
348 |
+
0 & 18 & -7 \\\\
|
349 |
+
6 & 12 & 5 \\\\
|
350 |
+
\\end{array}
|
351 |
+
\\right)}$.
|
352 |
+
"""
|
353 |
+
print(extract_answer(text, "math"))
|
354 |
+
# should output a dict
|
355 |
+
|
356 |
+
|
357 |
+
if __name__ == "__main__":
|
358 |
+
_test_extract_answer()
|
evaluate.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Source and credits: https://github.com/ZubinGou/math-evaluation-harness/blob/main/python_executor.py
|
3 |
+
"""
|
4 |
+
import argparse
|
5 |
+
import json
|
6 |
+
from concurrent.futures import TimeoutError
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
from pebble import ProcessPool
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
from grader import math_equal_process
|
13 |
+
|
14 |
+
|
15 |
+
def evaluate(samples: list=None, file_path: str=None):
|
16 |
+
assert samples or file_path, "samples or file_path must be provided"
|
17 |
+
if not samples:
|
18 |
+
with open(file_path, 'r') as f:
|
19 |
+
samples = [json.loads(line) for line in f]
|
20 |
+
|
21 |
+
# dedup by idx
|
22 |
+
if 'idx' in samples[0]:
|
23 |
+
samples = {sample['idx']: sample for sample in samples}.values()
|
24 |
+
samples = sorted(samples, key=lambda x: x['idx'])
|
25 |
+
else:
|
26 |
+
samples = [dict(idx=idx, **sample) for idx, sample in enumerate(samples)]
|
27 |
+
|
28 |
+
params = [(idx, sample['pred'], sample['gt']) for idx, sample in enumerate(samples)]
|
29 |
+
|
30 |
+
scores = []
|
31 |
+
timeout_cnt = 0
|
32 |
+
|
33 |
+
with ProcessPool() as pool:
|
34 |
+
future = pool.map(math_equal_process, params, timeout=3)
|
35 |
+
iterator = future.result()
|
36 |
+
with tqdm(total=len(samples), desc="Evaluate") as progress_bar:
|
37 |
+
while True:
|
38 |
+
try:
|
39 |
+
result = next(iterator)
|
40 |
+
scores.append(result)
|
41 |
+
except StopIteration:
|
42 |
+
break
|
43 |
+
except TimeoutError as error:
|
44 |
+
print(error)
|
45 |
+
scores.append(False)
|
46 |
+
timeout_cnt += 1
|
47 |
+
except Exception as error:
|
48 |
+
print(error.traceback)
|
49 |
+
exit()
|
50 |
+
progress_bar.update(1)
|
51 |
+
|
52 |
+
assert len(samples) == len(scores)
|
53 |
+
|
54 |
+
for i in range(len(samples)):
|
55 |
+
samples[i]['score'] = scores[i]
|
56 |
+
|
57 |
+
mean_score = np.round(np.mean([score for score in scores if score is not False]), decimals=2)
|
58 |
+
|
59 |
+
result_json = {
|
60 |
+
"num_samples": len(samples),
|
61 |
+
"num_scores": len(scores),
|
62 |
+
"timeout_samples": timeout_cnt,
|
63 |
+
"acc": mean_score
|
64 |
+
}
|
65 |
+
|
66 |
+
return samples, result_json
|
67 |
+
|
68 |
+
|
69 |
+
if __name__ == "__main__":
|
70 |
+
samples, results_json = evaluate(file_path="output/MATH.jsonl")
|
71 |
+
print('test')
|
grader.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Source and credits: https://github.com/ZubinGou/math-evaluation-harness/blob/main/python_executor.py
|
3 |
+
This logic is largely copied from the Hendrycks' MATH release (math_equivalence), and borrowed from:
|
4 |
+
- https://github.com/microsoft/ProphetNet/tree/master/CRITIC
|
5 |
+
- https://github.com/openai/prm800k
|
6 |
+
- https://github.com/microsoft/ToRA/blob/main/src/eval/grader.py
|
7 |
+
- https://github.com/deepseek-ai/DeepSeek-Math/blob/main/evaluation/eval/eval_utils.py
|
8 |
+
"""
|
9 |
+
import re
|
10 |
+
import regex
|
11 |
+
import multiprocessing
|
12 |
+
from math import isclose
|
13 |
+
from typing import Union
|
14 |
+
|
15 |
+
from sympy import simplify, N
|
16 |
+
from sympy.parsing.sympy_parser import parse_expr
|
17 |
+
from sympy.parsing.latex import parse_latex
|
18 |
+
from latex2sympy2 import latex2sympy
|
19 |
+
|
20 |
+
|
21 |
+
def parse_digits(num):
|
22 |
+
num = regex.sub(',', '', str(num))
|
23 |
+
try:
|
24 |
+
return float(num)
|
25 |
+
except:
|
26 |
+
if num.endswith('%'):
|
27 |
+
num = num[:-1]
|
28 |
+
if num.endswith('\\'):
|
29 |
+
num = num[:-1]
|
30 |
+
try:
|
31 |
+
return float(num) / 100
|
32 |
+
except:
|
33 |
+
pass
|
34 |
+
return None
|
35 |
+
|
36 |
+
def is_digit(num):
|
37 |
+
# paired with parse_digits
|
38 |
+
return parse_digits(num) is not None
|
39 |
+
|
40 |
+
|
41 |
+
def str_to_pmatrix(input_str):
|
42 |
+
input_str = input_str.strip()
|
43 |
+
matrix_str = re.findall(r'\{.*,.*\}', input_str)
|
44 |
+
pmatrix_list = []
|
45 |
+
|
46 |
+
for m in matrix_str:
|
47 |
+
m = m.strip('{}')
|
48 |
+
pmatrix = r'\begin{pmatrix}' + m.replace(',', '\\') + r'\end{pmatrix}'
|
49 |
+
pmatrix_list.append(pmatrix)
|
50 |
+
|
51 |
+
return ', '.join(pmatrix_list)
|
52 |
+
|
53 |
+
|
54 |
+
def math_equal(prediction: Union[bool, float, str],
|
55 |
+
reference: Union[float, str],
|
56 |
+
include_percentage: bool = True,
|
57 |
+
is_close: bool = True,
|
58 |
+
timeout: bool = False,
|
59 |
+
) -> bool:
|
60 |
+
"""
|
61 |
+
Exact match of math if and only if:
|
62 |
+
1. numerical equal: both can convert to float and are equal
|
63 |
+
2. symbolic equal: both can convert to sympy expression and are equal
|
64 |
+
"""
|
65 |
+
# print("Judge:", prediction, reference)
|
66 |
+
if str(prediction) == str(reference):
|
67 |
+
return True
|
68 |
+
|
69 |
+
try: # 1. numerical equal
|
70 |
+
if is_digit(prediction) and is_digit(reference):
|
71 |
+
prediction = parse_digits(prediction)
|
72 |
+
reference = parse_digits(reference)
|
73 |
+
# number questions
|
74 |
+
if include_percentage:
|
75 |
+
gt_result = [reference / 100, reference, reference * 100]
|
76 |
+
else:
|
77 |
+
gt_result = [reference]
|
78 |
+
for item in gt_result:
|
79 |
+
try:
|
80 |
+
if is_close:
|
81 |
+
if numeric_equal(prediction, item):
|
82 |
+
return True
|
83 |
+
else:
|
84 |
+
if item == prediction:
|
85 |
+
return True
|
86 |
+
except Exception:
|
87 |
+
continue
|
88 |
+
return False
|
89 |
+
except:
|
90 |
+
pass
|
91 |
+
|
92 |
+
if not prediction and prediction not in [0, False]:
|
93 |
+
return False
|
94 |
+
# print("try math_eval")
|
95 |
+
|
96 |
+
# 2. symbolic equal
|
97 |
+
reference = str(reference).strip()
|
98 |
+
prediction = str(prediction).strip()
|
99 |
+
|
100 |
+
## pmatrix (amps)
|
101 |
+
if "pmatrix" in prediction and not 'pmatrix' in reference:
|
102 |
+
reference = str_to_pmatrix(reference)
|
103 |
+
|
104 |
+
## deal with [], (), {}
|
105 |
+
pred_str, ref_str = prediction, reference
|
106 |
+
if (prediction.startswith("[") and prediction.endswith("]") and not reference.startswith("(")) or \
|
107 |
+
(prediction.startswith("(") and prediction.endswith(")") and not reference.startswith("[")):
|
108 |
+
pred_str = pred_str.strip("[]()")
|
109 |
+
ref_str = ref_str.strip("[]()")
|
110 |
+
for s in ['{', "}", "(", ")"]:
|
111 |
+
ref_str = ref_str.replace(s, "")
|
112 |
+
pred_str = pred_str.replace(s, "")
|
113 |
+
if pred_str.lower() == ref_str.lower():
|
114 |
+
return True
|
115 |
+
|
116 |
+
## [a, b] vs. [c, d], return a==c and b==d
|
117 |
+
if regex.match(r'(\(|\[).+(\)|\])', prediction) is not None and regex.match(r'(\(|\[).+(\)|\])', reference) is not None:
|
118 |
+
pred_parts = prediction[1:-1].split(",")
|
119 |
+
ref_parts = reference[1:-1].split(",")
|
120 |
+
if len(pred_parts) == len(ref_parts):
|
121 |
+
if all([math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close) for i in range(len(pred_parts))]):
|
122 |
+
return True
|
123 |
+
if (prediction.startswith("\\begin{pmatrix}") or prediction.startswith("\\begin{bmatrix}")) and (prediction.endswith("\\end{pmatrix}") or prediction.endswith("\\end{bmatrix}")) and \
|
124 |
+
(reference.startswith("\\begin{pmatrix}") or reference.startswith("\\begin{bmatrix}")) and (reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}")):
|
125 |
+
pred_lines = [line.strip() for line in prediction[len("\\begin{pmatrix}"): -len("\\end{pmatrix}")].split("\\\\") if line.strip()]
|
126 |
+
ref_lines = [line.strip() for line in reference[len("\\begin{pmatrix}"): -len("\\end{pmatrix}")].split("\\\\") if line.strip()]
|
127 |
+
matched = True
|
128 |
+
if len(pred_lines) == len(ref_lines):
|
129 |
+
for pred_line, ref_line in zip(pred_lines, ref_lines):
|
130 |
+
pred_parts = pred_line.split("&")
|
131 |
+
ref_parts = ref_line.split("&")
|
132 |
+
if len(pred_parts) == len(ref_parts):
|
133 |
+
if not all([math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close) for i in range(len(pred_parts))]):
|
134 |
+
matched = False
|
135 |
+
break
|
136 |
+
else:
|
137 |
+
matched = False
|
138 |
+
if not matched:
|
139 |
+
break
|
140 |
+
else:
|
141 |
+
matched = False
|
142 |
+
if matched:
|
143 |
+
return True
|
144 |
+
|
145 |
+
if prediction.count('=') == 1 and reference.count('=') == 1:
|
146 |
+
pred = prediction.split('=')
|
147 |
+
pred = f"{pred[0].strip()} - ({pred[1].strip()})"
|
148 |
+
ref = reference.split('=')
|
149 |
+
ref = f"{ref[0].strip()} - ({ref[1].strip()})"
|
150 |
+
if symbolic_equal(pred, ref) or symbolic_equal(f"-({pred})", ref):
|
151 |
+
return True
|
152 |
+
elif prediction.count('=') == 1 and len(prediction.split('=')[0].strip()) <= 2 and '=' not in reference:
|
153 |
+
if math_equal(prediction.split('=')[1], reference, include_percentage, is_close):
|
154 |
+
return True
|
155 |
+
elif reference.count('=') == 1 and len(reference.split('=')[0].strip()) <= 2 and '=' not in prediction:
|
156 |
+
if math_equal(prediction, reference.split('=')[1], include_percentage, is_close):
|
157 |
+
return True
|
158 |
+
|
159 |
+
# print("try final")
|
160 |
+
# symbolic equal with sympy
|
161 |
+
if timeout:
|
162 |
+
if call_with_timeout(symbolic_equal_process, prediction, reference):
|
163 |
+
return True
|
164 |
+
else:
|
165 |
+
if symbolic_equal(prediction, reference):
|
166 |
+
return True
|
167 |
+
|
168 |
+
return False
|
169 |
+
|
170 |
+
|
171 |
+
def math_equal_process(param):
|
172 |
+
return math_equal(param[-2], param[-1])
|
173 |
+
|
174 |
+
|
175 |
+
def numeric_equal(prediction: float, reference: float):
|
176 |
+
# Note that relative tolerance has significant impact
|
177 |
+
# on the result of the synthesized gsm_hard dataset
|
178 |
+
# if reference.is_integer():
|
179 |
+
# return isclose(reference, round(prediction), abs_tol=1e-4)
|
180 |
+
# else:
|
181 |
+
# prediction = round(prediction, len(str(reference).split(".")[-1]))
|
182 |
+
return isclose(reference, prediction, rel_tol=1e-4)
|
183 |
+
|
184 |
+
|
185 |
+
def symbolic_equal(a, b):
|
186 |
+
def _parse(s):
|
187 |
+
for f in [parse_latex, parse_expr, latex2sympy]:
|
188 |
+
try:
|
189 |
+
return f(s.replace("\\\\", "\\"))
|
190 |
+
except:
|
191 |
+
try:
|
192 |
+
return f(s)
|
193 |
+
except:
|
194 |
+
pass
|
195 |
+
return s
|
196 |
+
a = _parse(a)
|
197 |
+
b = _parse(b)
|
198 |
+
|
199 |
+
# direct equal
|
200 |
+
try:
|
201 |
+
if str(a) == str(b) or a == b:
|
202 |
+
return True
|
203 |
+
except:
|
204 |
+
pass
|
205 |
+
|
206 |
+
# print("try simplify")
|
207 |
+
# simplify equal
|
208 |
+
try:
|
209 |
+
if a.equals(b) or simplify(a-b) == 0:
|
210 |
+
return True
|
211 |
+
except:
|
212 |
+
pass
|
213 |
+
|
214 |
+
# print("try equation")
|
215 |
+
# equation equal
|
216 |
+
try:
|
217 |
+
if (abs(a.lhs - a.rhs)).equals(abs(b.lhs - b.rhs)):
|
218 |
+
return True
|
219 |
+
except:
|
220 |
+
pass
|
221 |
+
|
222 |
+
try:
|
223 |
+
if numeric_equal(float(N(a)), float(N(b))):
|
224 |
+
return True
|
225 |
+
except:
|
226 |
+
pass
|
227 |
+
|
228 |
+
# matrix
|
229 |
+
try:
|
230 |
+
# if a and b are matrix
|
231 |
+
if a.shape == b.shape:
|
232 |
+
_a = a.applyfunc(lambda x: round(x, 3))
|
233 |
+
_b = b.applyfunc(lambda x: round(x, 3))
|
234 |
+
if _a.equals(_b):
|
235 |
+
return True
|
236 |
+
except:
|
237 |
+
pass
|
238 |
+
|
239 |
+
return False
|
240 |
+
|
241 |
+
|
242 |
+
def symbolic_equal_process(a, b, output_queue):
|
243 |
+
result = symbolic_equal(a, b)
|
244 |
+
output_queue.put(result)
|
245 |
+
|
246 |
+
|
247 |
+
def call_with_timeout(func, *args, timeout=1, **kwargs):
|
248 |
+
output_queue = multiprocessing.Queue()
|
249 |
+
process_args = args + (output_queue,)
|
250 |
+
process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs)
|
251 |
+
process.start()
|
252 |
+
process.join(timeout)
|
253 |
+
|
254 |
+
if process.is_alive():
|
255 |
+
process.terminate()
|
256 |
+
process.join()
|
257 |
+
return False
|
258 |
+
|
259 |
+
return output_queue.get()
|
260 |
+
|
261 |
+
|
262 |
+
def _test_math_equal():
|
263 |
+
# print(math_equal("0.0833333333333333", "\\frac{1}{12}"))
|
264 |
+
# print(math_equal("(1,4.5)", "(1,\\frac{9}{2})"))
|
265 |
+
# print(math_equal("\\frac{x}{7}+\\frac{2}{7}", "\\frac{x+2}{7}", timeout=True))
|
266 |
+
# print(math_equal("\\sec^2(y)", "\\tan^2(y)+1", timeout=True))
|
267 |
+
# print(math_equal("\\begin{pmatrix}-\\frac{7}{4}&-2\\\\4&\\frac{1}{4}\\end{pmatrix}", "(\\begin{pmatrix}-\\frac{7}{4}&-2\\\\4&\\frac{1}{4}\\\\\\end{pmatrix})", timeout=True))
|
268 |
+
|
269 |
+
# pred = '\\begin{pmatrix}\\frac{1}{3x^{2/3}}&0&0\\\\0&1&0\\\\-\\sin(x)&0&0\\end{pmatrix}'
|
270 |
+
# gt = '(\\begin{pmatrix}\\frac{1}{3\\sqrt[3]{x}^2}&0&0\\\\0&1&0\\\\-\\sin(x)&0&0\\\\\\end{pmatrix})'
|
271 |
+
|
272 |
+
# pred= '-\\frac{8x^2}{9(x^2-2)^{5/3}}+\\frac{2}{3(x^2-2)^{2/3}}'
|
273 |
+
# gt= '-\\frac{2(x^2+6)}{9(x^2-2)\\sqrt[3]{x^2-2}^2}'
|
274 |
+
|
275 |
+
# pred = '-34x-45y+20z-100=0'
|
276 |
+
# gt = '34x+45y-20z+100=0'
|
277 |
+
|
278 |
+
# pred = '\\frac{100}{3}'
|
279 |
+
# gt = '33.3'
|
280 |
+
|
281 |
+
# pred = '\\begin{pmatrix}0.290243531202435\\\\0.196008371385084\\\\-0.186381278538813\\end{pmatrix}'
|
282 |
+
# gt = '(\\begin{pmatrix}0.29\\\\0.196\\\\-0.186\\\\\\end{pmatrix})'
|
283 |
+
|
284 |
+
# pred = '\\frac{\\sqrt{\\sqrt{11}+\\sqrt{194}}}{2\\sqrt{33}+15}'
|
285 |
+
# gt = '\\frac{\\sqrt{\\sqrt{11}+\\sqrt{194}}}{15+2\\sqrt{33}}'
|
286 |
+
|
287 |
+
# pred = '(+5)(b+2)'
|
288 |
+
# gt = '(a+5)(b+2)'
|
289 |
+
|
290 |
+
# pred = '\\frac{1+\\sqrt{5}}{2}'
|
291 |
+
# gt = '2'
|
292 |
+
|
293 |
+
# pred = '\\frac{34}{16}+\\frac{\\sqrt{1358}}{16}', gt = '4'
|
294 |
+
# pred = '1', gt = '1\\\\sqrt{19}'
|
295 |
+
|
296 |
+
pred = '(0.6,2.6667]'
|
297 |
+
gt = '(\\frac{3}{5},\\frac{8}{3}]'
|
298 |
+
|
299 |
+
print(math_equal(pred, gt, timeout=True))
|
300 |
+
|
301 |
+
|
302 |
+
if __name__ == "__main__":
|
303 |
+
_test_math_equal()
|
304 |
+
|
305 |
+
|
main.py
ADDED
@@ -0,0 +1,468 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import datetime
|
3 |
+
import json
|
4 |
+
import logging
|
5 |
+
import multiprocessing
|
6 |
+
import os
|
7 |
+
import re
|
8 |
+
from abc import ABC, abstractmethod
|
9 |
+
|
10 |
+
import hjson
|
11 |
+
import numpy as np
|
12 |
+
import openai
|
13 |
+
from tqdm import tqdm
|
14 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
15 |
+
|
16 |
+
from data_loader import load_data
|
17 |
+
from code_executor import PythonExecutor
|
18 |
+
from utils import (Agent, LLMClient, PromptTemplate, api_configs,
|
19 |
+
extract_and_parse_markup, setup_logging)
|
20 |
+
from data_utils import parse_question, parse_ground_truth
|
21 |
+
from evaluate import evaluate
|
22 |
+
|
23 |
+
|
24 |
+
logger = setup_logging()
|
25 |
+
|
26 |
+
class RetrievalAugmentation:
|
27 |
+
# TODO: implement the retrieval augmentation later
|
28 |
+
def __init__(self, dataset, embeddings):
|
29 |
+
self.dataset = dataset
|
30 |
+
self.embeddings = embeddings
|
31 |
+
|
32 |
+
def get_similar_examples(self, query_embedding, n=3):
|
33 |
+
similarities = cosine_similarity([query_embedding], self.embeddings)[0]
|
34 |
+
top_indices = similarities.argsort()[-n:][::-1]
|
35 |
+
return [self.dataset[i] for i in top_indices]
|
36 |
+
|
37 |
+
class SwiftAgent(Agent):
|
38 |
+
def __init__(self, prompt_template, llm_client, retrieval_augmentation=None):
|
39 |
+
super().__init__(prompt_template, llm_client)
|
40 |
+
self.retrieval_augmentation = retrieval_augmentation
|
41 |
+
self.plans = {}
|
42 |
+
self.codes = {}
|
43 |
+
|
44 |
+
def generate_response(self, prompt, reasoning, current_solution, plan, critical_feedback, prefill=True):
|
45 |
+
logger.info("SwiftAgent generating response")
|
46 |
+
if self.retrieval_augmentation:
|
47 |
+
query_embedding = self.get_query_embedding(prompt)
|
48 |
+
similar_examples = self.retrieval_augmentation.get_similar_examples(query_embedding)
|
49 |
+
examples_text = "\n".join(similar_examples) # TODO: add more context to the prompt
|
50 |
+
else:
|
51 |
+
examples_text = "No similar examples available."
|
52 |
+
|
53 |
+
swift_prompt = self.prompt_template.format(
|
54 |
+
"swift",
|
55 |
+
prompt=prompt,
|
56 |
+
current_reasoning=reasoning, # TODO: check if this is needed
|
57 |
+
examples=examples_text,
|
58 |
+
current_solution=current_solution,
|
59 |
+
critical_feedback=critical_feedback,
|
60 |
+
revised_plan=plan
|
61 |
+
)
|
62 |
+
# logger.info(f"SwiftAgent prompt:\n{swift_prompt}")
|
63 |
+
|
64 |
+
messages = [
|
65 |
+
{"role": "system", "content": ''},
|
66 |
+
{"role": "user", "content": swift_prompt}
|
67 |
+
]
|
68 |
+
if prefill:
|
69 |
+
messages.append({"role": "assistant", "content": "<plan>"}) # prefix-filling
|
70 |
+
|
71 |
+
response = self.llm_client.generate_response(messages)
|
72 |
+
if prefill:
|
73 |
+
response = "<plan>" + response
|
74 |
+
|
75 |
+
try:
|
76 |
+
parsed_response = extract_and_parse_markup(response)
|
77 |
+
return parsed_response
|
78 |
+
except json.JSONDecodeError:
|
79 |
+
logger.error("Error: Swift's response was not in valid JSON format. Returning raw response.")
|
80 |
+
return response
|
81 |
+
|
82 |
+
def get_query_embedding(self, query):
|
83 |
+
# Implement query embedding generation
|
84 |
+
return np.random.rand(768) # Placeholder, replace with actual embedding
|
85 |
+
|
86 |
+
class SageAgent(Agent):
|
87 |
+
def __init__(self, prompt_template, llm_client):
|
88 |
+
super().__init__(prompt_template, llm_client)
|
89 |
+
self.feedbacks = {}
|
90 |
+
self.plans = {}
|
91 |
+
|
92 |
+
|
93 |
+
def generate_response(self, prompt, reasoning, current_solution, prefill=True):
|
94 |
+
logger.info("SageAgent generating response")
|
95 |
+
sage_prompt = self.prompt_template.format(
|
96 |
+
"sage",
|
97 |
+
prompt=prompt,
|
98 |
+
reasoning=reasoning,
|
99 |
+
current_solution=current_solution
|
100 |
+
)
|
101 |
+
# logger.info(f"SageAgent prompt:\n{sage_prompt}")
|
102 |
+
|
103 |
+
messages = [
|
104 |
+
{"role": "system", "content": ""},
|
105 |
+
{"role": "user", "content": sage_prompt}
|
106 |
+
]
|
107 |
+
if prefill:
|
108 |
+
messages.append({"role": "assistant", "content": "<solved>"}) # prefix-filling
|
109 |
+
|
110 |
+
response = self.llm_client.generate_response(messages)
|
111 |
+
# logger.info(f"SageAgent raw response:\n{response}")
|
112 |
+
if prefill:
|
113 |
+
response = "<solved>" + response
|
114 |
+
try:
|
115 |
+
parsed_response = extract_and_parse_markup(response)
|
116 |
+
return parsed_response
|
117 |
+
except json.JSONDecodeError:
|
118 |
+
logger.error("Error: Sage's response was not in valid JSON format. Returning raw response.")
|
119 |
+
return response
|
120 |
+
|
121 |
+
class RewardModel:
|
122 |
+
def __init__(self, prompt_template, llm_client):
|
123 |
+
self.prompt_template = prompt_template
|
124 |
+
self.llm_client = llm_client
|
125 |
+
self.scores = []
|
126 |
+
self.feedbacks = []
|
127 |
+
self.stagnant_count = 0
|
128 |
+
|
129 |
+
def calculate_reward(self, problem, reasoning, current_solution, prefill=True):
|
130 |
+
reward_prompt = self.prompt_template.format(
|
131 |
+
"reward",
|
132 |
+
problem=problem,
|
133 |
+
reasoning= reasoning,
|
134 |
+
current_solution=current_solution
|
135 |
+
)
|
136 |
+
# logger.info(f"RewardModel prompt:\n{reward_prompt}")
|
137 |
+
|
138 |
+
messages = [
|
139 |
+
{"role": "system", "content": ""},
|
140 |
+
{"role": "user", "content": reward_prompt}
|
141 |
+
]
|
142 |
+
if prefill:
|
143 |
+
messages.append({"role": "assistant", "content": "<feedback>"}) # prefix-filling
|
144 |
+
|
145 |
+
reward_response = self.llm_client.generate_response(messages)
|
146 |
+
if prefill:
|
147 |
+
reward_response = "<feedback>" + reward_response
|
148 |
+
|
149 |
+
try:
|
150 |
+
parsed_response = extract_and_parse_markup(reward_response)
|
151 |
+
score = int(parsed_response["score"])
|
152 |
+
|
153 |
+
# Update stagnant_count based on score comparison
|
154 |
+
if len(self.scores) > 0 and score <= self.scores[-1]:
|
155 |
+
self.stagnant_count += 1
|
156 |
+
else:
|
157 |
+
self.stagnant_count = 0
|
158 |
+
|
159 |
+
return parsed_response
|
160 |
+
except json.JSONDecodeError:
|
161 |
+
logger.error("Error: Reward model's response was not in valid JSON format. Returning raw response.")
|
162 |
+
return reward_response
|
163 |
+
|
164 |
+
def should_consult_sage(self):
|
165 |
+
# This method remains unchanged
|
166 |
+
return self.stagnant_count >= 1 or (len(self.scores) > 0 and self.scores[-1] < 5)
|
167 |
+
|
168 |
+
class SwiftSage:
|
169 |
+
def __init__(self, dataset, embeddings, prompt_template_dir, swift_config, sage_config, reward_config, use_retrieval=True, start_with_sage=False):
|
170 |
+
prompt_template = PromptTemplate(prompt_template_dir)
|
171 |
+
retrieval_augmentation = RetrievalAugmentation(dataset, embeddings) if use_retrieval else None
|
172 |
+
|
173 |
+
# add logger to the following LLMClient
|
174 |
+
swift_llm = LLMClient(**swift_config, logger=logger)
|
175 |
+
sage_llm = LLMClient(**sage_config, logger=logger)
|
176 |
+
reward_llm = LLMClient(**reward_config, logger=logger)
|
177 |
+
|
178 |
+
self.swift = SwiftAgent(prompt_template, swift_llm, retrieval_augmentation)
|
179 |
+
self.sage = SageAgent(prompt_template, sage_llm)
|
180 |
+
self.reward_model = RewardModel(prompt_template, reward_llm)
|
181 |
+
self.start_with_sage = start_with_sage
|
182 |
+
# self.executor = PythonExecutor(get_answer_from_stdout=True)
|
183 |
+
|
184 |
+
def solve(self, problem, max_iterations=10, reward_threshold=8):
|
185 |
+
logger.info(f"Starting to solve problem: {problem}")
|
186 |
+
current_solution = "No current solution yet." # final answer
|
187 |
+
current_reasoning = "No reasoning steps yet." # reasoning steps
|
188 |
+
plan = "Initial plan: Take a deep breath and think step by step."
|
189 |
+
critical_feedback = "No critical feedback yet." # Initialize critical_feedback
|
190 |
+
solved = False
|
191 |
+
for i in range(max_iterations):
|
192 |
+
logger.info(f"Iteration {i+1}")
|
193 |
+
|
194 |
+
|
195 |
+
# Use the Sage Agent
|
196 |
+
if (i == 0 and self.start_with_sage) or self.reward_model.should_consult_sage():
|
197 |
+
sage_parsed = self.sage.generate_response(problem, current_reasoning, current_solution)
|
198 |
+
critical_feedback = sage_parsed["critical_feedback"]
|
199 |
+
# plan = "\n - " + "\n - ".join(sage_parsed["revised_plan"])
|
200 |
+
current_reasoning = sage_parsed["reasoning_steps"]
|
201 |
+
current_code = sage_parsed["code"]
|
202 |
+
|
203 |
+
solved = sage_parsed["solved"].lower() == "true" if i != 0 else sage_parsed["solved"]
|
204 |
+
if solved:
|
205 |
+
return current_reasoning, current_solution
|
206 |
+
logger.info(f"Sage's feedback (iteration {i+1}):\n{critical_feedback}")
|
207 |
+
# logger.info(f"Sage's reasoning steps:\n{current_reasoning}")
|
208 |
+
self.sage.feedbacks[i] = critical_feedback
|
209 |
+
|
210 |
+
# run the code
|
211 |
+
executor = PythonExecutor(get_answer_from_stdout=True)
|
212 |
+
code_result, code_report = executor.apply(current_code)
|
213 |
+
logger.info(f"Sage Code execution report: {code_report}")
|
214 |
+
logger.info(f"Sage Code execution result: {code_result}")
|
215 |
+
current_reasoning = current_reasoning + f"\n\nThe generated code is:\n\n```python\n{current_code}\n```"
|
216 |
+
current_solution = "Answer (from running the code):\n " + code_result
|
217 |
+
|
218 |
+
# current_solution = sage_parsed["final_answer"]
|
219 |
+
logger.info("Activated Sage, so we should return the reasoning and solution from Sage.")
|
220 |
+
return current_reasoning, current_solution
|
221 |
+
|
222 |
+
if not solved:
|
223 |
+
# Use the Swift Agent
|
224 |
+
swift_parsed = self.swift.generate_response(problem, current_reasoning, current_solution, plan, critical_feedback)
|
225 |
+
|
226 |
+
if "code" not in swift_parsed and "final_answer" not in swift_parsed:
|
227 |
+
logger.info("Swift's response does not contain the 'final_answer' or 'code' field. Returning raw response.")
|
228 |
+
self.reward_model.scores.append(0)
|
229 |
+
self.reward_model.feedbacks.append("No feedback")
|
230 |
+
self.reward_model.stagnant_count += max_iterations # force to use Sage Agent
|
231 |
+
continue
|
232 |
+
|
233 |
+
current_plan = swift_parsed["plan"]
|
234 |
+
current_code = swift_parsed["code"]
|
235 |
+
current_answer = swift_parsed.get("final_answer", None)
|
236 |
+
|
237 |
+
self.swift.plans[i] = current_plan
|
238 |
+
self.swift.codes[i] = current_code
|
239 |
+
|
240 |
+
logger.info(f"Swift's plan:\n{current_plan}")
|
241 |
+
logger.info(f"Swift's code:\n{current_code}")
|
242 |
+
|
243 |
+
# Call sandbox to run the code and get the result
|
244 |
+
executor = PythonExecutor(get_answer_from_stdout=True)
|
245 |
+
code_result, code_report = executor.apply(current_code)
|
246 |
+
logger.info(f"Code execution report: {code_report}")
|
247 |
+
logger.info(f"Code execution result: {code_result}")
|
248 |
+
|
249 |
+
current_reasoning = current_plan + f"\nThe generated code is:\n```python\n{current_code}\n```"
|
250 |
+
current_solution = "Answer (from running the code):\n " + code_result
|
251 |
+
|
252 |
+
# Calling the reward model to provide feedback and score
|
253 |
+
reward_parsed = self.reward_model.calculate_reward(problem, current_reasoning, current_solution)
|
254 |
+
score = int(reward_parsed["score"])
|
255 |
+
feedback = reward_parsed["feedback"]
|
256 |
+
prev_score = self.reward_model.scores[-1] if len(self.reward_model.scores) > 0 else 0
|
257 |
+
self.reward_model.scores.append(score)
|
258 |
+
self.reward_model.feedbacks.append(feedback)
|
259 |
+
|
260 |
+
# detect if the score is lower than the previous score
|
261 |
+
logger.info(f"Reward for iteration {i+1}: {score}/10")
|
262 |
+
logger.info(f"Feedback: {feedback}")
|
263 |
+
|
264 |
+
if False and score < prev_score:
|
265 |
+
logger.info("Score is lower than the previous score. Stopping the iteration. Reverting to the previous solution and reasoning.")
|
266 |
+
# revert to the previous solution and reasoning
|
267 |
+
current_solution = self.swift.codes[i-1]
|
268 |
+
current_reasoning = self.swift.plans[i-1]
|
269 |
+
continue
|
270 |
+
|
271 |
+
|
272 |
+
critical_feedback = feedback
|
273 |
+
|
274 |
+
|
275 |
+
if score >= reward_threshold or solved:
|
276 |
+
logger.info("Perfect solution found!")
|
277 |
+
return current_reasoning, current_solution
|
278 |
+
|
279 |
+
|
280 |
+
if self.reward_model.should_consult_sage():
|
281 |
+
logger.info("Reward model: The solution quality hasn't improved recently. Consulting Sage for the next iteration.")
|
282 |
+
|
283 |
+
logger.info("Max iterations reached without finding a perfect solution.")
|
284 |
+
logger.info("Problem solving completed")
|
285 |
+
return current_reasoning, current_solution
|
286 |
+
|
287 |
+
|
288 |
+
def run_test(swiftsage, problem, max_iterations=5, reward_threshold=8):
|
289 |
+
logger.info(f"Testing problem: {problem}")
|
290 |
+
reasoning, solution = swiftsage.solve(problem, max_iterations, reward_threshold)
|
291 |
+
logger.info(f"Final reasoning:\n{reasoning}")
|
292 |
+
logger.info(f"Final solution:\n{solution}")
|
293 |
+
logger.info("=" * 50)
|
294 |
+
|
295 |
+
|
296 |
+
def run_benchmark(swiftsage, args, max_iterations=5, reward_threshold=8):
|
297 |
+
examples = load_data(args.dataset_name, args.split, args.data_dir, args.num_test_sample)
|
298 |
+
|
299 |
+
res = []
|
300 |
+
skip_ids = []
|
301 |
+
|
302 |
+
output_path = os.path.join(args.output_path, f"{args.dataset_name}.jsonl")
|
303 |
+
if os.path.exists(output_path):
|
304 |
+
with open(output_path) as fr:
|
305 |
+
model_responses = fr.readlines()
|
306 |
+
|
307 |
+
for item in model_responses:
|
308 |
+
item = json.loads(item)
|
309 |
+
res.append(item)
|
310 |
+
skip_ids.append(item["idx"])
|
311 |
+
|
312 |
+
for example in tqdm(examples, desc=args.dataset_name):
|
313 |
+
if example["idx"] in skip_ids:
|
314 |
+
continue
|
315 |
+
question = parse_question(example, args.dataset_name)
|
316 |
+
gt_ans = parse_ground_truth(example, args.dataset_name)
|
317 |
+
reasoning, solution = swiftsage.solve(question, max_iterations, reward_threshold)
|
318 |
+
|
319 |
+
# TODO: extract answer from solution
|
320 |
+
|
321 |
+
cur_res = {
|
322 |
+
"idx": example["idx"],
|
323 |
+
"question": question,
|
324 |
+
"gt": gt_ans,
|
325 |
+
"pred": solution,
|
326 |
+
"reasoning": reasoning,
|
327 |
+
}
|
328 |
+
res.append(cur_res)
|
329 |
+
|
330 |
+
with open(output_path, "a") as fw:
|
331 |
+
fw.write(json.dumps(res[-1]) + "\n")
|
332 |
+
|
333 |
+
# Evaluate the results
|
334 |
+
res, result_metric = evaluate(res)
|
335 |
+
with open(args.output_path, f"{args.dataset_name}_score.jsonl", "w") as fw:
|
336 |
+
for item in res:
|
337 |
+
fw.write(json.dumps(item) + "\n")
|
338 |
+
with open(args.output_path, f"{args.dataset_name}_metric.jsonl", "w") as fw:
|
339 |
+
fw.write(json.dumps(result_metric) + "\n")
|
340 |
+
|
341 |
+
|
342 |
+
def main(args):
|
343 |
+
|
344 |
+
# TODO: for retrieval augmentation (not implemented yet now)
|
345 |
+
# dataset = ["Example problem 1: ...", "Example problem 2: ...", "Example problem 3: ..."]
|
346 |
+
# embeddings = np.random.rand(len(dataset), 768) # Placeholder, replace with actual embeddings
|
347 |
+
|
348 |
+
|
349 |
+
# Configuration for each LLM
|
350 |
+
# swift_config = {
|
351 |
+
# "model_id": "Meta-Llama-3.1-8B-Instruct",
|
352 |
+
# "api_config": api_configs['SambaNova']
|
353 |
+
# }
|
354 |
+
|
355 |
+
# reward_config = {
|
356 |
+
# "model_id": "Meta-Llama-3.1-70B-Instruct",
|
357 |
+
# "api_config": api_configs['SambaNova']
|
358 |
+
# }
|
359 |
+
|
360 |
+
# sage_config = {
|
361 |
+
# "model_id": "Meta-Llama-3.1-405B-Instruct",
|
362 |
+
# "api_config": api_configs['SambaNova']
|
363 |
+
# }
|
364 |
+
|
365 |
+
swift_config = {
|
366 |
+
"model_id": args.swift_model_id,
|
367 |
+
"api_config": api_configs[args.api_provider]
|
368 |
+
}
|
369 |
+
|
370 |
+
reward_config = {
|
371 |
+
"model_id": args.reward_model_id,
|
372 |
+
"api_config": api_configs[args.api_provider]
|
373 |
+
}
|
374 |
+
|
375 |
+
sage_config = {
|
376 |
+
"model_id": args.sage_model_id,
|
377 |
+
"api_config": api_configs[args.api_provider]
|
378 |
+
}
|
379 |
+
|
380 |
+
# specify the path to the prompt templates
|
381 |
+
prompt_template_dir = args.prompt_template_dir
|
382 |
+
dataset = []
|
383 |
+
embeddings = [] # TODO: for retrieval augmentation (not implemented yet now)
|
384 |
+
s2 = SwiftSage(
|
385 |
+
dataset,
|
386 |
+
embeddings,
|
387 |
+
prompt_template_dir,
|
388 |
+
swift_config,
|
389 |
+
sage_config,
|
390 |
+
reward_config,
|
391 |
+
use_retrieval=args.use_retrieval,
|
392 |
+
start_with_sage=args.start_with_sage,
|
393 |
+
)
|
394 |
+
|
395 |
+
if args.eval_mode == "test":
|
396 |
+
test_problems = [
|
397 |
+
"Solve the equation: 2x + 5 = 13", # 0
|
398 |
+
"If h(x)=x-4 and g(h(x))=x^2-8x+10, find g(x)? show the formula for g(x)", # 1
|
399 |
+
"Solve the equation: 6y + 5 = 29", # 2
|
400 |
+
"Who lives longer, Lowell Sherman or Jonathan Kaplan?", # 3
|
401 |
+
"9.9 or 9.11 -- which is bigger?", # 4
|
402 |
+
"How can you solve the quadratic equation 3x^2 + 7x + 4 = 0 using the quadratic formula?", # 5
|
403 |
+
"Explain why sound waves cannot travel in a vacuum?", # 6
|
404 |
+
"How many grams of hydrogen (H) are present in 23.5 grams of water (H2O)?", # 7
|
405 |
+
"What is the distance between the points (2, 3) and (5, 8)?", # 8
|
406 |
+
"Why can the Hubble telescope capture clear images of distant stars and galaxies, but not a detailed image of Pluto?", # 9
|
407 |
+
"""A rectangular band formation is a formation with $m$ band members in each of $r$ rows, where $m$ and $r$ are integers. A particular band has less than 100 band members. The director arranges them in a rectangular formation and finds that he has two members left over. If he increases the number of members in each row by 1 and reduces the number of rows by 2, there are exactly enough places in the new formation for each band member. What is the largest number of members the band could have?""",
|
408 |
+
"""Tim wants to invest some money in a bank which compounds quarterly with an annual interest rate of $7\%$. To the nearest dollar, how much money should he invest if he wants a total of $\$60,\!000$ at the end of $5$ years?""",
|
409 |
+
"""In an SR latch built from NOR gates, which condition is not allowed
|
410 |
+
|
411 |
+
Options:
|
412 |
+
[ "S=0, R=2", "S=2, R=2", "S=1, R=1", "S=1, R=-1", "S=1, R=2", "S=0, R=0", "S=2, R=0", "S=1, R=0", "S=2, R=1", "S=0, R=1" ]
|
413 |
+
|
414 |
+
Which one is the correct answer?""",
|
415 |
+
# ... add other problems here ...
|
416 |
+
"""How many letter r are there in the word "strawberry"?"""
|
417 |
+
]
|
418 |
+
|
419 |
+
# for problem in test_problems:
|
420 |
+
pid = 7
|
421 |
+
print(f"Problem {pid}: {test_problems[pid]}")
|
422 |
+
run_test(s2, test_problems[pid], args.max_iterations, args.reward_threshold)
|
423 |
+
elif args.eval_mode == "benchmark":
|
424 |
+
run_benchmark(s2, args, args.max_iterations, args.reward_threshold)
|
425 |
+
|
426 |
+
|
427 |
+
if __name__ == '__main__':
|
428 |
+
parser = argparse.ArgumentParser()
|
429 |
+
parser.add_argument("--eval_mode", default="test", choices=["test", "benchmark"], type=str)
|
430 |
+
|
431 |
+
parser.add_argument("--dataset_name", default="MATH", type=str)
|
432 |
+
parser.add_argument("--data_dir", default="./data", type=str)
|
433 |
+
parser.add_argument("--split", default="test", type=str)
|
434 |
+
parser.add_argument("--num_test_sample", default=-1, type=int) # -1 for full data
|
435 |
+
|
436 |
+
parser.add_argument("--api_provider", default="Together", choices=["Together", "SambaNova"], type=str)
|
437 |
+
parser.add_argument("--swift_model_id", default="meta-llama/Meta-Llama-3-8B-Instruct-Turbo", type=str)
|
438 |
+
parser.add_argument("--reward_model_id", default="meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", type=str)
|
439 |
+
parser.add_argument("--sage_model_id", default="meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo", type=str)
|
440 |
+
|
441 |
+
parser.add_argument("--prompt_template_dir", default='./prompt_templates', type=str)
|
442 |
+
parser.add_argument("--use_retrieval", action="store_true")
|
443 |
+
parser.add_argument("--start_with_sage", action="store_true")
|
444 |
+
|
445 |
+
parser.add_argument("--max_iterations", default=5, type=int)
|
446 |
+
parser.add_argument("--reward_threshold", default=8, type=int)
|
447 |
+
|
448 |
+
parser.add_argument("--save_outputs", action="store_true")
|
449 |
+
parser.add_argument("--output_path", default="./output", type=str)
|
450 |
+
parser.add_argument("--overwrite", action="store_true")
|
451 |
+
|
452 |
+
args = parser.parse_args()
|
453 |
+
|
454 |
+
# remove console output for benchmark evaluation
|
455 |
+
if args.eval_mode != "test":
|
456 |
+
root_logger = logging.getLogger("")
|
457 |
+
for handler in root_logger.handlers:
|
458 |
+
if isinstance(handler, logging.StreamHandler):
|
459 |
+
root_logger.removeHandler(handler)
|
460 |
+
break
|
461 |
+
|
462 |
+
if args.api_provider == "SambaNova":
|
463 |
+
args.swift_model_id = args.swift_model_id.split("/")[-1][:-len("Turbo")]
|
464 |
+
args.reward_model_id = args.reward_model_id.split("/")[-1][:-len("Turbo")]
|
465 |
+
args.sage_model_id = args.sage_model_id.split("/")[-1][:-len("Turbo")]
|
466 |
+
|
467 |
+
multiprocessing.set_start_method('spawn')
|
468 |
+
main(args)
|
prompt_templates/reward_template.md
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Instruction
|
2 |
+
|
3 |
+
You are a reward model. You will be given a problem, a solution. You will then evaluate the solution based on the criteria provided.
|
4 |
+
|
5 |
+
## Problem
|
6 |
+
<problem>
|
7 |
+
|
8 |
+
## Current Solution
|
9 |
+
|
10 |
+
### Reasoning Steps
|
11 |
+
<reasoning>
|
12 |
+
|
13 |
+
### Final Answer
|
14 |
+
<current_solution>
|
15 |
+
|
16 |
+
|
17 |
+
## Your Evaluation
|
18 |
+
|
19 |
+
We are not sure if the current solution is correct. Please evaluate the current solution based on the following criteria:
|
20 |
+
|
21 |
+
1. Correctness
|
22 |
+
2. Completeness
|
23 |
+
|
24 |
+
Provide a score from 1 to 10 and a brief explanation.
|
25 |
+
If you are not sure about the final answer, provide a score between 1 to 7 and explain why you are not sure about the final answer.
|
26 |
+
Take care and do not give false information in the critical feedback.
|
27 |
+
|
28 |
+
|
29 |
+
## Output Format
|
30 |
+
|
31 |
+
Remember to present your output in the following format:
|
32 |
+
|
33 |
+
<feedback>
|
34 |
+
Your critical feedback here.
|
35 |
+
</feedback>
|
36 |
+
|
37 |
+
|
38 |
+
<score>
|
39 |
+
Your score here.
|
40 |
+
</score>
|
41 |
+
|
42 |
+
# Important Notes
|
43 |
+
|
44 |
+
You must follow the format strictly, do not miss any field. Start your output by "<feedback>" and end your output by "</score>".
|
prompt_templates/sage_template.md
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Instruction
|
2 |
+
|
3 |
+
You are a high-level problem-solving agent. You will be given a problem and a current solution. You will then provide a critical feedback on the current solution and suggest a revised plan if needed.
|
4 |
+
If the current solution is correct and complete, you will suggest the problem is solved and no further action is needed.
|
5 |
+
|
6 |
+
## Problem
|
7 |
+
<prompt>
|
8 |
+
|
9 |
+
## Current Solution
|
10 |
+
|
11 |
+
### Reasoning Steps
|
12 |
+
<reasoning>
|
13 |
+
|
14 |
+
### Final Answer
|
15 |
+
<current_solution>
|
16 |
+
|
17 |
+
|
18 |
+
## Critical Feedback
|
19 |
+
|
20 |
+
We are not sure if the current solution is correct, can you provide a critical feedback on the current solution and suggest a revised plan for the next steps. Consider any challenges or improvements needed.
|
21 |
+
|
22 |
+
If the solution and answer are correct, please set `solved` to `"True"`, and leave `critical_feedback` and `reasoning_steps` empty.
|
23 |
+
Please point out the errors in the current solution if there are any in the `critical_feedback` field, and then provide the revised plan in the `reasoning_steps` field, and finally provide the final answer in the `final_answer` field.
|
24 |
+
|
25 |
+
|
26 |
+
Format your response in the following format:
|
27 |
+
|
28 |
+
|
29 |
+
<solved>
|
30 |
+
[True or False]
|
31 |
+
</solved>
|
32 |
+
|
33 |
+
<critical_feedback>
|
34 |
+
[Your critical feedback here.]
|
35 |
+
</critical_feedback>
|
36 |
+
|
37 |
+
<reasoning_steps>
|
38 |
+
[Put your reasoning steps here to revise the previous solution. Use additional knowledge if needed and then we will write the code to solve the problem in the next field.]
|
39 |
+
</reasoning_steps>
|
40 |
+
|
41 |
+
<code>
|
42 |
+
[Put your updated code here to solve the problem.]
|
43 |
+
</code>
|
44 |
+
|
45 |
+
|
prompt_templates/swift_template.md
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Instruction
|
2 |
+
|
3 |
+
## Similar Examples with Solutions
|
4 |
+
|
5 |
+
### Example Task 1
|
6 |
+
|
7 |
+
<task>
|
8 |
+
Convert the point $(0, -3 \sqrt{3}, 3)$ in rectangular coordinates to spherical coordinates. Enter your answer in the form $(\rho,\theta,\phi),$ where $\rho > 0,$ $0 \le \theta < 2 \pi,$ and $0 \le \phi \le \pi.$
|
9 |
+
</task>
|
10 |
+
|
11 |
+
<plan>
|
12 |
+
Step 1. Recall the formulas for converting from rectangular coordinates $(x, y, z)$ to spherical coordinates $(\rho, \theta, \phi)$:
|
13 |
+
- $\rho = \sqrt{x^2 + y^2 + z^2}$
|
14 |
+
- $\theta = \arctan2(y, x)$
|
15 |
+
- $\phi = \arccos\left(\frac{z}{\rho}\right)$
|
16 |
+
|
17 |
+
Step 2. Given point: $(0, -3\sqrt{3}, 3)$
|
18 |
+
$x = 0$
|
19 |
+
$y = -3\sqrt{3}$
|
20 |
+
$z = 3$
|
21 |
+
|
22 |
+
Step 3. Calculate $\rho$ using the formula.
|
23 |
+
|
24 |
+
Step 4. Calculate $\theta$:
|
25 |
+
- Since $x = 0$, we need to handle this special case.
|
26 |
+
- When $x = 0$ and $y < 0$, $\theta = \frac{3\pi}{2}$
|
27 |
+
|
28 |
+
Step 5. Calculate $\phi$ using the formula.
|
29 |
+
|
30 |
+
Step 6. Ensure $\theta$ is in the range $[0, 2\pi)$ and $\phi$ is in the range $[0, \pi]$.
|
31 |
+
</plan>
|
32 |
+
|
33 |
+
<code>
|
34 |
+
from sympy import sqrt, atan2, acos, pi
|
35 |
+
|
36 |
+
def rectangular_to_spherical():
|
37 |
+
x, y, z = 0, -3*sqrt(3), 3
|
38 |
+
rho = sqrt(x**2 + y**2 + z**2)
|
39 |
+
theta = atan2(y, x)
|
40 |
+
phi = acos(z/rho)
|
41 |
+
return rho, theta, phi
|
42 |
+
|
43 |
+
spherical_coordinates = rectangular_to_spherical()
|
44 |
+
print(spherical_coordinates)
|
45 |
+
</code>
|
46 |
+
|
47 |
+
|
48 |
+
<final_answer>
|
49 |
+
(6, -pi/2, pi/3)
|
50 |
+
</final_answer>
|
51 |
+
|
52 |
+
### Example Task 2
|
53 |
+
|
54 |
+
<task>
|
55 |
+
Determine who lived longer between Lowell Sherman and Jonathan Kaplan.
|
56 |
+
</task>
|
57 |
+
|
58 |
+
<plan>
|
59 |
+
Step 1: Research the birth and death dates of Lowell Sherman.
|
60 |
+
Step 2: Research the birth and death dates of Jonathan Kaplan.
|
61 |
+
Step 3: Calculate the lifespan of each person in years.
|
62 |
+
Step 4: Compare the lifespans to determine who lived longer.
|
63 |
+
</plan>
|
64 |
+
|
65 |
+
<code>
|
66 |
+
from datetime import datetime
|
67 |
+
|
68 |
+
def calculate_lifespan(birth_date, death_date):
|
69 |
+
birth = datetime.strptime(birth_date, "%Y-%m-%d")
|
70 |
+
death = datetime.strptime(death_date, "%Y-%m-%d")
|
71 |
+
return (death - birth).days / 365.25
|
72 |
+
|
73 |
+
def compare_lifespans():
|
74 |
+
lowell_sherman = calculate_lifespan("1885-10-11", "1934-12-28")
|
75 |
+
jonathan_kaplan = calculate_lifespan("1947-11-25", "2021-01-03")
|
76 |
+
|
77 |
+
if lowell_sherman > jonathan_kaplan:
|
78 |
+
return "Lowell Sherman"
|
79 |
+
elif jonathan_kaplan > lowell_sherman:
|
80 |
+
return "Jonathan Kaplan"
|
81 |
+
else:
|
82 |
+
return "They lived equally long"
|
83 |
+
|
84 |
+
result = compare_lifespans()
|
85 |
+
print(f"{result} lived longer.")
|
86 |
+
</code>
|
87 |
+
|
88 |
+
<final_answer>
|
89 |
+
Jonathan Kaplan lived longer.
|
90 |
+
</final_answer>
|
91 |
+
|
92 |
+
|
93 |
+
---
|
94 |
+
|
95 |
+
## Important Notes
|
96 |
+
|
97 |
+
Note that the above are some example tasks and output formats. You need to solve the current problem below.
|
98 |
+
|
99 |
+
---
|
100 |
+
|
101 |
+
## Current problem that we want to solve
|
102 |
+
<task>
|
103 |
+
<prompt>
|
104 |
+
</task>
|
105 |
+
|
106 |
+
## Previous Solution
|
107 |
+
|
108 |
+
### Previous Reasoning Steps
|
109 |
+
<plan>
|
110 |
+
<current_reasoning>
|
111 |
+
</plan>
|
112 |
+
|
113 |
+
### Previous Answer
|
114 |
+
<final_answer>
|
115 |
+
<current_solution>
|
116 |
+
</final_answer>
|
117 |
+
|
118 |
+
|
119 |
+
|
120 |
+
---
|
121 |
+
|
122 |
+
## Critical Feedback
|
123 |
+
<critical_feedback>
|
124 |
+
|
125 |
+
### Suggested Plan
|
126 |
+
<revised_plan>
|
127 |
+
|
128 |
+
---
|
129 |
+
|
130 |
+
## Your Final Solution
|
131 |
+
|
132 |
+
Read the current problem in <task>...</task> again.
|
133 |
+
|
134 |
+
<task>
|
135 |
+
<prompt>
|
136 |
+
</task>
|
137 |
+
|
138 |
+
To solve the current problem, you should first write the overall plan in <plan>...</plan> to solve the problem. Then, write python code in <code>...</code> tags to solve the problem. If there is critical feedback and suggested plan, please revise your previous solution (if any) and provide the new plan and solution to solve the problem based on the critical feedback and suggested plan.
|
139 |
+
|
140 |
+
## Remember to present your output in the following format:
|
141 |
+
|
142 |
+
<plan>
|
143 |
+
[Your general plan to solve the problem by using code. You can recall the required knowledge that you can use in the code, such as the facts, formulas, etc.]
|
144 |
+
</plan>
|
145 |
+
|
146 |
+
<code>
|
147 |
+
[Your python code to solve the current problem (instead of the example problems). Please print the final answer at the end of the code.]
|
148 |
+
</code>
|
149 |
+
|
150 |
+
You must follow the format strictly, do not miss any field.
|
151 |
+
Start your output by "<plan>...</plan>" and end your output by "<code> ... </code>".
|
152 |
+
|
run_eval.sh
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
DEBUG_MODE="-m debugpy --listen 127.0.0.1:5679 --wait-for-client"
|
2 |
+
|
3 |
+
python $DEBUG_MODE main.py \
|
4 |
+
--eval_mode benchmark \
|
5 |
+
--dataset_name MATH \
|
6 |
+
--num_test_sample 4 \
|
test.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from code_executor import PythonExecutor
|
2 |
+
import multiprocess
|
3 |
+
|
4 |
+
if __name__ == '__main__':
|
5 |
+
multiprocess.set_start_method('spawn')
|
6 |
+
|
7 |
+
current_code = """
|
8 |
+
```python
|
9 |
+
def calculate_hydrogen_mass(mass_of_water_grams):
|
10 |
+
mass_of_hydrogen = 1.00794 # g/mol
|
11 |
+
mass_of_water = 18.01528 # g/mol
|
12 |
+
ratio = (2 * mass_of_hydrogen) / mass_of_water
|
13 |
+
return ratio * mass_of_water_grams
|
14 |
+
|
15 |
+
mass_of_water = 23.5 # grams
|
16 |
+
hydrogen_mass = calculate_hydrogen_mass(mass_of_water)
|
17 |
+
|
18 |
+
print(hydrogen_mass)
|
19 |
+
```
|
20 |
+
"""
|
21 |
+
executor = PythonExecutor(get_answer_from_stdout=True)
|
22 |
+
result, report = executor.apply(current_code)
|
23 |
+
print("Result:", result)
|
24 |
+
print("Report:", report)
|
25 |
+
|
26 |
+
# Make sure to close the pool when done
|
27 |
+
executor.pool.close()
|
28 |
+
executor.pool.join()
|
utils.py
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import json
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
import re
|
6 |
+
from abc import ABC, abstractmethod
|
7 |
+
|
8 |
+
import dirtyjson
|
9 |
+
import hjson
|
10 |
+
import numpy as np
|
11 |
+
import openai
|
12 |
+
from fuzzywuzzy import process
|
13 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
14 |
+
|
15 |
+
api_configs = {
|
16 |
+
"SambaNova": {
|
17 |
+
"api_key": os.environ.get("SAMBANOVA_API_KEY"),
|
18 |
+
"url_base": "https://api.sambanova.ai/v1"
|
19 |
+
},
|
20 |
+
"Together": {
|
21 |
+
"api_key": os.environ.get("TOGETHER_API_KEY"),
|
22 |
+
"url_base": "https://api.together.xyz/v1"
|
23 |
+
}
|
24 |
+
# You can add more API configurations here for other providers
|
25 |
+
}
|
26 |
+
|
27 |
+
class Agent(ABC):
|
28 |
+
def __init__(self, prompt_template, llm_client):
|
29 |
+
self.prompt_template = prompt_template
|
30 |
+
self.llm_client = llm_client
|
31 |
+
|
32 |
+
@abstractmethod
|
33 |
+
def generate_response(self, prompt):
|
34 |
+
pass
|
35 |
+
|
36 |
+
|
37 |
+
def setup_logging():
|
38 |
+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
39 |
+
log_filename = f"logs/swiftsage_log_{timestamp}.txt"
|
40 |
+
|
41 |
+
logging.basicConfig(
|
42 |
+
level=logging.INFO,
|
43 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
44 |
+
filename=log_filename,
|
45 |
+
filemode='w'
|
46 |
+
)
|
47 |
+
|
48 |
+
# Also print to console
|
49 |
+
console = logging.StreamHandler()
|
50 |
+
console.setLevel(logging.INFO)
|
51 |
+
formatter = logging.Formatter('%(name)-12s: %(levelname)-8s %(message)s')
|
52 |
+
console.setFormatter(formatter)
|
53 |
+
logging.getLogger('').addHandler(console)
|
54 |
+
|
55 |
+
return logging.getLogger('SwiftSage')
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
def extract_and_parse_markup(text):
|
60 |
+
keys = ["reasoning_steps", "final_answer", "feedback", "score", "critical_feedback", "revised_plan", "solved", "plan", "code"]
|
61 |
+
result = {}
|
62 |
+
if "<final_answer>" in text and "</final_answer>" not in text:
|
63 |
+
text = text + "</final_answer>"
|
64 |
+
|
65 |
+
for key in keys:
|
66 |
+
# Create a pattern for each key
|
67 |
+
pattern = f'<{key}>(.*?)</{key}>'
|
68 |
+
|
69 |
+
# Search for the pattern in the text
|
70 |
+
match = re.search(pattern, text, re.DOTALL)
|
71 |
+
|
72 |
+
if match:
|
73 |
+
# Extract the content, strip whitespace, and add to the result
|
74 |
+
content = match.group(1).strip()
|
75 |
+
result[key] = content
|
76 |
+
|
77 |
+
if "code" in result.keys():
|
78 |
+
result["code"] = result["code"].replace("```python", "").replace("```", "").strip()
|
79 |
+
|
80 |
+
return result
|
81 |
+
|
82 |
+
|
83 |
+
class PromptTemplate:
|
84 |
+
def __init__(self, template_dir):
|
85 |
+
self.template_dir = template_dir
|
86 |
+
self.templates = {}
|
87 |
+
self.load_templates()
|
88 |
+
|
89 |
+
def load_templates(self):
|
90 |
+
for filename in ['swift_template.md', 'sage_template.md', 'reward_template.md']:
|
91 |
+
with open(os.path.join(self.template_dir, filename), 'r') as f:
|
92 |
+
key = filename.split('_')[0]
|
93 |
+
self.templates[key] = f.read()
|
94 |
+
|
95 |
+
def format(self, key, **kwargs):
|
96 |
+
template = self.templates.get(key, "")
|
97 |
+
for k, v in kwargs.items():
|
98 |
+
template = template.replace("<" + k + ">", str(v))
|
99 |
+
return template
|
100 |
+
|
101 |
+
|
102 |
+
class LLMClient:
|
103 |
+
def __init__(self, model_id, api_config, temperature=0.3, top_p=1.0, max_tokens=3000, logger=None):
|
104 |
+
self.client = openai.OpenAI(
|
105 |
+
api_key=api_config['api_key'],
|
106 |
+
base_url=api_config['url_base']
|
107 |
+
)
|
108 |
+
self.model_id = model_id
|
109 |
+
self.temperature = temperature
|
110 |
+
self.top_p = top_p
|
111 |
+
self.max_tokens = max_tokens
|
112 |
+
self.logger = logger
|
113 |
+
|
114 |
+
def generate_response(self, messages):
|
115 |
+
self.logger.info(f"Sending request to {self.model_id}")
|
116 |
+
self.logger.info(f"Messages: {messages}")
|
117 |
+
response = self.client.chat.completions.create(
|
118 |
+
model=self.model_id,
|
119 |
+
messages=messages,
|
120 |
+
temperature=self.temperature,
|
121 |
+
top_p=self.top_p,
|
122 |
+
max_tokens=self.max_tokens
|
123 |
+
)
|
124 |
+
content = response.choices[0].message.content
|
125 |
+
self.logger.info(f"Response from {self.model_id}:\n{content}")
|
126 |
+
return content
|
127 |
+
|
128 |
+
|
129 |
+
|
130 |
+
|
131 |
+
|
132 |
+
if __name__ == "__main__":
|
133 |
+
test_text = "test"
|
134 |
+
|
135 |
+
print(extract_and_parse_markup(test_text))
|
136 |
+
|
137 |
+
|
138 |
+
|
139 |
+
"""
|
140 |
+
|
141 |
+
def extract_and_parse_json(text):
|
142 |
+
|
143 |
+
keys_and_types = [
|
144 |
+
("reasoning_steps", list),
|
145 |
+
("final_answer", str),
|
146 |
+
("feedback", str),
|
147 |
+
("score", str),
|
148 |
+
("score", int),
|
149 |
+
("feedback", str),
|
150 |
+
("solved", str),
|
151 |
+
("critical_feedback", str),
|
152 |
+
("revised_plan", list),
|
153 |
+
]
|
154 |
+
|
155 |
+
# Try to parse the JSON first
|
156 |
+
try:
|
157 |
+
# find the first and last curly braces and parse the json
|
158 |
+
first_brace = text.find("{")
|
159 |
+
last_brace = text.rfind("}")
|
160 |
+
if last_brace == -1:
|
161 |
+
text = text + "\"}"
|
162 |
+
if first_brace != -1 and last_brace != -1 and first_brace < last_brace:
|
163 |
+
data = json.loads(text[first_brace:last_brace+1])
|
164 |
+
return data
|
165 |
+
except Exception as e:
|
166 |
+
data = {}
|
167 |
+
try:
|
168 |
+
data = dirtyjson.loads(text)
|
169 |
+
except Exception as e:
|
170 |
+
pass
|
171 |
+
# If JSON parsing fails, use regex to extract key-value pairs
|
172 |
+
|
173 |
+
for key, _ in keys_and_types:
|
174 |
+
# pattern = rf'"{key}"\s*:\s*([\[{{].*?[\]}}]|".*?")'
|
175 |
+
pattern = rf'"{key}"\s*:\s*([\[{{].*?[\]}}]|".*?"|[-+]?\d+)'
|
176 |
+
match = re.search(pattern, text, re.DOTALL)
|
177 |
+
if match:
|
178 |
+
try:
|
179 |
+
value = json.loads(match.group(1))
|
180 |
+
except Exception as e:
|
181 |
+
value = match.group(1).strip('"')
|
182 |
+
data[key] = value
|
183 |
+
|
184 |
+
result = {}
|
185 |
+
for key, expected_type in keys_and_types:
|
186 |
+
if key in result.keys() and result[key] is not None:
|
187 |
+
continue
|
188 |
+
# Use fuzzy matching to find the closest key
|
189 |
+
try:
|
190 |
+
closest_key, score = process.extractOne(key, data.keys())
|
191 |
+
except Exception as e:
|
192 |
+
continue
|
193 |
+
if score > 80: # You can adjust this threshold
|
194 |
+
value = data[closest_key]
|
195 |
+
|
196 |
+
# Type checking and conversion
|
197 |
+
if expected_type == list and isinstance(value, str):
|
198 |
+
value = [item.strip() for item in value.strip('[]').split(',')]
|
199 |
+
elif expected_type == str and isinstance(value, list):
|
200 |
+
value = ', '.join(value)
|
201 |
+
elif expected_type == int and value is not None:
|
202 |
+
try:
|
203 |
+
value = int(value)
|
204 |
+
except ValueError:
|
205 |
+
value = None
|
206 |
+
|
207 |
+
result[key] = value
|
208 |
+
else:
|
209 |
+
result[key] = None
|
210 |
+
|
211 |
+
for key in list(result.keys()):
|
212 |
+
if result[key] is None:
|
213 |
+
del result[key]
|
214 |
+
return result
|
215 |
+
|
216 |
+
def extract_and_parse_json_v1(text):
|
217 |
+
def find_json_objects(s):
|
218 |
+
# Find all substrings that look like JSON objects
|
219 |
+
json_like_strs = re.findall(r'\{(?:[^{}]|\{[^{}]*\})*\}', s)
|
220 |
+
return json_like_strs
|
221 |
+
|
222 |
+
def try_parse_json(s):
|
223 |
+
try:
|
224 |
+
return json.loads(s)
|
225 |
+
except json.JSONDecodeError:
|
226 |
+
try:
|
227 |
+
s = s.replace("\n", "")
|
228 |
+
return hjson.loads(s)
|
229 |
+
except json.JSONDecodeError:
|
230 |
+
return None
|
231 |
+
return None
|
232 |
+
|
233 |
+
# First, try to find JSON within code blocks
|
234 |
+
code_block_pattern = r'```(?:json)?\s*([\s\S]*?)\s*```'
|
235 |
+
code_blocks = re.findall(code_block_pattern, text, re.IGNORECASE)
|
236 |
+
|
237 |
+
all_json_candidates = []
|
238 |
+
|
239 |
+
# Add JSON candidates from code blocks
|
240 |
+
for block in code_blocks:
|
241 |
+
all_json_candidates.extend(find_json_objects(block))
|
242 |
+
|
243 |
+
# Add JSON candidates from the entire text
|
244 |
+
all_json_candidates.extend(find_json_objects(text))
|
245 |
+
|
246 |
+
# Sort candidates by length, descending
|
247 |
+
all_json_candidates.sort(key=len, reverse=True)
|
248 |
+
|
249 |
+
# Try to parse each candidate
|
250 |
+
for candidate in all_json_candidates:
|
251 |
+
parsed_json = try_parse_json(candidate)
|
252 |
+
if parsed_json is not None:
|
253 |
+
return parsed_json
|
254 |
+
|
255 |
+
raise ValueError("No valid JSON object found in the text")
|
256 |
+
|
257 |
+
|
258 |
+
|
259 |
+
|
260 |
+
"""
|