Spaces:
Running
Running
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
""" | |
Base classes for various fairseq models. | |
""" | |
import logging | |
from argparse import Namespace | |
from typing import Dict, List, Optional, Tuple | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from fairseq import utils | |
from fairseq.data import Dictionary | |
from fairseq.dataclass.utils import ( | |
convert_namespace_to_omegaconf, | |
gen_parser_from_dataclass, | |
) | |
from fairseq.models import FairseqDecoder, FairseqEncoder | |
from omegaconf import DictConfig | |
from torch import Tensor | |
logger = logging.getLogger(__name__) | |
def check_type(module, expected_type): | |
if hasattr(module, "unwrapped_module"): | |
assert isinstance( | |
module.unwrapped_module, expected_type | |
), f"{type(module.unwrapped_module)} != {expected_type}" | |
else: | |
assert isinstance(module, expected_type), f"{type(module)} != {expected_type}" | |
class BaseFairseqModel(nn.Module): | |
"""Base class for fairseq models.""" | |
def __init__(self): | |
super().__init__() | |
self._is_generation_fast = False | |
def add_args(cls, parser): | |
"""Add model-specific arguments to the parser.""" | |
dc = getattr(cls, "__dataclass", None) | |
if dc is not None: | |
# do not set defaults so that settings defaults from various architectures still works | |
gen_parser_from_dataclass(parser, dc(), delete_default=True) | |
def build_model(cls, args, task): | |
"""Build a new model instance.""" | |
raise NotImplementedError("Model must implement the build_model method") | |
def get_targets(self, sample, net_output): | |
"""Get targets from either the sample or the net's output.""" | |
return sample["target"] | |
def get_normalized_probs( | |
self, | |
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], | |
log_probs: bool, | |
sample: Optional[Dict[str, Tensor]] = None, | |
): | |
"""Get normalized probabilities (or log probs) from a net's output.""" | |
return self.get_normalized_probs_scriptable(net_output, log_probs, sample) | |
# TorchScript doesn't support super() method so that the scriptable Subclass | |
# can't access the base class model in Torchscript. | |
# Current workaround is to add a helper function with different name and | |
# call the helper function from scriptable Subclass. | |
def get_normalized_probs_scriptable( | |
self, | |
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], | |
log_probs: bool, | |
sample: Optional[Dict[str, Tensor]] = None, | |
): | |
"""Scriptable helper function for get_normalized_probs in ~BaseFairseqModel""" | |
if hasattr(self, "decoder"): | |
return self.decoder.get_normalized_probs(net_output, log_probs, sample) | |
elif torch.is_tensor(net_output): | |
# syntactic sugar for simple models which don't have a decoder | |
# (e.g., the classification tutorial) | |
logits = net_output.float() | |
if log_probs: | |
return F.log_softmax(logits, dim=-1) | |
else: | |
return F.softmax(logits, dim=-1) | |
raise NotImplementedError | |
def extract_features(self, *args, **kwargs): | |
"""Similar to *forward* but only return features.""" | |
return self(*args, **kwargs) | |
def max_positions(self): | |
"""Maximum length supported by the model.""" | |
return None | |
def load_state_dict( | |
self, | |
state_dict, | |
strict=True, | |
model_cfg: Optional[DictConfig] = None, | |
args: Optional[Namespace] = None, | |
): | |
"""Copies parameters and buffers from *state_dict* into this module and | |
its descendants. | |
Overrides the method in :class:`nn.Module`. Compared with that method | |
this additionally "upgrades" *state_dicts* from old checkpoints. | |
""" | |
if model_cfg is None and args is not None: | |
logger.warn( | |
"using 'args' is deprecated, please update your code to use dataclass config" | |
) | |
model_cfg = convert_namespace_to_omegaconf(args).model | |
self.upgrade_state_dict(state_dict) | |
from fairseq.checkpoint_utils import prune_state_dict | |
new_state_dict = prune_state_dict(state_dict, model_cfg) | |
return super().load_state_dict(new_state_dict, strict) | |
def upgrade_state_dict(self, state_dict): | |
"""Upgrade old state dicts to work with newer code.""" | |
self.upgrade_state_dict_named(state_dict, "") | |
def upgrade_state_dict_named(self, state_dict, name): | |
"""Upgrade old state dicts to work with newer code. | |
Args: | |
state_dict (dict): state dictionary to upgrade, in place | |
name (str): the state dict key corresponding to the current module | |
""" | |
assert state_dict is not None | |
def do_upgrade(m, prefix): | |
if len(prefix) > 0: | |
prefix += "." | |
for n, c in m.named_children(): | |
name = prefix + n | |
if hasattr(c, "upgrade_state_dict_named"): | |
c.upgrade_state_dict_named(state_dict, name) | |
elif hasattr(c, "upgrade_state_dict"): | |
c.upgrade_state_dict(state_dict) | |
do_upgrade(c, name) | |
do_upgrade(self, name) | |
def set_num_updates(self, num_updates): | |
"""State from trainer to pass along to model at every update.""" | |
for m in self.modules(): | |
if hasattr(m, "set_num_updates") and m != self: | |
m.set_num_updates(num_updates) | |
def prepare_for_inference_(self, cfg: DictConfig): | |
"""Prepare model for inference.""" | |
kwargs = {} | |
kwargs["beamable_mm_beam_size"] = ( | |
None | |
if getattr(cfg.generation, "no_beamable_mm", False) | |
else getattr(cfg.generation, "beam", 5) | |
) | |
kwargs["need_attn"] = getattr(cfg.generation, "print_alignment", False) | |
if getattr(cfg.generation, "retain_dropout", False): | |
kwargs["retain_dropout"] = cfg.generation.retain_dropout | |
kwargs["retain_dropout_modules"] = cfg.generation.retain_dropout_modules | |
self.make_generation_fast_(**kwargs) | |
def make_generation_fast_(self, **kwargs): | |
""" | |
Legacy entry point to optimize model for faster generation. | |
Prefer prepare_for_inference_. | |
""" | |
if self._is_generation_fast: | |
return # only apply once | |
self._is_generation_fast = True | |
# remove weight norm from all modules in the network | |
def apply_remove_weight_norm(module): | |
try: | |
nn.utils.remove_weight_norm(module) | |
except (AttributeError, ValueError): # this module didn't have weight norm | |
return | |
self.apply(apply_remove_weight_norm) | |
def apply_make_generation_fast_(module, prefix): | |
if len(prefix) > 0: | |
prefix += "." | |
base_func = BaseFairseqModel.make_generation_fast_ | |
for n, m in module.named_modules(): | |
if ( | |
m != self | |
and hasattr(m, "make_generation_fast_") | |
# don't call this implementation again, e.g., if | |
# children modules also inherit from BaseFairseqModel | |
and m.make_generation_fast_.__func__ is not base_func | |
): | |
name = prefix + n | |
m.make_generation_fast_(name=name, **kwargs) | |
apply_make_generation_fast_(self, "") | |
def train(mode=True): | |
if mode: | |
raise RuntimeError("cannot train after make_generation_fast") | |
# this model should no longer be used for training | |
self.eval() | |
self.train = train | |
def prepare_for_onnx_export_(self, **kwargs): | |
"""Make model exportable via ONNX trace.""" | |
seen = set() | |
def apply_prepare_for_onnx_export_(module): | |
if ( | |
module != self | |
and hasattr(module, "prepare_for_onnx_export_") | |
and module not in seen | |
): | |
seen.add(module) | |
module.prepare_for_onnx_export_(**kwargs) | |
self.apply(apply_prepare_for_onnx_export_) | |
def from_pretrained( | |
cls, | |
model_name_or_path, | |
checkpoint_file="model.pt", | |
data_name_or_path=".", | |
**kwargs, | |
): | |
""" | |
Load a :class:`~fairseq.models.FairseqModel` from a pre-trained model | |
file. Downloads and caches the pre-trained model file if needed. | |
The base implementation returns a | |
:class:`~fairseq.hub_utils.GeneratorHubInterface`, which can be used to | |
generate translations or sample from language models. The underlying | |
:class:`~fairseq.models.FairseqModel` can be accessed via the | |
*generator.models* attribute. | |
Other models may override this to implement custom hub interfaces. | |
Args: | |
model_name_or_path (str): either the name of a pre-trained model to | |
load or a path/URL to a pre-trained model state dict | |
checkpoint_file (str, optional): colon-separated list of checkpoint | |
files in the model archive to ensemble (default: 'model.pt') | |
data_name_or_path (str, optional): point args.data to the archive | |
at the given path/URL. Can start with '.' or './' to reuse the | |
model archive path. | |
""" | |
from fairseq import hub_utils | |
x = hub_utils.from_pretrained( | |
model_name_or_path, | |
checkpoint_file, | |
data_name_or_path, | |
archive_map=cls.hub_models(), | |
**kwargs, | |
) | |
logger.info(x["args"]) | |
return hub_utils.GeneratorHubInterface(x["args"], x["task"], x["models"]) | |
def hub_models(cls): | |
return {} | |
class FairseqEncoderDecoderModel(BaseFairseqModel): | |
"""Base class for encoder-decoder models. | |
Args: | |
encoder (FairseqEncoder): the encoder | |
decoder (FairseqDecoder): the decoder | |
""" | |
def __init__(self, encoder, decoder): | |
super().__init__() | |
self.encoder = encoder | |
self.decoder = decoder | |
check_type(self.encoder, FairseqEncoder) | |
check_type(self.decoder, FairseqDecoder) | |
def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs): | |
""" | |
Run the forward pass for an encoder-decoder model. | |
First feed a batch of source tokens through the encoder. Then, feed the | |
encoder output and previous decoder outputs (i.e., teacher forcing) to | |
the decoder to produce the next outputs:: | |
encoder_out = self.encoder(src_tokens, src_lengths) | |
return self.decoder(prev_output_tokens, encoder_out) | |
Args: | |
src_tokens (LongTensor): tokens in the source language of shape | |
`(batch, src_len)` | |
src_lengths (LongTensor): source sentence lengths of shape `(batch)` | |
prev_output_tokens (LongTensor): previous decoder outputs of shape | |
`(batch, tgt_len)`, for teacher forcing | |
Returns: | |
tuple: | |
- the decoder's output of shape `(batch, tgt_len, vocab)` | |
- a dictionary with any model-specific outputs | |
""" | |
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs) | |
decoder_out = self.decoder( | |
prev_output_tokens, encoder_out=encoder_out, **kwargs | |
) | |
return decoder_out | |
def forward_decoder(self, prev_output_tokens, **kwargs): | |
return self.decoder(prev_output_tokens, **kwargs) | |
def extract_features(self, src_tokens, src_lengths, prev_output_tokens, **kwargs): | |
""" | |
Similar to *forward* but only return features. | |
Returns: | |
tuple: | |
- the decoder's features of shape `(batch, tgt_len, embed_dim)` | |
- a dictionary with any model-specific outputs | |
""" | |
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs) | |
features = self.decoder.extract_features( | |
prev_output_tokens, encoder_out=encoder_out, **kwargs | |
) | |
return features | |
def output_layer(self, features, **kwargs): | |
"""Project features to the default output size (typically vocabulary size).""" | |
return self.decoder.output_layer(features, **kwargs) | |
def max_positions(self): | |
"""Maximum length supported by the model.""" | |
return (self.encoder.max_positions(), self.decoder.max_positions()) | |
def max_decoder_positions(self): | |
"""Maximum length supported by the decoder.""" | |
return self.decoder.max_positions() | |
class FairseqModel(FairseqEncoderDecoderModel): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
utils.deprecation_warning( | |
"FairseqModel is deprecated, please use FairseqEncoderDecoderModel " | |
"or BaseFairseqModel instead", | |
stacklevel=4, | |
) | |
class FairseqMultiModel(BaseFairseqModel): | |
"""Base class for combining multiple encoder-decoder models.""" | |
def __init__(self, encoders, decoders): | |
super().__init__() | |
assert encoders.keys() == decoders.keys() | |
self.keys = list(encoders.keys()) | |
for key in self.keys: | |
check_type(encoders[key], FairseqEncoder) | |
check_type(decoders[key], FairseqDecoder) | |
self.models = nn.ModuleDict( | |
{ | |
key: FairseqEncoderDecoderModel(encoders[key], decoders[key]) | |
for key in self.keys | |
} | |
) | |
def build_shared_embeddings( | |
dicts: Dict[str, Dictionary], | |
langs: List[str], | |
embed_dim: int, | |
build_embedding: callable, | |
pretrained_embed_path: Optional[str] = None, | |
): | |
""" | |
Helper function to build shared embeddings for a set of languages after | |
checking that all dicts corresponding to those languages are equivalent. | |
Args: | |
dicts: Dict of lang_id to its corresponding Dictionary | |
langs: languages that we want to share embeddings for | |
embed_dim: embedding dimension | |
build_embedding: callable function to actually build the embedding | |
pretrained_embed_path: Optional path to load pretrained embeddings | |
""" | |
shared_dict = dicts[langs[0]] | |
if any(dicts[lang] != shared_dict for lang in langs): | |
raise ValueError( | |
"--share-*-embeddings requires a joined dictionary: " | |
"--share-encoder-embeddings requires a joined source " | |
"dictionary, --share-decoder-embeddings requires a joined " | |
"target dictionary, and --share-all-embeddings requires a " | |
"joint source + target dictionary." | |
) | |
return build_embedding(shared_dict, embed_dim, pretrained_embed_path) | |
def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs): | |
raise NotImplementedError | |
def max_positions(self): | |
"""Maximum length supported by the model.""" | |
return { | |
key: ( | |
self.models[key].encoder.max_positions(), | |
self.models[key].decoder.max_positions(), | |
) | |
for key in self.keys | |
} | |
def max_decoder_positions(self): | |
"""Maximum length supported by the decoder.""" | |
return min(model.decoder.max_positions() for model in self.models.values()) | |
def encoder(self): | |
return self.models[self.keys[0]].encoder | |
def decoder(self): | |
return self.models[self.keys[0]].decoder | |
def forward_decoder(self, prev_output_tokens, **kwargs): | |
return self.decoder(prev_output_tokens, **kwargs) | |
def load_state_dict( | |
self, | |
state_dict, | |
strict=True, | |
model_cfg=None, | |
args: Optional[Namespace] = None, | |
): | |
"""Copies parameters and buffers from *state_dict* into this module and | |
its descendants. | |
Overrides the method in :class:`nn.Module`. Compared with that method | |
this additionally "upgrades" *state_dicts* from old checkpoints. | |
""" | |
if model_cfg is None and args is not None: | |
logger.warn( | |
"using 'args' is deprecated, please update your code to use dataclass config" | |
) | |
model_cfg = convert_namespace_to_omegaconf(args).model | |
self.upgrade_state_dict(state_dict) | |
from fairseq.checkpoint_utils import prune_state_dict | |
new_state_dict = prune_state_dict(state_dict, model_cfg) | |
return super().load_state_dict(new_state_dict, strict) | |
class FairseqLanguageModel(BaseFairseqModel): | |
"""Base class for decoder-only models. | |
Args: | |
decoder (FairseqDecoder): the decoder | |
""" | |
def __init__(self, decoder): | |
super().__init__() | |
self.decoder = decoder | |
check_type(self.decoder, FairseqDecoder) | |
def forward(self, src_tokens, **kwargs): | |
""" | |
Run the forward pass for a decoder-only model. | |
Feeds a batch of tokens through the decoder to predict the next tokens. | |
Args: | |
src_tokens (LongTensor): tokens on which to condition the decoder, | |
of shape `(batch, tgt_len)` | |
src_lengths (LongTensor): source sentence lengths of shape `(batch)` | |
Returns: | |
tuple: | |
- the decoder's output of shape `(batch, seq_len, vocab)` | |
- a dictionary with any model-specific outputs | |
""" | |
return self.decoder(src_tokens, **kwargs) | |
def forward_decoder(self, prev_output_tokens, **kwargs): | |
return self.decoder(prev_output_tokens, **kwargs) | |
def extract_features(self, src_tokens, **kwargs): | |
""" | |
Similar to *forward* but only return features. | |
Returns: | |
tuple: | |
- the decoder's features of shape `(batch, seq_len, embed_dim)` | |
- a dictionary with any model-specific outputs | |
""" | |
return self.decoder.extract_features(src_tokens, **kwargs) | |
def output_layer(self, features, **kwargs): | |
"""Project features to the default output size (typically vocabulary size).""" | |
return self.decoder.output_layer(features, **kwargs) | |
def max_positions(self): | |
"""Maximum length supported by the model.""" | |
return self.decoder.max_positions() | |
def max_decoder_positions(self): | |
"""Maximum length supported by the decoder.""" | |
return self.decoder.max_positions() | |
def supported_targets(self): | |
return {"future"} | |
class FairseqEncoderModel(BaseFairseqModel): | |
"""Base class for encoder-only models. | |
Args: | |
encoder (FairseqEncoder): the encoder | |
""" | |
def __init__(self, encoder): | |
super().__init__() | |
self.encoder = encoder | |
check_type(self.encoder, FairseqEncoder) | |
def forward(self, src_tokens, src_lengths, **kwargs): | |
""" | |
Run the forward pass for a encoder-only model. | |
Feeds a batch of tokens through the encoder to generate features. | |
Args: | |
src_tokens (LongTensor): input tokens of shape `(batch, src_len)` | |
src_lengths (LongTensor): source sentence lengths of shape `(batch)` | |
Returns: | |
the encoder's output, typically of shape `(batch, src_len, features)` | |
""" | |
return self.encoder(src_tokens, src_lengths, **kwargs) | |
def get_normalized_probs(self, net_output, log_probs, sample=None): | |
"""Get normalized probabilities (or log probs) from a net's output.""" | |
encoder_out = net_output["encoder_out"] | |
if torch.is_tensor(encoder_out): | |
logits = encoder_out.float() | |
if log_probs: | |
return F.log_softmax(logits, dim=-1) | |
else: | |
return F.softmax(logits, dim=-1) | |
raise NotImplementedError | |
def max_positions(self): | |
"""Maximum length supported by the model.""" | |
return self.encoder.max_positions() | |