jx-yang's picture
<ADD> +app
9d21d47
import numpy as np
class BaseProbInference:
def __init__(self, prompt_version):
if prompt_version == "default":
self.prompt_version = self.default_prompt_version()
else:
self.prompt_version = prompt_version
self.raw_data_result = None
self.raw_data_sample = None
self.raw_data_dev = None
self.can_be_stratified = False
self.CHOICES = None
self.num_base_shot = 1
def default_prompt_version(self):
raise NotImplementedError
def dataset_signature(self):
# {
# "result": (dataset_name, subset, split), # which produce the final result
# "sample": (dataset_name, subset, split), # which we sample ICL few-shot examples
# }
raise NotImplementedError
def dataset_part(self, part):
return self.dataset_signature()[part]
def dataset_preprocess(self, raw_data):
raise NotImplementedError
def handcrafted_exemplars(self):
raise NotImplementedError
def exemplar_seperator(self):
raise NotImplementedError
def multiple_choice_promptify(self, query, choice):
raise NotImplementedError
@staticmethod
def merge_choice_info(choice_info):
merged = {}
for k in ["lm_log_p", "norm_lm_log_p"]:
one_metric_merged = []
for info in choice_info:
one_metric_merged.append(info[k])
merged[k] = one_metric_merged
return merged
@staticmethod
def choice_info_to_predictions(info):
lm_log_p_idx = int(np.argmax(info["lm_log_p"]))
norm_lm_log_p_idx = int(np.argmax(info["norm_lm_log_p"]))
return {"lm_log_p": lm_log_p_idx, "norm_lm_log_p": norm_lm_log_p_idx}