import numpy as np class BaseProbInference: def __init__(self, prompt_version): if prompt_version == "default": self.prompt_version = self.default_prompt_version() else: self.prompt_version = prompt_version self.raw_data_result = None self.raw_data_sample = None self.raw_data_dev = None self.can_be_stratified = False self.CHOICES = None self.num_base_shot = 1 def default_prompt_version(self): raise NotImplementedError def dataset_signature(self): # { # "result": (dataset_name, subset, split), # which produce the final result # "sample": (dataset_name, subset, split), # which we sample ICL few-shot examples # } raise NotImplementedError def dataset_part(self, part): return self.dataset_signature()[part] def dataset_preprocess(self, raw_data): raise NotImplementedError def handcrafted_exemplars(self): raise NotImplementedError def exemplar_seperator(self): raise NotImplementedError def multiple_choice_promptify(self, query, choice): raise NotImplementedError @staticmethod def merge_choice_info(choice_info): merged = {} for k in ["lm_log_p", "norm_lm_log_p"]: one_metric_merged = [] for info in choice_info: one_metric_merged.append(info[k]) merged[k] = one_metric_merged return merged @staticmethod def choice_info_to_predictions(info): lm_log_p_idx = int(np.argmax(info["lm_log_p"])) norm_lm_log_p_idx = int(np.argmax(info["norm_lm_log_p"])) return {"lm_log_p": lm_log_p_idx, "norm_lm_log_p": norm_lm_log_p_idx}