TwT-6's picture
Upload 2667 files
256a159 verified
"""Prompt Template."""
import copy
from typing import Dict, Hashable, List, Optional, Union
from opencompass.registry import ICL_PROMPT_TEMPLATES
from opencompass.utils.prompt import PromptList, safe_format
from opencompass.utils.types import _check_type_list
PromptType = Union[PromptList, str]
@ICL_PROMPT_TEMPLATES.register_module()
class PromptTemplate:
"""In-context Learning Prompt Template Class This class represents a
template that guides the generation of prompts in the retrieval or
inference process.
Attributes:
template (:obj:`Dict` or :obj:`str`): A custom template dictionary or
string. If a dictionary, the keys of the dictionary represent the
values of the output_column, and the values represent the
corresponding generated statement. If a string, it represents a
string template.
ice_token(:obj:`str`, optional): A string that represents the specific
token mapping from in-context examples. None if you want to use
this template only to generate in-context examples, otherwise it
can be used to generate the final prompt that is fed into the PLM.
The ice_token will be invisible when generating in-context
examples.
"""
def __init__(
self,
template: Union[Dict, str],
ice_token: Optional[str] = None,
sep_token: Optional[str] = None,
) -> None:
self.template = template
assert isinstance(self.template, (str, Dict))
self.ice_token = _check_type_list(ice_token, [None, str])
self.sep_token = _check_type_list(sep_token, [None, str])
# A sign used to distinguish the prompt type
self.prompt_type = 'origin'
self._check_template_legacy()
def _check_template_legacy(self):
if isinstance(self.template, Dict):
# Check if it's the label-prompt type or just a meta prompt type
ctr = sum(key in self.template
for key in ('begin', 'round', 'end'))
self.prompt_type = 'meta' if ctr == len(
self.template.keys()) else 'origin'
# Check if token exists in values of tp_dict
for tp_dict_val in self.template.values():
if not isinstance(tp_dict_val, (str, list, dict)):
raise TypeError(
'dictionary of template expects a str, list or a '
f"dict, but got '{tp_dict_val}'")
if isinstance(
tp_dict_val, str
) and self.ice_token and self.ice_token not in tp_dict_val:
raise LookupError(
f"'{self.ice_token}' not in '{tp_dict_val}'")
if isinstance(self.template, str):
if self.ice_token and self.ice_token not in self.template:
raise LookupError(
f"'{self.ice_token}' not in '{self.template}'")
def generate_ice_item(self, entry: Dict, label: Hashable) -> PromptType:
"""Generate in-context example based on the provided :obj:`entry` data.
Args:
entry (:obj:`Dict`): A piece of data to be used for generating the
in-context example.
label (:obj:`Hashable`): The value of the output field.
Returns:
str or PromptList: The generated in-context example.
"""
# Select the corresponding template
if isinstance(self.template, str) or self.prompt_type == 'meta':
tp = self.template
else:
# prompt type == origin
tp = self.template[label]
# tp = self._meta2str(tp, mode='ice')
tp = self._encode_template(tp, ice=True)
# Remove sep token
if self.sep_token is not None:
tp.replace(self.sep_token, '')
# Remove ice_token
if self.ice_token is not None:
tp = tp.replace(self.ice_token, '')
# Replace context token
if isinstance(tp, str):
# We want to use safe_substitute instead of str.format to avoid
# KeyError while preserving the original string in curly brackets
tp = safe_format(tp, **entry)
else:
tp = tp.format(**entry)
return tp
def generate_label_prompt_item(self,
entry: Dict,
ice: PromptType,
label: Hashable,
remain_sep: Optional[bool] = False) -> str:
"""Generate prompt based on :obj:`entry` data, :obj:`ice` in-context
example, and the corresponding :obj:`label`.
Args:
entry (:obj:`Dict`): A piece of data containing the input field
content.
ice (str or PromptList): The generated in-context example.
label (:obj:`Hashable`): The value of the output field.
remain_sep (:obj:`bool`): If remain sep_token
Returns:
:obj:`str`: The generated prompt.
"""
# Select the corresponding template
if isinstance(self.template, str) or self.prompt_type == 'meta':
template = self.template
else:
# template is a dict with a label -> prompt mapping
template = self.template[label]
template = self._encode_template(template, ice=False)
# Remove sep token
if not remain_sep and self.sep_token is not None:
template = template.replace(self.sep_token, '')
# Insert in-context examples
if self.ice_token is not None:
template = template.replace(self.ice_token, ice)
# Replace context token
if isinstance(template, str):
# We want to use safe_substitute instead of str.format to avoid
# KeyError while preserving the original string in curly brackets
template = safe_format(template, **entry)
else:
template = template.format(**entry)
return template
def generate_item(
self,
entry: Dict,
output_field: Optional[Hashable] = None,
output_field_replace_token: Optional[str] = '',
ice_field_replace_token: Optional[str] = '') -> PromptType:
"""Generate an item based on the provided :obj:`entry` data, as well as
optional output field and ice field tokens.
Warning:
This method is only used in generation task, i.e. GenInferencer.
Args:
entry (:obj:`Dict`): A piece of data.
output_field (:obj:`Hashable`, optional): Column name of output
field. Defaults to :obj:`None`.
output_field_replace_token (:obj:`str`, optional): Tokens used to
replace output field. Defaults to ``''``.
ice_field_replace_token (str, optional): Tokens used to replace
the :obj:`ice_token`. Defaults to ``''``.
Returns:
str or PromptList: The generated item.
"""
template = None
if isinstance(self.template, str):
template = self.template
elif self.prompt_type == 'origin':
# This if is only effective when you are using GenInferecner
# with multi-label prompts.
# Such a combination doesn't make sense at all :)
# TODO: Check this, seems it is used in XXXRetriever as well
template = self.template[list(self.template.keys())[0]]
template = self._encode_template(template, ice=False)
else:
template = self._encode_template(self.template, ice=False)
if self.ice_token is not None:
template = template.replace(self.ice_token,
ice_field_replace_token)
# Remove sep token
if self.sep_token is not None:
template = template.replace(self.sep_token, '')
if output_field is not None:
entry = copy.deepcopy(entry)
entry[output_field] = output_field_replace_token
if isinstance(template, str):
# We want to use safe_substitute instead of str.format to avoid
# KeyError while preserving the original string in curly brackets
template = safe_format(template, **entry)
else:
template = template.format(**entry)
return template
def _check_prompt_template(obj) -> 'PromptTemplate':
if isinstance(obj, PromptTemplate):
return obj
else:
raise TypeError(f'Expect a PromptTemplate object, but got {obj}')
def __repr__(self):
return (f'PromptTemplate({{\n\ttemplate: {self.template},\n\t'
f'ice_token: {self.ice_token}\n}})')
def _encode_template(self, prompt_template: Union[List[Union[str, Dict]],
str],
ice: bool) -> PromptType:
"""Encode the raw template given in the config into a str or a
PromptList.
Args:
prompt_template (List[Dict]] or str): The raw template given in the
config, used for generating the prompt. If it's a string, the
result will be directly returned.
ice (bool): If the template is used for generating in-context
examples.
Returns:
str or PromptList: The encoded template.
"""
if isinstance(prompt_template, str):
return prompt_template
prompt = PromptList()
# TODO: Why can't we generate begin & end for ice template?
# To fix this, first we need to allow specifying prompt_template
# only
if 'begin' in prompt_template and not ice:
prompt.append(dict(section='begin', pos='begin'))
if isinstance(prompt_template['begin'], list):
prompt += prompt_template['begin']
else:
prompt.append(prompt_template['begin'])
prompt.append(dict(section='begin', pos='end'))
if ice:
prompt.append(dict(section='ice', pos='begin'))
else:
prompt.append(dict(section='round', pos='begin'))
prompt += prompt_template['round']
if ice:
prompt.append(dict(section='ice', pos='end'))
else:
prompt.append(dict(section='round', pos='end'))
if 'end' in prompt_template and not ice:
prompt.append(dict(section='end', pos='end'))
if isinstance(prompt_template['end'], list):
prompt += prompt_template['end']
else:
prompt.append(prompt_template['end'])
prompt.append(dict(section='end', pos='end'))
return prompt