"""Basic Retriever.""" from abc import abstractmethod from typing import Dict, List, Optional from mmengine.dist import is_main_process from opencompass.openicl.icl_prompt_template import PromptTemplate from opencompass.utils.prompt import PromptList class BaseRetriever: """Base class for In-context Learning Example Retriever, without any retrieval method implemented. Args: dataset (`BaseDataset`): Any BaseDataset instances. Attributes of ``reader``, ``train`` and ``test`` will be used. ice_separator (`Optional[str]`): The separator between each in-context example template when origin `PromptTemplate` is provided. Defaults to '\n'. ice_eos_token (`Optional[str]`): The end of sentence token for in-context example template when origin `PromptTemplate` is provided. Defaults to '\n'. ice_num (`Optional[int]`): The number of in-context example template when origin `PromptTemplate` is provided. Defaults to 1. """ index_ds = None test_ds = None def __init__(self, dataset, ice_separator: Optional[str] = '\n', ice_eos_token: Optional[str] = '\n', ice_num: Optional[int] = 1) -> None: self.ice_separator = ice_separator self.ice_eos_token = ice_eos_token self.ice_num = ice_num self.is_main_process = is_main_process() self.dataset_reader = dataset.reader self.index_ds = dataset.train self.test_ds = dataset.test @abstractmethod def retrieve(self) -> List[List[int]]: """Retrieve the in-context example index for each test example.""" def get_labels( self, ice_template: Optional[PromptTemplate] = None, prompt_template: Optional[PromptTemplate] = None) -> List[str]: """Get the labels of the dataset, especially useful for ppl inferencer. If `ice_template` is provided, the labels will be the keys of the template. If `prompt_template` is provided, the labels will be the keys of the template. If neither of them is provided, the labels will be the unique values of the output column. Args: ice_template (`Optional[PromptTemplate]`): The template for in-context example. Defaults to None. prompt_template (`Optional[PromptTemplate]`): The template for prompt. Defaults to None. """ if prompt_template is not None and isinstance(prompt_template.template, Dict): labels = list(prompt_template.template.keys()) elif ice_template is not None and ice_template.ice_token is not None \ and isinstance(ice_template.template, Dict): labels = list(ice_template.template.keys()) else: labels = list(set(self.test_ds[self.dataset_reader.output_column])) return labels def generate_ice(self, idx_list: List[int], ice_template: Optional[PromptTemplate] = None) -> str: """Generate the in-context example for one test example. If `ice_template` is an instance of `PromptTemplate`, the `ice_separator` and `ice_eos_token` will be set as empty. Args: idx_list (`List[int]`): The index of in-context examples for the test example. ice_template (`Optional[PromptTemplate]`): The template for in-context example. Defaults to None. """ if ice_template is None: assert len( idx_list ) == 0, 'You have not specified ice_template while retrieving examples from train set! Please either specify ice_template or use `ZeroRetriever`.' # noqa if ice_template is not None and ice_template.prompt_type == 'meta': ice_separator, ice_eos_token = '', '' else: ice_separator = self.ice_separator ice_eos_token = self.ice_eos_token generated_ice_list = [] for idx in idx_list: generated_ice_list.append( ice_template.generate_ice_item( self.index_ds[idx], self.index_ds[idx][self.dataset_reader.output_column])) if len(generated_ice_list) > 0 and isinstance(generated_ice_list[0], PromptList): generated_ice = [] for ice in generated_ice_list: generated_ice += ice + ice_separator generated_ice.append(ice_eos_token) else: generated_ice = ice_separator.join( generated_ice_list) + ice_eos_token return generated_ice def generate_label_prompt(self, idx: int, ice: str, label, ice_template: Optional[PromptTemplate] = None, prompt_template: Optional[PromptTemplate] = None, remain_sep: Optional[bool] = False) -> str: """Generate the prompt for one test example in perpelxity evaluation with `prompt_template`. If `prompt_template` is not provided, the `ice_template` will be used to generate the prompt. Args: idx (`int`): The index of the test example. ice (`str`): The in-context example for the test example. label (`str`): The label of the test example. ice_template (`Optional[PromptTemplate]`): The template for in-context example. Defaults to None. prompt_template (`Optional[PromptTemplate]`): The template for prompt. Defaults to None. remain_sep (`Optional[bool]`): Whether to remain the sep token. Defaults to False. """ if prompt_template is not None and ice_template is not None: if prompt_template.ice_token is not None: return prompt_template.generate_label_prompt_item( self.test_ds[idx], ice, label, remain_sep) else: raise NotImplementedError( 'ice_token of prompt_template is not provided') elif ice_template is not None and prompt_template is None: if ice_template.ice_token is not None: return ice_template.generate_label_prompt_item( self.test_ds[idx], ice, label, remain_sep) else: raise NotImplementedError( 'ice_token of ice_template is not provided') elif ice_template is None and prompt_template is not None: return prompt_template.generate_label_prompt_item( self.test_ds[idx], ice, label, remain_sep) else: raise NotImplementedError( 'Leaving prompt as empty is not supported') def generate_prompt_for_generate_task( self, idx, ice, gen_field_replace_token='', ice_template: Optional[PromptTemplate] = None, prompt_template: Optional[PromptTemplate] = None): """Generate the prompt for one test example in generative evaluation with `prompt_template`. If `prompt_template` is not provided, the `ice_template` will be used to generate the prompt. The token represented by `gen_field_replace_token` will not be replaced by the generated text, or it will leaks the answer. Args: idx (`int`): The index of the test example. ice (`str`): The in-context example for the test example. gen_field_replace_token (`str`): The token of the answer in the prompt. Defaults to ''. ice_template (`Optional[PromptTemplate]`): The template for in-context example. Defaults to None. prompt_template (`Optional[PromptTemplate]`): The template for prompt. Defaults to None. """ if prompt_template is not None and ice_template is not None: if prompt_template.ice_token is not None: return prompt_template.generate_item( self.test_ds[idx], output_field=self.dataset_reader.output_column, output_field_replace_token=gen_field_replace_token, ice_field_replace_token=ice) else: raise NotImplementedError( 'ice_token of prompt_template is not provided') elif ice_template is not None and prompt_template is None: if ice_template.ice_token is not None: return ice_template.generate_item( self.test_ds[idx], output_field=self.dataset_reader.output_column, output_field_replace_token=gen_field_replace_token, ice_field_replace_token=ice) else: raise NotImplementedError( 'ice_token of ice_template is not provided') elif ice_template is None and prompt_template is not None: return prompt_template.generate_item( self.test_ds[idx], output_field=self.dataset_reader.output_column, output_field_replace_token=gen_field_replace_token, ice_field_replace_token=ice) else: raise NotImplementedError( 'Leaving prompt as empty is not supported') def generate_prompt_for_adv_generate_task( self, idx, ice, extra_prompt=dict(), gen_field_replace_token='', ice_template: Optional[PromptTemplate] = None, prompt_template: Optional[PromptTemplate] = None): """Generate the prompt for one test example in generative evaluation with `prompt_template`. If `prompt_template` is not provided, the `ice_template` will be used to generate the prompt. The token represented by `gen_field_replace_token` will not be replaced by the generated text, or it will leaks the answer. Args: idx (`int`): The index of the test example. ice (`str`): The in-context example for the test example. gen_field_replace_token (`str`): The token of the answer in the prompt. Defaults to ''. ice_template (`Optional[PromptTemplate]`): The template for in-context example. Defaults to None. prompt_template (`Optional[PromptTemplate]`): The template for prompt. Defaults to None. """ if prompt_template is not None and ice_template is not None: if prompt_template.ice_token is not None: return prompt_template.generate_item( { **self.test_ds[idx], **extra_prompt }, output_field=self.dataset_reader.output_column, output_field_replace_token=gen_field_replace_token, ice_field_replace_token=ice) else: raise NotImplementedError( 'ice_token of prompt_template is not provided') elif ice_template is not None and prompt_template is None: if ice_template.ice_token is not None: return ice_template.generate_item( { **self.test_ds[idx], **extra_prompt }, output_field=self.dataset_reader.output_column, output_field_replace_token=gen_field_replace_token, ice_field_replace_token=ice) else: raise NotImplementedError( 'ice_token of ice_template is not provided') elif ice_template is None and prompt_template is not None: return prompt_template.generate_item( { **self.test_ds[idx], **extra_prompt }, output_field=self.dataset_reader.output_column, output_field_replace_token=gen_field_replace_token, ice_field_replace_token=ice) else: raise NotImplementedError( 'Leaving prompt as empty is not supported')