Spaces:
Runtime error
Runtime error
File size: 2,183 Bytes
73baeae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
from dataclasses import dataclass
import torch
from transformers import BatchEncoding, DataCollatorForSeq2Seq
@dataclass
class DataCollatorForEnClapBart(DataCollatorForSeq2Seq):
input_pad_token_id: int = 1024
num_rvq: int = 16
def __call__(self, features, return_tensors=None):
if return_tensors is None:
return_tensors = self.return_tensors
batch_size = len(features)
# stacked_features = {k: [f[k] for f in features] for k in features[0]}
clap_embedding = torch.Tensor(
[feature["clap_embedding"] for feature in features]
)
pad_token_id = self.tokenizer.pad_token_id
self.tokenizer.pad_token_id = self.input_pad_token_id
keys = ["input_ids", "mcm_labels"]
tmp_key_map = {"input_ids": "input_ids", "mcm_labels": "labels"}
input_features = super().__call__(
[
{tmp_key_map[key]: feature[key][:, i] for key in keys}
for feature in features
for i in range(feature[keys[0]].shape[-1])
],
return_tensors,
)
self.tokenizer.pad_token_id = 1
keys = ["encodec_mask", "attention_mask", "labels"]
tmp_key_map = {
"encodec_mask": "input_ids",
"attention_mask": "attention_mask",
"labels": "labels",
}
other_features = super().__call__(
[{tmp_key_map[key]: feature[key] for key in keys} for feature in features],
return_tensors,
)
self.tokenizer.pad_token_id = pad_token_id
return BatchEncoding(
{
"input_ids": input_features["input_ids"]
.reshape(batch_size, self.num_rvq, -1)
.transpose(1, 2),
"mcm_labels": input_features["labels"]
.reshape(batch_size, self.num_rvq, -1)
.transpose(1, 2),
"attention_mask": other_features["attention_mask"],
"encodec_mask": other_features["input_ids"],
"labels": other_features["labels"],
"clap_embedding": clap_embedding,
}
)
|