|
"""Prompt Template.""" |
|
import copy |
|
from typing import Dict, Hashable, List, Optional, Union |
|
|
|
from opencompass.registry import ICL_PROMPT_TEMPLATES |
|
from opencompass.utils.prompt import PromptList, safe_format |
|
from opencompass.utils.types import _check_type_list |
|
|
|
PromptType = Union[PromptList, str] |
|
|
|
|
|
@ICL_PROMPT_TEMPLATES.register_module() |
|
class PromptTemplate: |
|
"""In-context Learning Prompt Template Class This class represents a |
|
template that guides the generation of prompts in the retrieval or |
|
inference process. |
|
|
|
Attributes: |
|
template (:obj:`Dict` or :obj:`str`): A custom template dictionary or |
|
string. If a dictionary, the keys of the dictionary represent the |
|
values of the output_column, and the values represent the |
|
corresponding generated statement. If a string, it represents a |
|
string template. |
|
ice_token(:obj:`str`, optional): A string that represents the specific |
|
token mapping from in-context examples. None if you want to use |
|
this template only to generate in-context examples, otherwise it |
|
can be used to generate the final prompt that is fed into the PLM. |
|
The ice_token will be invisible when generating in-context |
|
examples. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
template: Union[Dict, str], |
|
ice_token: Optional[str] = None, |
|
sep_token: Optional[str] = None, |
|
) -> None: |
|
self.template = template |
|
assert isinstance(self.template, (str, Dict)) |
|
self.ice_token = _check_type_list(ice_token, [None, str]) |
|
self.sep_token = _check_type_list(sep_token, [None, str]) |
|
|
|
self.prompt_type = 'origin' |
|
self._check_template_legacy() |
|
|
|
def _check_template_legacy(self): |
|
if isinstance(self.template, Dict): |
|
|
|
ctr = sum(key in self.template |
|
for key in ('begin', 'round', 'end')) |
|
self.prompt_type = 'meta' if ctr == len( |
|
self.template.keys()) else 'origin' |
|
|
|
|
|
for tp_dict_val in self.template.values(): |
|
if not isinstance(tp_dict_val, (str, list, dict)): |
|
raise TypeError( |
|
'dictionary of template expects a str, list or a ' |
|
f"dict, but got '{tp_dict_val}'") |
|
if isinstance( |
|
tp_dict_val, str |
|
) and self.ice_token and self.ice_token not in tp_dict_val: |
|
raise LookupError( |
|
f"'{self.ice_token}' not in '{tp_dict_val}'") |
|
|
|
if isinstance(self.template, str): |
|
if self.ice_token and self.ice_token not in self.template: |
|
raise LookupError( |
|
f"'{self.ice_token}' not in '{self.template}'") |
|
|
|
def generate_ice_item(self, entry: Dict, label: Hashable) -> PromptType: |
|
"""Generate in-context example based on the provided :obj:`entry` data. |
|
|
|
Args: |
|
entry (:obj:`Dict`): A piece of data to be used for generating the |
|
in-context example. |
|
label (:obj:`Hashable`): The value of the output field. |
|
|
|
Returns: |
|
str or PromptList: The generated in-context example. |
|
""" |
|
|
|
if isinstance(self.template, str) or self.prompt_type == 'meta': |
|
tp = self.template |
|
else: |
|
|
|
tp = self.template[label] |
|
|
|
tp = self._encode_template(tp, ice=True) |
|
|
|
if self.sep_token is not None: |
|
tp.replace(self.sep_token, '') |
|
|
|
if self.ice_token is not None: |
|
tp = tp.replace(self.ice_token, '') |
|
|
|
if isinstance(tp, str): |
|
|
|
|
|
tp = safe_format(tp, **entry) |
|
else: |
|
tp = tp.format(**entry) |
|
return tp |
|
|
|
def generate_label_prompt_item(self, |
|
entry: Dict, |
|
ice: PromptType, |
|
label: Hashable, |
|
remain_sep: Optional[bool] = False) -> str: |
|
"""Generate prompt based on :obj:`entry` data, :obj:`ice` in-context |
|
example, and the corresponding :obj:`label`. |
|
|
|
Args: |
|
|
|
entry (:obj:`Dict`): A piece of data containing the input field |
|
content. |
|
ice (str or PromptList): The generated in-context example. |
|
label (:obj:`Hashable`): The value of the output field. |
|
remain_sep (:obj:`bool`): If remain sep_token |
|
|
|
Returns: |
|
:obj:`str`: The generated prompt. |
|
""" |
|
|
|
if isinstance(self.template, str) or self.prompt_type == 'meta': |
|
template = self.template |
|
else: |
|
|
|
template = self.template[label] |
|
template = self._encode_template(template, ice=False) |
|
|
|
if not remain_sep and self.sep_token is not None: |
|
template = template.replace(self.sep_token, '') |
|
|
|
if self.ice_token is not None: |
|
template = template.replace(self.ice_token, ice) |
|
|
|
if isinstance(template, str): |
|
|
|
|
|
template = safe_format(template, **entry) |
|
else: |
|
template = template.format(**entry) |
|
return template |
|
|
|
def generate_item( |
|
self, |
|
entry: Dict, |
|
output_field: Optional[Hashable] = None, |
|
output_field_replace_token: Optional[str] = '', |
|
ice_field_replace_token: Optional[str] = '') -> PromptType: |
|
"""Generate an item based on the provided :obj:`entry` data, as well as |
|
optional output field and ice field tokens. |
|
|
|
Warning: |
|
This method is only used in generation task, i.e. GenInferencer. |
|
|
|
Args: |
|
entry (:obj:`Dict`): A piece of data. |
|
output_field (:obj:`Hashable`, optional): Column name of output |
|
field. Defaults to :obj:`None`. |
|
output_field_replace_token (:obj:`str`, optional): Tokens used to |
|
replace output field. Defaults to ``''``. |
|
ice_field_replace_token (str, optional): Tokens used to replace |
|
the :obj:`ice_token`. Defaults to ``''``. |
|
|
|
Returns: |
|
str or PromptList: The generated item. |
|
""" |
|
template = None |
|
if isinstance(self.template, str): |
|
template = self.template |
|
elif self.prompt_type == 'origin': |
|
|
|
|
|
|
|
|
|
template = self.template[list(self.template.keys())[0]] |
|
template = self._encode_template(template, ice=False) |
|
else: |
|
template = self._encode_template(self.template, ice=False) |
|
if self.ice_token is not None: |
|
template = template.replace(self.ice_token, |
|
ice_field_replace_token) |
|
|
|
if self.sep_token is not None: |
|
template = template.replace(self.sep_token, '') |
|
if output_field is not None: |
|
entry = copy.deepcopy(entry) |
|
entry[output_field] = output_field_replace_token |
|
if isinstance(template, str): |
|
|
|
|
|
template = safe_format(template, **entry) |
|
else: |
|
template = template.format(**entry) |
|
return template |
|
|
|
def _check_prompt_template(obj) -> 'PromptTemplate': |
|
if isinstance(obj, PromptTemplate): |
|
return obj |
|
else: |
|
raise TypeError(f'Expect a PromptTemplate object, but got {obj}') |
|
|
|
def __repr__(self): |
|
return (f'PromptTemplate({{\n\ttemplate: {self.template},\n\t' |
|
f'ice_token: {self.ice_token}\n}})') |
|
|
|
def _encode_template(self, prompt_template: Union[List[Union[str, Dict]], |
|
str], |
|
ice: bool) -> PromptType: |
|
"""Encode the raw template given in the config into a str or a |
|
PromptList. |
|
|
|
Args: |
|
prompt_template (List[Dict]] or str): The raw template given in the |
|
config, used for generating the prompt. If it's a string, the |
|
result will be directly returned. |
|
ice (bool): If the template is used for generating in-context |
|
examples. |
|
|
|
Returns: |
|
str or PromptList: The encoded template. |
|
""" |
|
if isinstance(prompt_template, str): |
|
return prompt_template |
|
|
|
prompt = PromptList() |
|
|
|
|
|
|
|
|
|
if 'begin' in prompt_template and not ice: |
|
prompt.append(dict(section='begin', pos='begin')) |
|
if isinstance(prompt_template['begin'], list): |
|
prompt += prompt_template['begin'] |
|
else: |
|
prompt.append(prompt_template['begin']) |
|
prompt.append(dict(section='begin', pos='end')) |
|
|
|
if ice: |
|
prompt.append(dict(section='ice', pos='begin')) |
|
else: |
|
prompt.append(dict(section='round', pos='begin')) |
|
prompt += prompt_template['round'] |
|
if ice: |
|
prompt.append(dict(section='ice', pos='end')) |
|
else: |
|
prompt.append(dict(section='round', pos='end')) |
|
|
|
if 'end' in prompt_template and not ice: |
|
prompt.append(dict(section='end', pos='end')) |
|
if isinstance(prompt_template['end'], list): |
|
prompt += prompt_template['end'] |
|
else: |
|
prompt.append(prompt_template['end']) |
|
prompt.append(dict(section='end', pos='end')) |
|
|
|
return prompt |
|
|