OFA-Image_Caption / models /ofa /unify_transformer.py
JustinLin610's picture
update
75ba0e0
raw
history blame
71.1 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import random
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
from fairseq.distributed import fsdp_wrap
from fairseq.models import (
FairseqEncoder,
FairseqEncoderDecoderModel,
FairseqIncrementalDecoder,
register_model,
register_model_architecture,
)
from fairseq.modules import (
AdaptiveSoftmax,
BaseLayer,
FairseqDropout,
LayerDropModuleList,
LayerNorm,
SinusoidalPositionalEmbedding,
GradMultiply
)
from fairseq.modules.checkpoint_activations import checkpoint_wrapper
from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_
from torch import Tensor
from .unify_transformer_layer import TransformerEncoderLayer, TransformerDecoderLayer
from .resnet import ResNet
DEFAULT_MAX_SOURCE_POSITIONS = 1024
DEFAULT_MAX_TARGET_POSITIONS = 1024
DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8)
def BatchNorm2d(out_chan, momentum=0.1, eps=1e-3):
return nn.SyncBatchNorm.convert_sync_batchnorm(
nn.BatchNorm2d(out_chan, momentum=momentum, eps=eps)
)
def make_token_bucket_position(bucket_size, max_position=DEFAULT_MAX_SOURCE_POSITIONS):
context_pos = torch.arange(max_position, dtype=torch.long)[:, None]
memory_pos = torch.arange(max_position, dtype=torch.long)[None, :]
relative_pos = context_pos - memory_pos
sign = torch.sign(relative_pos)
mid = bucket_size // 2
abs_pos = torch.where((relative_pos<mid) & (relative_pos > -mid), mid-1, torch.abs(relative_pos))
log_pos = torch.ceil(torch.log(abs_pos/mid)/math.log((max_position-1)/mid) * (mid-1)) + mid
log_pos = log_pos.int()
bucket_pos = torch.where(abs_pos.le(mid), relative_pos, log_pos*sign).long()
return bucket_pos + bucket_size - 1
def make_image_bucket_position(bucket_size, num_relative_distance):
coords_h = torch.arange(bucket_size)
coords_w = torch.arange(bucket_size)
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += bucket_size - 1 # shift to start from 0
relative_coords[:, :, 1] += bucket_size - 1
relative_coords[:, :, 0] *= 2 * bucket_size - 1
relative_position_index = torch.zeros(size=(bucket_size * bucket_size + 1,) * 2, dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = num_relative_distance - 3
relative_position_index[0:, 0] = num_relative_distance - 2
relative_position_index[0, 0] = num_relative_distance - 1
return relative_position_index
@register_model("unify_transformer")
class TransformerModel(FairseqEncoderDecoderModel):
"""
Transformer model from `"Attention Is All You Need" (Vaswani, et al, 2017)
<https://arxiv.org/abs/1706.03762>`_.
Args:
encoder (TransformerEncoder): the encoder
decoder (TransformerDecoder): the decoder
The Transformer model provides the following named architectures and
command-line arguments:
.. argparse::
:ref: fairseq.models.transformer_parser
:prog:
"""
def __init__(self, args, encoder, decoder):
super().__init__(encoder, decoder)
self.args = args
self.supports_align_args = True
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
# fmt: off
parser.add_argument('--activation-fn',
choices=utils.get_available_activation_fns(),
help='activation function to use')
parser.add_argument('--dropout', type=float, metavar='D',
help='dropout probability')
parser.add_argument('--attention-dropout', type=float, metavar='D',
help='dropout probability for attention weights')
parser.add_argument('--activation-dropout', '--relu-dropout', type=float, metavar='D',
help='dropout probability after activation in FFN.')
parser.add_argument('--encoder-embed-path', type=str, metavar='STR',
help='path to pre-trained encoder embedding')
parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
help='encoder embedding dimension')
parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N',
help='encoder embedding dimension for FFN')
parser.add_argument('--encoder-layers', type=int, metavar='N',
help='num encoder layers')
parser.add_argument('--encoder-attention-heads', type=int, metavar='N',
help='num encoder attention heads')
parser.add_argument('--encoder-normalize-before', action='store_true',
help='apply layernorm before each encoder block')
parser.add_argument('--encoder-learned-pos', action='store_true',
help='use learned positional embeddings in the encoder')
parser.add_argument('--decoder-embed-path', type=str, metavar='STR',
help='path to pre-trained decoder embedding')
parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
help='decoder embedding dimension')
parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N',
help='decoder embedding dimension for FFN')
parser.add_argument('--decoder-layers', type=int, metavar='N',
help='num decoder layers')
parser.add_argument('--decoder-attention-heads', type=int, metavar='N',
help='num decoder attention heads')
parser.add_argument('--decoder-learned-pos', action='store_true',
help='use learned positional embeddings in the decoder')
parser.add_argument('--decoder-normalize-before', action='store_true',
help='apply layernorm before each decoder block')
parser.add_argument('--decoder-output-dim', type=int, metavar='N',
help='decoder output dimension (extra linear layer '
'if different from decoder embed dim')
parser.add_argument('--share-decoder-input-output-embed', action='store_true',
help='share decoder input and output embeddings')
parser.add_argument('--share-all-embeddings', action='store_true',
help='share encoder, decoder and output embeddings'
' (requires shared dictionary and embed dim)')
parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true',
help='if set, disables positional embeddings (outside self attention)')
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion'),
parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
help='sets adaptive softmax dropout for the tail projections')
parser.add_argument('--layernorm-embedding', action='store_true',
help='add layernorm to embedding')
parser.add_argument('--no-scale-embedding', action='store_true',
help='if True, dont scale embeddings')
parser.add_argument('--checkpoint-activations', action='store_true',
help='checkpoint activations at each layer, which saves GPU '
'memory usage at the cost of some additional compute')
parser.add_argument('--offload-activations', action='store_true',
help='checkpoint activations at each layer, then save to gpu. Sets --checkpoint-activations.')
# args for "Cross+Self-Attention for Transformer Models" (Peitz et al., 2019)
parser.add_argument('--no-cross-attention', default=False, action='store_true',
help='do not perform cross-attention')
parser.add_argument('--cross-self-attention', default=False, action='store_true',
help='perform cross+self-attention')
# args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019)
parser.add_argument('--encoder-layerdrop', type=float, metavar='D', default=0,
help='LayerDrop probability for encoder')
parser.add_argument('--decoder-layerdrop', type=float, metavar='D', default=0,
help='LayerDrop probability for decoder')
parser.add_argument('--encoder-layers-to-keep', default=None,
help='which layers to *keep* when pruning as a comma-separated list')
parser.add_argument('--decoder-layers-to-keep', default=None,
help='which layers to *keep* when pruning as a comma-separated list')
# args for Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020)
parser.add_argument('--quant-noise-pq', type=float, metavar='D', default=0,
help='iterative PQ quantization noise at training time')
parser.add_argument('--quant-noise-pq-block-size', type=int, metavar='D', default=8,
help='block size of quantization noise at training time')
parser.add_argument('--quant-noise-scalar', type=float, metavar='D', default=0,
help='scalar quantization noise and scalar quantization at training time')
# args for Fully Sharded Data Parallel (FSDP) training
parser.add_argument(
'--min-params-to-wrap', type=int, metavar='D', default=DEFAULT_MIN_PARAMS_TO_WRAP,
help=(
'minimum number of params for a layer to be wrapped with FSDP() when '
'training with --ddp-backend=fully_sharded. Smaller values will '
'improve memory efficiency, but may make torch.distributed '
'communication less efficient due to smaller input sizes. This option '
'is set to 0 (i.e., always wrap) when --checkpoint-activations or '
'--offload-activations are passed.'
)
)
parser.add_argument('--resnet-drop-path-rate', type=float,
help='resnet drop path rate')
parser.add_argument('--encoder-drop-path-rate', type=float,
help='encoder drop path rate')
parser.add_argument('--decoder-drop-path-rate', type=float,
help='encoder drop path rate')
parser.add_argument('--token-bucket-size', type=int,
help='token bucket size')
parser.add_argument('--image-bucket-size', type=int,
help='image bucket size')
parser.add_argument('--attn-scale-factor', type=float,
help='attention scale factor')
parser.add_argument('--freeze-resnet', action='store_true',
help='freeze resnet')
parser.add_argument('--freeze-encoder-embedding', action='store_true',
help='freeze encoder token embedding')
parser.add_argument('--freeze-decoder-embedding', action='store_true',
help='freeze decoder token embedding')
parser.add_argument('--add-type-embedding', action='store_true',
help='add source/region/patch type embedding')
parser.add_argument('--resnet-type', choices=['resnet50', 'resnet101', 'resnet152'],
help='resnet type')
parser.add_argument('--resnet-model-path', type=str, metavar='STR',
help='path to load resnet')
parser.add_argument('--code-image-size', type=int,
help='code image size')
parser.add_argument('--patch-layernorm-embedding', action='store_true',
help='add layernorm to patch embedding')
parser.add_argument('--code-layernorm-embedding', action='store_true',
help='add layernorm to code embedding')
parser.add_argument('--entangle-position-embedding', action='store_true',
help='entangle position embedding')
parser.add_argument('--disable-entangle', action='store_true',
help='disable entangle')
parser.add_argument('--sync-bn', action='store_true',
help='sync batchnorm')
parser.add_argument('--scale-attn', action='store_true',
help='scale attn')
parser.add_argument('--scale-fc', action='store_true',
help='scale fc')
parser.add_argument('--scale-heads', action='store_true',
help='scale heads')
parser.add_argument('--scale-resids', action='store_true',
help='scale resids')
# fmt: on
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
# make sure all arguments are present in older models
base_architecture(args)
if args.encoder_layers_to_keep:
args.encoder_layers = len(args.encoder_layers_to_keep.split(","))
if args.decoder_layers_to_keep:
args.decoder_layers = len(args.decoder_layers_to_keep.split(","))
if getattr(args, "max_source_positions", None) is None:
args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
if getattr(args, "max_target_positions", None) is None:
args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
if args.share_all_embeddings:
if src_dict != tgt_dict:
raise ValueError("--share-all-embeddings requires a joined dictionary")
if args.encoder_embed_dim != args.decoder_embed_dim:
raise ValueError(
"--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
)
if args.decoder_embed_path and (
args.decoder_embed_path != args.encoder_embed_path
):
raise ValueError(
"--share-all-embeddings not compatible with --decoder-embed-path"
)
encoder_embed_tokens = cls.build_embedding(
args, src_dict, args.encoder_embed_dim, args.encoder_embed_path
)
decoder_embed_tokens = encoder_embed_tokens
args.share_decoder_input_output_embed = True
else:
encoder_embed_tokens = cls.build_embedding(
args, src_dict, args.encoder_embed_dim, args.encoder_embed_path
)
decoder_embed_tokens = cls.build_embedding(
args, tgt_dict, args.decoder_embed_dim, args.decoder_embed_path
)
if getattr(args, "freeze_encoder_embedding", False):
encoder_embed_tokens.weight.requires_grad = False
if getattr(args, "freeze_decoder_embedding", False):
decoder_embed_tokens.weight.requires_grad = False
if getattr(args, "offload_activations", False):
args.checkpoint_activations = True # offloading implies checkpointing
encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens)
decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens)
if not args.share_all_embeddings:
min_params_to_wrap = getattr(
args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP
)
# fsdp_wrap is a no-op when --ddp-backend != fully_sharded
encoder = fsdp_wrap(encoder, min_num_params=min_params_to_wrap)
decoder = fsdp_wrap(decoder, min_num_params=min_params_to_wrap)
return cls(args, encoder, decoder)
@classmethod
def build_embedding(cls, args, dictionary, embed_dim, path=None):
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
emb = Embedding(num_embeddings, embed_dim, padding_idx)
# if provided, load from preloaded dictionaries
if path:
embed_dict = utils.parse_embedding(path)
utils.load_embedding(embed_dict, dictionary, emb)
return emb
@classmethod
def build_encoder(cls, args, src_dict, embed_tokens):
return TransformerEncoder(args, src_dict, embed_tokens)
@classmethod
def build_decoder(cls, args, tgt_dict, embed_tokens):
return TransformerDecoder(
args,
tgt_dict,
embed_tokens,
no_encoder_attn=getattr(args, "no_cross_attention", False),
)
# TorchScript doesn't support optional arguments with variable length (**kwargs).
# Current workaround is to add union of all arguments in child classes.
def forward(
self,
src_tokens,
src_lengths,
prev_output_tokens,
return_all_hiddens: bool = True,
features_only: bool = False,
alignment_layer: Optional[int] = None,
alignment_heads: Optional[int] = None,
):
"""
Run the forward pass for an encoder-decoder model.
Copied from the base class, but without ``**kwargs``,
which are not supported by TorchScript.
"""
encoder_out = self.encoder(
src_tokens, src_lengths=src_lengths, return_all_hiddens=return_all_hiddens
)
decoder_out = self.decoder(
prev_output_tokens,
encoder_out=encoder_out,
features_only=features_only,
alignment_layer=alignment_layer,
alignment_heads=alignment_heads,
src_lengths=src_lengths,
return_all_hiddens=return_all_hiddens,
)
return decoder_out
# Since get_normalized_probs is in the Fairseq Model which is not scriptable,
# I rewrite the get_normalized_probs from Base Class to call the
# helper function in the Base Class.
@torch.jit.export
def get_normalized_probs(
self,
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
log_probs: bool,
sample: Optional[Dict[str, Tensor]] = None,
):
"""Get normalized probabilities (or log probs) from a net's output."""
return self.get_normalized_probs_scriptable(net_output, log_probs, sample)
class TransformerEncoder(FairseqEncoder):
"""
Transformer encoder consisting of *args.encoder_layers* layers. Each layer
is a :class:`TransformerEncoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): encoding dictionary
embed_tokens (torch.nn.Embedding): input embedding
"""
def __init__(self, args, dictionary, embed_tokens):
self.args = args
super().__init__(dictionary)
self.register_buffer("version", torch.Tensor([3]))
self.dropout_module = FairseqDropout(
args.dropout, module_name=self.__class__.__name__
)
self.encoder_layerdrop = args.encoder_layerdrop
embed_dim = embed_tokens.embedding_dim
self.padding_idx = embed_tokens.padding_idx
self.max_source_positions = args.max_source_positions
self.num_attention_heads = args.encoder_attention_heads
self.embed_tokens = embed_tokens
self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim)
if getattr(args, "layernorm_embedding", False):
self.layernorm_embedding = LayerNorm(embed_dim)
else:
self.layernorm_embedding = None
if getattr(args, "add_type_embedding", False):
self.type_embedding = Embedding(2, embed_dim, padding_idx=None)
else:
self.type_embedding = None
if getattr(args, "sync_bn", False):
norm_layer = BatchNorm2d
else:
norm_layer = None
if args.resnet_type == 'resnet101':
self.embed_images = ResNet([3, 4, 23], norm_layer=norm_layer, drop_path_rate=args.resnet_drop_path_rate)
elif args.resnet_type == 'resnet152':
self.embed_images = ResNet([3, 8, 36], norm_layer=norm_layer, drop_path_rate=args.resnet_drop_path_rate)
else:
raise NotImplementedError
self.image_proj = Linear(1024, embed_dim)
if getattr(args, "resnet_model_path", None):
print("load resnet {}".format(args.resnet_model_path))
resnet_state_dict = torch.load(self.args.resnet_model_path)
self.embed_images.load_state_dict(resnet_state_dict)
if getattr(args, "patch_layernorm_embedding", False):
self.patch_layernorm_embedding = LayerNorm(embed_dim)
else:
self.patch_layernorm_embedding = None
self.embed_positions = Embedding(args.max_source_positions + 2, embed_dim)
self.embed_image_positions = Embedding(args.image_bucket_size ** 2 + 1, embed_dim)
self.pos_ln = LayerNorm(embed_dim)
self.image_pos_ln = LayerNorm(embed_dim)
self.pos_scaling = float(embed_dim / args.encoder_attention_heads * args.attn_scale_factor) ** -0.5
self.pos_q_linear = nn.Linear(embed_dim, embed_dim)
self.pos_k_linear = nn.Linear(embed_dim, embed_dim)
if not args.adaptive_input and args.quant_noise_pq > 0:
self.quant_noise = apply_quant_noise_(
nn.Linear(embed_dim, embed_dim, bias=False),
args.quant_noise_pq,
args.quant_noise_pq_block_size,
)
else:
self.quant_noise = None
if self.encoder_layerdrop > 0.0:
self.layers = LayerDropModuleList(p=self.encoder_layerdrop)
else:
self.layers = nn.ModuleList([])
dpr = [x.item() for x in torch.linspace(0, args.encoder_drop_path_rate, args.encoder_layers)]
self.layers.extend(
[self.build_encoder_layer(args, drop_path_rate=dpr[i]) for i in range(args.encoder_layers)]
)
self.num_layers = len(self.layers)
if args.encoder_normalize_before:
self.layer_norm = LayerNorm(embed_dim)
else:
self.layer_norm = None
token_bucket_size = args.token_bucket_size
token_num_rel_dis = 2 * token_bucket_size - 1
token_rp_bucket = make_token_bucket_position(token_bucket_size)
self.token_rel_pos_table_list = nn.ModuleList(
[Embedding(token_num_rel_dis, self.num_attention_heads, zero_init=True) for _ in range(args.encoder_layers)]
)
image_bucket_size = args.image_bucket_size
image_num_rel_dis = (2 * image_bucket_size - 1) * (2 * image_bucket_size - 1) + 3
image_rp_bucket = make_image_bucket_position(image_bucket_size, image_num_rel_dis)
self.image_rel_pos_table_list = nn.ModuleList(
[Embedding(image_num_rel_dis, self.num_attention_heads, zero_init=True) for _ in range(args.encoder_layers)]
)
self.register_buffer("token_rp_bucket", token_rp_bucket)
self.register_buffer("image_rp_bucket", image_rp_bucket)
self.entangle_position_embedding = args.entangle_position_embedding
def train(self, mode=True):
super(TransformerEncoder, self).train(mode)
if getattr(self.args, "freeze_resnet", False):
for m in self.embed_images.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
m.weight.requires_grad = False
m.bias.requires_grad = False
def build_encoder_layer(self, args, drop_path_rate=0.0):
layer = TransformerEncoderLayer(args, drop_path_rate=drop_path_rate)
checkpoint = getattr(args, "checkpoint_activations", False)
if checkpoint:
offload_to_cpu = getattr(args, "offload_activations", False)
layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu)
# if we are checkpointing, enforce that FSDP always wraps the
# checkpointed layer, regardless of layer size
min_params_to_wrap = (
getattr(args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP)
if not checkpoint else 0
)
layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap)
return layer
def get_rel_pos_bias(self, x, idx):
seq_len = x.size(1)
rp_bucket = self.token_rp_bucket[:seq_len, :seq_len]
values = F.embedding(rp_bucket, self.token_rel_pos_table_list[idx].weight)
values = values.unsqueeze(0).expand(x.size(0), -1, -1, -1)
values = values.permute([0, 3, 1, 2])
return values.contiguous()
def get_image_rel_pos_bias(self, image_position_ids, idx):
bsz, seq_len = image_position_ids.shape
rp_bucket_size = self.image_rp_bucket.size(1)
rp_bucket = self.image_rp_bucket.unsqueeze(0).expand(
bsz, rp_bucket_size, rp_bucket_size
).gather(1, image_position_ids[:, :, None].expand(bsz, seq_len, rp_bucket_size)
).gather(2, image_position_ids[:, None, :].expand(bsz, seq_len, seq_len))
values = F.embedding(rp_bucket, self.image_rel_pos_table_list[idx].weight)
values = values.permute(0, 3, 1, 2)
return values
def get_patch_images_info(self, patch_images, sample_patch_num, device):
image_embed = self.embed_images(patch_images)
h, w = image_embed.shape[-2:]
image_num_patches = h * w
image_padding_mask = patch_images.new_zeros((patch_images.size(0), image_num_patches)).bool()
image_position_idx = torch.arange(w).unsqueeze(0).expand(h, w) + \
torch.arange(h).unsqueeze(1) * self.args.image_bucket_size + 1
image_position_idx = image_position_idx.view(-1).to(device)
image_position_ids = image_position_idx[None, :].expand(patch_images.size(0), image_num_patches)
image_embed = image_embed.flatten(2).transpose(1, 2)
if sample_patch_num is not None:
patch_orders = [
random.sample(range(image_num_patches), k=sample_patch_num)
for _ in range(patch_images.size(0))
]
patch_orders = torch.LongTensor(patch_orders).to(device)
image_embed = image_embed.gather(
1, patch_orders.unsqueeze(2).expand(-1, -1, image_embed.size(2))
)
image_num_patches = sample_patch_num
image_padding_mask = image_padding_mask.gather(1, patch_orders)
image_position_ids = image_position_ids.gather(1, patch_orders)
image_pos_embed = self.embed_image_positions(image_position_ids)
return image_embed, image_num_patches, image_padding_mask, image_position_ids, image_pos_embed
def forward_embedding(
self,
src_tokens,
image_embed: Optional[torch.Tensor] = None,
image_embed_2: Optional[torch.Tensor] = None,
token_embedding: Optional[torch.Tensor] = None,
pos_embed: Optional[torch.Tensor] = None,
image_pos_embed: Optional[torch.Tensor] = None,
image_pos_embed_2: Optional[torch.Tensor] = None
):
# embed tokens and positions
if token_embedding is None:
token_embedding = self.embed_tokens(src_tokens)
x = embed = self.embed_scale * token_embedding
if self.entangle_position_embedding and pos_embed is not None:
x += pos_embed
if self.type_embedding is not None:
x += self.type_embedding(src_tokens.new_zeros(x.size()[:2]))
if self.layernorm_embedding is not None:
x = self.layernorm_embedding(x)
x = self.dropout_module(x)
if self.quant_noise is not None:
x = self.quant_noise(x)
# embed raw images
if image_embed is not None:
image_embed = self.image_proj(image_embed)
image_x = image_embed = self.embed_scale * image_embed
if self.entangle_position_embedding and image_pos_embed is not None:
image_x += image_pos_embed
if self.type_embedding is not None:
image_x += self.type_embedding(src_tokens.new_ones(image_x.size()[:2]))
if self.patch_layernorm_embedding is not None:
image_x = self.patch_layernorm_embedding(image_x)
image_x = self.dropout_module(image_x)
if self.quant_noise is not None:
image_x = self.quant_noise(image_x)
x = torch.cat([image_x, x], dim=1)
embed = torch.cat([image_embed, embed], dim=1)
if image_embed_2 is not None:
assert self.type_embedding is not None
image_embed_2 = self.image_proj(image_embed_2)
image_x_2 = image_embed_2 = self.embed_scale * image_embed_2
if self.entangle_position_embedding and image_pos_embed_2 is not None:
image_x_2 += image_pos_embed_2
if self.type_embedding is not None:
image_x_2 += self.type_embedding(src_tokens.new_full(image_x_2.size()[:2], fill_value=2))
if self.patch_layernorm_embedding is not None:
image_x_2 = self.patch_layernorm_embedding(image_x_2)
image_x_2 = self.dropout_module(image_x_2)
if self.quant_noise is not None:
image_x_2 = self.quant_noise(image_x_2)
x = torch.cat([image_x_2, x], dim=1)
embed = torch.cat([image_embed_2, embed], dim=1)
return x, embed
def forward(
self,
src_tokens,
src_lengths,
patch_images: Optional[torch.Tensor] = None,
patch_images_2: Optional[torch.Tensor] = None,
patch_masks: Optional[torch.Tensor] = None,
code_masks: Optional[torch.Tensor] = None,
return_all_hiddens: bool = False,
token_embeddings: Optional[torch.Tensor] = None,
sample_patch_num: Optional[int] = None
):
"""
Args:
src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)`
src_lengths (torch.LongTensor): lengths of each source sentence of
shape `(batch)`
return_all_hiddens (bool, optional): also return all of the
intermediate hidden states (default: False).
token_embeddings (torch.Tensor, optional): precomputed embeddings
default `None` will recompute embeddings
Returns:
dict:
- **encoder_out** (Tensor): the last encoder layer's output of
shape `(src_len, batch, embed_dim)`
- **encoder_padding_mask** (ByteTensor): the positions of
padding elements of shape `(batch, src_len)`
- **encoder_embedding** (Tensor): the (scaled) embedding lookup
of shape `(batch, src_len, embed_dim)`
- **encoder_states** (List[Tensor]): all intermediate
hidden states of shape `(src_len, batch, embed_dim)`.
Only populated if *return_all_hiddens* is True.
"""
return self.forward_scriptable(src_tokens,
src_lengths,
patch_images,
patch_images_2,
patch_masks,
return_all_hiddens,
token_embeddings,
sample_patch_num)
# TorchScript doesn't support super() method so that the scriptable Subclass
# can't access the base class model in Torchscript.
# Current workaround is to add a helper function with different name and
# call the helper function from scriptable Subclass.
def forward_scriptable(
self,
src_tokens,
src_lengths,
patch_images: Optional[torch.Tensor] = None,
patch_images_2: Optional[torch.Tensor] = None,
patch_masks: Optional[torch.Tensor] = None,
return_all_hiddens: bool = False,
token_embeddings: Optional[torch.Tensor] = None,
sample_patch_num: Optional[int] = None
):
"""
Args:
src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)`
src_lengths (torch.LongTensor): lengths of each source sentence of
shape `(batch)`
return_all_hiddens (bool, optional): also return all of the
intermediate hidden states (default: False).
token_embeddings (torch.Tensor, optional): precomputed embeddings
default `None` will recompute embeddings
Returns:
dict:
- **encoder_out** (Tensor): the last encoder layer's output of
shape `(src_len, batch, embed_dim)`
- **encoder_padding_mask** (ByteTensor): the positions of
padding elements of shape `(batch, src_len)`
- **encoder_embedding** (Tensor): the (scaled) embedding lookup
of shape `(batch, src_len, embed_dim)`
- **encoder_states** (List[Tensor]): all intermediate
hidden states of shape `(src_len, batch, embed_dim)`.
Only populated if *return_all_hiddens* is True.
"""
image_embed = None
image_embed_2 = None
image_pos_embed = None
image_pos_embed_2 = None
if patch_images is not None:
image_embed, image_num_patches, image_padding_mask, image_position_ids, image_pos_embed = \
self.get_patch_images_info(patch_images, sample_patch_num, src_tokens.device)
image_padding_mask[~patch_masks] = True
if patch_images_2 is not None:
image_embed_2, image_num_patches_2, image_padding_mask_2, image_position_ids_2, image_pos_embed_2 = \
self.get_patch_images_info(patch_images_2, sample_patch_num, src_tokens.device)
image_padding_mask_2[~patch_masks] = True
encoder_padding_mask = src_tokens.eq(self.padding_idx)
if patch_images is not None:
encoder_padding_mask = torch.cat([image_padding_mask, encoder_padding_mask], dim=1)
if patch_images_2 is not None:
encoder_padding_mask = torch.cat([image_padding_mask_2, encoder_padding_mask], dim=1)
has_pads = (src_tokens.device.type == "xla" or encoder_padding_mask.any())
pos_embed = self.embed_positions(utils.new_arange(src_tokens))
x, encoder_embedding = self.forward_embedding(
src_tokens, image_embed, image_embed_2, token_embeddings,
pos_embed, image_pos_embed, image_pos_embed_2
)
# account for padding while computing the representation
if has_pads:
x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
# B x T x C -> T x B x C
x = x.transpose(0, 1)
pos_embed = self.pos_ln(pos_embed)
if patch_images is not None:
image_pos_embed = self.image_pos_ln(image_pos_embed)
pos_embed = torch.cat([image_pos_embed, pos_embed], dim=1)
if patch_images_2 is not None:
image_pos_embed_2 = self.image_pos_ln(image_pos_embed_2)
pos_embed = torch.cat([image_pos_embed_2, pos_embed], dim=1)
pos_q = self.pos_q_linear(pos_embed).view(
x.size(1), x.size(0), self.num_attention_heads, -1
).transpose(1, 2) * self.pos_scaling
pos_k = self.pos_k_linear(pos_embed).view(
x.size(1), x.size(0), self.num_attention_heads, -1
).transpose(1, 2)
abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3))
encoder_states = []
if return_all_hiddens:
encoder_states.append(x)
# encoder layers
for idx, layer in enumerate(self.layers):
self_attn_bias = abs_pos_bias.clone()
self_attn_bias[:, :, -src_tokens.size(1):, -src_tokens.size(1):] += self.get_rel_pos_bias(src_tokens, idx)
if patch_images_2 is not None:
self_attn_bias[:, :, :image_num_patches_2, :image_num_patches_2] += \
self.get_image_rel_pos_bias(image_position_ids_2, idx)
self_attn_bias[:, :, image_num_patches_2:image_num_patches_2+image_num_patches, image_num_patches_2:image_num_patches_2+image_num_patches] += \
self.get_image_rel_pos_bias(image_position_ids, idx)
elif patch_images is not None:
self_attn_bias[:, :, :x.size(0) - src_tokens.size(1), :x.size(0) - src_tokens.size(1)] += \
self.get_image_rel_pos_bias(image_position_ids, idx)
self_attn_bias = self_attn_bias.reshape(-1, x.size(0), x.size(0))
x = layer(
x, encoder_padding_mask=encoder_padding_mask if has_pads else None, self_attn_bias=self_attn_bias
)
if return_all_hiddens:
assert encoder_states is not None
encoder_states.append(x)
if self.layer_norm is not None:
x = self.layer_norm(x)
# The Pytorch Mobile lite interpreter does not supports returning NamedTuple in
# `forward` so we use a dictionary instead.
# TorchScript does not support mixed values so the values are all lists.
# The empty list is equivalent to None.
return {
"encoder_out": [x], # T x B x C
"encoder_padding_mask": [encoder_padding_mask], # B x T
"encoder_embedding": [], # B x T x C
"encoder_states": encoder_states, # List[T x B x C]
"src_tokens": [],
"src_lengths": [],
"position_embeddings": [pos_embed], # B x T x C
}
@torch.jit.export
def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order):
"""
Reorder encoder output according to *new_order*.
Args:
encoder_out: output from the ``forward()`` method
new_order (LongTensor): desired order
Returns:
*encoder_out* rearranged according to *new_order*
"""
if len(encoder_out["encoder_out"]) == 0:
new_encoder_out = []
else:
new_encoder_out = [encoder_out["encoder_out"][0].index_select(1, new_order)]
if len(encoder_out["encoder_padding_mask"]) == 0:
new_encoder_padding_mask = []
else:
new_encoder_padding_mask = [
encoder_out["encoder_padding_mask"][0].index_select(0, new_order)
]
if len(encoder_out["encoder_embedding"]) == 0:
new_encoder_embedding = []
else:
new_encoder_embedding = [
encoder_out["encoder_embedding"][0].index_select(0, new_order)
]
if len(encoder_out["src_tokens"]) == 0:
new_src_tokens = []
else:
new_src_tokens = [(encoder_out["src_tokens"][0]).index_select(0, new_order)]
if len(encoder_out["src_lengths"]) == 0:
new_src_lengths = []
else:
new_src_lengths = [(encoder_out["src_lengths"][0]).index_select(0, new_order)]
if len(encoder_out["position_embeddings"]) == 0:
new_position_embeddings = []
else:
new_position_embeddings = [(encoder_out["position_embeddings"][0]).index_select(0, new_order)]
encoder_states = encoder_out["encoder_states"]
if len(encoder_states) > 0:
for idx, state in enumerate(encoder_states):
encoder_states[idx] = state.index_select(1, new_order)
return {
"encoder_out": new_encoder_out, # T x B x C
"encoder_padding_mask": new_encoder_padding_mask, # B x T
"encoder_embedding": new_encoder_embedding, # B x T x C
"encoder_states": encoder_states, # List[T x B x C]
"src_tokens": new_src_tokens, # B x T
"src_lengths": new_src_lengths, # B x 1
"position_embeddings": new_position_embeddings, # B x T x C
}
def max_positions(self):
"""Maximum input length supported by the encoder."""
if self.embed_positions is None:
return self.max_source_positions
return self.max_source_positions
def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
weights_key = "{}.embed_positions.weights".format(name)
if weights_key in state_dict:
print("deleting {0}".format(weights_key))
del state_dict[weights_key]
state_dict[
"{}.embed_positions._float_tensor".format(name)
] = torch.FloatTensor(1)
for i in range(self.num_layers):
# update layer norms
self.layers[i].upgrade_state_dict_named(
state_dict, "{}.layers.{}".format(name, i)
)
# version_key = "{}.version".format(name)
# if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2:
# # earlier checkpoints did not normalize after the stack of layers
# self.layer_norm = None
# self.normalize = False
# state_dict[version_key] = torch.Tensor([1])
prefix = name + "." if name != "" else ""
for param_name, param_tensor in self.state_dict().items():
if (prefix + param_name) not in state_dict and param_name in self.state_dict():
state_dict[prefix + param_name] = self.state_dict()[param_name]
if len(state_dict["encoder.embed_image_positions.weight"]) < len(self.state_dict()["embed_image_positions.weight"]):
num_posids_to_add = len(self.state_dict()["embed_image_positions.weight"]) - len(state_dict["encoder.embed_image_positions.weight"])
embed_dim = state_dict["encoder.embed_image_positions.weight"].size(1)
new_pos_embed_to_add = torch.zeros(num_posids_to_add, embed_dim)
nn.init.normal_(new_pos_embed_to_add, mean=0, std=embed_dim ** -0.5)
new_pos_embed_to_add = new_pos_embed_to_add.to(
dtype=state_dict["encoder.embed_image_positions.weight"].dtype,
)
state_dict["encoder.embed_image_positions.weight"] = torch.cat(
[state_dict["encoder.embed_image_positions.weight"], new_pos_embed_to_add]
)
return state_dict
class TransformerDecoder(FairseqIncrementalDecoder):
"""
Transformer decoder consisting of *args.decoder_layers* layers. Each layer
is a :class:`TransformerDecoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): decoding dictionary
embed_tokens (torch.nn.Embedding): output embedding
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).
"""
def __init__(
self,
args,
dictionary,
embed_tokens,
no_encoder_attn=False,
output_projection=None,
):
self.args = args
super().__init__(dictionary)
self.register_buffer("version", torch.Tensor([3]))
self._future_mask = torch.empty(0)
self.dropout_module = FairseqDropout(
args.dropout, module_name=self.__class__.__name__
)
self.decoder_layerdrop = args.decoder_layerdrop
self.share_input_output_embed = args.share_decoder_input_output_embed
self.num_attention_heads = args.decoder_attention_heads
input_embed_dim = embed_tokens.embedding_dim
embed_dim = args.decoder_embed_dim
self.embed_dim = embed_dim
self.output_embed_dim = args.decoder_output_dim
self.padding_idx = embed_tokens.padding_idx
self.max_target_positions = args.max_target_positions
self.embed_tokens = embed_tokens
self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim)
if not args.adaptive_input and args.quant_noise_pq > 0:
self.quant_noise = apply_quant_noise_(
nn.Linear(embed_dim, embed_dim, bias=False),
args.quant_noise_pq,
args.quant_noise_pq_block_size,
)
else:
self.quant_noise = None
self.project_in_dim = (
Linear(input_embed_dim, embed_dim, bias=False)
if embed_dim != input_embed_dim
else None
)
if getattr(args, "layernorm_embedding", False):
self.layernorm_embedding = LayerNorm(embed_dim)
else:
self.layernorm_embedding = None
self.window_size = args.code_image_size // 8
self.embed_positions = Embedding(args.max_target_positions + 2, embed_dim)
self.embed_image_positions = Embedding(args.image_bucket_size ** 2 + 1, embed_dim)
self.pos_ln = LayerNorm(embed_dim)
self.image_pos_ln = LayerNorm(embed_dim)
self.pos_scaling = float(embed_dim / self.num_attention_heads * args.attn_scale_factor) ** -0.5
self.self_pos_q_linear = nn.Linear(embed_dim, embed_dim)
self.self_pos_k_linear = nn.Linear(embed_dim, embed_dim)
self.cross_pos_q_linear = nn.Linear(embed_dim, embed_dim)
self.cross_pos_k_linear = nn.Linear(embed_dim, embed_dim)
if getattr(args, "code_layernorm_embedding", False):
self.code_layernorm_embedding = LayerNorm(embed_dim)
else:
self.code_layernorm_embedding = None
self.cross_self_attention = getattr(args, "cross_self_attention", False)
if self.decoder_layerdrop > 0.0:
self.layers = LayerDropModuleList(p=self.decoder_layerdrop)
else:
self.layers = nn.ModuleList([])
dpr = [x.item() for x in torch.linspace(0, args.decoder_drop_path_rate, args.decoder_layers)]
self.layers.extend(
[
self.build_decoder_layer(args, no_encoder_attn, drop_path_rate=dpr[i])
for i in range(args.decoder_layers)
]
)
self.num_layers = len(self.layers)
if args.decoder_normalize_before:
self.layer_norm = LayerNorm(embed_dim)
else:
self.layer_norm = None
self.project_out_dim = (
Linear(embed_dim, self.output_embed_dim, bias=False)
if embed_dim != self.output_embed_dim and not args.tie_adaptive_weights
else None
)
self.adaptive_softmax = None
self.output_projection = output_projection
if self.output_projection is None:
self.build_output_projection(args, dictionary, embed_tokens)
token_bucket_size = args.token_bucket_size
token_num_rel_dis = 2 * token_bucket_size - 1
token_rp_bucket = make_token_bucket_position(token_bucket_size)
self.token_rel_pos_table_list = nn.ModuleList(
[Embedding(token_num_rel_dis, self.num_attention_heads, zero_init=True) for _ in range(args.decoder_layers)]
)
image_bucket_size = args.image_bucket_size
image_num_rel_dis = (2 * image_bucket_size - 1) * (2 * image_bucket_size - 1) + 3
image_rp_bucket = make_image_bucket_position(image_bucket_size, image_num_rel_dis)
image_position_idx = torch.arange(self.window_size).unsqueeze(0).expand(self.window_size, self.window_size) + \
torch.arange(self.window_size).unsqueeze(1) * image_bucket_size + 1
image_position_idx = torch.cat([torch.tensor([0]), image_position_idx.view(-1)])
image_position_idx = torch.cat([image_position_idx, torch.tensor([1024] * 768)])
self.image_rel_pos_table_list = nn.ModuleList(
[Embedding(image_num_rel_dis, self.num_attention_heads, zero_init=True) for _ in range(args.decoder_layers)]
)
self.register_buffer("token_rp_bucket", token_rp_bucket)
self.register_buffer("image_rp_bucket", image_rp_bucket)
self.register_buffer("image_position_idx", image_position_idx)
self.entangle_position_embedding = args.entangle_position_embedding
def build_output_projection(self, args, dictionary, embed_tokens):
if args.adaptive_softmax_cutoff is not None:
self.adaptive_softmax = AdaptiveSoftmax(
len(dictionary),
self.output_embed_dim,
utils.eval_str_list(args.adaptive_softmax_cutoff, type=int),
dropout=args.adaptive_softmax_dropout,
adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None,
factor=args.adaptive_softmax_factor,
tie_proj=args.tie_adaptive_proj,
)
elif self.share_input_output_embed:
self.output_projection = nn.Linear(
self.embed_tokens.weight.shape[1],
self.embed_tokens.weight.shape[0],
bias=False,
)
self.output_projection.weight = self.embed_tokens.weight
else:
self.output_projection = nn.Linear(
self.output_embed_dim, len(dictionary), bias=False
)
nn.init.normal_(
self.output_projection.weight, mean=0, std=self.output_embed_dim ** -0.5
)
num_base_layers = getattr(args, "base_layers", 0)
for i in range(num_base_layers):
self.layers.insert(((i+1) * args.decoder_layers) // (num_base_layers + 1), BaseLayer(args))
def build_decoder_layer(self, args, no_encoder_attn=False, drop_path_rate=0.0):
layer = TransformerDecoderLayer(args, no_encoder_attn, drop_path_rate=drop_path_rate)
checkpoint = getattr(args, "checkpoint_activations", False)
if checkpoint:
offload_to_cpu = getattr(args, "offload_activations", False)
layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu)
# if we are checkpointing, enforce that FSDP always wraps the
# checkpointed layer, regardless of layer size
min_params_to_wrap = (
getattr(args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP)
if not checkpoint else 0
)
layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap)
return layer
def get_rel_pos_bias(self, x, idx):
seq_len = x.size(1)
rp_bucket = self.token_rp_bucket[:seq_len, :seq_len]
values = F.embedding(rp_bucket, self.token_rel_pos_table_list[idx].weight)
values = values.permute([2, 0, 1])
return values.contiguous()
def get_image_rel_pos_bias(self, x, idx):
seq_len = x.size(1)
image_position_idx = self.image_position_idx[:seq_len]
rp_bucket = self.image_rp_bucket[image_position_idx][:, image_position_idx]
values = F.embedding(rp_bucket, self.image_rel_pos_table_list[idx].weight)
values = values.permute(2, 0, 1)
return values
def get_pos_info(self, tokens, tgt_pos_embed, src_pos_embed=None, use_image=False):
batch_size = tokens.size(0)
tgt_len = tokens.size(1)
tgt_pos_embed = self.image_pos_ln(tgt_pos_embed) if use_image else self.pos_ln(tgt_pos_embed)
if src_pos_embed is not None:
src_len = src_pos_embed.size(1)
pos_q = self.cross_pos_q_linear(tgt_pos_embed).view(
batch_size, tgt_len, self.num_attention_heads, -1
).transpose(1, 2) * self.pos_scaling
pos_k = self.cross_pos_k_linear(src_pos_embed).view(
batch_size, src_len, self.num_attention_heads, -1
).transpose(1, 2)
else:
src_len = tgt_pos_embed.size(1)
pos_q = self.self_pos_q_linear(tgt_pos_embed).view(
batch_size, tgt_len, self.num_attention_heads, -1
).transpose(1, 2) * self.pos_scaling
pos_k = self.self_pos_k_linear(tgt_pos_embed).view(
batch_size, src_len, self.num_attention_heads, -1
).transpose(1, 2)
abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3))
return abs_pos_bias
def forward(
self,
prev_output_tokens,
code_masks: Optional[torch.Tensor] = None,
encoder_out: Optional[Dict[str, List[Tensor]]] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
features_only: bool = False,
full_context_alignment: bool = False,
alignment_layer: Optional[int] = None,
alignment_heads: Optional[int] = None,
src_lengths: Optional[Any] = None,
return_all_hiddens: bool = False,
):
"""
Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for teacher forcing
encoder_out (optional): output from the encoder, used for
encoder-side attention, should be of size T x B x C
incremental_state (dict): dictionary used for storing state during
:ref:`Incremental decoding`
features_only (bool, optional): only return features without
applying output layer (default: False).
full_context_alignment (bool, optional): don't apply
auto-regressive mask to self-attention (default: False).
Returns:
tuple:
- the decoder's output of shape `(batch, tgt_len, vocab)`
- a dictionary with any model-specific outputs
"""
x, extra = self.extract_features(
prev_output_tokens,
code_masks=code_masks,
encoder_out=encoder_out,
incremental_state=incremental_state,
full_context_alignment=full_context_alignment,
alignment_layer=alignment_layer,
alignment_heads=alignment_heads,
)
if not features_only:
x = self.output_layer(x)
return x, extra
def extract_features(
self,
prev_output_tokens,
code_masks: Optional[torch.Tensor],
encoder_out: Optional[Dict[str, List[Tensor]]],
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
full_context_alignment: bool = False,
alignment_layer: Optional[int] = None,
alignment_heads: Optional[int] = None,
):
return self.extract_features_scriptable(
prev_output_tokens,
code_masks,
encoder_out,
incremental_state,
full_context_alignment,
alignment_layer,
alignment_heads,
)
"""
A scriptable subclass of this class has an extract_features method and calls
super().extract_features, but super() is not supported in torchscript. A copy of
this function is made to be used in the subclass instead.
"""
def extract_features_scriptable(
self,
prev_output_tokens,
code_masks: Optional[torch.Tensor],
encoder_out: Optional[Dict[str, List[Tensor]]],
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
full_context_alignment: bool = False,
alignment_layer: Optional[int] = None,
alignment_heads: Optional[int] = None,
):
"""
Similar to *forward* but only return features.
Includes several features from "Jointly Learning to Align and
Translate with Transformer Models" (Garg et al., EMNLP 2019).
Args:
full_context_alignment (bool, optional): don't apply
auto-regressive mask to self-attention (default: False).
alignment_layer (int, optional): return mean alignment over
heads at this layer (default: last layer).
alignment_heads (int, optional): only average alignment over
this many heads (default: all heads).
Returns:
tuple:
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
- a dictionary with any model-specific outputs
"""
bs, slen = prev_output_tokens.size()
if alignment_layer is None:
alignment_layer = self.num_layers - 1
enc: Optional[Tensor] = None
padding_mask: Optional[Tensor] = None
if encoder_out is not None and len(encoder_out["encoder_out"]) > 0:
enc = encoder_out["encoder_out"][0]
assert (
enc.size()[1] == bs
), f"Expected enc.shape == (t, {bs}, c) got {enc.shape}"
if encoder_out is not None and len(encoder_out["encoder_padding_mask"]) > 0:
padding_mask = encoder_out["encoder_padding_mask"][0]
bsz, tgt_len = prev_output_tokens.shape
token_position_idx = utils.new_arange(prev_output_tokens)
tgt_pos_embed = self.embed_positions(token_position_idx)
if code_masks is not None and torch.any(code_masks):
image_position_idx = self.image_position_idx[:prev_output_tokens.size(1)].unsqueeze(0).expand(bsz, tgt_len)
tgt_pos_embed[code_masks] = self.embed_image_positions(image_position_idx)[code_masks]
# self attn position bias
self_abs_pos_bias = self.get_pos_info(prev_output_tokens, tgt_pos_embed, use_image=False)
if code_masks is not None and torch.any(code_masks):
self_image_abs_pos_bias = self.get_pos_info(prev_output_tokens, tgt_pos_embed, use_image=True)
self_abs_pos_bias[code_masks] = self_image_abs_pos_bias[code_masks]
# cross attn position bias
src_pos_embed = encoder_out['position_embeddings'][0]
cross_abs_pos_bias = self.get_pos_info(prev_output_tokens, tgt_pos_embed, src_pos_embed=src_pos_embed)
if code_masks is not None and torch.any(code_masks):
cross_image_abs_pos_bias = self.get_pos_info(prev_output_tokens, tgt_pos_embed, src_pos_embed=src_pos_embed, use_image=True)
cross_abs_pos_bias[code_masks] = cross_image_abs_pos_bias[code_masks]
cross_abs_pos_bias = cross_abs_pos_bias.reshape(-1, *cross_abs_pos_bias.size()[-2:])
all_prev_output_tokens = prev_output_tokens.clone()
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
cross_abs_pos_bias = cross_abs_pos_bias[:, -1:, :]
tgt_pos_embed = tgt_pos_embed[:, -1:, :]
# embed tokens and positions
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
if self.quant_noise is not None:
x = self.quant_noise(x)
if self.project_in_dim is not None:
x = self.project_in_dim(x)
if self.entangle_position_embedding is not None and not self.args.disable_entangle:
x += tgt_pos_embed
if self.layernorm_embedding is not None:
if code_masks is None or not code_masks.any() or not getattr(self, "code_layernorm_embedding", False):
x = self.layernorm_embedding(x)
elif code_masks is not None and code_masks.all():
x = self.code_layernorm_embedding(x)
else:
x[~code_masks] = self.layernorm_embedding(x[~code_masks])
x[code_masks] = self.code_layernorm_embedding(x[code_masks])
x = self.dropout_module(x)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
self_attn_padding_mask: Optional[Tensor] = None
if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any():
self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
# decoder layers
attn: Optional[Tensor] = None
inner_states: List[Optional[Tensor]] = [x]
for idx, layer in enumerate(self.layers):
if incremental_state is None and not full_context_alignment:
self_attn_mask = self.buffered_future_mask(x)
else:
self_attn_mask = None
self_attn_bias = self_abs_pos_bias.clone()
if code_masks is None or not code_masks.any():
self_attn_bias += self.get_rel_pos_bias(all_prev_output_tokens, idx).unsqueeze(0)
elif code_masks is not None and code_masks.all():
self_attn_bias += self.get_image_rel_pos_bias(all_prev_output_tokens, idx).unsqueeze(0)
else:
self_attn_bias[~code_masks] += self.get_rel_pos_bias(all_prev_output_tokens, idx).unsqueeze(0)
self_attn_bias[code_masks] += self.get_image_rel_pos_bias(all_prev_output_tokens, idx).unsqueeze(0)
self_attn_bias = self_attn_bias.reshape(-1, *self_attn_bias.size()[-2:])
if incremental_state is not None:
self_attn_bias = self_attn_bias[:, -1:, :]
x, layer_attn, _ = layer(
x,
enc,
padding_mask,
incremental_state,
self_attn_mask=self_attn_mask,
self_attn_padding_mask=self_attn_padding_mask,
need_attn=bool((idx == alignment_layer)),
need_head_weights=bool((idx == alignment_layer)),
self_attn_bias=self_attn_bias,
cross_attn_bias=cross_abs_pos_bias
)
inner_states.append(x)
if layer_attn is not None and idx == alignment_layer:
attn = layer_attn.float().to(x)
if attn is not None:
if alignment_heads is not None:
attn = attn[:alignment_heads]
# average probabilities over heads
attn = attn.mean(dim=0)
if self.layer_norm is not None:
x = self.layer_norm(x)
# T x B x C -> B x T x C
x = x.transpose(0, 1)
if self.project_out_dim is not None:
x = self.project_out_dim(x)
return x, {"attn": [attn], "inner_states": inner_states}
def output_layer(self, features):
"""Project features to the vocabulary size."""
if self.adaptive_softmax is None:
# project back to size of vocabulary
return self.output_projection(features)
else:
return features
def max_positions(self):
"""Maximum output length supported by the decoder."""
if self.embed_positions is None:
return self.max_target_positions
return self.max_target_positions
def buffered_future_mask(self, tensor):
dim = tensor.size(0)
# self._future_mask.device != tensor.device is not working in TorchScript. This is a workaround.
if (
self._future_mask.size(0) == 0
or (not self._future_mask.device == tensor.device)
or self._future_mask.size(0) < dim
):
self._future_mask = torch.triu(
utils.fill_with_neg_inf(torch.zeros([dim, dim])), 1
)
self._future_mask = self._future_mask.to(tensor)
return self._future_mask[:dim, :dim]
def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
weights_key = "{}.embed_positions.weights".format(name)
if weights_key in state_dict:
del state_dict[weights_key]
state_dict[
"{}.embed_positions._float_tensor".format(name)
] = torch.FloatTensor(1)
if f"{name}.output_projection.weight" not in state_dict:
if self.share_input_output_embed:
embed_out_key = f"{name}.embed_tokens.weight"
else:
embed_out_key = f"{name}.embed_out"
if embed_out_key in state_dict:
state_dict[f"{name}.output_projection.weight"] = state_dict[
embed_out_key
]
if not self.share_input_output_embed:
del state_dict[embed_out_key]
for i in range(self.num_layers):
# update layer norms
self.layers[i].upgrade_state_dict_named(
state_dict, "{}.layers.{}".format(name, i)
)
# version_key = "{}.version".format(name)
# if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2:
# # earlier checkpoints did not normalize after the stack of layers
# self.layer_norm = None
# self.normalize = False
# state_dict[version_key] = torch.Tensor([1])
prefix = name + "." if name != "" else ""
image_params = ["image_position_idx"]
for image_param in image_params:
state_dict[prefix + image_param] = self.state_dict()[image_param]
for param_name, param_tensor in self.state_dict().items():
if (prefix + param_name) not in state_dict and param_name in self.state_dict():
state_dict[prefix + param_name] = self.state_dict()[param_name]
if len(state_dict["decoder.embed_image_positions.weight"]) < len(self.state_dict()["embed_image_positions.weight"]):
num_posids_to_add = len(self.state_dict()["embed_image_positions.weight"]) - len(state_dict["decoder.embed_image_positions.weight"])
embed_dim = state_dict["decoder.embed_image_positions.weight"].size(1)
new_pos_embed_to_add = torch.zeros(num_posids_to_add, embed_dim)
nn.init.normal_(new_pos_embed_to_add, mean=0, std=embed_dim ** -0.5)
new_pos_embed_to_add = new_pos_embed_to_add.to(
dtype=state_dict["decoder.embed_image_positions.weight"].dtype,
)
state_dict["decoder.embed_image_positions.weight"] = torch.cat(
[state_dict["decoder.embed_image_positions.weight"], new_pos_embed_to_add]
)
return state_dict
def Embedding(num_embeddings, embedding_dim, padding_idx=None, zero_init=False):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
if padding_idx is not None:
nn.init.constant_(m.weight[padding_idx], 0)
if zero_init:
nn.init.constant_(m.weight, 0)
return m
def Linear(in_features, out_features, bias=True):
m = nn.Linear(in_features, out_features, bias)
nn.init.xavier_uniform_(m.weight)
if bias:
nn.init.constant_(m.bias, 0.0)
return m
@register_model_architecture("unify_transformer", "unify_transformer")
def base_architecture(args):
args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
args.encoder_layers = getattr(args, "encoder_layers", 6)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
args.decoder_ffn_embed_dim = getattr(
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
)
args.decoder_layers = getattr(args, "decoder_layers", 6)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
args.attention_dropout = getattr(args, "attention_dropout", 0.0)
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
args.activation_fn = getattr(args, "activation_fn", "relu")
args.dropout = getattr(args, "dropout", 0.1)
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False
)
args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False
)
args.adaptive_input = getattr(args, "adaptive_input", False)
args.no_cross_attention = getattr(args, "no_cross_attention", False)
args.cross_self_attention = getattr(args, "cross_self_attention", False)
args.decoder_output_dim = getattr(
args, "decoder_output_dim", args.decoder_embed_dim
)
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
args.checkpoint_activations = getattr(args, "checkpoint_activations", False)
args.offload_activations = getattr(args, "offload_activations", False)
if args.offload_activations:
args.checkpoint_activations = True
args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None)
args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None)
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0)
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8)
args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0)