|
|
|
|
|
|
|
|
|
|
|
import os |
|
from typing import Any, Dict |
|
|
|
from fairseq import checkpoint_utils |
|
from fairseq.data.legacy.masked_lm_dictionary import MaskedLMDictionary |
|
from fairseq.models import register_model, register_model_architecture |
|
from fairseq.models.transformer import ( |
|
TransformerDecoder, |
|
TransformerEncoder, |
|
TransformerModel, |
|
base_architecture as transformer_base_architecture, |
|
) |
|
|
|
|
|
@register_model("transformer_from_pretrained_xlm") |
|
class TransformerFromPretrainedXLMModel(TransformerModel): |
|
@staticmethod |
|
def add_args(parser): |
|
"""Add model-specific arguments to the parser.""" |
|
TransformerModel.add_args(parser) |
|
parser.add_argument( |
|
"--pretrained-xlm-checkpoint", |
|
type=str, |
|
metavar="STR", |
|
help="XLM model to use for initializing transformer encoder and/or decoder", |
|
) |
|
parser.add_argument( |
|
"--init-encoder-only", |
|
action="store_true", |
|
help="if set, don't load the XLM weights and embeddings into decoder", |
|
) |
|
parser.add_argument( |
|
"--init-decoder-only", |
|
action="store_true", |
|
help="if set, don't load the XLM weights and embeddings into encoder", |
|
) |
|
|
|
@classmethod |
|
def build_model(self, args, task, cls_dictionary=MaskedLMDictionary): |
|
assert hasattr(args, "pretrained_xlm_checkpoint"), ( |
|
"You must specify a path for --pretrained-xlm-checkpoint to use " |
|
"--arch transformer_from_pretrained_xlm" |
|
) |
|
assert isinstance(task.source_dictionary, cls_dictionary) and isinstance( |
|
task.target_dictionary, cls_dictionary |
|
), ( |
|
"You should use a MaskedLMDictionary when using --arch " |
|
"transformer_from_pretrained_xlm because the pretrained XLM model " |
|
"was trained using data binarized with MaskedLMDictionary. " |
|
"For translation, you may want to use --task " |
|
"translation_from_pretrained_xlm" |
|
) |
|
assert not ( |
|
getattr(args, "init_encoder_only", False) |
|
and getattr(args, "init_decoder_only", False) |
|
), "Only one of --init-encoder-only and --init-decoder-only can be set." |
|
return super().build_model(args, task) |
|
|
|
@classmethod |
|
def build_encoder(cls, args, src_dict, embed_tokens): |
|
return TransformerEncoderFromPretrainedXLM(args, src_dict, embed_tokens) |
|
|
|
@classmethod |
|
def build_decoder(cls, args, tgt_dict, embed_tokens): |
|
return TransformerDecoderFromPretrainedXLM(args, tgt_dict, embed_tokens) |
|
|
|
|
|
def upgrade_state_dict_with_xlm_weights( |
|
state_dict: Dict[str, Any], pretrained_xlm_checkpoint: str |
|
) -> Dict[str, Any]: |
|
""" |
|
Load XLM weights into a Transformer encoder or decoder model. |
|
|
|
Args: |
|
state_dict: state dict for either TransformerEncoder or |
|
TransformerDecoder |
|
pretrained_xlm_checkpoint: checkpoint to load XLM weights from |
|
|
|
Raises: |
|
AssertionError: If architecture (num layers, attention heads, etc.) |
|
does not match between the current Transformer encoder or |
|
decoder and the pretrained_xlm_checkpoint |
|
""" |
|
if not os.path.exists(pretrained_xlm_checkpoint): |
|
raise IOError("Model file not found: {}".format(pretrained_xlm_checkpoint)) |
|
|
|
state = checkpoint_utils.load_checkpoint_to_cpu(pretrained_xlm_checkpoint) |
|
xlm_state_dict = state["model"] |
|
for key in xlm_state_dict.keys(): |
|
|
|
for search_key in ["embed_tokens", "embed_positions", "layers"]: |
|
if search_key in key: |
|
subkey = key[key.find(search_key) :] |
|
assert subkey in state_dict, ( |
|
"{} Transformer encoder / decoder " |
|
"state_dict does not contain {}. Cannot " |
|
"load {} from pretrained XLM checkpoint " |
|
"{} into Transformer.".format( |
|
str(state_dict.keys()), subkey, key, pretrained_xlm_checkpoint |
|
) |
|
) |
|
|
|
state_dict[subkey] = xlm_state_dict[key] |
|
return state_dict |
|
|
|
|
|
class TransformerEncoderFromPretrainedXLM(TransformerEncoder): |
|
def __init__(self, args, dictionary, embed_tokens): |
|
super().__init__(args, dictionary, embed_tokens) |
|
if getattr(args, "init_decoder_only", False): |
|
|
|
return |
|
|
|
assert hasattr(args, "pretrained_xlm_checkpoint"), ( |
|
"--pretrained-xlm-checkpoint must be specified to load Transformer " |
|
"encoder from pretrained XLM" |
|
) |
|
xlm_loaded_state_dict = upgrade_state_dict_with_xlm_weights( |
|
state_dict=self.state_dict(), |
|
pretrained_xlm_checkpoint=args.pretrained_xlm_checkpoint, |
|
) |
|
self.load_state_dict(xlm_loaded_state_dict, strict=True) |
|
|
|
|
|
class TransformerDecoderFromPretrainedXLM(TransformerDecoder): |
|
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): |
|
super().__init__(args, dictionary, embed_tokens, no_encoder_attn) |
|
if getattr(args, "init_encoder_only", False): |
|
|
|
return |
|
assert hasattr(args, "pretrained_xlm_checkpoint"), ( |
|
"--pretrained-xlm-checkpoint must be specified to load Transformer " |
|
"decoder from pretrained XLM" |
|
) |
|
|
|
xlm_loaded_state_dict = upgrade_state_dict_with_xlm_weights( |
|
state_dict=self.state_dict(), |
|
pretrained_xlm_checkpoint=args.pretrained_xlm_checkpoint, |
|
) |
|
self.load_state_dict(xlm_loaded_state_dict, strict=True) |
|
|
|
|
|
@register_model_architecture( |
|
"transformer_from_pretrained_xlm", "transformer_from_pretrained_xlm" |
|
) |
|
def base_architecture(args): |
|
transformer_base_architecture(args) |
|
|