File size: 8,277 Bytes
41d1bc5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 |
from utils import enumerate_resume, make_printv, write_jsonl, resume_success_count
from executors import executor_factory
from generators import generator_factory, model_factory
from typing import List, Dict, Any
import math
from typing import Tuple
import sys
import random
sys.set_int_max_str_digits(100000) # Increase the limit to 10000 digits
react_prompt_header = "Here are some previous solutions and the corresponding test results.\n"
react_prompt_starter = "\n\nYour solution:\n"
extra_header = "\n\nName the function answer()"
class Node:
def __init__(self, solution: str, parent=None, context="", depth=0):
self.solution = solution
self.parent = parent
self.children = []
self.value = 0
self.visits = 0
self.context = ""
self.depth = depth
self.reflection = ""
self.test_feedback = ""
def uct(self, exploration_weight=1.0):
if self.visits == 0:
#return float('inf')
return self.value
return (self.value / self.visits) + exploration_weight * math.sqrt(math.log(self.parent.visits) / self.visits)
def best_child(self):
if not self.children: # Check if children list is empty
return None
return max(self.children, key=lambda child: child.uct())
def best_child_value(self):
if not self.children: # Check if children list is empty
return None
return max(self.children, key=lambda child: child.value)
def update(self, reward: float):
self.visits += 1
self.value += reward
def prune_context_blocks(context: str, max_length: int) -> str:
"""Prune the context to fit within the specified max_length by removing entire blocks of content using 'trial' as a delimiter."""
if len(context) <= max_length:
return context
# Split by the block delimiter "trial".
blocks = context.split('Previous Trial')
# Remove the earliest blocks until the context fits within max_length.
while len('trial'.join(blocks)) > max_length and blocks:
blocks.pop(0)
return 'trial'.join(blocks)
def gather_context_from_tree(node: Node) -> Tuple[List[str], List[str]]:
"""
Given a node, walk up its tree and gather the feedback and reflections
from each parent node until the root is reached.
Args:
node (Node): The node to start gathering context from.
Returns:
Tuple[List[str], List[str]]: Two lists containing the accumulated feedback and reflections.
"""
accumulated_feedback = []
accumulated_reflection = []
num_nodes = 0
while node and num_nodes < 2:
num_nodes += 1
if node.test_feedback:
accumulated_feedback.append(node.test_feedback)
if node.reflection:
accumulated_reflection.append(node.reflection)
node = node.parent
# Reverse the lists so that the context from the earliest nodes is first
return accumulated_feedback[::-1], accumulated_reflection[::-1]
def sample_n_random(items: List[str], n: int) -> List[str]:
"""Sample min(n, len(items)) random items from a list"""
assert n >= 0
if n >= len(items):
return items
return random.sample(items, n)
def run_lats(
model_name: str,
language: str,
max_iters: int,
verbose: bool,
instruction: str = "Write some code to print Hello World in Python",
n_samples: int = 3,
depth: int = 5,
) -> None:
exe = executor_factory(language)
gen = generator_factory(language)
model = model_factory(model_name)
num_success = 0 # Counter for successful solutions
cur_func_impl = None
item = {}
#for idx, item in enumerate(dataset):
tests = gen.internal_tests(instruction + extra_header, model, 1)
tests_i = sample_n_random(tests, 1)
while cur_func_impl is None:
cur_func_impl = gen.func_impl(instruction + extra_header, model, "simple")
root = Node(cur_func_impl) # initial solution (for pass@1 metric)
# Lists for logging
reflections = []
implementations = []
test_feedback = []
is_solved = False
# first attempt
implementations.append(cur_func_impl)
assert isinstance(cur_func_impl, str)
is_passing, feedback, _ = exe.execute(cur_func_impl, tests_i)
test_feedback.append(feedback)
# if solved, exit early
if is_passing:
num_success += 1
return cur_func_impl # GET SOLUTION
reflection = gen.self_reflection(cur_func_impl, feedback, model)
reflections += [reflection]
root.test_feedback = feedback
root.reflection = reflection
max_iters = int(max_iters)
for cur_iter in range(max_iters):
# Selection
tests_i = sample_n_random(tests, 1)
node = root
trajectory = {
'solutions': [],
'feedbacks': []
}
while node.children:
node = node.best_child()
trajectory['solutions'].append(node.solution)
# Expansion
for _ in range(n_samples):
new_solution = None
strategy = "mcts"
prev_func_impl = node.solution
feedback = node.test_feedback
reflection = node.reflection
acc_feedback, acc_reflection = gather_context_from_tree(node)
while new_solution is None:
new_solution = gen.func_impl(
func_sig=instruction+extra_header,
model=model,
strategy=strategy,
prev_func_impl=prev_func_impl,
feedback=feedback,
self_reflection=reflection,
acc_feedback = acc_feedback,
acc_reflection = acc_reflection
)
combined_context = "\nPrevious Trial\n\n" + new_solution
child = Node(new_solution, parent=node, context=combined_context, depth=node.depth + 1)
node.children.append(child)
# Simulation
reward_real = 0
for child in node.children:
is_passing_internal, feedback_internal, _ = exe.execute(child.solution, tests_i)
if not is_passing_internal:
reflection = gen.self_reflection(child.solution, feedback_internal, model)
reflections.append(reflection)
child.reflection = reflection
child.test_feedback = feedback_internal
child.context += "\n\nPrevious Trial\n\n" + child.solution + "\n\nTest results: \n" + feedback_internal + "\n\nSelf-reflection: " + reflection
else:
child.context += "\n\nPrevious Trial\n\n" + child.solution + "\n\nTest results: \n" + feedback_internal
child.reflection = ""
child.test_feedback = feedback_internal
if "Tested passed:" in feedback_internal:
# Split at "Tests failed:" and get the part before it (which contains the passed tests)
passed_section = feedback_internal.split("Tests failed:")[0]
# Split at "Tested passed:" and get the part after it, then count the non-empty lines
reward_internal = len([line for line in passed_section.split("Tested passed:")[1].splitlines() if line.strip() != ''])
reward_internal = reward_internal / len(tests_i)
else:
reward_internal = 0
if is_passing_internal or cur_iter == max_iters - 1:
item["solution"] = child.solution
break
if is_solved:
break
reward = reward_internal + reward_real
child.update(reward)
# Backpropagation
temp = child
while temp.parent:
temp = temp.parent
temp.update(reward)
# Choose the best solution after all iterations
if is_solved:
best_solution = item["solution"]
else:
best_solution = root.best_child_value().solution
item["solution"] = best_solution
return best_solution |