|
|
|
|
|
|
|
|
|
|
|
from typing import Dict, List, NamedTuple, Optional |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch import Tensor |
|
|
|
|
|
EncoderOut = NamedTuple( |
|
"EncoderOut", |
|
[ |
|
("encoder_out", Tensor), |
|
("encoder_padding_mask", Optional[Tensor]), |
|
("encoder_embedding", Optional[Tensor]), |
|
("encoder_states", Optional[List[Tensor]]), |
|
("src_tokens", Optional[Tensor]), |
|
("src_lengths", Optional[Tensor]), |
|
], |
|
) |
|
|
|
|
|
class FairseqEncoder(nn.Module): |
|
"""Base class for encoders.""" |
|
|
|
def __init__(self, dictionary): |
|
super().__init__() |
|
self.dictionary = dictionary |
|
|
|
def forward(self, src_tokens, src_lengths=None, **kwargs): |
|
""" |
|
Args: |
|
src_tokens (LongTensor): tokens in the source language of shape |
|
`(batch, src_len)` |
|
src_lengths (LongTensor): lengths of each source sentence of shape |
|
`(batch)` |
|
""" |
|
raise NotImplementedError |
|
|
|
def forward_torchscript(self, net_input: Dict[str, Tensor]): |
|
"""A TorchScript-compatible version of forward. |
|
|
|
Encoders which use additional arguments may want to override |
|
this method for TorchScript compatibility. |
|
""" |
|
if torch.jit.is_scripting(): |
|
return self.forward( |
|
src_tokens=net_input["src_tokens"], |
|
src_lengths=net_input["src_lengths"], |
|
) |
|
else: |
|
return self.forward_non_torchscript(net_input) |
|
|
|
@torch.jit.unused |
|
def forward_non_torchscript(self, net_input: Dict[str, Tensor]): |
|
encoder_input = { |
|
k: v for k, v in net_input.items() if k != "prev_output_tokens" |
|
} |
|
return self.forward(**encoder_input) |
|
|
|
def reorder_encoder_out(self, encoder_out, new_order): |
|
""" |
|
Reorder encoder output according to `new_order`. |
|
|
|
Args: |
|
encoder_out: output from the ``forward()`` method |
|
new_order (LongTensor): desired order |
|
|
|
Returns: |
|
`encoder_out` rearranged according to `new_order` |
|
""" |
|
raise NotImplementedError |
|
|
|
def max_positions(self): |
|
"""Maximum input length supported by the encoder.""" |
|
return 1e6 |
|
|
|
def upgrade_state_dict_named(self, state_dict, name): |
|
"""Upgrade old state dicts to work with newer code.""" |
|
return state_dict |
|
|
|
def set_num_updates(self, num_updates): |
|
"""State from trainer to pass along to model at every update.""" |
|
|
|
def _apply(m): |
|
if hasattr(m, "set_num_updates") and m != self: |
|
m.set_num_updates(num_updates) |
|
|
|
self.apply(_apply) |
|
|