|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.nn import LayerNorm |
|
import torchaudio.compliance.kaldi as ta_kaldi |
|
|
|
from beats.backbone import ( |
|
TransformerEncoder, |
|
) |
|
|
|
import logging |
|
from typing import Optional |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class BEATsConfig: |
|
def __init__(self, cfg=None): |
|
self.input_patch_size: int = -1 |
|
self.embed_dim: int = 512 |
|
self.conv_bias: bool = False |
|
|
|
self.encoder_layers: int = 12 |
|
self.encoder_embed_dim: int = 768 |
|
self.encoder_ffn_embed_dim: int = 3072 |
|
self.encoder_attention_heads: int = 12 |
|
self.activation_fn: str = "gelu" |
|
|
|
self.layer_wise_gradient_decay_ratio: float = 1.0 |
|
self.layer_norm_first: bool = False |
|
self.deep_norm: bool = False |
|
|
|
|
|
self.dropout: float = 0.1 |
|
self.attention_dropout: float = 0.1 |
|
self.activation_dropout: float = 0.0 |
|
self.encoder_layerdrop: float = 0.0 |
|
self.dropout_input: float = 0.0 |
|
|
|
|
|
self.conv_pos: int = 128 |
|
self.conv_pos_groups: int = 16 |
|
|
|
|
|
self.relative_position_embedding: bool = False |
|
self.num_buckets: int = 320 |
|
self.max_distance: int = 1280 |
|
self.gru_rel_pos: bool = False |
|
|
|
|
|
self.finetuned_model: bool = False |
|
self.predictor_dropout: float = 0.1 |
|
self.predictor_class: int = 527 |
|
|
|
if cfg is not None: |
|
self.update(cfg) |
|
|
|
def update(self, cfg: dict): |
|
self.__dict__.update(cfg) |
|
|
|
|
|
class BEATs(nn.Module): |
|
def __init__( |
|
self, |
|
cfg: BEATsConfig, |
|
) -> None: |
|
super().__init__() |
|
logger.info(f"BEATs Config: {cfg.__dict__}") |
|
|
|
self.cfg = cfg |
|
|
|
self.embed = cfg.embed_dim |
|
self.post_extract_proj = ( |
|
nn.Linear(self.embed, cfg.encoder_embed_dim) |
|
if self.embed != cfg.encoder_embed_dim |
|
else None |
|
) |
|
|
|
self.input_patch_size = cfg.input_patch_size |
|
self.patch_embedding = nn.Conv2d(1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size, |
|
bias=cfg.conv_bias) |
|
|
|
self.dropout_input = nn.Dropout(cfg.dropout_input) |
|
|
|
assert not cfg.deep_norm or not cfg.layer_norm_first |
|
self.encoder = TransformerEncoder(cfg) |
|
self.layer_norm = LayerNorm(self.embed) |
|
|
|
if cfg.finetuned_model: |
|
self.predictor_dropout = nn.Dropout(cfg.predictor_dropout) |
|
self.predictor = nn.Linear(cfg.encoder_embed_dim, cfg.predictor_class) |
|
else: |
|
self.predictor = None |
|
|
|
def forward_padding_mask( |
|
self, |
|
features: torch.Tensor, |
|
padding_mask: torch.Tensor, |
|
) -> torch.Tensor: |
|
extra = padding_mask.size(1) % features.size(1) |
|
if extra > 0: |
|
padding_mask = padding_mask[:, :-extra] |
|
padding_mask = padding_mask.view( |
|
padding_mask.size(0), features.size(1), -1 |
|
) |
|
padding_mask = padding_mask.all(-1) |
|
return padding_mask |
|
|
|
def preprocess( |
|
self, |
|
source: torch.Tensor, |
|
fbank_mean: float = 15.41663, |
|
fbank_std: float = 6.55582, |
|
) -> torch.Tensor: |
|
fbanks = [] |
|
for waveform in source: |
|
waveform = waveform.unsqueeze(0) * 2 ** 15 |
|
fbank = ta_kaldi.fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10) |
|
fbanks.append(fbank) |
|
fbank = torch.stack(fbanks, dim=0) |
|
fbank = (fbank - fbank_mean) / (2 * fbank_std) |
|
return fbank |
|
|
|
def extract_features( |
|
self, |
|
source: torch.Tensor, |
|
padding_mask: Optional[torch.Tensor] = None, |
|
fbank_mean: float = 15.41663, |
|
fbank_std: float = 6.55582, |
|
feature_only=False, |
|
): |
|
fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std).to(torch.float32) |
|
|
|
if padding_mask is not None: |
|
padding_mask = self.forward_padding_mask(fbank, padding_mask) |
|
|
|
fbank = fbank.unsqueeze(1) |
|
features = self.patch_embedding(fbank) |
|
features = features.reshape(features.shape[0], features.shape[1], -1) |
|
features = features.transpose(1, 2) |
|
features = self.layer_norm(features) |
|
|
|
if padding_mask is not None: |
|
padding_mask = self.forward_padding_mask(features, padding_mask) |
|
|
|
if self.post_extract_proj is not None: |
|
features = self.post_extract_proj(features) |
|
|
|
x = self.dropout_input(features) |
|
|
|
x, layer_results = self.encoder( |
|
x, |
|
padding_mask=padding_mask, |
|
) |
|
|
|
if not feature_only and self.predictor is not None: |
|
x = self.predictor_dropout(x) |
|
logits = self.predictor(x) |
|
|
|
if padding_mask is not None and padding_mask.any(): |
|
logits[padding_mask] = 0 |
|
logits = logits.sum(dim=1) |
|
logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(logits) |
|
else: |
|
logits = logits.mean(dim=1) |
|
|
|
lprobs = torch.sigmoid(logits) |
|
|
|
return lprobs, padding_mask |
|
else: |
|
return x, padding_mask |