File size: 4,251 Bytes
b9e43f7
 
 
 
 
 
 
 
 
 
 
 
 
9744ddc
b9e43f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9744ddc
b9e43f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9744ddc
 
b9e43f7
 
 
 
 
9744ddc
b9e43f7
 
 
9744ddc
 
 
 
 
 
 
 
 
 
 
b9e43f7
9744ddc
 
 
 
 
 
 
b9e43f7
 
 
 
 
 
 
 
 
 
 
9744ddc
b9e43f7
9744ddc
 
 
 
 
b9e43f7
 
 
 
9744ddc
 
 
 
 
b9e43f7
 
 
 
 
9744ddc
b9e43f7
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from dataclasses import astuple
from typing import Optional

import gym
import numpy as np
from torch.utils.tensorboard.writer import SummaryWriter

from rl_algo_impls.runner.config import Config, EnvHyperparams
from rl_algo_impls.wrappers.action_mask_wrapper import MicrortsMaskWrapper
from rl_algo_impls.wrappers.episode_stats_writer import EpisodeStatsWriter
from rl_algo_impls.wrappers.hwc_to_chw_observation import HwcToChwObservation
from rl_algo_impls.wrappers.is_vector_env import IsVectorEnv
from rl_algo_impls.wrappers.microrts_stats_recorder import MicrortsStatsRecorder
from rl_algo_impls.wrappers.self_play_wrapper import SelfPlayWrapper
from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv


def make_microrts_env(
    config: Config,
    hparams: EnvHyperparams,
    training: bool = True,
    render: bool = False,
    normalize_load_path: Optional[str] = None,
    tb_writer: Optional[SummaryWriter] = None,
) -> VecEnv:
    import gym_microrts
    from gym_microrts import microrts_ai

    from rl_algo_impls.shared.vec_env.microrts_compat import (
        MicroRTSGridModeSharedMemVecEnvCompat,
        MicroRTSGridModeVecEnvCompat,
    )

    (
        _,  # env_type
        n_envs,
        _,  # frame_stack
        make_kwargs,
        _,  # no_reward_timeout_steps
        _,  # no_reward_fire_steps
        _,  # vec_env_class
        _,  # normalize
        _,  # normalize_kwargs,
        rolling_length,
        _,  # train_record_video
        _,  # video_step_interval
        _,  # initial_steps_to_truncate
        _,  # clip_atari_rewards
        _,  # normalize_type
        _,  # mask_actions
        bots,
        self_play_kwargs,
        selfplay_bots,
    ) = astuple(hparams)

    seed = config.seed(training=training)

    make_kwargs = make_kwargs or {}
    self_play_kwargs = self_play_kwargs or {}
    if "num_selfplay_envs" not in make_kwargs:
        make_kwargs["num_selfplay_envs"] = 0
    if "num_bot_envs" not in make_kwargs:
        num_selfplay_envs = make_kwargs["num_selfplay_envs"]
        if num_selfplay_envs:
            num_bot_envs = (
                n_envs
                - make_kwargs["num_selfplay_envs"]
                + self_play_kwargs.get("num_old_policies", 0)
                + (len(selfplay_bots) if selfplay_bots else 0)
            )
        else:
            num_bot_envs = n_envs
        make_kwargs["num_bot_envs"] = num_bot_envs
    if "reward_weight" in make_kwargs:
        # Reward Weights:
        # WinLossRewardFunction
        # ResourceGatherRewardFunction
        # ProduceWorkerRewardFunction
        # ProduceBuildingRewardFunction
        # AttackRewardFunction
        # ProduceCombatUnitRewardFunction
        make_kwargs["reward_weight"] = np.array(make_kwargs["reward_weight"])
    if bots:
        ai2s = []
        for ai_name, n in bots.items():
            for _ in range(n):
                if len(ai2s) >= make_kwargs["num_bot_envs"]:
                    break
                ai = getattr(microrts_ai, ai_name)
                assert ai, f"{ai_name} not in microrts_ai"
                ai2s.append(ai)
    else:
        ai2s = [microrts_ai.randomAI for _ in range(make_kwargs["num_bot_envs"])]
    make_kwargs["ai2s"] = ai2s
    if len(make_kwargs.get("map_paths", [])) < 2:
        EnvClass = MicroRTSGridModeSharedMemVecEnvCompat
    else:
        EnvClass = MicroRTSGridModeVecEnvCompat
    envs = EnvClass(**make_kwargs)
    envs = HwcToChwObservation(envs)
    envs = IsVectorEnv(envs)
    envs = MicrortsMaskWrapper(envs)

    if self_play_kwargs:
        if selfplay_bots:
            self_play_kwargs["selfplay_bots"] = selfplay_bots
        envs = SelfPlayWrapper(envs, config, **self_play_kwargs)

    if seed is not None:
        envs.action_space.seed(seed)
        envs.observation_space.seed(seed)

    envs = gym.wrappers.RecordEpisodeStatistics(envs)
    envs = MicrortsStatsRecorder(envs, config.algo_hyperparams.get("gamma", 0.99), bots)
    if training:
        assert tb_writer
        envs = EpisodeStatsWriter(
            envs,
            tb_writer,
            training=training,
            rolling_length=rolling_length,
            additional_keys_to_log=config.additional_keys_to_log,
        )

    return envs