enclap / data /collator.py
tonyswoo's picture
Initial Commit
73baeae
raw
history blame
2.18 kB
from dataclasses import dataclass
import torch
from transformers import BatchEncoding, DataCollatorForSeq2Seq
@dataclass
class DataCollatorForEnClapBart(DataCollatorForSeq2Seq):
input_pad_token_id: int = 1024
num_rvq: int = 16
def __call__(self, features, return_tensors=None):
if return_tensors is None:
return_tensors = self.return_tensors
batch_size = len(features)
# stacked_features = {k: [f[k] for f in features] for k in features[0]}
clap_embedding = torch.Tensor(
[feature["clap_embedding"] for feature in features]
)
pad_token_id = self.tokenizer.pad_token_id
self.tokenizer.pad_token_id = self.input_pad_token_id
keys = ["input_ids", "mcm_labels"]
tmp_key_map = {"input_ids": "input_ids", "mcm_labels": "labels"}
input_features = super().__call__(
[
{tmp_key_map[key]: feature[key][:, i] for key in keys}
for feature in features
for i in range(feature[keys[0]].shape[-1])
],
return_tensors,
)
self.tokenizer.pad_token_id = 1
keys = ["encodec_mask", "attention_mask", "labels"]
tmp_key_map = {
"encodec_mask": "input_ids",
"attention_mask": "attention_mask",
"labels": "labels",
}
other_features = super().__call__(
[{tmp_key_map[key]: feature[key] for key in keys} for feature in features],
return_tensors,
)
self.tokenizer.pad_token_id = pad_token_id
return BatchEncoding(
{
"input_ids": input_features["input_ids"]
.reshape(batch_size, self.num_rvq, -1)
.transpose(1, 2),
"mcm_labels": input_features["labels"]
.reshape(batch_size, self.num_rvq, -1)
.transpose(1, 2),
"attention_mask": other_features["attention_mask"],
"encodec_mask": other_features["input_ids"],
"labels": other_features["labels"],
"clap_embedding": clap_embedding,
}
)