# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import torch.nn as nn import torch.nn.functional as F from fairseq.data import Dictionary from fairseq.models import ( FairseqDecoder, FairseqLanguageModel, register_model, register_model_architecture, ) @register_model("dummy_model") class DummyModel(FairseqLanguageModel): def __init__(self, args, encoder): super().__init__(encoder) self.args = args @staticmethod def add_args(parser): parser.add_argument("--num-layers", type=int, default=24) parser.add_argument("--embed-dim", type=int, default=1024) @classmethod def build_model(cls, args, task): encoder = DummyEncoder( num_embed=len(task.target_dictionary), embed_dim=args.embed_dim, num_layers=args.num_layers, ) return cls(args, encoder) def forward(self, src_tokens, masked_tokens=None, **kwargs): return self.decoder(src_tokens, masked_tokens=masked_tokens) class DummyEncoder(FairseqDecoder): def __init__(self, num_embed=50000, embed_dim=1024, num_layers=24): super().__init__(Dictionary()) self.embed = nn.Embedding( num_embeddings=num_embed, embedding_dim=embed_dim, padding_idx=0 ) self.layers_a = nn.ModuleList( [ nn.Sequential( nn.LayerNorm(embed_dim), nn.Linear(embed_dim, 3 * embed_dim), # q, k, v input projection nn.Linear(3 * embed_dim, embed_dim), # skip self-attention nn.Linear(embed_dim, embed_dim), # output projection nn.Dropout(), ) for i in range(num_layers) ] ) self.layers_b = nn.ModuleList( [ nn.Sequential( nn.LayerNorm(embed_dim), nn.Linear(embed_dim, 4 * embed_dim), # FFN nn.ReLU(), nn.Linear(4 * embed_dim, embed_dim), # FFN nn.Dropout(0.1), ) for i in range(num_layers) ] ) self.out_proj = nn.Linear(embed_dim, num_embed) def forward(self, tokens, masked_tokens=None): x = self.embed(tokens) for layer_a, layer_b in zip(self.layers_a, self.layers_b): x = x + layer_a(x) x = x + layer_b(x) x = self.out_proj(x) if masked_tokens is not None: x = x[masked_tokens] return (x,) def max_positions(self): return 1024 def get_normalized_probs(self, net_output, log_probs, sample=None): logits = net_output[0].float() if log_probs: return F.log_softmax(logits, dim=-1) else: return F.softmax(logits, dim=-1) @register_model_architecture("dummy_model", "dummy_model") def base_architecture(args): pass