Spaces:
Running
Running
import re | |
import math | |
import random | |
from collections import defaultdict | |
def naive_parse(answer): | |
out = [] | |
start = False | |
end = False | |
for l in reversed(list(answer)): | |
if l in '0123456789' and not end: | |
start = True | |
out.append(l) | |
else: | |
if start: | |
end = True | |
out = reversed(out) | |
return ''.join(out) | |
import re | |
import sys | |
import subprocess | |
def return_last_print(output, n): | |
lines = output.strip().split('\n') | |
if lines: | |
return lines[n] | |
else: | |
return "" | |
def process_code(code, return_shell_output=False): | |
def repl(match): | |
if "real" not in match.group(): | |
return "{}{}".format(match.group()[:-1], ', real=True)') | |
else: | |
return "{}{}".format(match.group()[:-1], ')') | |
code = re.sub(r"symbols\([^)]+\)", repl, code) | |
if return_shell_output: | |
code = code.replace('\n', '\n ') | |
# Add a try...except block | |
code = "\ntry:\n from sympy import *\n{}\nexcept Exception as e:\n print(e)\n print('FAIL')\n".format(code) | |
if not return_shell_output: | |
print(code) | |
with open('code.py', 'w') as fout: | |
fout.write(code) | |
batcmd = 'timeout 7 ' + sys.executable + ' code.py' | |
try: | |
shell_output = subprocess.check_output(batcmd, shell=True).decode('utf8') | |
return_value = return_last_print(shell_output, -1) | |
print(shell_output) | |
if return_shell_output: | |
if return_value=='FAIL': | |
CODE_STATUS = False | |
return_value = return_last_print(shell_output, -2) | |
if "not defined" in return_value: | |
return_value+='\nTry checking the formatting and imports' | |
else: | |
CODE_STATUS = True | |
return return_value, CODE_STATUS | |
code_output = round(float(eval(return_value))) % 1000 | |
except Exception as e: | |
print(e,'shell_output') | |
code_output = -1 | |
if return_shell_output: | |
if code_output==-1: | |
CODE_STATUS = False | |
else: | |
CODE_STATUS = True | |
return code_output, CODE_STATUS | |
return code_output | |
def process_text_output(output): | |
result = output | |
try: | |
result_output = re.findall(r'\\boxed\{(\d+)\}', result) | |
print('BOXED', result_output) | |
if not len(result_output): | |
result_output = naive_parse(result) | |
else: | |
result_output = result_output[-1] | |
print('BOXED FINAL', result_output) | |
if not len(result_output): | |
result_output = -1 | |
else: | |
result_output = round(float(eval(result_output))) % 1000 | |
except Exception as e: | |
print(e) | |
print('ERROR PARSING TEXT') | |
result_output = -1 | |
return result_output | |
from collections import defaultdict | |
from collections import Counter | |
def predict(problem): | |
temperature = 0.9 | |
top_p = 3.0 | |
temperature_coding = 0.9 | |
top_p_coding = 3.0 | |
total_results = {} | |
total_answers = {} | |
best_stats = {} | |
total_outputs = {} | |
question_type_counts = {} | |
starting_counts = (2,3) | |
i = 0 | |
global n_repetitions,TOTAL_TOKENS,model,tokenizer,USE_PAST_KEY,NOTEBOOK_START_TIME,promplt_options,code,cot | |
for jj in tqdm(range(n_repetitions)): | |
best, best_count = best_stats.get(i,(-1,-1)) | |
if best_count>np.sqrt(jj): | |
print("SKIPPING CAUSE ALREADY FOUND BEST") | |
continue | |
outputs = total_outputs.get(i,[]) | |
text_answers, code_answers = question_type_counts.get(i,starting_counts) | |
results = total_results.get(i,[]) | |
answers = total_answers.get(i,[]) | |
for _ in range(5): | |
torch.cuda.empty_cache() | |
gc.collect() | |
time.sleep(0.2) | |
try: | |
ALREADY_GEN = 0 | |
code_error = None | |
code_error_count = 0 | |
code_output = -1 | |
#initail_message = problem + tool_instruction | |
counts = np.array([text_answers,code_answers]) | |
draw = choice(promplt_options, 1, | |
p=counts/counts.sum()) | |
initail_message = draw[0].format(problem,"{}") | |
prompt = f"User: {initail_message}" | |
current_printed = len(prompt) | |
print(f"{jj}_{prompt}\n") | |
model_inputs = tokenizer(prompt, return_tensors='pt').to(model.device) | |
input_len = len(model_inputs['input_ids'][0]) | |
generation_output = model.generate(**model_inputs, | |
max_new_tokens=TOTAL_TOKENS-ALREADY_GEN, | |
return_dict_in_generate=USE_PAST_KEY, | |
do_sample = True, | |
temperature = temperature, | |
top_p = top_p, | |
num_return_sequences=1, stopping_criteria = stopping_criteria) | |
if USE_PAST_KEY: | |
output_ids = generation_output.sequences[0] | |
else: | |
output_ids = generation_output[0] | |
decoded_output = tokenizer.decode(output_ids, skip_special_tokens=True) | |
print(f"{decoded_output[current_printed:]}\n") | |
current_printed += len(decoded_output[current_printed:]) | |
cummulative_code = "" | |
stop_word_cond = False | |
for stop_word in stop_words: | |
stop_word_cond = stop_word_cond or (decoded_output[-len(stop_word):]==stop_word) | |
while (stop_word_cond) and (ALREADY_GEN<(TOTAL_TOKENS)): | |
if (decoded_output[-len("```python"):]=="```python"): | |
temperature_inner=temperature_coding | |
top_p_inner = top_p_coding | |
prompt = decoded_output | |
else: | |
temperature_inner=temperature | |
top_p_inner = top_p | |
try: | |
if (decoded_output[-len("``````output"):]=="``````output"): | |
code_text = decoded_output.split('```python')[-1].split("``````")[0] | |
else: | |
code_text = decoded_output.split('```python')[-1].split("```")[0] | |
cummulative_code+=code_text | |
code_output, CODE_STATUS = process_code(cummulative_code, return_shell_output=True) | |
print('CODE RESULTS', code_output) | |
if code_error==code_output: | |
code_error_count+=1 | |
else: | |
code_error=code_output | |
code_error_count = 0 | |
if not CODE_STATUS: | |
cummulative_code = cummulative_code[:-len(code_text)] | |
if code_error_count>=1: | |
print("REPEATED ERRORS") | |
break | |
except Exception as e: | |
print(e) | |
print('ERROR PARSING CODE') | |
code_output = -1 | |
if code_output!=-1: | |
if (decoded_output[-len(")\n```"):]==")\n```"): | |
prompt = decoded_output+'```output\n'+str(code_output)+'\n```\n' | |
else: | |
prompt = decoded_output+'\n'+str(code_output)+'\n```\n' | |
else: | |
prompt = decoded_output | |
cummulative_code="" | |
model_inputs = tokenizer(prompt, return_tensors='pt').to(model.device) | |
ALREADY_GEN = len(model_inputs['input_ids'][0])-input_len | |
if USE_PAST_KEY: | |
old_values = generation_output.past_key_values | |
else: | |
old_values = None | |
generation_output = model.generate(**model_inputs, | |
max_new_tokens=TOTAL_TOKENS-ALREADY_GEN, | |
return_dict_in_generate=USE_PAST_KEY, | |
past_key_values=old_values, | |
do_sample = True, | |
temperature = temperature_inner, | |
top_p = top_p_inner, | |
num_return_sequences=1, stopping_criteria = stopping_criteria) | |
if USE_PAST_KEY: | |
output_ids = generation_output.sequences[0] | |
else: | |
output_ids = generation_output[0] | |
decoded_output = tokenizer.decode(output_ids, skip_special_tokens=True) | |
print(f"\nINTERMEDIATE OUT :\n{decoded_output[current_printed:]}\n") | |
current_printed+=len(decoded_output[current_printed:]) | |
stop_word_cond = False | |
for stop_word in stop_words: | |
stop_word_cond = stop_word_cond or (decoded_output[-len(stop_word):]==stop_word) | |
if USE_PAST_KEY: | |
output_ids = generation_output.sequences[0] | |
else: | |
output_ids = generation_output[0] | |
raw_output = tokenizer.decode(output_ids[input_len:], skip_special_tokens=True) | |
#print(f"\n\nOutput :\n{raw_output}\n") | |
result_output = process_text_output(raw_output) | |
try: | |
code_output = round(float(eval(code_output))) % 1000 | |
except Exception as e: | |
print(e,'final_eval') | |
code_output = -1 | |
except Exception as e: | |
print(e,"5") | |
result_output, code_output = -1, -1 | |
if code_output!=-1: | |
outputs.append(code_output) | |
code_answers+=1 | |
if result_output!=-1: | |
outputs.append(result_output) | |
text_answers+=1 | |
if len(outputs) > 0: | |
occurances = Counter(outputs).most_common() | |
print(occurances) | |
if occurances[0][1] > best_count: | |
print("GOOD ANSWER UPDATED!") | |
best = occurances[0][0] | |
best_count = occurances[0][1] | |
if occurances[0][1] > 5: | |
print("ANSWER FOUND!") | |
break | |
results.append(result_output) | |
answers.append(code_output) | |
best_stats[i] = (best, best_count) | |
question_type_counts[i] = (text_answers, code_answers) | |
total_outputs[i] = outputs | |
total_results[i] = results | |
total_answers[i] = answers | |
print("code_answers",code_answers-starting_counts[1],"text_answers",text_answers-starting_counts[0]) | |
return best_stats[0][0] | |