File size: 2,047 Bytes
256a159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
"""Random Retriever."""

from typing import List, Optional

from tqdm import trange

from opencompass.openicl.icl_retriever import BaseRetriever
from opencompass.openicl.utils.logging import get_logger
from opencompass.registry import ICL_RETRIEVERS

logger = get_logger(__name__)


@ICL_RETRIEVERS.register_module()
class FixKRetriever(BaseRetriever):
    """Fix-K Retriever. Each in-context example of the test prompts is
    retrieved as the same K examples from the index set.

    Args:
        dataset (`BaseDataset`): Any BaseDataset instances.
            Attributes of ``reader``, ``train`` and ``test`` will be used.
        fix_id_list (List[int]): List of in-context example indices for every
            test prompts.
        ice_separator (`Optional[str]`): The separator between each in-context
            example template when origin `PromptTemplate` is provided. Defaults
            to '\n'.
        ice_eos_token (`Optional[str]`): The end of sentence token for
            in-context example template when origin `PromptTemplate` is
            provided. Defaults to '\n'.
        ice_num (`Optional[int]`): The number of in-context example template
            when origin `PromptTemplate` is provided. Defaults to 1.
    """

    def __init__(self,
                 dataset,
                 fix_id_list: List[int],
                 ice_separator: Optional[str] = '\n',
                 ice_eos_token: Optional[str] = '\n',
                 ice_num: Optional[int] = 1) -> None:
        super().__init__(dataset, ice_separator, ice_eos_token, ice_num)
        self.fix_id_list = fix_id_list

    def retrieve(self):
        """Retrieve the in-context example index for each test example."""
        num_idx = len(self.index_ds)
        for idx in self.fix_id_list:
            assert idx < num_idx, f'Index {idx} is out of range of {num_idx}'
        rtr_idx_list = []
        for _ in trange(len(self.test_ds), disable=not self.is_main_process):
            rtr_idx_list.append(self.fix_id_list)
        return rtr_idx_list