CodeLATS / lats /lats.py
Ron
initial commit
41d1bc5
raw
history blame
8.28 kB
from utils import enumerate_resume, make_printv, write_jsonl, resume_success_count
from executors import executor_factory
from generators import generator_factory, model_factory
from typing import List, Dict, Any
import math
from typing import Tuple
import sys
import random
sys.set_int_max_str_digits(100000) # Increase the limit to 10000 digits
react_prompt_header = "Here are some previous solutions and the corresponding test results.\n"
react_prompt_starter = "\n\nYour solution:\n"
extra_header = "\n\nName the function answer()"
class Node:
def __init__(self, solution: str, parent=None, context="", depth=0):
self.solution = solution
self.parent = parent
self.children = []
self.value = 0
self.visits = 0
self.context = ""
self.depth = depth
self.reflection = ""
self.test_feedback = ""
def uct(self, exploration_weight=1.0):
if self.visits == 0:
#return float('inf')
return self.value
return (self.value / self.visits) + exploration_weight * math.sqrt(math.log(self.parent.visits) / self.visits)
def best_child(self):
if not self.children: # Check if children list is empty
return None
return max(self.children, key=lambda child: child.uct())
def best_child_value(self):
if not self.children: # Check if children list is empty
return None
return max(self.children, key=lambda child: child.value)
def update(self, reward: float):
self.visits += 1
self.value += reward
def prune_context_blocks(context: str, max_length: int) -> str:
"""Prune the context to fit within the specified max_length by removing entire blocks of content using 'trial' as a delimiter."""
if len(context) <= max_length:
return context
# Split by the block delimiter "trial".
blocks = context.split('Previous Trial')
# Remove the earliest blocks until the context fits within max_length.
while len('trial'.join(blocks)) > max_length and blocks:
blocks.pop(0)
return 'trial'.join(blocks)
def gather_context_from_tree(node: Node) -> Tuple[List[str], List[str]]:
"""
Given a node, walk up its tree and gather the feedback and reflections
from each parent node until the root is reached.
Args:
node (Node): The node to start gathering context from.
Returns:
Tuple[List[str], List[str]]: Two lists containing the accumulated feedback and reflections.
"""
accumulated_feedback = []
accumulated_reflection = []
num_nodes = 0
while node and num_nodes < 2:
num_nodes += 1
if node.test_feedback:
accumulated_feedback.append(node.test_feedback)
if node.reflection:
accumulated_reflection.append(node.reflection)
node = node.parent
# Reverse the lists so that the context from the earliest nodes is first
return accumulated_feedback[::-1], accumulated_reflection[::-1]
def sample_n_random(items: List[str], n: int) -> List[str]:
"""Sample min(n, len(items)) random items from a list"""
assert n >= 0
if n >= len(items):
return items
return random.sample(items, n)
def run_lats(
model_name: str,
language: str,
max_iters: int,
verbose: bool,
instruction: str = "Write some code to print Hello World in Python",
n_samples: int = 3,
depth: int = 5,
) -> None:
exe = executor_factory(language)
gen = generator_factory(language)
model = model_factory(model_name)
num_success = 0 # Counter for successful solutions
cur_func_impl = None
item = {}
#for idx, item in enumerate(dataset):
tests = gen.internal_tests(instruction + extra_header, model, 1)
tests_i = sample_n_random(tests, 1)
while cur_func_impl is None:
cur_func_impl = gen.func_impl(instruction + extra_header, model, "simple")
root = Node(cur_func_impl) # initial solution (for pass@1 metric)
# Lists for logging
reflections = []
implementations = []
test_feedback = []
is_solved = False
# first attempt
implementations.append(cur_func_impl)
assert isinstance(cur_func_impl, str)
is_passing, feedback, _ = exe.execute(cur_func_impl, tests_i)
test_feedback.append(feedback)
# if solved, exit early
if is_passing:
num_success += 1
return cur_func_impl # GET SOLUTION
reflection = gen.self_reflection(cur_func_impl, feedback, model)
reflections += [reflection]
root.test_feedback = feedback
root.reflection = reflection
max_iters = int(max_iters)
for cur_iter in range(max_iters):
# Selection
tests_i = sample_n_random(tests, 1)
node = root
trajectory = {
'solutions': [],
'feedbacks': []
}
while node.children:
node = node.best_child()
trajectory['solutions'].append(node.solution)
# Expansion
for _ in range(n_samples):
new_solution = None
strategy = "mcts"
prev_func_impl = node.solution
feedback = node.test_feedback
reflection = node.reflection
acc_feedback, acc_reflection = gather_context_from_tree(node)
while new_solution is None:
new_solution = gen.func_impl(
func_sig=instruction+extra_header,
model=model,
strategy=strategy,
prev_func_impl=prev_func_impl,
feedback=feedback,
self_reflection=reflection,
acc_feedback = acc_feedback,
acc_reflection = acc_reflection
)
combined_context = "\nPrevious Trial\n\n" + new_solution
child = Node(new_solution, parent=node, context=combined_context, depth=node.depth + 1)
node.children.append(child)
# Simulation
reward_real = 0
for child in node.children:
is_passing_internal, feedback_internal, _ = exe.execute(child.solution, tests_i)
if not is_passing_internal:
reflection = gen.self_reflection(child.solution, feedback_internal, model)
reflections.append(reflection)
child.reflection = reflection
child.test_feedback = feedback_internal
child.context += "\n\nPrevious Trial\n\n" + child.solution + "\n\nTest results: \n" + feedback_internal + "\n\nSelf-reflection: " + reflection
else:
child.context += "\n\nPrevious Trial\n\n" + child.solution + "\n\nTest results: \n" + feedback_internal
child.reflection = ""
child.test_feedback = feedback_internal
if "Tested passed:" in feedback_internal:
# Split at "Tests failed:" and get the part before it (which contains the passed tests)
passed_section = feedback_internal.split("Tests failed:")[0]
# Split at "Tested passed:" and get the part after it, then count the non-empty lines
reward_internal = len([line for line in passed_section.split("Tested passed:")[1].splitlines() if line.strip() != ''])
reward_internal = reward_internal / len(tests_i)
else:
reward_internal = 0
if is_passing_internal or cur_iter == max_iters - 1:
item["solution"] = child.solution
break
if is_solved:
break
reward = reward_internal + reward_real
child.update(reward)
# Backpropagation
temp = child
while temp.parent:
temp = temp.parent
temp.update(reward)
# Choose the best solution after all iterations
if is_solved:
best_solution = item["solution"]
else:
best_solution = root.best_child_value().solution
item["solution"] = best_solution
return best_solution