Spaces:
Sleeping
Sleeping
""" | |
Source and credits: https://github.com/ZubinGou/math-evaluation-harness/blob/main/python_executor.py | |
This logic is largely copied from the Hendrycks' MATH release (math_equivalence), and borrowed from: | |
- https://github.com/microsoft/ProphetNet/tree/master/CRITIC | |
- https://github.com/openai/prm800k | |
- https://github.com/microsoft/ToRA/blob/main/src/eval/grader.py | |
- https://github.com/deepseek-ai/DeepSeek-Math/blob/main/evaluation/eval/eval_utils.py | |
""" | |
import re | |
import regex | |
import multiprocessing | |
from math import isclose | |
from typing import Union | |
from sympy import simplify, N | |
from sympy.parsing.sympy_parser import parse_expr | |
from sympy.parsing.latex import parse_latex | |
from latex2sympy2 import latex2sympy | |
def parse_digits(num): | |
num = regex.sub(',', '', str(num)) | |
try: | |
return float(num) | |
except: | |
if num.endswith('%'): | |
num = num[:-1] | |
if num.endswith('\\'): | |
num = num[:-1] | |
try: | |
return float(num) / 100 | |
except: | |
pass | |
return None | |
def is_digit(num): | |
# paired with parse_digits | |
return parse_digits(num) is not None | |
def str_to_pmatrix(input_str): | |
input_str = input_str.strip() | |
matrix_str = re.findall(r'\{.*,.*\}', input_str) | |
pmatrix_list = [] | |
for m in matrix_str: | |
m = m.strip('{}') | |
pmatrix = r'\begin{pmatrix}' + m.replace(',', '\\') + r'\end{pmatrix}' | |
pmatrix_list.append(pmatrix) | |
return ', '.join(pmatrix_list) | |
def math_equal(prediction: Union[bool, float, str], | |
reference: Union[float, str], | |
include_percentage: bool = True, | |
is_close: bool = True, | |
timeout: bool = False, | |
) -> bool: | |
""" | |
Exact match of math if and only if: | |
1. numerical equal: both can convert to float and are equal | |
2. symbolic equal: both can convert to sympy expression and are equal | |
""" | |
# print("Judge:", prediction, reference) | |
if str(prediction) == str(reference): | |
return True | |
try: # 1. numerical equal | |
if is_digit(prediction) and is_digit(reference): | |
prediction = parse_digits(prediction) | |
reference = parse_digits(reference) | |
# number questions | |
if include_percentage: | |
gt_result = [reference / 100, reference, reference * 100] | |
else: | |
gt_result = [reference] | |
for item in gt_result: | |
try: | |
if is_close: | |
if numeric_equal(prediction, item): | |
return True | |
else: | |
if item == prediction: | |
return True | |
except Exception: | |
continue | |
return False | |
except: | |
pass | |
if not prediction and prediction not in [0, False]: | |
return False | |
# print("try math_eval") | |
# 2. symbolic equal | |
reference = str(reference).strip() | |
prediction = str(prediction).strip() | |
## pmatrix (amps) | |
if "pmatrix" in prediction and not 'pmatrix' in reference: | |
reference = str_to_pmatrix(reference) | |
## deal with [], (), {} | |
pred_str, ref_str = prediction, reference | |
if (prediction.startswith("[") and prediction.endswith("]") and not reference.startswith("(")) or \ | |
(prediction.startswith("(") and prediction.endswith(")") and not reference.startswith("[")): | |
pred_str = pred_str.strip("[]()") | |
ref_str = ref_str.strip("[]()") | |
for s in ['{', "}", "(", ")"]: | |
ref_str = ref_str.replace(s, "") | |
pred_str = pred_str.replace(s, "") | |
if pred_str.lower() == ref_str.lower(): | |
return True | |
## [a, b] vs. [c, d], return a==c and b==d | |
if regex.match(r'(\(|\[).+(\)|\])', prediction) is not None and regex.match(r'(\(|\[).+(\)|\])', reference) is not None: | |
pred_parts = prediction[1:-1].split(",") | |
ref_parts = reference[1:-1].split(",") | |
if len(pred_parts) == len(ref_parts): | |
if all([math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close) for i in range(len(pred_parts))]): | |
return True | |
if (prediction.startswith("\\begin{pmatrix}") or prediction.startswith("\\begin{bmatrix}")) and (prediction.endswith("\\end{pmatrix}") or prediction.endswith("\\end{bmatrix}")) and \ | |
(reference.startswith("\\begin{pmatrix}") or reference.startswith("\\begin{bmatrix}")) and (reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}")): | |
pred_lines = [line.strip() for line in prediction[len("\\begin{pmatrix}"): -len("\\end{pmatrix}")].split("\\\\") if line.strip()] | |
ref_lines = [line.strip() for line in reference[len("\\begin{pmatrix}"): -len("\\end{pmatrix}")].split("\\\\") if line.strip()] | |
matched = True | |
if len(pred_lines) == len(ref_lines): | |
for pred_line, ref_line in zip(pred_lines, ref_lines): | |
pred_parts = pred_line.split("&") | |
ref_parts = ref_line.split("&") | |
if len(pred_parts) == len(ref_parts): | |
if not all([math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close) for i in range(len(pred_parts))]): | |
matched = False | |
break | |
else: | |
matched = False | |
if not matched: | |
break | |
else: | |
matched = False | |
if matched: | |
return True | |
if prediction.count('=') == 1 and reference.count('=') == 1: | |
pred = prediction.split('=') | |
pred = f"{pred[0].strip()} - ({pred[1].strip()})" | |
ref = reference.split('=') | |
ref = f"{ref[0].strip()} - ({ref[1].strip()})" | |
if symbolic_equal(pred, ref) or symbolic_equal(f"-({pred})", ref): | |
return True | |
elif prediction.count('=') == 1 and len(prediction.split('=')[0].strip()) <= 2 and '=' not in reference: | |
if math_equal(prediction.split('=')[1], reference, include_percentage, is_close): | |
return True | |
elif reference.count('=') == 1 and len(reference.split('=')[0].strip()) <= 2 and '=' not in prediction: | |
if math_equal(prediction, reference.split('=')[1], include_percentage, is_close): | |
return True | |
# print("try final") | |
# symbolic equal with sympy | |
if timeout: | |
if call_with_timeout(symbolic_equal_process, prediction, reference): | |
return True | |
else: | |
if symbolic_equal(prediction, reference): | |
return True | |
return False | |
def math_equal_process(param): | |
return math_equal(param[-2], param[-1]) | |
def numeric_equal(prediction: float, reference: float): | |
# Note that relative tolerance has significant impact | |
# on the result of the synthesized gsm_hard dataset | |
# if reference.is_integer(): | |
# return isclose(reference, round(prediction), abs_tol=1e-4) | |
# else: | |
# prediction = round(prediction, len(str(reference).split(".")[-1])) | |
return isclose(reference, prediction, rel_tol=1e-4) | |
def symbolic_equal(a, b): | |
def _parse(s): | |
for f in [parse_latex, parse_expr, latex2sympy]: | |
try: | |
return f(s.replace("\\\\", "\\")) | |
except: | |
try: | |
return f(s) | |
except: | |
pass | |
return s | |
a = _parse(a) | |
b = _parse(b) | |
# direct equal | |
try: | |
if str(a) == str(b) or a == b: | |
return True | |
except: | |
pass | |
# print("try simplify") | |
# simplify equal | |
try: | |
if a.equals(b) or simplify(a-b) == 0: | |
return True | |
except: | |
pass | |
# print("try equation") | |
# equation equal | |
try: | |
if (abs(a.lhs - a.rhs)).equals(abs(b.lhs - b.rhs)): | |
return True | |
except: | |
pass | |
try: | |
if numeric_equal(float(N(a)), float(N(b))): | |
return True | |
except: | |
pass | |
# matrix | |
try: | |
# if a and b are matrix | |
if a.shape == b.shape: | |
_a = a.applyfunc(lambda x: round(x, 3)) | |
_b = b.applyfunc(lambda x: round(x, 3)) | |
if _a.equals(_b): | |
return True | |
except: | |
pass | |
return False | |
def symbolic_equal_process(a, b, output_queue): | |
result = symbolic_equal(a, b) | |
output_queue.put(result) | |
def call_with_timeout(func, *args, timeout=1, **kwargs): | |
output_queue = multiprocessing.Queue() | |
process_args = args + (output_queue,) | |
process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs) | |
process.start() | |
process.join(timeout) | |
if process.is_alive(): | |
process.terminate() | |
process.join() | |
return False | |
return output_queue.get() | |
def _test_math_equal(): | |
# print(math_equal("0.0833333333333333", "\\frac{1}{12}")) | |
# print(math_equal("(1,4.5)", "(1,\\frac{9}{2})")) | |
# print(math_equal("\\frac{x}{7}+\\frac{2}{7}", "\\frac{x+2}{7}", timeout=True)) | |
# print(math_equal("\\sec^2(y)", "\\tan^2(y)+1", timeout=True)) | |
# print(math_equal("\\begin{pmatrix}-\\frac{7}{4}&-2\\\\4&\\frac{1}{4}\\end{pmatrix}", "(\\begin{pmatrix}-\\frac{7}{4}&-2\\\\4&\\frac{1}{4}\\\\\\end{pmatrix})", timeout=True)) | |
# pred = '\\begin{pmatrix}\\frac{1}{3x^{2/3}}&0&0\\\\0&1&0\\\\-\\sin(x)&0&0\\end{pmatrix}' | |
# gt = '(\\begin{pmatrix}\\frac{1}{3\\sqrt[3]{x}^2}&0&0\\\\0&1&0\\\\-\\sin(x)&0&0\\\\\\end{pmatrix})' | |
# pred= '-\\frac{8x^2}{9(x^2-2)^{5/3}}+\\frac{2}{3(x^2-2)^{2/3}}' | |
# gt= '-\\frac{2(x^2+6)}{9(x^2-2)\\sqrt[3]{x^2-2}^2}' | |
# pred = '-34x-45y+20z-100=0' | |
# gt = '34x+45y-20z+100=0' | |
# pred = '\\frac{100}{3}' | |
# gt = '33.3' | |
# pred = '\\begin{pmatrix}0.290243531202435\\\\0.196008371385084\\\\-0.186381278538813\\end{pmatrix}' | |
# gt = '(\\begin{pmatrix}0.29\\\\0.196\\\\-0.186\\\\\\end{pmatrix})' | |
# pred = '\\frac{\\sqrt{\\sqrt{11}+\\sqrt{194}}}{2\\sqrt{33}+15}' | |
# gt = '\\frac{\\sqrt{\\sqrt{11}+\\sqrt{194}}}{15+2\\sqrt{33}}' | |
# pred = '(+5)(b+2)' | |
# gt = '(a+5)(b+2)' | |
# pred = '\\frac{1+\\sqrt{5}}{2}' | |
# gt = '2' | |
# pred = '\\frac{34}{16}+\\frac{\\sqrt{1358}}{16}', gt = '4' | |
# pred = '1', gt = '1\\\\sqrt{19}' | |
pred = '(0.6,2.6667]' | |
gt = '(\\frac{3}{5},\\frac{8}{3}]' | |
print(math_equal(pred, gt, timeout=True)) | |
if __name__ == "__main__": | |
_test_math_equal() | |