|
|
|
import argparse |
|
import os |
|
import random |
|
import time |
|
from distutils.util import strtobool |
|
|
|
import flax |
|
import flax.linen as nn |
|
import gymnasium as gym |
|
import jax |
|
import jax.numpy as jnp |
|
import numpy as np |
|
import optax |
|
from flax.training.train_state import TrainState |
|
from stable_baselines3.common.buffers import ReplayBuffer |
|
from torch.utils.tensorboard import SummaryWriter |
|
|
|
|
|
def parse_args(): |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"), |
|
help="the name of this experiment") |
|
parser.add_argument("--seed", type=int, default=1, |
|
help="seed of the experiment") |
|
parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, |
|
help="if toggled, this experiment will be tracked with Weights and Biases") |
|
parser.add_argument("--wandb-project-name", type=str, default="cleanRL", |
|
help="the wandb's project name") |
|
parser.add_argument("--wandb-entity", type=str, default=None, |
|
help="the entity (team) of wandb's project") |
|
parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, |
|
help="whether to capture videos of the agent performances (check out `videos` folder)") |
|
parser.add_argument("--save-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, |
|
help="whether to save model into the `runs/{run_name}` folder") |
|
parser.add_argument("--upload-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, |
|
help="whether to upload the saved model to huggingface") |
|
parser.add_argument("--hf-entity", type=str, default="", |
|
help="the user or org name of the model repository from the Hugging Face Hub") |
|
|
|
|
|
parser.add_argument("--env-id", type=str, default="HalfCheetah-v4", |
|
help="the id of the environment") |
|
parser.add_argument("--total-timesteps", type=int, default=1000000, |
|
help="total timesteps of the experiments") |
|
parser.add_argument("--learning-rate", type=float, default=3e-4, |
|
help="the learning rate of the optimizer") |
|
parser.add_argument("--buffer-size", type=int, default=int(1e6), |
|
help="the replay memory buffer size") |
|
parser.add_argument("--gamma", type=float, default=0.99, |
|
help="the discount factor gamma") |
|
parser.add_argument("--tau", type=float, default=0.005, |
|
help="target smoothing coefficient (default: 0.005)") |
|
parser.add_argument("--policy-noise", type=float, default=0.2, |
|
help="the scale of policy noise") |
|
parser.add_argument("--batch-size", type=int, default=256, |
|
help="the batch size of sample from the reply memory") |
|
parser.add_argument("--exploration-noise", type=float, default=0.1, |
|
help="the scale of exploration noise") |
|
parser.add_argument("--learning-starts", type=int, default=25e3, |
|
help="timestep to start learning") |
|
parser.add_argument("--policy-frequency", type=int, default=2, |
|
help="the frequency of training policy (delayed)") |
|
parser.add_argument("--noise-clip", type=float, default=0.5, |
|
help="noise clip parameter of the Target Policy Smoothing Regularization") |
|
args = parser.parse_args() |
|
|
|
return args |
|
|
|
|
|
def make_env(env_id, seed, idx, capture_video, run_name): |
|
def thunk(): |
|
if capture_video and idx == 0: |
|
env = gym.make(env_id, render_mode="rgb_array") |
|
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") |
|
else: |
|
env = gym.make(env_id) |
|
env = gym.wrappers.RecordEpisodeStatistics(env) |
|
env.action_space.seed(seed) |
|
return env |
|
|
|
return thunk |
|
|
|
|
|
|
|
class QNetwork(nn.Module): |
|
@nn.compact |
|
def __call__(self, x: jnp.ndarray, a: jnp.ndarray): |
|
x = jnp.concatenate([x, a], -1) |
|
x = nn.Dense(256)(x) |
|
x = nn.relu(x) |
|
x = nn.Dense(256)(x) |
|
x = nn.relu(x) |
|
x = nn.Dense(1)(x) |
|
return x |
|
|
|
|
|
class Actor(nn.Module): |
|
action_dim: int |
|
action_scale: jnp.ndarray |
|
action_bias: jnp.ndarray |
|
|
|
@nn.compact |
|
def __call__(self, x): |
|
x = nn.Dense(256)(x) |
|
x = nn.relu(x) |
|
x = nn.Dense(256)(x) |
|
x = nn.relu(x) |
|
x = nn.Dense(self.action_dim)(x) |
|
x = nn.tanh(x) |
|
x = x * self.action_scale + self.action_bias |
|
return x |
|
|
|
|
|
class TrainState(TrainState): |
|
target_params: flax.core.FrozenDict |
|
|
|
|
|
if __name__ == "__main__": |
|
import stable_baselines3 as sb3 |
|
|
|
if sb3.__version__ < "2.0": |
|
raise ValueError( |
|
"""Ongoing migration: run the following command to install the new dependencies: |
|
poetry run pip install "stable_baselines3==2.0.0a1" |
|
""" |
|
) |
|
args = parse_args() |
|
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" |
|
if args.track: |
|
import wandb |
|
|
|
wandb.init( |
|
project=args.wandb_project_name, |
|
entity=args.wandb_entity, |
|
sync_tensorboard=True, |
|
config=vars(args), |
|
name=run_name, |
|
monitor_gym=True, |
|
save_code=True, |
|
) |
|
writer = SummaryWriter(f"runs/{run_name}") |
|
writer.add_text( |
|
"hyperparameters", |
|
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), |
|
) |
|
video_filenames = set() |
|
|
|
|
|
random.seed(args.seed) |
|
np.random.seed(args.seed) |
|
key = jax.random.PRNGKey(args.seed) |
|
key, actor_key, qf1_key, qf2_key = jax.random.split(key, 4) |
|
|
|
|
|
envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)]) |
|
assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" |
|
|
|
max_action = float(envs.single_action_space.high[0]) |
|
envs.single_observation_space.dtype = np.float32 |
|
rb = ReplayBuffer( |
|
args.buffer_size, |
|
envs.single_observation_space, |
|
envs.single_action_space, |
|
device="cpu", |
|
handle_timeout_termination=False, |
|
) |
|
|
|
|
|
obs, _ = envs.reset(seed=args.seed) |
|
|
|
actor = Actor( |
|
action_dim=np.prod(envs.single_action_space.shape), |
|
action_scale=jnp.array((envs.action_space.high - envs.action_space.low) / 2.0), |
|
action_bias=jnp.array((envs.action_space.high + envs.action_space.low) / 2.0), |
|
) |
|
actor_state = TrainState.create( |
|
apply_fn=actor.apply, |
|
params=actor.init(actor_key, obs), |
|
target_params=actor.init(actor_key, obs), |
|
tx=optax.adam(learning_rate=args.learning_rate), |
|
) |
|
qf = QNetwork() |
|
qf1_state = TrainState.create( |
|
apply_fn=qf.apply, |
|
params=qf.init(qf1_key, obs, envs.action_space.sample()), |
|
target_params=qf.init(qf1_key, obs, envs.action_space.sample()), |
|
tx=optax.adam(learning_rate=args.learning_rate), |
|
) |
|
qf2_state = TrainState.create( |
|
apply_fn=qf.apply, |
|
params=qf.init(qf2_key, obs, envs.action_space.sample()), |
|
target_params=qf.init(qf2_key, obs, envs.action_space.sample()), |
|
tx=optax.adam(learning_rate=args.learning_rate), |
|
) |
|
actor.apply = jax.jit(actor.apply) |
|
qf.apply = jax.jit(qf.apply) |
|
|
|
@jax.jit |
|
def update_critic( |
|
actor_state: TrainState, |
|
qf1_state: TrainState, |
|
qf2_state: TrainState, |
|
observations: np.ndarray, |
|
actions: np.ndarray, |
|
next_observations: np.ndarray, |
|
rewards: np.ndarray, |
|
terminations: np.ndarray, |
|
key: jnp.ndarray, |
|
): |
|
|
|
|
|
key, noise_key = jax.random.split(key, 2) |
|
clipped_noise = ( |
|
jnp.clip( |
|
(jax.random.normal(noise_key, actions.shape) * args.policy_noise), |
|
-args.noise_clip, |
|
args.noise_clip, |
|
) |
|
* actor.action_scale |
|
) |
|
next_state_actions = jnp.clip( |
|
actor.apply(actor_state.target_params, next_observations) + clipped_noise, |
|
envs.single_action_space.low, |
|
envs.single_action_space.high, |
|
) |
|
qf1_next_target = qf.apply(qf1_state.target_params, next_observations, next_state_actions).reshape(-1) |
|
qf2_next_target = qf.apply(qf2_state.target_params, next_observations, next_state_actions).reshape(-1) |
|
min_qf_next_target = jnp.minimum(qf1_next_target, qf2_next_target) |
|
next_q_value = (rewards + (1 - terminations) * args.gamma * (min_qf_next_target)).reshape(-1) |
|
|
|
def mse_loss(params): |
|
qf_a_values = qf.apply(params, observations, actions).squeeze() |
|
return ((qf_a_values - next_q_value) ** 2).mean(), qf_a_values.mean() |
|
|
|
(qf1_loss_value, qf1_a_values), grads1 = jax.value_and_grad(mse_loss, has_aux=True)(qf1_state.params) |
|
(qf2_loss_value, qf2_a_values), grads2 = jax.value_and_grad(mse_loss, has_aux=True)(qf2_state.params) |
|
qf1_state = qf1_state.apply_gradients(grads=grads1) |
|
qf2_state = qf2_state.apply_gradients(grads=grads2) |
|
|
|
return (qf1_state, qf2_state), (qf1_loss_value, qf2_loss_value), (qf1_a_values, qf2_a_values), key |
|
|
|
@jax.jit |
|
def update_actor( |
|
actor_state: TrainState, |
|
qf1_state: TrainState, |
|
qf2_state: TrainState, |
|
observations: np.ndarray, |
|
): |
|
def actor_loss(params): |
|
return -qf.apply(qf1_state.params, observations, actor.apply(params, observations)).mean() |
|
|
|
actor_loss_value, grads = jax.value_and_grad(actor_loss)(actor_state.params) |
|
actor_state = actor_state.apply_gradients(grads=grads) |
|
actor_state = actor_state.replace( |
|
target_params=optax.incremental_update(actor_state.params, actor_state.target_params, args.tau) |
|
) |
|
|
|
qf1_state = qf1_state.replace( |
|
target_params=optax.incremental_update(qf1_state.params, qf1_state.target_params, args.tau) |
|
) |
|
qf2_state = qf2_state.replace( |
|
target_params=optax.incremental_update(qf2_state.params, qf2_state.target_params, args.tau) |
|
) |
|
return actor_state, (qf1_state, qf2_state), actor_loss_value |
|
|
|
start_time = time.time() |
|
for global_step in range(args.total_timesteps): |
|
|
|
if global_step < args.learning_starts: |
|
actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) |
|
else: |
|
actions = actor.apply(actor_state.params, obs) |
|
actions = np.array( |
|
[ |
|
( |
|
jax.device_get(actions)[0] |
|
+ np.random.normal(0, max_action * args.exploration_noise, size=envs.single_action_space.shape) |
|
).clip(envs.single_action_space.low, envs.single_action_space.high) |
|
] |
|
) |
|
|
|
|
|
next_obs, rewards, terminations, truncations, infos = envs.step(actions) |
|
|
|
|
|
if "final_info" in infos: |
|
for info in infos["final_info"]: |
|
print(f"global_step={global_step}, episodic_return={info['episode']['r']}") |
|
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) |
|
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) |
|
break |
|
|
|
|
|
real_next_obs = next_obs.copy() |
|
for idx, trunc in enumerate(truncations): |
|
if trunc: |
|
real_next_obs[idx] = infos["final_observation"][idx] |
|
rb.add(obs, real_next_obs, actions, rewards, terminations, infos) |
|
|
|
|
|
obs = next_obs |
|
|
|
|
|
if global_step > args.learning_starts: |
|
data = rb.sample(args.batch_size) |
|
|
|
(qf1_state, qf2_state), (qf1_loss_value, qf2_loss_value), (qf1_a_values, qf2_a_values), key = update_critic( |
|
actor_state, |
|
qf1_state, |
|
qf2_state, |
|
data.observations.numpy(), |
|
data.actions.numpy(), |
|
data.next_observations.numpy(), |
|
data.rewards.flatten().numpy(), |
|
data.dones.flatten().numpy(), |
|
key, |
|
) |
|
|
|
if global_step % args.policy_frequency == 0: |
|
actor_state, (qf1_state, qf2_state), actor_loss_value = update_actor( |
|
actor_state, |
|
qf1_state, |
|
qf2_state, |
|
data.observations.numpy(), |
|
) |
|
|
|
if global_step % 100 == 0: |
|
writer.add_scalar("losses/qf1_loss", qf1_loss_value.item(), global_step) |
|
writer.add_scalar("losses/qf2_loss", qf2_loss_value.item(), global_step) |
|
writer.add_scalar("losses/qf1_values", qf1_a_values.item(), global_step) |
|
writer.add_scalar("losses/qf2_values", qf2_a_values.item(), global_step) |
|
writer.add_scalar("losses/actor_loss", actor_loss_value.item(), global_step) |
|
print("SPS:", int(global_step / (time.time() - start_time))) |
|
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) |
|
|
|
if args.save_model: |
|
model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model" |
|
with open(model_path, "wb") as f: |
|
f.write( |
|
flax.serialization.to_bytes( |
|
[ |
|
actor_state.params, |
|
qf1_state.params, |
|
qf2_state.params, |
|
] |
|
) |
|
) |
|
print(f"model saved to {model_path}") |
|
from cleanrl_utils.evals.td3_jax_eval import evaluate |
|
|
|
episodic_returns = evaluate( |
|
model_path, |
|
make_env, |
|
args.env_id, |
|
eval_episodes=10, |
|
run_name=f"{run_name}-eval", |
|
Model=(Actor, QNetwork), |
|
exploration_noise=args.exploration_noise, |
|
) |
|
for idx, episodic_return in enumerate(episodic_returns): |
|
writer.add_scalar("eval/episodic_return", episodic_return, idx) |
|
|
|
if args.upload_model: |
|
from cleanrl_utils.huggingface import push_to_hub |
|
|
|
repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}" |
|
repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name |
|
push_to_hub(args, episodic_returns, repo_id, "TD3", f"runs/{run_name}", f"videos/{run_name}-eval") |
|
|
|
envs.close() |
|
writer.close() |
|
|