File size: 2,723 Bytes
41d1bc5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import os
import argparse
from lats import run_lats
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--run_name", type=str, help="The name of the run")
parser.add_argument("--root_dir", type=str,
help="The root logging directory", default="root")
parser.add_argument("--dataset_path", type=str,
help="The path to the benchmark dataset", default="root")
parser.add_argument("--strategy", type=str,
help="Strategy: `simple`, `reflexion`")
parser.add_argument("--language", type=str, help="Strategy: `py` or `rs`")
parser.add_argument(
"--model", type=str, help="OpenAI models only for now. For best results, use GPT-4")
parser.add_argument("--pass_at_k", type=int,
help="Pass@k metric", default=1)
parser.add_argument("--max_iters", type=int,
help="The maximum number of self-improvement iterations", default=10)
parser.add_argument("--expansion_factor", type=int,
help="The expansion factor for the reflexion UCS and A* strategy", default=3)
parser.add_argument("--verbose", action='store_true',
help="To print live logs")
parser.add_argument("--instruction", type=str,
help="text string", default="")
parser.add_argument("--n_samples", type=int,
help="The number of nodes added during expansion", default=3)
parser.add_argument("--depth", type=int,
help="Tree depth", default=5)
# TODO: implement this
# parser.add_argument("--is_resume", action='store_true', help="To resume run")
# parser.add_argument("--resume_dir", type=str, help="If resume, the logging directory", default="")
args = parser.parse_args()
return args
def strategy_factory(strategy: str):
def kwargs_wrapper_gen(func, delete_keys=[]):
def kwargs_wrapper(**kwargs):
for key in delete_keys:
del kwargs[key]
return func(**kwargs)
return kwargs_wrapper
return kwargs_wrapper_gen(run_lats, delete_keys=[])
def lats_main(args):
# check if the strategy is valid
run_strategy = strategy_factory(args.strategy)
# start the run
# evaluate with pass@k
x = run_strategy(
model_name=args.model,
language=args.language,
max_iters=args.max_iters,
verbose=args.verbose,
instruction=args.instruction,
n_samples=args.n_samples,
depth=args.depth
)
return x
def main(args):
lats_main(args)
if __name__ == "__main__":
args = get_args()
main(args)
|