# flake8: noqa # code from https://github.com/hendrycks/math/blob/main/modeling/math_equivalence.py def _fix_fracs(string): substrs = string.split('\\frac') new_str = substrs[0] if len(substrs) > 1: substrs = substrs[1:] for substr in substrs: new_str += '\\frac' if substr[0] == '{': new_str += substr else: try: assert len(substr) >= 2 except: return string a = substr[0] b = substr[1] if b != '{': if len(substr) > 2: post_substr = substr[2:] new_str += '{' + a + '}{' + b + '}' + post_substr else: new_str += '{' + a + '}{' + b + '}' else: if len(substr) > 2: post_substr = substr[2:] new_str += '{' + a + '}' + b + post_substr else: new_str += '{' + a + '}' + b string = new_str return string def _fix_a_slash_b(string): if len(string.split('/')) != 2: return string a = string.split('/')[0] b = string.split('/')[1] try: a = int(a) b = int(b) assert string == '{}/{}'.format(a, b) new_string = '\\frac{' + str(a) + '}{' + str(b) + '}' return new_string except: return string def _remove_right_units(string): # "\\text{ " only ever occurs (at least in the val set) when describing units if '\\text{ ' in string: splits = string.split('\\text{ ') assert len(splits) == 2 return splits[0] else: return string def _fix_sqrt(string): if '\\sqrt' not in string: return string splits = string.split('\\sqrt') new_string = splits[0] for split in splits[1:]: if split[0] != '{': a = split[0] new_substr = '\\sqrt{' + a + '}' + split[1:] else: new_substr = '\\sqrt' + split new_string += new_substr return new_string def _strip_string(string): # linebreaks string = string.replace('\n', '') # print(string) # remove inverse spaces string = string.replace('\\!', '') # print(string) # replace \\ with \ string = string.replace('\\\\', '\\') # print(string) # replace tfrac and dfrac with frac string = string.replace('tfrac', 'frac') string = string.replace('dfrac', 'frac') # print(string) # remove \left and \right string = string.replace('\\left', '') string = string.replace('\\right', '') # print(string) # Remove circ (degrees) string = string.replace('^{\\circ}', '') string = string.replace('^\\circ', '') # remove dollar signs string = string.replace('\\$', '') # remove units (on the right) string = _remove_right_units(string) # remove percentage string = string.replace('\\%', '') string = string.replace('\%', '') # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string string = string.replace(' .', ' 0.') string = string.replace('{.', '{0.') # if empty, return empty string if len(string) == 0: return string if string[0] == '.': string = '0' + string # to consider: get rid of e.g. "k = " or "q = " at beginning if len(string.split('=')) == 2: if len(string.split('=')[0]) <= 2: string = string.split('=')[1] # fix sqrt3 --> sqrt{3} string = _fix_sqrt(string) # remove spaces string = string.replace(' ', '') # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} string = _fix_fracs(string) # manually change 0.5 --> \frac{1}{2} if string == '0.5': string = '\\frac{1}{2}' # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y string = _fix_a_slash_b(string) return string def is_equiv(str1, str2, verbose=False): if str1 is None and str2 is None: print('WARNING: Both None') return True if str1 is None or str2 is None: return False try: ss1 = _strip_string(str1) ss2 = _strip_string(str2) if verbose: print(ss1, ss2) return ss1 == ss2 except: return str1 == str2