File size: 2,464 Bytes
41a6762
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import numpy as np

from dataclasses import dataclass, field
from stable_baselines3.common.vec_env.base_vec_env import VecEnvObs
from typing import Generic, List, Optional, Type, TypeVar


@dataclass
class Trajectory:
    obs: List[np.ndarray] = field(default_factory=list)
    act: List[np.ndarray] = field(default_factory=list)
    next_obs: Optional[np.ndarray] = None
    rew: List[float] = field(default_factory=list)
    terminated: bool = False
    v: List[float] = field(default_factory=list)

    def add(
        self,
        obs: np.ndarray,
        act: np.ndarray,
        next_obs: np.ndarray,
        rew: float,
        terminated: bool,
        v: float,
    ):
        self.obs.append(obs)
        self.act.append(act)
        self.next_obs = next_obs if not terminated else None
        self.rew.append(rew)
        self.terminated = terminated
        self.v.append(v)

    def __len__(self) -> int:
        return len(self.obs)


T = TypeVar("T", bound=Trajectory)


class TrajectoryAccumulator(Generic[T]):
    def __init__(self, num_envs: int, trajectory_class: Type[T] = Trajectory) -> None:
        self.num_envs = num_envs
        self.trajectory_class = trajectory_class

        self._trajectories = []
        self._current_trajectories = [trajectory_class() for _ in range(num_envs)]

    def step(
        self,
        obs: VecEnvObs,
        action: np.ndarray,
        next_obs: VecEnvObs,
        reward: np.ndarray,
        done: np.ndarray,
        val: np.ndarray,
        *args,
    ) -> None:
        assert isinstance(obs, np.ndarray)
        assert isinstance(next_obs, np.ndarray)
        for i, args in enumerate(zip(obs, action, next_obs, reward, done, val, *args)):
            trajectory = self._current_trajectories[i]
            # TODO: Eventually take advantage of terminated/truncated differentiation in
            # later versions of gym.
            trajectory.add(*args)
            if done[i]:
                self._trajectories.append(trajectory)
                self._current_trajectories[i] = self.trajectory_class()
                self.on_done(i, trajectory)

    @property
    def all_trajectories(self) -> List[T]:
        return self._trajectories + list(
            filter(lambda t: len(t), self._current_trajectories)
        )

    def n_timesteps(self) -> int:
        return sum(len(t) for t in self.all_trajectories)

    def on_done(self, env_idx: int, trajectory: T) -> None:
        pass