|
import os |
|
import argparse |
|
|
|
from lats import run_lats |
|
|
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--run_name", type=str, help="The name of the run") |
|
parser.add_argument("--root_dir", type=str, |
|
help="The root logging directory", default="root") |
|
parser.add_argument("--dataset_path", type=str, |
|
help="The path to the benchmark dataset", default="root") |
|
parser.add_argument("--strategy", type=str, |
|
help="Strategy: `simple`, `reflexion`") |
|
parser.add_argument("--language", type=str, help="Strategy: `py` or `rs`") |
|
parser.add_argument( |
|
"--model", type=str, help="OpenAI models only for now. For best results, use GPT-4") |
|
parser.add_argument("--pass_at_k", type=int, |
|
help="Pass@k metric", default=1) |
|
parser.add_argument("--max_iters", type=int, |
|
help="The maximum number of self-improvement iterations", default=10) |
|
parser.add_argument("--expansion_factor", type=int, |
|
help="The expansion factor for the reflexion UCS and A* strategy", default=3) |
|
parser.add_argument("--verbose", action='store_true', |
|
help="To print live logs") |
|
parser.add_argument("--instruction", type=str, |
|
help="text string", default="") |
|
parser.add_argument("--n_samples", type=int, |
|
help="The number of nodes added during expansion", default=3) |
|
parser.add_argument("--depth", type=int, |
|
help="Tree depth", default=5) |
|
|
|
|
|
|
|
|
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def strategy_factory(strategy: str): |
|
def kwargs_wrapper_gen(func, delete_keys=[]): |
|
def kwargs_wrapper(**kwargs): |
|
for key in delete_keys: |
|
del kwargs[key] |
|
return func(**kwargs) |
|
return kwargs_wrapper |
|
|
|
return kwargs_wrapper_gen(run_lats, delete_keys=[]) |
|
|
|
|
|
def lats_main(args): |
|
|
|
|
|
run_strategy = strategy_factory(args.strategy) |
|
|
|
|
|
|
|
x = run_strategy( |
|
model_name=args.model, |
|
language=args.language, |
|
max_iters=args.max_iters, |
|
verbose=args.verbose, |
|
instruction=args.instruction, |
|
n_samples=args.n_samples, |
|
depth=args.depth |
|
) |
|
|
|
return x |
|
|
|
|
|
|
|
def main(args): |
|
lats_main(args) |
|
|
|
if __name__ == "__main__": |
|
args = get_args() |
|
main(args) |
|
|