HuBERT / fairseq /models /nat /fairseq_nat_model.py
aliabd
full working demo
d5175d3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
from fairseq.models.transformer import (
TransformerDecoder,
TransformerEncoder,
TransformerModel,
)
from fairseq.modules.transformer_sentence_encoder import init_bert_params
def ensemble_encoder(func):
def wrapper(self, *args, **kwargs):
if self.ensemble_models is None or len(self.ensemble_models) == 1:
return func(self, *args, **kwargs)
encoder_outs = [func(model, *args, **kwargs, return_all_hiddens=True) for model in self.ensemble_models]
_encoder_out = encoder_outs[0].copy()
def stack(key):
outs = [e[key][0] for e in encoder_outs]
return [torch.stack(outs, -1) if outs[0] is not None else None]
_encoder_out["encoder_out"] = stack("encoder_out")
_encoder_out["encoder_embedding"] = stack("encoder_embedding")
num_layers = len(_encoder_out["encoder_states"])
if num_layers > 0:
_encoder_out["encoder_states"] = [
torch.stack([e["encoder_states"][i] for e in encoder_outs], -1)
for i in range(num_layers)
]
return _encoder_out
return wrapper
def ensemble_decoder(func):
def wrapper(self, normalize=False, encoder_out=None, *args, **kwargs):
if self.ensemble_models is None or len(self.ensemble_models) == 1:
return func(
self, normalize=normalize, encoder_out=encoder_out, *args, **kwargs
)
def _replace(encoder_out, new_val):
new_encoder_out = encoder_out.copy()
new_encoder_out["encoder_out"] = [new_val]
return new_encoder_out
action_outs = [
func(
model,
normalize=normalize,
encoder_out=_replace(
encoder_out,
encoder_out["encoder_out"][0][:, :, :, i]
),
*args,
**kwargs
)
for i, model in enumerate(self.ensemble_models)
]
if not isinstance(action_outs[0], tuple): # return multiple values
action_outs = [[a] for a in action_outs]
else:
action_outs = [list(a) for a in action_outs]
ensembled_outs = []
for i in range(len(action_outs[0])):
if i == 0 and normalize:
ensembled_outs += [
torch.logsumexp(
torch.stack([a[i] for a in action_outs], -1), dim=-1
)
- math.log(len(self.ensemble_models))
]
elif action_outs[0][i] is not None:
ensembled_outs += [torch.stack([a[i] for a in action_outs], -1)]
else:
ensembled_outs += [None]
if len(ensembled_outs) == 1:
return ensembled_outs[0]
return tuple(ensembled_outs)
return wrapper
class FairseqNATModel(TransformerModel):
"""
Abstract class for all nonautoregressive-based models
"""
def __init__(self, args, encoder, decoder):
super().__init__(args, encoder, decoder)
self.tgt_dict = decoder.dictionary
self.bos = decoder.dictionary.bos()
self.eos = decoder.dictionary.eos()
self.pad = decoder.dictionary.pad()
self.unk = decoder.dictionary.unk()
self.ensemble_models = None
@property
def allow_length_beam(self):
return False
@property
def allow_ensemble(self):
return True
def enable_ensemble(self, models):
self.encoder.ensemble_models = [m.encoder for m in models]
self.decoder.ensemble_models = [m.decoder for m in models]
@staticmethod
def add_args(parser):
TransformerModel.add_args(parser)
parser.add_argument(
"--apply-bert-init",
action="store_true",
help="use custom param initialization for BERT",
)
@classmethod
def build_decoder(cls, args, tgt_dict, embed_tokens):
decoder = FairseqNATDecoder(args, tgt_dict, embed_tokens)
if getattr(args, "apply_bert_init", False):
decoder.apply(init_bert_params)
return decoder
@classmethod
def build_encoder(cls, args, src_dict, embed_tokens):
encoder = FairseqNATEncoder(args, src_dict, embed_tokens)
if getattr(args, "apply_bert_init", False):
encoder.apply(init_bert_params)
return encoder
def forward_encoder(self, encoder_inputs):
return self.encoder(*encoder_inputs)
def forward_decoder(self, *args, **kwargs):
return NotImplementedError
def initialize_output_tokens(self, *args, **kwargs):
return NotImplementedError
def forward(self, *args, **kwargs):
return NotImplementedError
class FairseqNATEncoder(TransformerEncoder):
def __init__(self, args, dictionary, embed_tokens):
super().__init__(args, dictionary, embed_tokens)
self.ensemble_models = None
@ensemble_encoder
def forward(self, *args, **kwargs):
return super().forward(*args, **kwargs)
class FairseqNATDecoder(TransformerDecoder):
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
super().__init__(args, dictionary, embed_tokens, no_encoder_attn)
self.ensemble_models = None