import torch from torch.utils.data import Dataset from transformers import PreTrainedTokenizer class TokenizedForMCRightPad(Dataset): def __init__(self, data, tok: PreTrainedTokenizer, prompt_fn): # data: [query: str, choices: list(str)] self.tok = tok self.prompt_fn = prompt_fn self.max_length = self._find_max_length(data) self.data = self._build_mc_data(data) def _find_max_length(self, data): max_len = 0 def tok_len(t): return len(self.tok.encode(t)) for ex in data: query = ex["query"] len_choices = [tok_len(self.prompt_fn(query, c)[1]) for c in ex["choices"]] max_len = max(max_len, *len_choices) return max_len def _build_mc_data(self, data): processed = [] num_choices = set(len(e["choices"]) for e in data) if not len(num_choices) == 1: raise ValueError(f"Queries have different number of choices, which is not supported! #choices: {num_choices}") for ex in data: query, choices = ex["query"], ex["choices"] processed_input = [self.prompt_fn(query, choice) for choice in choices] processed_input = [self.tokenize(t_query, t_full) for t_query, t_full in processed_input] processed.append(processed_input) return processed def tokenize_demonstration(self, demonstration): e = self.tok(demonstration) return torch.LongTensor(e["input_ids"]), torch.LongTensor(e["attention_mask"]) # no padding def tokenize(self, only_query, full_text): tok_only_query = self.tok(only_query, add_special_tokens=False) tok_full_no_padding = self.tok(full_text, add_special_tokens=False) tok_full = self.tok( full_text, padding="max_length", max_length=self.max_length, add_special_tokens=False, ) # is not a special token # tok_only_query = self.tok(only_query) # tok_full_no_padding = self.tok(full_text) # tok_full = self.tok( # full_text, # padding="max_length", # max_length=self.max_length, # ) # is not a special token # print(f"tok_only_query: {self.tok.convert_ids_to_tokens(tok_only_query.input_ids)}") # print(f"tok_full_no_padding: {self.tok.convert_ids_to_tokens(tok_full_no_padding.input_ids)}") # print(f"tok_full: {self.tok.convert_ids_to_tokens(tok_full.input_ids)}") # exit(0) len_full = len(tok_full_no_padding.input_ids) len_query = len(tok_only_query.input_ids) e = { "input_ids": tok_full.input_ids, "attention_mask": tok_full.attention_mask, "choice_start": len_query, "choice_end": len_full, } # print("Attn:") # print(tok_full.attention_mask) # print("input_ids:") # print(tok_full.input_ids) dcd_sp = self.tok.convert_ids_to_tokens(tok_full.input_ids, skip_special_tokens=False) # print(f'{e["choice_start"]}: {e["choice_end"]} = [{self.tok.convert_tokens_to_string(dcd_sp[e["choice_start"] : e["choice_end"]])}]') return e def __len__(self): return len(self.data) def __getitem__(self, idx): def _get_one_item(e): return torch.LongTensor(e["input_ids"]), torch.LongTensor(e["attention_mask"]), e["choice_start"], e["choice_end"] es = self.data[idx] # num_choices * (input_ids, attn, start_idx, end_idx) # input_ids, attn: [B, L] # start_idx, end_idx: [B, ] return [_get_one_item(e) for e in es]