TwT-6's picture
Upload 2667 files
256a159 verified
raw
history blame
12.6 kB
"""Basic Retriever."""
from abc import abstractmethod
from typing import Dict, List, Optional
from mmengine.dist import is_main_process
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.utils.prompt import PromptList
class BaseRetriever:
"""Base class for In-context Learning Example Retriever, without any
retrieval method implemented.
Args:
dataset (`BaseDataset`): Any BaseDataset instances.
Attributes of ``reader``, ``train`` and ``test`` will be used.
ice_separator (`Optional[str]`): The separator between each in-context
example template when origin `PromptTemplate` is provided. Defaults
to '\n'.
ice_eos_token (`Optional[str]`): The end of sentence token for
in-context example template when origin `PromptTemplate` is
provided. Defaults to '\n'.
ice_num (`Optional[int]`): The number of in-context example template
when origin `PromptTemplate` is provided. Defaults to 1.
"""
index_ds = None
test_ds = None
def __init__(self,
dataset,
ice_separator: Optional[str] = '\n',
ice_eos_token: Optional[str] = '\n',
ice_num: Optional[int] = 1) -> None:
self.ice_separator = ice_separator
self.ice_eos_token = ice_eos_token
self.ice_num = ice_num
self.is_main_process = is_main_process()
self.dataset_reader = dataset.reader
self.index_ds = dataset.train
self.test_ds = dataset.test
@abstractmethod
def retrieve(self) -> List[List[int]]:
"""Retrieve the in-context example index for each test example."""
def get_labels(
self,
ice_template: Optional[PromptTemplate] = None,
prompt_template: Optional[PromptTemplate] = None) -> List[str]:
"""Get the labels of the dataset, especially useful for ppl inferencer.
If `ice_template` is provided, the labels will be the keys of the
template. If `prompt_template` is provided, the labels will be the keys
of the template. If neither of them is provided, the labels will be the
unique values of the output column.
Args:
ice_template (`Optional[PromptTemplate]`): The template for
in-context example. Defaults to None.
prompt_template (`Optional[PromptTemplate]`): The template for
prompt. Defaults to None.
"""
if prompt_template is not None and isinstance(prompt_template.template,
Dict):
labels = list(prompt_template.template.keys())
elif ice_template is not None and ice_template.ice_token is not None \
and isinstance(ice_template.template, Dict):
labels = list(ice_template.template.keys())
else:
labels = list(set(self.test_ds[self.dataset_reader.output_column]))
return labels
def generate_ice(self,
idx_list: List[int],
ice_template: Optional[PromptTemplate] = None) -> str:
"""Generate the in-context example for one test example. If
`ice_template` is an instance of `PromptTemplate`, the `ice_separator`
and `ice_eos_token` will be set as empty.
Args:
idx_list (`List[int]`): The index of in-context examples for the
test example.
ice_template (`Optional[PromptTemplate]`): The template for
in-context example. Defaults to None.
"""
if ice_template is None:
assert len(
idx_list
) == 0, 'You have not specified ice_template while retrieving examples from train set! Please either specify ice_template or use `ZeroRetriever`.' # noqa
if ice_template is not None and ice_template.prompt_type == 'meta':
ice_separator, ice_eos_token = '', ''
else:
ice_separator = self.ice_separator
ice_eos_token = self.ice_eos_token
generated_ice_list = []
for idx in idx_list:
generated_ice_list.append(
ice_template.generate_ice_item(
self.index_ds[idx],
self.index_ds[idx][self.dataset_reader.output_column]))
if len(generated_ice_list) > 0 and isinstance(generated_ice_list[0],
PromptList):
generated_ice = []
for ice in generated_ice_list:
generated_ice += ice + ice_separator
generated_ice.append(ice_eos_token)
else:
generated_ice = ice_separator.join(
generated_ice_list) + ice_eos_token
return generated_ice
def generate_label_prompt(self,
idx: int,
ice: str,
label,
ice_template: Optional[PromptTemplate] = None,
prompt_template: Optional[PromptTemplate] = None,
remain_sep: Optional[bool] = False) -> str:
"""Generate the prompt for one test example in perpelxity evaluation
with `prompt_template`. If `prompt_template` is not provided, the
`ice_template` will be used to generate the prompt.
Args:
idx (`int`): The index of the test example.
ice (`str`): The in-context example for the test example.
label (`str`): The label of the test example.
ice_template (`Optional[PromptTemplate]`): The template for
in-context example. Defaults to None.
prompt_template (`Optional[PromptTemplate]`): The template for
prompt. Defaults to None.
remain_sep (`Optional[bool]`): Whether to remain the sep token.
Defaults to False.
"""
if prompt_template is not None and ice_template is not None:
if prompt_template.ice_token is not None:
return prompt_template.generate_label_prompt_item(
self.test_ds[idx], ice, label, remain_sep)
else:
raise NotImplementedError(
'ice_token of prompt_template is not provided')
elif ice_template is not None and prompt_template is None:
if ice_template.ice_token is not None:
return ice_template.generate_label_prompt_item(
self.test_ds[idx], ice, label, remain_sep)
else:
raise NotImplementedError(
'ice_token of ice_template is not provided')
elif ice_template is None and prompt_template is not None:
return prompt_template.generate_label_prompt_item(
self.test_ds[idx], ice, label, remain_sep)
else:
raise NotImplementedError(
'Leaving prompt as empty is not supported')
def generate_prompt_for_generate_task(
self,
idx,
ice,
gen_field_replace_token='',
ice_template: Optional[PromptTemplate] = None,
prompt_template: Optional[PromptTemplate] = None):
"""Generate the prompt for one test example in generative evaluation
with `prompt_template`. If `prompt_template` is not provided, the
`ice_template` will be used to generate the prompt. The token
represented by `gen_field_replace_token` will not be replaced by the
generated text, or it will leaks the answer.
Args:
idx (`int`): The index of the test example.
ice (`str`): The in-context example for the test example.
gen_field_replace_token (`str`): The token of the answer in the
prompt. Defaults to ''.
ice_template (`Optional[PromptTemplate]`): The template for
in-context example. Defaults to None.
prompt_template (`Optional[PromptTemplate]`): The template for
prompt. Defaults to None.
"""
if prompt_template is not None and ice_template is not None:
if prompt_template.ice_token is not None:
return prompt_template.generate_item(
self.test_ds[idx],
output_field=self.dataset_reader.output_column,
output_field_replace_token=gen_field_replace_token,
ice_field_replace_token=ice)
else:
raise NotImplementedError(
'ice_token of prompt_template is not provided')
elif ice_template is not None and prompt_template is None:
if ice_template.ice_token is not None:
return ice_template.generate_item(
self.test_ds[idx],
output_field=self.dataset_reader.output_column,
output_field_replace_token=gen_field_replace_token,
ice_field_replace_token=ice)
else:
raise NotImplementedError(
'ice_token of ice_template is not provided')
elif ice_template is None and prompt_template is not None:
return prompt_template.generate_item(
self.test_ds[idx],
output_field=self.dataset_reader.output_column,
output_field_replace_token=gen_field_replace_token,
ice_field_replace_token=ice)
else:
raise NotImplementedError(
'Leaving prompt as empty is not supported')
def generate_prompt_for_adv_generate_task(
self,
idx,
ice,
extra_prompt=dict(),
gen_field_replace_token='',
ice_template: Optional[PromptTemplate] = None,
prompt_template: Optional[PromptTemplate] = None):
"""Generate the prompt for one test example in generative evaluation
with `prompt_template`. If `prompt_template` is not provided, the
`ice_template` will be used to generate the prompt. The token
represented by `gen_field_replace_token` will not be replaced by the
generated text, or it will leaks the answer.
Args:
idx (`int`): The index of the test example.
ice (`str`): The in-context example for the test example.
gen_field_replace_token (`str`): The token of the answer in the
prompt. Defaults to ''.
ice_template (`Optional[PromptTemplate]`): The template for
in-context example. Defaults to None.
prompt_template (`Optional[PromptTemplate]`): The template for
prompt. Defaults to None.
"""
if prompt_template is not None and ice_template is not None:
if prompt_template.ice_token is not None:
return prompt_template.generate_item(
{
**self.test_ds[idx],
**extra_prompt
},
output_field=self.dataset_reader.output_column,
output_field_replace_token=gen_field_replace_token,
ice_field_replace_token=ice)
else:
raise NotImplementedError(
'ice_token of prompt_template is not provided')
elif ice_template is not None and prompt_template is None:
if ice_template.ice_token is not None:
return ice_template.generate_item(
{
**self.test_ds[idx],
**extra_prompt
},
output_field=self.dataset_reader.output_column,
output_field_replace_token=gen_field_replace_token,
ice_field_replace_token=ice)
else:
raise NotImplementedError(
'ice_token of ice_template is not provided')
elif ice_template is None and prompt_template is not None:
return prompt_template.generate_item(
{
**self.test_ds[idx],
**extra_prompt
},
output_field=self.dataset_reader.output_column,
output_field_replace_token=gen_field_replace_token,
ice_field_replace_token=ice)
else:
raise NotImplementedError(
'Leaving prompt as empty is not supported')