dqn-CartPole-v1 / shared /trajectory.py
sgoodfriend's picture
DQN playing CartPole-v1 from https://github.com/sgoodfriend/rl-algo-impls/tree/1d4094fbcc9082de7f53f4348dd4c7c354152907
ff8c6a7
raw
history blame
634 Bytes
import numpy as np
import torch
from dataclasses import dataclass
from typing import List
@dataclass
class Trajectory:
obs: List[np.ndarray]
act: List[np.ndarray]
rew: List[float]
v: List[float]
terminated: bool
def __init__(self) -> None:
self.obs = []
self.act = []
self.rew = []
self.v = []
self.terminated = False
def add(self, obs: np.ndarray, act: np.ndarray, rew: float, v: float):
self.obs.append(obs)
self.act.append(act)
self.rew.append(rew)
self.v.append(v)
def __len__(self) -> int:
return len(self.obs)