File size: 1,768 Bytes
2a8bf2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# Support for PyTorch mps mode (https://pytorch.org/docs/stable/notes/mps.html)
import os

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

from rl_algo_impls.runner.running_utils import base_parser
from rl_algo_impls.runner.selfplay_evaluate import SelfplayEvalArgs, selfplay_evaluate


def selfplay_enjoy() -> None:
    parser = base_parser(multiple=False)
    parser.add_argument(
        "--wandb-run-paths",
        type=str,
        nargs="*",
        help="WandB run paths to load players from. Must be 0 or 2",
    )
    parser.add_argument(
        "--model-file-paths",
        type=str,
        help="File paths to load players from. Must be 0 or 2",
    )
    parser.add_argument("--render", action="store_true")
    parser.add_argument("--n-envs", default=1, type=int)
    parser.add_argument("--n-episodes", default=1, type=int)
    parser.add_argument("--deterministic-eval", default=None, type=bool)
    parser.add_argument(
        "--no-print-returns", action="store_true", help="Limit printing"
    )
    parser.add_argument(
        "--video-path", type=str, help="Path to save video of all plays"
    )
    # parser.set_defaults(
    #     algo=["ppo"],
    #     env=["Microrts-selfplay-unet-decay"],
    #     n_episodes=10,
    #     model_file_paths=[
    #         "downloaded_models/ppo-Microrts-selfplay-unet-decay-S3-best",
    #         "downloaded_models/ppo-Microrts-selfplay-unet-decay-S2-best",
    #     ],
    #     video_path="/Users/sgoodfriend/Desktop/decay3-vs-decay2",
    # )
    args = parser.parse_args()
    args.algo = args.algo[0]
    args.env = args.env[0]
    args.seed = args.seed[0]
    args = SelfplayEvalArgs(**vars(args))

    selfplay_evaluate(args, os.getcwd())


if __name__ == "__main__":
    selfplay_enjoy()