deep-thinking / tasks /loader.py
jx-yang's picture
<ADD> +app
9d21d47
import torch
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer
class TokenizedForMCRightPad(Dataset):
def __init__(self, data, tok: PreTrainedTokenizer, prompt_fn):
# data: [query: str, choices: list(str)]
self.tok = tok
self.prompt_fn = prompt_fn
self.max_length = self._find_max_length(data)
self.data = self._build_mc_data(data)
def _find_max_length(self, data):
max_len = 0
def tok_len(t):
return len(self.tok.encode(t))
for ex in data:
query = ex["query"]
len_choices = [tok_len(self.prompt_fn(query, c)[1]) for c in ex["choices"]]
max_len = max(max_len, *len_choices)
return max_len
def _build_mc_data(self, data):
processed = []
num_choices = set(len(e["choices"]) for e in data)
if not len(num_choices) == 1:
raise ValueError(f"Queries have different number of choices, which is not supported! #choices: {num_choices}")
for ex in data:
query, choices = ex["query"], ex["choices"]
processed_input = [self.prompt_fn(query, choice) for choice in choices]
processed_input = [self.tokenize(t_query, t_full) for t_query, t_full in processed_input]
processed.append(processed_input)
return processed
def tokenize_demonstration(self, demonstration):
e = self.tok(demonstration)
return torch.LongTensor(e["input_ids"]), torch.LongTensor(e["attention_mask"]) # no padding
def tokenize(self, only_query, full_text):
tok_only_query = self.tok(only_query, add_special_tokens=False)
tok_full_no_padding = self.tok(full_text, add_special_tokens=False)
tok_full = self.tok(
full_text,
padding="max_length",
max_length=self.max_length,
add_special_tokens=False,
) # <pad> is not a special token
# tok_only_query = self.tok(only_query)
# tok_full_no_padding = self.tok(full_text)
# tok_full = self.tok(
# full_text,
# padding="max_length",
# max_length=self.max_length,
# ) # <pad> is not a special token
# print(f"tok_only_query: {self.tok.convert_ids_to_tokens(tok_only_query.input_ids)}")
# print(f"tok_full_no_padding: {self.tok.convert_ids_to_tokens(tok_full_no_padding.input_ids)}")
# print(f"tok_full: {self.tok.convert_ids_to_tokens(tok_full.input_ids)}")
# exit(0)
len_full = len(tok_full_no_padding.input_ids)
len_query = len(tok_only_query.input_ids)
e = {
"input_ids": tok_full.input_ids,
"attention_mask": tok_full.attention_mask,
"choice_start": len_query,
"choice_end": len_full,
}
# print("Attn:")
# print(tok_full.attention_mask)
# print("input_ids:")
# print(tok_full.input_ids)
dcd_sp = self.tok.convert_ids_to_tokens(tok_full.input_ids, skip_special_tokens=False)
# print(f'{e["choice_start"]}: {e["choice_end"]} = [{self.tok.convert_tokens_to_string(dcd_sp[e["choice_start"] : e["choice_end"]])}]')
return e
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
def _get_one_item(e):
return torch.LongTensor(e["input_ids"]), torch.LongTensor(e["attention_mask"]), e["choice_start"], e["choice_end"]
es = self.data[idx]
# num_choices * (input_ids, attn, start_idx, end_idx)
# input_ids, attn: [B, L]
# start_idx, end_idx: [B, ]
return [_get_one_item(e) for e in es]