File size: 6,076 Bytes
d5175d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
from typing import Any, Dict
from fairseq import checkpoint_utils
from fairseq.data.legacy.masked_lm_dictionary import MaskedLMDictionary
from fairseq.models import register_model, register_model_architecture
from fairseq.models.transformer import (
TransformerDecoder,
TransformerEncoder,
TransformerModel,
base_architecture as transformer_base_architecture,
)
@register_model("transformer_from_pretrained_xlm")
class TransformerFromPretrainedXLMModel(TransformerModel):
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
TransformerModel.add_args(parser)
parser.add_argument(
"--pretrained-xlm-checkpoint",
type=str,
metavar="STR",
help="XLM model to use for initializing transformer encoder and/or decoder",
)
parser.add_argument(
"--init-encoder-only",
action="store_true",
help="if set, don't load the XLM weights and embeddings into decoder",
)
parser.add_argument(
"--init-decoder-only",
action="store_true",
help="if set, don't load the XLM weights and embeddings into encoder",
)
@classmethod
def build_model(self, args, task, cls_dictionary=MaskedLMDictionary):
assert hasattr(args, "pretrained_xlm_checkpoint"), (
"You must specify a path for --pretrained-xlm-checkpoint to use "
"--arch transformer_from_pretrained_xlm"
)
assert isinstance(task.source_dictionary, cls_dictionary) and isinstance(
task.target_dictionary, cls_dictionary
), (
"You should use a MaskedLMDictionary when using --arch "
"transformer_from_pretrained_xlm because the pretrained XLM model "
"was trained using data binarized with MaskedLMDictionary. "
"For translation, you may want to use --task "
"translation_from_pretrained_xlm"
)
assert not (
getattr(args, "init_encoder_only", False)
and getattr(args, "init_decoder_only", False)
), "Only one of --init-encoder-only and --init-decoder-only can be set."
return super().build_model(args, task)
@classmethod
def build_encoder(cls, args, src_dict, embed_tokens):
return TransformerEncoderFromPretrainedXLM(args, src_dict, embed_tokens)
@classmethod
def build_decoder(cls, args, tgt_dict, embed_tokens):
return TransformerDecoderFromPretrainedXLM(args, tgt_dict, embed_tokens)
def upgrade_state_dict_with_xlm_weights(
state_dict: Dict[str, Any], pretrained_xlm_checkpoint: str
) -> Dict[str, Any]:
"""
Load XLM weights into a Transformer encoder or decoder model.
Args:
state_dict: state dict for either TransformerEncoder or
TransformerDecoder
pretrained_xlm_checkpoint: checkpoint to load XLM weights from
Raises:
AssertionError: If architecture (num layers, attention heads, etc.)
does not match between the current Transformer encoder or
decoder and the pretrained_xlm_checkpoint
"""
if not os.path.exists(pretrained_xlm_checkpoint):
raise IOError("Model file not found: {}".format(pretrained_xlm_checkpoint))
state = checkpoint_utils.load_checkpoint_to_cpu(pretrained_xlm_checkpoint)
xlm_state_dict = state["model"]
for key in xlm_state_dict.keys():
for search_key in ["embed_tokens", "embed_positions", "layers"]:
if search_key in key:
subkey = key[key.find(search_key) :]
assert subkey in state_dict, (
"{} Transformer encoder / decoder "
"state_dict does not contain {}. Cannot "
"load {} from pretrained XLM checkpoint "
"{} into Transformer.".format(
str(state_dict.keys()), subkey, key, pretrained_xlm_checkpoint
)
)
state_dict[subkey] = xlm_state_dict[key]
return state_dict
class TransformerEncoderFromPretrainedXLM(TransformerEncoder):
def __init__(self, args, dictionary, embed_tokens):
super().__init__(args, dictionary, embed_tokens)
if getattr(args, "init_decoder_only", False):
# Don't load XLM weights for encoder if --init-decoder-only
return
assert hasattr(args, "pretrained_xlm_checkpoint"), (
"--pretrained-xlm-checkpoint must be specified to load Transformer "
"encoder from pretrained XLM"
)
xlm_loaded_state_dict = upgrade_state_dict_with_xlm_weights(
state_dict=self.state_dict(),
pretrained_xlm_checkpoint=args.pretrained_xlm_checkpoint,
)
self.load_state_dict(xlm_loaded_state_dict, strict=True)
class TransformerDecoderFromPretrainedXLM(TransformerDecoder):
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
super().__init__(args, dictionary, embed_tokens, no_encoder_attn)
if getattr(args, "init_encoder_only", False):
# Don't load XLM weights for decoder if --init-encoder-only
return
assert hasattr(args, "pretrained_xlm_checkpoint"), (
"--pretrained-xlm-checkpoint must be specified to load Transformer "
"decoder from pretrained XLM"
)
xlm_loaded_state_dict = upgrade_state_dict_with_xlm_weights(
state_dict=self.state_dict(),
pretrained_xlm_checkpoint=args.pretrained_xlm_checkpoint,
)
self.load_state_dict(xlm_loaded_state_dict, strict=True)
@register_model_architecture(
"transformer_from_pretrained_xlm", "transformer_from_pretrained_xlm"
)
def base_architecture(args):
transformer_base_architecture(args)
|