Spaces:
Runtime error
Runtime error
from dataclasses import dataclass | |
import torch | |
from transformers import BatchEncoding, DataCollatorForSeq2Seq | |
class DataCollatorForEnClapBart(DataCollatorForSeq2Seq): | |
input_pad_token_id: int = 1024 | |
num_rvq: int = 16 | |
def __call__(self, features, return_tensors=None): | |
if return_tensors is None: | |
return_tensors = self.return_tensors | |
batch_size = len(features) | |
# stacked_features = {k: [f[k] for f in features] for k in features[0]} | |
clap_embedding = torch.Tensor( | |
[feature["clap_embedding"] for feature in features] | |
) | |
pad_token_id = self.tokenizer.pad_token_id | |
self.tokenizer.pad_token_id = self.input_pad_token_id | |
keys = ["input_ids", "mcm_labels"] | |
tmp_key_map = {"input_ids": "input_ids", "mcm_labels": "labels"} | |
input_features = super().__call__( | |
[ | |
{tmp_key_map[key]: feature[key][:, i] for key in keys} | |
for feature in features | |
for i in range(feature[keys[0]].shape[-1]) | |
], | |
return_tensors, | |
) | |
self.tokenizer.pad_token_id = 1 | |
keys = ["encodec_mask", "attention_mask", "labels"] | |
tmp_key_map = { | |
"encodec_mask": "input_ids", | |
"attention_mask": "attention_mask", | |
"labels": "labels", | |
} | |
other_features = super().__call__( | |
[{tmp_key_map[key]: feature[key] for key in keys} for feature in features], | |
return_tensors, | |
) | |
self.tokenizer.pad_token_id = pad_token_id | |
return BatchEncoding( | |
{ | |
"input_ids": input_features["input_ids"] | |
.reshape(batch_size, self.num_rvq, -1) | |
.transpose(1, 2), | |
"mcm_labels": input_features["labels"] | |
.reshape(batch_size, self.num_rvq, -1) | |
.transpose(1, 2), | |
"attention_mask": other_features["attention_mask"], | |
"encodec_mask": other_features["input_ids"], | |
"labels": other_features["labels"], | |
"clap_embedding": clap_embedding, | |
} | |
) | |