from collections import defaultdict from typing import Dict from typing import List import torch from espnet.nets.pytorch_backend.rnn.attentions import AttAdd from espnet.nets.pytorch_backend.rnn.attentions import AttCov from espnet.nets.pytorch_backend.rnn.attentions import AttCovLoc from espnet.nets.pytorch_backend.rnn.attentions import AttDot from espnet.nets.pytorch_backend.rnn.attentions import AttForward from espnet.nets.pytorch_backend.rnn.attentions import AttForwardTA from espnet.nets.pytorch_backend.rnn.attentions import AttLoc from espnet.nets.pytorch_backend.rnn.attentions import AttLoc2D from espnet.nets.pytorch_backend.rnn.attentions import AttLocRec from espnet.nets.pytorch_backend.rnn.attentions import AttMultiHeadAdd from espnet.nets.pytorch_backend.rnn.attentions import AttMultiHeadDot from espnet.nets.pytorch_backend.rnn.attentions import AttMultiHeadLoc from espnet.nets.pytorch_backend.rnn.attentions import AttMultiHeadMultiResLoc from espnet.nets.pytorch_backend.rnn.attentions import NoAtt from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention from espnet2.train.abs_espnet_model import AbsESPnetModel @torch.no_grad() def calculate_all_attentions( model: AbsESPnetModel, batch: Dict[str, torch.Tensor] ) -> Dict[str, List[torch.Tensor]]: """Derive the outputs from the all attention layers Args: model: batch: same as forward Returns: return_dict: A dict of a list of tensor. key_names x batch x (D1, D2, ...) """ bs = len(next(iter(batch.values()))) assert all(len(v) == bs for v in batch.values()), { k: v.shape for k, v in batch.items() } # 1. Register forward_hook fn to save the output from specific layers outputs = {} handles = {} for name, modu in model.named_modules(): def hook(module, input, output, name=name): if isinstance(module, MultiHeadedAttention): # NOTE(kamo): MultiHeadedAttention doesn't return attention weight # attn: (B, Head, Tout, Tin) outputs[name] = module.attn.detach().cpu() elif isinstance(module, AttLoc2D): c, w = output # w: previous concate attentions # w: (B, nprev, Tin) att_w = w[:, -1].detach().cpu() outputs.setdefault(name, []).append(att_w) elif isinstance(module, (AttCov, AttCovLoc)): c, w = output assert isinstance(w, list), type(w) # w: list of previous attentions # w: nprev x (B, Tin) att_w = w[-1].detach().cpu() outputs.setdefault(name, []).append(att_w) elif isinstance(module, AttLocRec): # w: (B, Tin) c, (w, (att_h, att_c)) = output att_w = w.detach().cpu() outputs.setdefault(name, []).append(att_w) elif isinstance( module, ( AttMultiHeadDot, AttMultiHeadAdd, AttMultiHeadLoc, AttMultiHeadMultiResLoc, ), ): c, w = output # w: nhead x (B, Tin) assert isinstance(w, list), type(w) att_w = [_w.detach().cpu() for _w in w] outputs.setdefault(name, []).append(att_w) elif isinstance( module, ( AttAdd, AttDot, AttForward, AttForwardTA, AttLoc, NoAtt, ), ): c, w = output att_w = w.detach().cpu() outputs.setdefault(name, []).append(att_w) handle = modu.register_forward_hook(hook) handles[name] = handle # 2. Just forward one by one sample. # Batch-mode can't be used to keep requirements small for each models. keys = [] for k in batch: if not k.endswith("_lengths"): keys.append(k) return_dict = defaultdict(list) for ibatch in range(bs): # *: (B, L, ...) -> (1, L2, ...) _sample = { k: batch[k][ibatch, None, : batch[k + "_lengths"][ibatch]] if k + "_lengths" in batch else batch[k][ibatch, None] for k in keys } # *_lengths: (B,) -> (1,) _sample.update( { k + "_lengths": batch[k + "_lengths"][ibatch, None] for k in keys if k + "_lengths" in batch } ) model(**_sample) # Derive the attention results for name, output in outputs.items(): if isinstance(output, list): if isinstance(output[0], list): # output: nhead x (Tout, Tin) output = torch.stack( [ # Tout x (1, Tin) -> (Tout, Tin) torch.cat([o[idx] for o in output], dim=0) for idx in range(len(output[0])) ], dim=0, ) else: # Tout x (1, Tin) -> (Tout, Tin) output = torch.cat(output, dim=0) else: # output: (1, NHead, Tout, Tin) -> (NHead, Tout, Tin) output = output.squeeze(0) # output: (Tout, Tin) or (NHead, Tout, Tin) return_dict[name].append(output) outputs.clear() # 3. Remove all hooks for _, handle in handles.items(): handle.remove() return dict(return_dict)