Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import agent.dreamer_utils as common | |
from collections import defaultdict | |
import numpy as np | |
class ResidualLinear(nn.Module): | |
def __init__(self, in_channels, out_channels, norm='layer', act='SiLU', prenorm=False): | |
super().__init__() | |
self.norm_layer = common.NormLayer(norm, in_channels if prenorm else out_channels) | |
self.act = common.get_act(act) | |
self.layer = nn.Linear(in_channels, out_channels) | |
self.prenorm = prenorm | |
self.res_proj = nn.Identity() if in_channels == out_channels else nn.Linear(in_channels, out_channels) | |
def forward(self, x): | |
if self.prenorm: | |
h = self.norm_layer(x) | |
h = self.layer(h) | |
else: | |
h = self.layer(x) | |
h = self.norm_layer(h) | |
h = self.act(h) | |
return h + self.res_proj(x) | |
class UNetDenoiser(nn.Module): | |
def __init__(self, in_channels : int, mid_channels : int, n_layers : int, norm='layer', act= 'SiLU', ): | |
super().__init__() | |
out_channels = in_channels | |
self.down = nn.ModuleList() | |
for i in range(n_layers): | |
if i == (n_layers - 1): | |
self.down.append(ResidualLinear(in_channels, mid_channels, norm=norm, act=act)) | |
else: | |
self.down.append(ResidualLinear(in_channels, in_channels, norm=norm, act=act)) | |
self.mid = nn.ModuleList() | |
for i in range(n_layers): | |
self.mid.append(ResidualLinear(mid_channels, mid_channels, norm=norm, act=act)) | |
self.up = nn.ModuleList() | |
for i in range(n_layers): | |
if i == 0: | |
self.up.append(ResidualLinear(mid_channels * 2, out_channels, norm='none', act='Identity')) | |
else: | |
self.up.append(ResidualLinear(out_channels * 2, out_channels, norm=norm, act=act)) | |
def forward(self, x): | |
down_res = [] | |
for down_layer in self.down: | |
x = down_layer(x) | |
down_res.append(x) | |
for mid_layer in self.mid: | |
x = mid_layer(x) | |
down_res.reverse() | |
for up_layer, res in zip(self.up, down_res): | |
x = up_layer(torch.cat([x, res], dim=-1)) | |
return x | |
class VideoSSM(common.EnsembleRSSM): | |
def __init__(self, *args, | |
connector_kl={}, temporal_embeds=False, detached_post=True, n_frames=8, | |
token_dropout=0., loss_scale=1, clip_add_noise=0, clip_lafite_noise=0, | |
rescale_embeds=False, denoising_ae=False, learn_initial=True, **kwargs,): | |
super().__init__(*args, **kwargs) | |
# | |
self.n_frames = n_frames | |
# by default, adding the n_frames in actions (doesn't hurt and easier to test whether it's useful or not) | |
self.viclip_emb_dim = kwargs['action_dim'] - self.n_frames | |
# | |
self.temporal_embeds = temporal_embeds | |
self.detached_post = detached_post | |
self.connector_kl = connector_kl | |
self.token_dropout = token_dropout | |
self.loss_scale = loss_scale | |
self.rescale_embeds = rescale_embeds | |
self.clip_add_noise = clip_add_noise | |
self.clip_lafite_noise = clip_lafite_noise | |
self.clip_const = np.sqrt(self.viclip_emb_dim).item() | |
self.denoising_ae = denoising_ae | |
if self.denoising_ae: | |
self.aligner = UNetDenoiser(self.viclip_emb_dim, self.viclip_emb_dim // 2, n_layers=2, norm='layer', act='SiLU') | |
self.learn_initial = learn_initial | |
if self.learn_initial: | |
self.initial_state_pred = nn.Sequential( | |
nn.Linear(kwargs['action_dim'], kwargs['hidden']), | |
common.NormLayer(kwargs['norm'],kwargs['hidden']), common.get_act('SiLU'), | |
nn.Linear(kwargs['hidden'], kwargs['hidden']), | |
common.NormLayer(kwargs['norm'],kwargs['hidden']), common.get_act('SiLU'), | |
nn.Linear(kwargs['hidden'], kwargs['deter']) | |
) | |
# Deleting non-useful models | |
del self._obs_out | |
del self._obs_dist | |
def initial(self, batch_size, init_embed=None, ignore_learned=False): | |
init = super().initial(batch_size) | |
if self.learn_initial and not ignore_learned and hasattr(self, 'initial_state_pred'): | |
assert init_embed is not None | |
# patcher to avoid edge cases | |
if init_embed.shape[-1] == self.viclip_emb_dim: | |
patcher = torch.zeros((*init_embed.shape[:-1], 8), device=self.device) | |
init_embed = torch.cat([init_embed, patcher], dim=-1) | |
init['deter'] = self.initial_state_pred(init_embed) | |
stoch, stats = self.get_stoch_stats_from_deter_state(init) | |
init['stoch'] = stoch | |
init.update(stats) | |
return init | |
def get_action(self, video_embed): | |
n_frames = self.n_frames | |
B, T = video_embed.shape[:2] | |
if self.rescale_embeds: | |
video_embed = video_embed * self.clip_const | |
temporal_embeds = F.one_hot(torch.arange(T).to(video_embed.device) % n_frames, n_frames).reshape(1, T, n_frames,).repeat(B, 1, 1,) | |
if not self.temporal_embeds: | |
temporal_embeds *= 0 | |
return torch.cat([video_embed, temporal_embeds],dim=-1) | |
def update(self, video_embed, wm_post): | |
n_frames = self.n_frames | |
B, T = video_embed.shape[:2] | |
loss = 0 | |
metrics = {} | |
# NOVEL | |
video_embed = video_embed[:,n_frames-1::n_frames] # tested | |
video_embed = video_embed.to(self.device) | |
video_embed = video_embed.reshape(B, T // n_frames, 1, -1).repeat(1,1, n_frames, 1).reshape(B, T, -1) | |
orig_video_embed = video_embed | |
if self.clip_add_noise > 0: | |
video_embed = video_embed + torch.randn_like(video_embed, device=video_embed.device) * self.clip_add_noise | |
video_embed = nn.functional.normalize(video_embed, dim=-1) | |
if self.clip_lafite_noise > 0: | |
normed_noise = F.normalize(torch.randn_like(video_embed, device=video_embed.device), dim=-1) | |
video_embed = (1 - self.clip_lafite_noise) * video_embed + self.clip_lafite_noise * normed_noise | |
video_embed = nn.functional.normalize(video_embed, dim=-1) | |
if self.denoising_ae: | |
assert (self.clip_lafite_noise + self.clip_add_noise) > 0, "Nothing to denoise" | |
denoised_embed = self.aligner(video_embed) | |
denoised_embed = F.normalize(denoised_embed, dim=-1) | |
denoising_loss = 1 - F.cosine_similarity(denoised_embed, orig_video_embed, dim=-1).mean() # works same as F.mse_loss(denoised_embed, orig_video_embed).mean() | |
loss += denoising_loss | |
metrics['aligner_cosine_distance'] = denoising_loss | |
# if using a denoiser, it's the denoiser's duty to denoise the video embed | |
video_embed = orig_video_embed # could also be denoised_embed for e2e training | |
embed_actions = self.get_action(video_embed) | |
if self.detached_post: | |
wm_post = { k : v.reshape(B, T, *v.shape[2:]).detach() for k,v in wm_post.items() } | |
else: | |
wm_post = { k : v.reshape(B, T, *v.shape[2:]) for k,v in wm_post.items() } | |
# Get prior states | |
prior_states = defaultdict(list) | |
for t in range(T): | |
# Get video action | |
action = embed_actions[:, t] | |
if t == 0: | |
prev_state = self.initial(batch_size=wm_post['stoch'].shape[0], init_embed=action) | |
else: | |
# Get deter from prior, get stoch from wm_post | |
prev_state = prior | |
prev_state[self.cell_input] = wm_post[self.cell_input][:, t-1] | |
if self.token_dropout > 0: | |
prev_state['stoch'] = torch.einsum('b...,b->b...', prev_state['stoch'], (torch.rand(B, device=action.device) > self.token_dropout).float() ) | |
prior = self.img_step(prev_state, action) | |
for k in prior: | |
prior_states[k].append(prior[k]) | |
# Aggregate | |
for k in prior_states: | |
prior_states[k] = torch.stack(prior_states[k], dim=1) | |
# Compute loss | |
prior = prior_states | |
kl_loss, kl_value = self.kl_loss(wm_post, prior, **self.connector_kl) | |
video_loss = self.loss_scale * kl_loss | |
metrics['connector_kl'] = kl_value.mean() | |
loss += video_loss | |
# Compute initial KL | |
video_embed = video_embed.reshape(B, T // n_frames, n_frames, -1)[:,1:,0].reshape(B * (T//n_frames-1), 1, -1) # taking only one (0) and skipping first temporal step | |
embed_actions = self.get_action(video_embed) | |
wm_post = { k : v.reshape(B, T // n_frames, n_frames, *v.shape[2:])[:,1:,0].reshape(B * (T//n_frames-1), *v.shape[2:]) for k,v in wm_post.items() } | |
action = embed_actions[:, 0] | |
prev_state = self.initial(batch_size=wm_post['stoch'].shape[0], init_embed=action) | |
prior = self.img_step(prev_state, action) | |
kl_loss, kl_value = self.kl_loss(wm_post, prior, **self.connector_kl) | |
metrics['connector_initial_kl'] = kl_value.mean() | |
return loss, metrics | |
def video_imagine(self, video_embed, dreamer_init=None, sample=True, reset_every_n_frames=True, denoise=False): | |
n_frames = self.n_frames | |
B, T = video_embed.shape[:2] | |
if self.denoising_ae and denoise: | |
denoised_embed = self.aligner(video_embed) | |
video_embed = F.normalize(denoised_embed, dim=-1) | |
action = self.get_action(video_embed) | |
# Imagine | |
init = self.initial(batch_size=B, init_embed=action[:, 0]) # -> this ensures only stoch is used from the current frame | |
if dreamer_init is not None: | |
init[self.cell_input] = dreamer_init[self.cell_input] | |
if reset_every_n_frames: | |
prior_states = defaultdict(list) | |
for action_chunk in torch.chunk(action, T // n_frames, dim=1): | |
prior = self.imagine(action_chunk, init, sample=sample) | |
for k in prior: | |
prior_states[k].append(prior[k]) | |
# -> this ensures only stoch is used from the current frame | |
init = self.initial(batch_size=B, ignore_learned=True) | |
init[self.cell_input] = prior[self.cell_input][:, -1] | |
# Agg | |
for k in prior_states: | |
prior_states[k] = torch.cat(prior_states[k], dim=1) | |
prior = prior_states | |
else: | |
prior = self.imagine(action, init, sample=sample) | |
return prior |