from dataclasses import dataclass from pathlib import Path from random import randint from typing import Optional, Tuple import numpy as np import torch from transformers import BartTokenizerFast @dataclass class Preprocessor: encodec_base_path: Path clap_base_path: Path tokenizer: BartTokenizerFast = BartTokenizerFast.from_pretrained( "facebook/bart-base" ) max_length: int = 1024 mcm_masking_prob: float = 0.15 mcm_masking_span: int = 10 label_pad_token_id: int = -100 mask_token_id: int = 1024 num_eval_captions: int = 5 def __post_init__(self): if isinstance(self.encodec_base_path, str): self.encodec_base_path = Path(self.encodec_base_path) if isinstance(self.clap_base_path, str): self.clap_base_path = Path(self.clap_base_path) if isinstance(self.tokenizer, str): self.tokenizer = BartTokenizerFast.from_pretrained(self.tokenizer) def preprocess_train(self, example): path = example["file_path"] encodec = np.load(self.encodec_base_path / path) clap_embedding = np.load(self.clap_base_path / path) encodec_mask = np.array( [0, 0] + [1] * min(encodec.shape[0], self.max_length - 3) + [0] ) attention_mask = np.ones(min(encodec.shape[0] + 3, self.max_length)).astype( np.int64 ) target_text = self.tokenizer(text_target=example["caption"]) if encodec.shape[0] + 3 > self.max_length: start = randint(0, encodec.shape[0] - self.max_length + 3) encodec = encodec[start : start + self.max_length - 3] mcm_labels = None if self.mcm_masking_prob > 0: num_rvq = encodec.shape[-1] mcm_mask, _ = _compute_mask_indices( encodec.T.shape, self.mcm_masking_prob, self.mcm_masking_span ) mcm_mask = mcm_mask.T mcm_labels = np.where(mcm_mask, encodec, self.label_pad_token_id) mcm_labels = np.concatenate( [ np.ones((2, num_rvq), dtype=np.int64) * self.label_pad_token_id, mcm_labels, np.ones((1, num_rvq), dtype=np.int64) * self.label_pad_token_id, ], axis=0, ) encodec[mcm_mask] = self.mask_token_id encodec = np.concatenate( [ np.ones((2, num_rvq), dtype=np.int64) * self.tokenizer.bos_token_id, encodec, np.ones((1, num_rvq), dtype=np.int64) * self.tokenizer.eos_token_id, ], axis=0, ) return { "input_ids": encodec, "clap_embedding": clap_embedding, "encodec_mask": encodec_mask, "attention_mask": attention_mask, "mcm_labels": mcm_labels, "labels": target_text["input_ids"], } def preprocess_eval(self, example): path = example["file_path"] encodec = np.load(self.encodec_base_path / path) clap_embedding = np.load(self.clap_base_path / path) attention_mask = np.ones(min(encodec.shape[0] + 3, self.max_length)).astype( np.int64 ) if encodec.shape[0] + 3 > self.max_length: encodec = encodec[: self.max_length - 3] captions = [] for i in range(self.num_eval_captions): captions.append(example[f"caption_{i+1}"]) return { "input_ids": encodec, "attention_mask": attention_mask, "clap": clap_embedding, "captions": captions, } def _compute_mask_indices( shape: Tuple[int, int], mask_prob: float, mask_length: int, attention_mask: Optional[torch.LongTensor] = None, min_masks: int = 0, ) -> np.ndarray: """ Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on CPU as part of the preprocessing during training. Args: shape: The shape for which to compute masks. This should be of a tuple of size 2 where the first element is the batch size and the second element is the length of the axis to span. mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of independently generated mask spans of length `mask_length` is computed by `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the actual percentage will be smaller. mask_length: size of the mask min_masks: minimum number of masked spans attention_mask: A (right-padded) attention mask which independently shortens the feature axis of each batch dimension. """ batch_size, sequence_length = shape if mask_length < 1: raise ValueError("`mask_length` has to be bigger than 0.") if mask_length > sequence_length: raise ValueError( f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}" f" and `sequence_length`: {sequence_length}`" ) # epsilon is used for probabilistic rounding epsilon = np.random.rand(1).item() def compute_num_masked_span(input_length): """Given input length, compute how many spans should be masked""" num_masked_span = int(mask_prob * input_length / mask_length + epsilon) num_masked_span = max(num_masked_span, min_masks) # make sure num masked span <= sequence_length if num_masked_span * mask_length > sequence_length: num_masked_span = sequence_length // mask_length # make sure num_masked span is also <= input_length - (mask_length - 1) if input_length - (mask_length - 1) < num_masked_span: num_masked_span = max(input_length - (mask_length - 1), 0) return num_masked_span # compute number of masked spans in batch input_lengths = ( attention_mask.sum(-1).detach().tolist() if attention_mask is not None else [sequence_length for _ in range(batch_size)] ) # SpecAugment mask to fill spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool) spec_aug_mask_idxs = [] max_num_masked_span = compute_num_masked_span(sequence_length) if max_num_masked_span == 0: return spec_aug_mask for input_length in input_lengths: # compute num of masked spans for this input num_masked_span = compute_num_masked_span(input_length) # get random indices to mask spec_aug_mask_idx = np.random.choice( np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False ) # pick first sampled index that will serve as a dummy index to pad vector # to ensure same dimension for all batches due to probabilistic rounding # Picking first sample just pads those vectors twice. if len(spec_aug_mask_idx) == 0: # this case can only happen if `input_length` is strictly smaller then # `sequence_length` in which case the last token has to be a padding # token which we can use as a dummy mask id dummy_mask_idx = sequence_length - 1 else: dummy_mask_idx = spec_aug_mask_idx[0] spec_aug_mask_idx = np.concatenate( [ spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx, ] ) spec_aug_mask_idxs.append(spec_aug_mask_idx) spec_aug_mask_idxs = np.array(spec_aug_mask_idxs) # expand masked indices to masked spans spec_aug_mask_idxs = np.broadcast_to( spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length) ) spec_aug_mask_idxs = spec_aug_mask_idxs.reshape( batch_size, max_num_masked_span * mask_length ) # add offset to the starting indexes so that indexes now create a span offsets = np.arange(mask_length)[None, None, :] offsets = np.broadcast_to( offsets, (batch_size, max_num_masked_span, mask_length) ).reshape(batch_size, max_num_masked_span * mask_length) spec_aug_mask_idxs = spec_aug_mask_idxs + offsets # ensure that we cannot have indices larger than sequence_length if spec_aug_mask_idxs.max() > sequence_length - 1: spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = ( sequence_length - 1 ) # scatter indices to mask np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) return torch.from_numpy(spec_aug_mask), spec_aug_mask_idxs