|
from utils import enumerate_resume, make_printv, write_jsonl, resume_success_count |
|
from executors import executor_factory |
|
from generators import generator_factory, model_factory |
|
from typing import List, Dict, Any |
|
import math |
|
from typing import Tuple |
|
import sys |
|
import random |
|
|
|
sys.set_int_max_str_digits(100000) |
|
|
|
react_prompt_header = "Here are some previous solutions and the corresponding test results.\n" |
|
react_prompt_starter = "\n\nYour solution:\n" |
|
extra_header = "\n\nName the function answer()" |
|
|
|
class Node: |
|
def __init__(self, solution: str, parent=None, context="", depth=0): |
|
self.solution = solution |
|
self.parent = parent |
|
self.children = [] |
|
self.value = 0 |
|
self.visits = 0 |
|
self.context = "" |
|
self.depth = depth |
|
self.reflection = "" |
|
self.test_feedback = "" |
|
|
|
def uct(self, exploration_weight=1.0): |
|
if self.visits == 0: |
|
|
|
return self.value |
|
return (self.value / self.visits) + exploration_weight * math.sqrt(math.log(self.parent.visits) / self.visits) |
|
|
|
def best_child(self): |
|
if not self.children: |
|
return None |
|
return max(self.children, key=lambda child: child.uct()) |
|
|
|
def best_child_value(self): |
|
if not self.children: |
|
return None |
|
return max(self.children, key=lambda child: child.value) |
|
|
|
def update(self, reward: float): |
|
self.visits += 1 |
|
self.value += reward |
|
|
|
|
|
def prune_context_blocks(context: str, max_length: int) -> str: |
|
"""Prune the context to fit within the specified max_length by removing entire blocks of content using 'trial' as a delimiter.""" |
|
if len(context) <= max_length: |
|
return context |
|
|
|
|
|
blocks = context.split('Previous Trial') |
|
|
|
|
|
while len('trial'.join(blocks)) > max_length and blocks: |
|
blocks.pop(0) |
|
|
|
return 'trial'.join(blocks) |
|
|
|
def gather_context_from_tree(node: Node) -> Tuple[List[str], List[str]]: |
|
""" |
|
Given a node, walk up its tree and gather the feedback and reflections |
|
from each parent node until the root is reached. |
|
|
|
Args: |
|
node (Node): The node to start gathering context from. |
|
|
|
Returns: |
|
Tuple[List[str], List[str]]: Two lists containing the accumulated feedback and reflections. |
|
""" |
|
accumulated_feedback = [] |
|
accumulated_reflection = [] |
|
num_nodes = 0 |
|
|
|
while node and num_nodes < 2: |
|
num_nodes += 1 |
|
if node.test_feedback: |
|
accumulated_feedback.append(node.test_feedback) |
|
if node.reflection: |
|
accumulated_reflection.append(node.reflection) |
|
node = node.parent |
|
|
|
|
|
return accumulated_feedback[::-1], accumulated_reflection[::-1] |
|
|
|
def sample_n_random(items: List[str], n: int) -> List[str]: |
|
"""Sample min(n, len(items)) random items from a list""" |
|
assert n >= 0 |
|
if n >= len(items): |
|
return items |
|
return random.sample(items, n) |
|
|
|
def run_lats( |
|
model_name: str, |
|
language: str, |
|
max_iters: int, |
|
verbose: bool, |
|
instruction: str = "Write some code to print Hello World in Python", |
|
n_samples: int = 3, |
|
depth: int = 5, |
|
) -> None: |
|
exe = executor_factory(language) |
|
gen = generator_factory(language) |
|
model = model_factory(model_name) |
|
|
|
|
|
num_success = 0 |
|
cur_func_impl = None |
|
|
|
item = {} |
|
|
|
|
|
|
|
tests = gen.internal_tests(instruction + extra_header, model, 1) |
|
tests_i = sample_n_random(tests, 1) |
|
|
|
while cur_func_impl is None: |
|
cur_func_impl = gen.func_impl(instruction + extra_header, model, "simple") |
|
root = Node(cur_func_impl) |
|
|
|
|
|
reflections = [] |
|
implementations = [] |
|
test_feedback = [] |
|
is_solved = False |
|
|
|
|
|
|
|
implementations.append(cur_func_impl) |
|
assert isinstance(cur_func_impl, str) |
|
is_passing, feedback, _ = exe.execute(cur_func_impl, tests_i) |
|
test_feedback.append(feedback) |
|
|
|
|
|
if is_passing: |
|
num_success += 1 |
|
return cur_func_impl |
|
|
|
reflection = gen.self_reflection(cur_func_impl, feedback, model) |
|
reflections += [reflection] |
|
root.test_feedback = feedback |
|
root.reflection = reflection |
|
max_iters = int(max_iters) |
|
for cur_iter in range(max_iters): |
|
|
|
tests_i = sample_n_random(tests, 1) |
|
|
|
node = root |
|
trajectory = { |
|
'solutions': [], |
|
'feedbacks': [] |
|
} |
|
|
|
while node.children: |
|
node = node.best_child() |
|
trajectory['solutions'].append(node.solution) |
|
|
|
|
|
for _ in range(n_samples): |
|
new_solution = None |
|
strategy = "mcts" |
|
prev_func_impl = node.solution |
|
feedback = node.test_feedback |
|
reflection = node.reflection |
|
acc_feedback, acc_reflection = gather_context_from_tree(node) |
|
|
|
while new_solution is None: |
|
new_solution = gen.func_impl( |
|
func_sig=instruction+extra_header, |
|
model=model, |
|
strategy=strategy, |
|
prev_func_impl=prev_func_impl, |
|
feedback=feedback, |
|
self_reflection=reflection, |
|
acc_feedback = acc_feedback, |
|
acc_reflection = acc_reflection |
|
) |
|
|
|
combined_context = "\nPrevious Trial\n\n" + new_solution |
|
|
|
child = Node(new_solution, parent=node, context=combined_context, depth=node.depth + 1) |
|
node.children.append(child) |
|
|
|
|
|
reward_real = 0 |
|
for child in node.children: |
|
is_passing_internal, feedback_internal, _ = exe.execute(child.solution, tests_i) |
|
if not is_passing_internal: |
|
reflection = gen.self_reflection(child.solution, feedback_internal, model) |
|
reflections.append(reflection) |
|
child.reflection = reflection |
|
child.test_feedback = feedback_internal |
|
child.context += "\n\nPrevious Trial\n\n" + child.solution + "\n\nTest results: \n" + feedback_internal + "\n\nSelf-reflection: " + reflection |
|
else: |
|
child.context += "\n\nPrevious Trial\n\n" + child.solution + "\n\nTest results: \n" + feedback_internal |
|
child.reflection = "" |
|
child.test_feedback = feedback_internal |
|
|
|
if "Tested passed:" in feedback_internal: |
|
|
|
passed_section = feedback_internal.split("Tests failed:")[0] |
|
|
|
reward_internal = len([line for line in passed_section.split("Tested passed:")[1].splitlines() if line.strip() != '']) |
|
reward_internal = reward_internal / len(tests_i) |
|
else: |
|
reward_internal = 0 |
|
if is_passing_internal or cur_iter == max_iters - 1: |
|
item["solution"] = child.solution |
|
break |
|
|
|
if is_solved: |
|
break |
|
|
|
reward = reward_internal + reward_real |
|
child.update(reward) |
|
|
|
|
|
temp = child |
|
while temp.parent: |
|
temp = temp.parent |
|
temp.update(reward) |
|
|
|
|
|
if is_solved: |
|
best_solution = item["solution"] |
|
else: |
|
best_solution = root.best_child_value().solution |
|
item["solution"] = best_solution |
|
|
|
return best_solution |