Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from agent.dreamer import DreamerAgent, stop_gradient | |
import agent.dreamer_utils as common | |
class Disagreement(nn.Module): | |
def __init__(self, obs_dim, action_dim, hidden_dim, n_models=5, pred_dim=None): | |
super().__init__() | |
if pred_dim is None: pred_dim = obs_dim | |
self.ensemble = nn.ModuleList([ | |
nn.Sequential(nn.Linear(obs_dim + action_dim, hidden_dim), | |
nn.ReLU(), nn.Linear(hidden_dim, pred_dim)) | |
for _ in range(n_models) | |
]) | |
def forward(self, obs, action, next_obs): | |
assert obs.shape[0] == next_obs.shape[0] | |
assert obs.shape[0] == action.shape[0] | |
errors = [] | |
for model in self.ensemble: | |
next_obs_hat = model(torch.cat([obs, action], dim=-1)) | |
model_error = torch.norm(next_obs - next_obs_hat, | |
dim=-1, | |
p=2, | |
keepdim=True) | |
errors.append(model_error) | |
return torch.cat(errors, dim=1) | |
def get_disagreement(self, obs, action): | |
assert obs.shape[0] == action.shape[0] | |
preds = [] | |
for model in self.ensemble: | |
next_obs_hat = model(torch.cat([obs, action], dim=-1)) | |
preds.append(next_obs_hat) | |
preds = torch.stack(preds, dim=0) | |
return torch.var(preds, dim=0).mean(dim=-1) | |
class Plan2Explore(DreamerAgent): | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
in_dim = self.wm.inp_size | |
pred_dim = self.wm.embed_dim | |
self.hidden_dim = pred_dim | |
self.reward_free = True | |
self.disagreement = Disagreement(in_dim, self.act_dim, | |
self.hidden_dim, pred_dim=pred_dim).to(self.device) | |
# optimizers | |
self.disagreement_opt = common.Optimizer('disagreement', self.disagreement.parameters(), **self.cfg.model_opt, use_amp=self._use_amp) | |
self.disagreement.train() | |
self.requires_grad_(requires_grad=False) | |
def update_disagreement(self, obs, action, next_obs, step): | |
metrics = dict() | |
error = self.disagreement(obs, action, next_obs) | |
loss = error.mean() | |
metrics.update(self.disagreement_opt(loss, self.disagreement.parameters())) | |
metrics['disagreement_loss'] = loss.item() | |
return metrics | |
def compute_intr_reward(self, seq): | |
obs, action = seq['feat'][:-1], stop_gradient(seq['action'][1:]) | |
intr_rew = torch.zeros(list(seq['action'].shape[:-1]) + [1], device=self.device) | |
if len(action.shape) > 2: | |
B, T, _ = action.shape | |
obs = obs.reshape(B*T, -1) | |
action = action.reshape(B*T, -1) | |
reward = self.disagreement.get_disagreement(obs, action).reshape(B, T, 1) | |
else: | |
reward = self.disagreement.get_disagreement(obs, action).unsqueeze(-1) | |
intr_rew[1:] = reward | |
return intr_rew | |
def update(self, data, step): | |
metrics = {} | |
B, T, _ = data['action'].shape | |
state, outputs, mets = self.wm.update(data, state=None) | |
metrics.update(mets) | |
start = outputs['post'] | |
start = {k: stop_gradient(v) for k,v in start.items()} | |
if self.reward_free: | |
T = T-1 | |
inp = stop_gradient(outputs['feat'][:, :-1]).reshape(B*T, -1) | |
action = data['action'][:, 1:].reshape(B*T, -1) | |
out = stop_gradient(outputs['embed'][:,1:]).reshape(B*T,-1) | |
with common.RequiresGrad(self.disagreement): | |
with torch.cuda.amp.autocast(enabled=self._use_amp): | |
metrics.update( | |
self.update_disagreement(inp, action, out, step)) | |
metrics.update(self._acting_behavior.update( | |
self.wm, start, data['is_terminal'], reward_fn=self.compute_intr_reward)) | |
else: | |
reward_fn = lambda seq: self.wm.heads['reward'](seq['feat']).mean | |
metrics.update(self._acting_behavior.update( | |
self.wm, start, data['is_terminal'], reward_fn)) | |
return state, metrics |