File size: 7,569 Bytes
d5175d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""isort:skip_file"""
import argparse
import importlib
import os
from fairseq.dataclass import FairseqDataclass
from fairseq.dataclass.utils import merge_with_parent, populate_dataclass
from hydra.core.config_store import ConfigStore
from .composite_encoder import CompositeEncoder
from .distributed_fairseq_model import DistributedFairseqModel
from .fairseq_decoder import FairseqDecoder
from .fairseq_encoder import FairseqEncoder
from .fairseq_incremental_decoder import FairseqIncrementalDecoder
from .fairseq_model import (
BaseFairseqModel,
FairseqEncoderDecoderModel,
FairseqEncoderModel,
FairseqLanguageModel,
FairseqModel,
FairseqMultiModel,
)
MODEL_REGISTRY = {}
MODEL_DATACLASS_REGISTRY = {}
ARCH_MODEL_REGISTRY = {}
ARCH_MODEL_NAME_REGISTRY = {}
ARCH_MODEL_INV_REGISTRY = {}
ARCH_CONFIG_REGISTRY = {}
__all__ = [
"BaseFairseqModel",
"CompositeEncoder",
"DistributedFairseqModel",
"FairseqDecoder",
"FairseqEncoder",
"FairseqEncoderDecoderModel",
"FairseqEncoderModel",
"FairseqIncrementalDecoder",
"FairseqLanguageModel",
"FairseqModel",
"FairseqMultiModel",
]
def build_model(cfg: FairseqDataclass, task):
model = None
model_type = getattr(cfg, "_name", None) or getattr(cfg, "arch", None)
if not model_type and len(cfg) == 1:
# this is hit if config object is nested in directory that is named after model type
model_type = next(iter(cfg))
if model_type in MODEL_DATACLASS_REGISTRY:
cfg = cfg[model_type]
else:
raise Exception(
"Could not infer model type from directory. Please add _name field to indicate model type. "
"Available models: "
+ str(MODEL_DATACLASS_REGISTRY.keys())
+ " Requested model type: "
+ model_type
)
if model_type in ARCH_MODEL_REGISTRY:
# case 1: legacy models
model = ARCH_MODEL_REGISTRY[model_type]
elif model_type in MODEL_DATACLASS_REGISTRY:
# case 2: config-driven models
model = MODEL_REGISTRY[model_type]
if model_type in MODEL_DATACLASS_REGISTRY:
# set defaults from dataclass. note that arch name and model name can be the same
dc = MODEL_DATACLASS_REGISTRY[model_type]
if isinstance(cfg, argparse.Namespace):
cfg = populate_dataclass(dc(), cfg)
else:
cfg = merge_with_parent(dc(), cfg)
assert model is not None, (
f"Could not infer model type from {cfg}. "
f"Available models: "
+ str(MODEL_DATACLASS_REGISTRY.keys())
+ " Requested model type: "
+ model_type
)
return model.build_model(cfg, task)
def register_model(name, dataclass=None):
"""
New model types can be added to fairseq with the :func:`register_model`
function decorator.
For example::
@register_model('lstm')
class LSTM(FairseqEncoderDecoderModel):
(...)
.. note:: All models must implement the :class:`BaseFairseqModel` interface.
Typically you will extend :class:`FairseqEncoderDecoderModel` for
sequence-to-sequence tasks or :class:`FairseqLanguageModel` for
language modeling tasks.
Args:
name (str): the name of the model
"""
def register_model_cls(cls):
if name in MODEL_REGISTRY:
raise ValueError("Cannot register duplicate model ({})".format(name))
if not issubclass(cls, BaseFairseqModel):
raise ValueError(
"Model ({}: {}) must extend BaseFairseqModel".format(name, cls.__name__)
)
MODEL_REGISTRY[name] = cls
if dataclass is not None and not issubclass(dataclass, FairseqDataclass):
raise ValueError(
"Dataclass {} must extend FairseqDataclass".format(dataclass)
)
cls.__dataclass = dataclass
if dataclass is not None:
MODEL_DATACLASS_REGISTRY[name] = dataclass
cs = ConfigStore.instance()
node = dataclass()
node._name = name
cs.store(name=name, group="model", node=node, provider="fairseq")
@register_model_architecture(name, name)
def noop(_):
pass
return cls
return register_model_cls
def register_model_architecture(model_name, arch_name):
"""
New model architectures can be added to fairseq with the
:func:`register_model_architecture` function decorator. After registration,
model architectures can be selected with the ``--arch`` command-line
argument.
For example::
@register_model_architecture('lstm', 'lstm_luong_wmt_en_de')
def lstm_luong_wmt_en_de(cfg):
args.encoder_embed_dim = getattr(cfg.model, 'encoder_embed_dim', 1000)
(...)
The decorated function should take a single argument *cfg*, which is a
:class:`omegaconf.DictConfig`. The decorated function should modify these
arguments in-place to match the desired architecture.
Args:
model_name (str): the name of the Model (Model must already be
registered)
arch_name (str): the name of the model architecture (``--arch``)
"""
def register_model_arch_fn(fn):
if model_name not in MODEL_REGISTRY:
raise ValueError(
"Cannot register model architecture for unknown model type ({})".format(
model_name
)
)
if arch_name in ARCH_MODEL_REGISTRY:
raise ValueError(
"Cannot register duplicate model architecture ({})".format(arch_name)
)
if not callable(fn):
raise ValueError(
"Model architecture must be callable ({})".format(arch_name)
)
ARCH_MODEL_REGISTRY[arch_name] = MODEL_REGISTRY[model_name]
ARCH_MODEL_NAME_REGISTRY[arch_name] = model_name
ARCH_MODEL_INV_REGISTRY.setdefault(model_name, []).append(arch_name)
ARCH_CONFIG_REGISTRY[arch_name] = fn
return fn
return register_model_arch_fn
def import_models(models_dir, namespace):
for file in os.listdir(models_dir):
path = os.path.join(models_dir, file)
if (
not file.startswith("_")
and not file.startswith(".")
and (file.endswith(".py") or os.path.isdir(path))
):
model_name = file[: file.find(".py")] if file.endswith(".py") else file
importlib.import_module(namespace + "." + model_name)
# extra `model_parser` for sphinx
if model_name in MODEL_REGISTRY:
parser = argparse.ArgumentParser(add_help=False)
group_archs = parser.add_argument_group("Named architectures")
group_archs.add_argument(
"--arch", choices=ARCH_MODEL_INV_REGISTRY[model_name]
)
group_args = parser.add_argument_group(
"Additional command-line arguments"
)
MODEL_REGISTRY[model_name].add_args(group_args)
globals()[model_name + "_parser"] = parser
# automatically import any Python files in the models/ directory
models_dir = os.path.dirname(__file__)
import_models(models_dir, "fairseq.models")
|