import os import shutil from dataclasses import dataclass from typing import NamedTuple, Optional from rl_algo_impls.shared.vec_env import make_eval_env from rl_algo_impls.runner.config import Config, EnvHyperparams, Hyperparams, RunArgs from rl_algo_impls.runner.running_utils import ( load_hyperparams, set_seeds, get_device, make_policy, ) from rl_algo_impls.shared.callbacks.eval_callback import evaluate from rl_algo_impls.shared.policy.policy import Policy from rl_algo_impls.shared.stats import EpisodesStats @dataclass class EvalArgs(RunArgs): render: bool = True best: bool = True n_envs: Optional[int] = 1 n_episodes: int = 3 deterministic_eval: Optional[bool] = None no_print_returns: bool = False wandb_run_path: Optional[str] = None class Evaluation(NamedTuple): policy: Policy stats: EpisodesStats config: Config def evaluate_model(args: EvalArgs, root_dir: str) -> Evaluation: if args.wandb_run_path: import wandb api = wandb.Api() run = api.run(args.wandb_run_path) params = run.config args.algo = params["algo"] args.env = params["env"] args.seed = params.get("seed", None) args.use_deterministic_algorithms = params.get( "use_deterministic_algorithms", True ) config = Config(args, Hyperparams.from_dict_with_extra_fields(params), root_dir) model_path = config.model_dir_path(best=args.best, downloaded=True) model_archive_name = config.model_dir_name(best=args.best, extension=".zip") run.file(model_archive_name).download() if os.path.isdir(model_path): shutil.rmtree(model_path) shutil.unpack_archive(model_archive_name, model_path) os.remove(model_archive_name) else: hyperparams = load_hyperparams(args.algo, args.env) config = Config(args, hyperparams, root_dir) model_path = config.model_dir_path(best=args.best) print(args) set_seeds(args.seed, args.use_deterministic_algorithms) env = make_eval_env( config, EnvHyperparams(**config.env_hyperparams), override_n_envs=args.n_envs, render=args.render, normalize_load_path=model_path, ) device = get_device(config, env) policy = make_policy( args.algo, env, device, load_path=model_path, **config.policy_hyperparams, ).eval() deterministic = ( args.deterministic_eval if args.deterministic_eval is not None else config.eval_params.get("deterministic", True) ) return Evaluation( policy, evaluate( env, policy, args.n_episodes, render=args.render, deterministic=deterministic, print_returns=not args.no_print_returns, ), config, )