""" Inference CTC class derived from HubertForCTC. Author: Marcely Zanon Boito, 2024 """ from typing import Optional, Tuple, Union import torch from torch import nn from transformers import HubertPreTrainedModel, HubertModel from transformers.modeling_outputs import CausalLMOutput, SequenceClassifierOutput class VanillaNN(nn.Module): def __init__(self, input_dim, output_dim): """ simple NN with ReLU activation (no norm) """ super().__init__() self.linear = nn.Linear(input_dim, output_dim) self.act_fn = nn.ReLU() def forward(self, hidden_states: torch.FloatTensor): hidden_states = self.linear(hidden_states) hidden_states = self.act_fn(hidden_states) return hidden_states class mHubertForCTC(HubertPreTrainedModel): def __init__(self, config, target_lang: Optional[str] = None): super().__init__(config) self.hubert = HubertModel(config) self.dropout = nn.Dropout(config.final_dropout) output_hidden_size = config.hidden_size self.has_interface = config.add_interface # NN layers on top of the trainable stack if config.add_interface: self.interface = nn.ModuleList([VanillaNN(output_hidden_size,output_hidden_size) for i in range(config.num_interface_layers)]) self.lm_head = nn.Linear(output_hidden_size, config.vocab_size) self.post_init() def forward( self, input_values: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, labels: Optional[torch.Tensor] = None, ) -> Union[Tuple, SequenceClassifierOutput]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict output_hidden_states = self.config.output_hidden_states outputs = self.hubert( input_values, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] hidden_states = self.dropout(hidden_states) if self.has_interface: for layer in self.interface: hidden_states = layer(hidden_states) logits = self.lm_head(hidden_states) loss = None if labels is not None: if labels.max() >= self.config.vocab_size: raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}") # retrieve loss input_lengths from attention_mask attention_mask = ( attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long) ) input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) # assuming that padded tokens are filled with -100 # when not being attended to labels_mask = labels >= 0 target_lengths = labels_mask.sum(-1) flattened_targets = labels.masked_select(labels_mask) # ctc_loss doesn't support fp16 log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1) with torch.backends.cudnn.flags(enabled=False): loss = nn.functional.ctc_loss( log_probs, flattened_targets, input_lengths, target_lengths, blank=self.config.ctc_token_id, reduction=self.config.ctc_loss_reduction, zero_infinity=self.config.ctc_zero_infinity, ) return CausalLMOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions )