File size: 2,464 Bytes
41a6762 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import numpy as np
from dataclasses import dataclass, field
from stable_baselines3.common.vec_env.base_vec_env import VecEnvObs
from typing import Generic, List, Optional, Type, TypeVar
@dataclass
class Trajectory:
obs: List[np.ndarray] = field(default_factory=list)
act: List[np.ndarray] = field(default_factory=list)
next_obs: Optional[np.ndarray] = None
rew: List[float] = field(default_factory=list)
terminated: bool = False
v: List[float] = field(default_factory=list)
def add(
self,
obs: np.ndarray,
act: np.ndarray,
next_obs: np.ndarray,
rew: float,
terminated: bool,
v: float,
):
self.obs.append(obs)
self.act.append(act)
self.next_obs = next_obs if not terminated else None
self.rew.append(rew)
self.terminated = terminated
self.v.append(v)
def __len__(self) -> int:
return len(self.obs)
T = TypeVar("T", bound=Trajectory)
class TrajectoryAccumulator(Generic[T]):
def __init__(self, num_envs: int, trajectory_class: Type[T] = Trajectory) -> None:
self.num_envs = num_envs
self.trajectory_class = trajectory_class
self._trajectories = []
self._current_trajectories = [trajectory_class() for _ in range(num_envs)]
def step(
self,
obs: VecEnvObs,
action: np.ndarray,
next_obs: VecEnvObs,
reward: np.ndarray,
done: np.ndarray,
val: np.ndarray,
*args,
) -> None:
assert isinstance(obs, np.ndarray)
assert isinstance(next_obs, np.ndarray)
for i, args in enumerate(zip(obs, action, next_obs, reward, done, val, *args)):
trajectory = self._current_trajectories[i]
# TODO: Eventually take advantage of terminated/truncated differentiation in
# later versions of gym.
trajectory.add(*args)
if done[i]:
self._trajectories.append(trajectory)
self._current_trajectories[i] = self.trajectory_class()
self.on_done(i, trajectory)
@property
def all_trajectories(self) -> List[T]:
return self._trajectories + list(
filter(lambda t: len(t), self._current_trajectories)
)
def n_timesteps(self) -> int:
return sum(len(t) for t in self.all_trajectories)
def on_done(self, env_idx: int, trajectory: T) -> None:
pass
|