import os import sys import copy from typing import List, Optional import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from tqdm import tqdm from transformers import AutoModelForCausalLM, AutoTokenizer from . import utils def perplexity_from_logits(am_gen_logits, om_gen_logits): """ Calculate perplexity from two sets of logits """ if len(om_gen_logits.squeeze().shape)>1: om_gen_logits = torch.argmax(om_gen_logits.squeeze(), dim=-1) # load loss objects m = nn.LogSoftmax(dim=1) log_probs = torch.gather( m(am_gen_logits.float()), 1, om_gen_logits[:,None])[0] return torch.exp(-1 / om_gen_logits.size(0) * log_probs.sum()).item() def set_perplexity_from_logits(am_set, om_set, prompt_lens): """ Calculate perplexity from two sets of logits (for a set of samples) """ perplexities = np.zeros(len(om_set)) for i in range(len(om_set)): perplexities[i] = perplexity_from_logits( am_set[i][prompt_lens[i]:], om_set[i][prompt_lens[i]:] ) return perplexities def generation_ppl( model: AutoModelForCausalLM, tok: AutoTokenizer, prompts: List[str], tokens_true: torch.Tensor = None, token_window: int = 30, batch_size: int = 32, verbose: bool = False ): """ Run generation and calculate perplexity """ from . import generate texts = [] preds = [] perplexity = [] if len(prompts)==1: prompts = prompts*2 # find number of batches num_batches = int(np.ceil(len(prompts) / batch_size)) prompt_lens = [ len(tok.encode(p)) for p in prompts ] prompt_mask = np.array(prompt_lens)<(token_window-1) if np.sum(prompt_mask)!=len(prompts): print('Removed prompts with length > token window') prompts = list(np.array(prompts)[prompt_mask]) prompt_lens = list(np.array(prompt_lens)[prompt_mask]) for i in tqdm(range(num_batches), disable=(not verbose)): # run generation gen_texts, gen_logits = generate.generate_fast( model, tok, prompts = prompts[i*batch_size:(i+1)*batch_size], n_gen_per_prompt = 1, top_k = 1, max_out_len = token_window, return_logits = True, ) pred_tokens = torch.argmax(gen_logits.squeeze(), dim=-1) # get true tokens if tokens_true is None: subset_tokens_true = pred_tokens else: subset_tokens_true = tokens_true[i*batch_size:(i+1)*batch_size] if type(subset_tokens_true) == np.ndarray: subset_tokens_true = torch.from_numpy(subset_tokens_true) # calculate perplexity ppl = set_perplexity_from_logits( gen_logits, subset_tokens_true, prompt_lens[i*batch_size:(i+1)*batch_size]) texts = texts + gen_texts preds.append(pred_tokens.numpy()) perplexity.append(ppl) texts = np.array(texts) preds = np.concatenate(preds) perplexity = np.concatenate(perplexity) return texts, preds, perplexity def cache_ppl( model, tok, dataset, cache_ppl_file, token_window = 50, batch_size = 64, static_context = '', selection = None, reverse_selection = False, verbose = True ): """ Function to load or cache perplexity measures """ if os.path.exists(cache_ppl_file): print('Loaded cached perplexity file: ', cache_ppl_file) cache_ppl_contents = utils.loadpickle(cache_ppl_file) raw_case_ids = cache_ppl_contents['case_ids'] else: # find raw requests and case_ids raw_ds, _, _ = utils.load_dataset(tok, ds_name=dataset) raw_requests = utils.extract_requests(raw_ds) raw_case_ids = np.array([r['case_id'] for r in raw_requests]) print('Running perplexity evaluation for original model and prompts...') texts, preds, ppl_values = generation_ppl( model, tok, prompts = [static_context + r['prompt'].format(r['subject']) for r in raw_requests], tokens_true = None, token_window = token_window, batch_size = batch_size, verbose = verbose ) cache_ppl_contents = { 'texts': texts, 'preds': preds, 'requests': raw_requests, 'perplexity': ppl_values, 'case_ids': raw_case_ids, 'token_window': token_window, 'batch_size': batch_size, 'static_context': static_context } utils.assure_path_exists(os.path.dirname(cache_ppl_file)) utils.savepickle(cache_ppl_file, cache_ppl_contents) print('Saved perplexity cache file: ', cache_ppl_file) # filter cache_ppl_contents for selected samples if selection is not None: # load json file containing a dict with key case_ids containing a list of selected samples select_case_ids = utils.loadjson(selection)['case_ids'] # boolean mask for selected samples w.r.t. all samples in the subjects pickle matching = utils.generate_mask(raw_case_ids, np.array(select_case_ids)) if reverse_selection: matching = ~matching # filter cache_ppl_contents for selected samples cache_ppl_contents = utils.filter_for_selection(cache_ppl_contents, matching) return cache_ppl_contents