Spaces:
Runtime error
Runtime error
"""Base class for encoders and generic multi encoders.""" | |
import torch.nn as nn | |
from .misc import aeq | |
class EncoderBase(nn.Module): | |
""" | |
Base encoder class. Specifies the interface used by different encoder types | |
and required by :class:`onmt.Models.NMTModel`. | |
.. mermaid:: | |
graph BT | |
A[Input] | |
subgraph RNN | |
C[Pos 1] | |
D[Pos 2] | |
E[Pos N] | |
end | |
F[Memory_Bank] | |
G[Final] | |
A-->C | |
A-->D | |
A-->E | |
C-->F | |
D-->F | |
E-->F | |
E-->G | |
""" | |
def from_opt(cls, opt, embeddings=None): | |
raise NotImplementedError | |
def _check_args(self, src, lengths=None, hidden=None): | |
n_batch = src.size(1) | |
if lengths is not None: | |
n_batch_, = lengths.size() | |
aeq(n_batch, n_batch_) | |
def forward(self, src, lengths=None): | |
""" | |
Args: | |
src (LongTensor): | |
padded sequences of sparse indices ``(src_len, batch, nfeat)`` | |
lengths (LongTensor): length of each sequence ``(batch,)`` | |
Returns: | |
(FloatTensor, FloatTensor): | |
* final encoder state, used to initialize decoder | |
* memory bank for attention, ``(src_len, batch, hidden)`` | |
""" | |
raise NotImplementedError | |