PPO playing MicrortsDefeatCoacAIShaped-v3 from https://github.com/sgoodfriend/rl-algo-impls/tree/342013343b316412ba3aff97b0430343c69c8364
b05d1d6
import numpy as np | |
import optuna | |
from time import perf_counter | |
from torch.utils.tensorboard.writer import SummaryWriter | |
from typing import NamedTuple, Union | |
from rl_algo_impls.shared.callbacks.callback import Callback | |
from rl_algo_impls.shared.callbacks.eval_callback import evaluate | |
from rl_algo_impls.shared.policy.policy import Policy | |
from rl_algo_impls.shared.stats import EpisodesStats | |
from rl_algo_impls.wrappers.episode_stats_writer import EpisodeStatsWriter | |
from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv, find_wrapper | |
class Evaluation(NamedTuple): | |
eval_stat: EpisodesStats | |
train_stat: EpisodesStats | |
score: float | |
class OptimizeCallback(Callback): | |
def __init__( | |
self, | |
policy: Policy, | |
env: VecEnv, | |
trial: optuna.Trial, | |
tb_writer: SummaryWriter, | |
step_freq: Union[int, float] = 50_000, | |
n_episodes: int = 10, | |
deterministic: bool = True, | |
) -> None: | |
super().__init__() | |
self.policy = policy | |
self.env = env | |
self.trial = trial | |
self.tb_writer = tb_writer | |
self.step_freq = step_freq | |
self.n_episodes = n_episodes | |
self.deterministic = deterministic | |
stats_writer = find_wrapper(policy.env, EpisodeStatsWriter) | |
assert stats_writer | |
self.stats_writer = stats_writer | |
self.eval_step = 1 | |
self.is_pruned = False | |
self.last_eval_stat = None | |
self.last_train_stat = None | |
self.last_score = -np.inf | |
def on_step(self, timesteps_elapsed: int = 1) -> bool: | |
super().on_step(timesteps_elapsed) | |
if self.timesteps_elapsed >= self.eval_step * self.step_freq: | |
self.evaluate() | |
return not self.is_pruned | |
return True | |
def evaluate(self) -> None: | |
self.last_eval_stat, self.last_train_stat, score = evaluation( | |
self.policy, | |
self.env, | |
self.tb_writer, | |
self.n_episodes, | |
self.deterministic, | |
self.timesteps_elapsed, | |
) | |
self.last_score = score | |
self.trial.report(score, self.eval_step) | |
if self.trial.should_prune(): | |
self.is_pruned = True | |
self.eval_step += 1 | |
def evaluation( | |
policy: Policy, | |
env: VecEnv, | |
tb_writer: SummaryWriter, | |
n_episodes: int, | |
deterministic: bool, | |
timesteps_elapsed: int, | |
) -> Evaluation: | |
start_time = perf_counter() | |
eval_stat = evaluate( | |
env, | |
policy, | |
n_episodes, | |
deterministic=deterministic, | |
print_returns=False, | |
) | |
end_time = perf_counter() | |
tb_writer.add_scalar( | |
"eval/steps_per_second", | |
eval_stat.length.sum() / (end_time - start_time), | |
timesteps_elapsed, | |
) | |
policy.train() | |
print(f"Eval Timesteps: {timesteps_elapsed} | {eval_stat}") | |
eval_stat.write_to_tensorboard(tb_writer, "eval", timesteps_elapsed) | |
stats_writer = find_wrapper(policy.env, EpisodeStatsWriter) | |
assert stats_writer | |
train_stat = EpisodesStats(stats_writer.episodes) | |
print(f" Train Stat: {train_stat}") | |
score = (eval_stat.score.mean + train_stat.score.mean) / 2 | |
print(f" Score: {round(score, 2)}") | |
tb_writer.add_scalar( | |
"eval/score", | |
score, | |
timesteps_elapsed, | |
) | |
return Evaluation(eval_stat, train_stat, score) | |