Spaces:
Runtime error
Runtime error
import numpy as np | |
class BaseProbInference: | |
def __init__(self, prompt_version): | |
if prompt_version == "default": | |
self.prompt_version = self.default_prompt_version() | |
else: | |
self.prompt_version = prompt_version | |
self.raw_data_result = None | |
self.raw_data_sample = None | |
self.raw_data_dev = None | |
self.can_be_stratified = False | |
self.CHOICES = None | |
self.num_base_shot = 1 | |
def default_prompt_version(self): | |
raise NotImplementedError | |
def dataset_signature(self): | |
# { | |
# "result": (dataset_name, subset, split), # which produce the final result | |
# "sample": (dataset_name, subset, split), # which we sample ICL few-shot examples | |
# } | |
raise NotImplementedError | |
def dataset_part(self, part): | |
return self.dataset_signature()[part] | |
def dataset_preprocess(self, raw_data): | |
raise NotImplementedError | |
def handcrafted_exemplars(self): | |
raise NotImplementedError | |
def exemplar_seperator(self): | |
raise NotImplementedError | |
def multiple_choice_promptify(self, query, choice): | |
raise NotImplementedError | |
def merge_choice_info(choice_info): | |
merged = {} | |
for k in ["lm_log_p", "norm_lm_log_p"]: | |
one_metric_merged = [] | |
for info in choice_info: | |
one_metric_merged.append(info[k]) | |
merged[k] = one_metric_merged | |
return merged | |
def choice_info_to_predictions(info): | |
lm_log_p_idx = int(np.argmax(info["lm_log_p"])) | |
norm_lm_log_p_idx = int(np.argmax(info["norm_lm_log_p"])) | |
return {"lm_log_p": lm_log_p_idx, "norm_lm_log_p": norm_lm_log_p_idx} | |