# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import logging import torch import torch.nn as nn import torch.nn.functional as F from fairseq import utils from fairseq.models import ( FairseqEncoder, FairseqEncoderModel, register_model, register_model_architecture, ) from fairseq.modules import ( LayerNorm, SinusoidalPositionalEmbedding, TransformerSentenceEncoder, ) from fairseq.modules.transformer_sentence_encoder import init_bert_params logger = logging.getLogger(__name__) @register_model("masked_lm") class MaskedLMModel(FairseqEncoderModel): """ Class for training a Masked Language Model. It also supports an additional sentence level prediction if the sent-loss argument is set. """ def __init__(self, args, encoder): super().__init__(encoder) self.args = args # if specified then apply bert initialization on the model. We need # to explictly call this to make sure that the output embeddings # and projection layers are also correctly initialized if getattr(args, "apply_bert_init", False): self.apply(init_bert_params) @staticmethod def add_args(parser): """Add model-specific arguments to the parser.""" # Arguments related to dropout parser.add_argument( "--dropout", type=float, metavar="D", help="dropout probability" ) parser.add_argument( "--attention-dropout", type=float, metavar="D", help="dropout probability for" " attention weights", ) parser.add_argument( "--act-dropout", type=float, metavar="D", help="dropout probability after" " activation in FFN", ) # Arguments related to hidden states and self-attention parser.add_argument( "--encoder-ffn-embed-dim", type=int, metavar="N", help="encoder embedding dimension for FFN", ) parser.add_argument( "--encoder-layers", type=int, metavar="N", help="num encoder layers" ) parser.add_argument( "--encoder-attention-heads", type=int, metavar="N", help="num encoder attention heads", ) # Arguments related to input and output embeddings parser.add_argument( "--encoder-embed-dim", type=int, metavar="N", help="encoder embedding dimension", ) parser.add_argument( "--share-encoder-input-output-embed", action="store_true", help="share encoder input" " and output embeddings", ) parser.add_argument( "--encoder-learned-pos", action="store_true", help="use learned positional embeddings in the encoder", ) parser.add_argument( "--no-token-positional-embeddings", action="store_true", help="if set, disables positional embeddings" " (outside self attention)", ) parser.add_argument( "--num-segment", type=int, metavar="N", help="num segment in the input" ) parser.add_argument( "--max-positions", type=int, help="number of positional embeddings to learn" ) # Arguments related to sentence level prediction parser.add_argument( "--sentence-class-num", type=int, metavar="N", help="number of classes for sentence task", ) parser.add_argument( "--sent-loss", action="store_true", help="if set," " calculate sentence level predictions", ) # Arguments related to parameter initialization parser.add_argument( "--apply-bert-init", action="store_true", help="use custom param initialization for BERT", ) # misc params parser.add_argument( "--activation-fn", choices=utils.get_available_activation_fns(), help="activation function to use", ) parser.add_argument( "--pooler-activation-fn", choices=utils.get_available_activation_fns(), help="Which activation function to use for pooler layer.", ) parser.add_argument( "--encoder-normalize-before", action="store_true", help="apply layernorm before each encoder block", ) def forward(self, src_tokens, segment_labels=None, **kwargs): return self.encoder(src_tokens, segment_labels=segment_labels, **kwargs) def max_positions(self): return self.encoder.max_positions @classmethod def build_model(cls, args, task): """Build a new model instance.""" # make sure all arguments are present in older models base_architecture(args) if not hasattr(args, "max_positions"): args.max_positions = args.tokens_per_sample logger.info(args) encoder = MaskedLMEncoder(args, task.dictionary) return cls(args, encoder) class MaskedLMEncoder(FairseqEncoder): """ Encoder for Masked Language Modelling. """ def __init__(self, args, dictionary): super().__init__(dictionary) self.padding_idx = dictionary.pad() self.vocab_size = dictionary.__len__() self.max_positions = args.max_positions self.sentence_encoder = TransformerSentenceEncoder( padding_idx=self.padding_idx, vocab_size=self.vocab_size, num_encoder_layers=args.encoder_layers, embedding_dim=args.encoder_embed_dim, ffn_embedding_dim=args.encoder_ffn_embed_dim, num_attention_heads=args.encoder_attention_heads, dropout=args.dropout, attention_dropout=args.attention_dropout, activation_dropout=args.act_dropout, max_seq_len=self.max_positions, num_segments=args.num_segment, use_position_embeddings=not args.no_token_positional_embeddings, encoder_normalize_before=args.encoder_normalize_before, apply_bert_init=args.apply_bert_init, activation_fn=args.activation_fn, learned_pos_embedding=args.encoder_learned_pos, ) self.share_input_output_embed = args.share_encoder_input_output_embed self.embed_out = None self.sentence_projection_layer = None self.sentence_out_dim = args.sentence_class_num self.lm_output_learned_bias = None # Remove head is set to true during fine-tuning self.load_softmax = not getattr(args, "remove_head", False) self.masked_lm_pooler = nn.Linear( args.encoder_embed_dim, args.encoder_embed_dim ) self.pooler_activation = utils.get_activation_fn(args.pooler_activation_fn) self.lm_head_transform_weight = nn.Linear( args.encoder_embed_dim, args.encoder_embed_dim ) self.activation_fn = utils.get_activation_fn(args.activation_fn) self.layer_norm = LayerNorm(args.encoder_embed_dim) self.lm_output_learned_bias = None if self.load_softmax: self.lm_output_learned_bias = nn.Parameter(torch.zeros(self.vocab_size)) if not self.share_input_output_embed: self.embed_out = nn.Linear( args.encoder_embed_dim, self.vocab_size, bias=False ) if args.sent_loss: self.sentence_projection_layer = nn.Linear( args.encoder_embed_dim, self.sentence_out_dim, bias=False ) def forward(self, src_tokens, segment_labels=None, masked_tokens=None, **unused): """ Forward pass for Masked LM encoder. This first computes the token embedding using the token embedding matrix, position embeddings (if specified) and segment embeddings (if specified). Here we assume that the sentence representation corresponds to the output of the classification_token (see bert_task or cross_lingual_lm task for more details). Args: - src_tokens: B x T matrix representing sentences - segment_labels: B x T matrix representing segment label for tokens Returns: - a tuple of the following: - logits for predictions in format B x T x C to be used in softmax afterwards - a dictionary of additional data, where 'pooled_output' contains the representation for classification_token and 'inner_states' is a list of internal model states used to compute the predictions (similar in ELMO). 'sentence_logits' is the prediction logit for NSP task and is only computed if this is specified in the input arguments. """ inner_states, sentence_rep = self.sentence_encoder( src_tokens, segment_labels=segment_labels, ) x = inner_states[-1].transpose(0, 1) # project masked tokens only if masked_tokens is not None: x = x[masked_tokens, :] x = self.layer_norm(self.activation_fn(self.lm_head_transform_weight(x))) pooled_output = self.pooler_activation(self.masked_lm_pooler(sentence_rep)) # project back to size of vocabulary if self.share_input_output_embed and hasattr( self.sentence_encoder.embed_tokens, "weight" ): x = F.linear(x, self.sentence_encoder.embed_tokens.weight) elif self.embed_out is not None: x = self.embed_out(x) if self.lm_output_learned_bias is not None: x = x + self.lm_output_learned_bias sentence_logits = None if self.sentence_projection_layer: sentence_logits = self.sentence_projection_layer(pooled_output) return x, { "inner_states": inner_states, "pooled_output": pooled_output, "sentence_logits": sentence_logits, } def max_positions(self): """Maximum output length supported by the encoder.""" return self.max_positions def upgrade_state_dict_named(self, state_dict, name): if isinstance( self.sentence_encoder.embed_positions, SinusoidalPositionalEmbedding ): state_dict[ name + ".sentence_encoder.embed_positions._float_tensor" ] = torch.FloatTensor(1) if not self.load_softmax: for k in list(state_dict.keys()): if ( "embed_out.weight" in k or "sentence_projection_layer.weight" in k or "lm_output_learned_bias" in k ): del state_dict[k] return state_dict @register_model_architecture("masked_lm", "masked_lm") def base_architecture(args): args.dropout = getattr(args, "dropout", 0.1) args.attention_dropout = getattr(args, "attention_dropout", 0.1) args.act_dropout = getattr(args, "act_dropout", 0.0) args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096) args.encoder_layers = getattr(args, "encoder_layers", 6) args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) args.share_encoder_input_output_embed = getattr( args, "share_encoder_input_output_embed", False ) args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False) args.no_token_positional_embeddings = getattr( args, "no_token_positional_embeddings", False ) args.num_segment = getattr(args, "num_segment", 2) args.sentence_class_num = getattr(args, "sentence_class_num", 2) args.sent_loss = getattr(args, "sent_loss", False) args.apply_bert_init = getattr(args, "apply_bert_init", False) args.activation_fn = getattr(args, "activation_fn", "relu") args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh") args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) @register_model_architecture("masked_lm", "bert_base") def bert_base_architecture(args): args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768) args.share_encoder_input_output_embed = getattr( args, "share_encoder_input_output_embed", True ) args.no_token_positional_embeddings = getattr( args, "no_token_positional_embeddings", False ) args.encoder_learned_pos = getattr(args, "encoder_learned_pos", True) args.num_segment = getattr(args, "num_segment", 2) args.encoder_layers = getattr(args, "encoder_layers", 12) args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12) args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 3072) args.sentence_class_num = getattr(args, "sentence_class_num", 2) args.sent_loss = getattr(args, "sent_loss", True) args.apply_bert_init = getattr(args, "apply_bert_init", True) args.activation_fn = getattr(args, "activation_fn", "gelu") args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh") args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True) base_architecture(args) @register_model_architecture("masked_lm", "bert_large") def bert_large_architecture(args): args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) args.encoder_layers = getattr(args, "encoder_layers", 24) args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16) args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096) bert_base_architecture(args) @register_model_architecture("masked_lm", "xlm_base") def xlm_architecture(args): args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) args.share_encoder_input_output_embed = getattr( args, "share_encoder_input_output_embed", True ) args.no_token_positional_embeddings = getattr( args, "no_token_positional_embeddings", False ) args.encoder_learned_pos = getattr(args, "encoder_learned_pos", True) args.num_segment = getattr(args, "num_segment", 1) args.encoder_layers = getattr(args, "encoder_layers", 6) args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096) args.sent_loss = getattr(args, "sent_loss", False) args.activation_fn = getattr(args, "activation_fn", "gelu") args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh") args.apply_bert_init = getattr(args, "apply_bert_init", True) base_architecture(args)