File size: 1,295 Bytes
256a159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
"""Random Retriever."""

from typing import Optional

import numpy as np
from tqdm import trange

from opencompass.openicl.icl_retriever import BaseRetriever
from opencompass.openicl.utils.logging import get_logger

logger = get_logger(__name__)


class RandomRetriever(BaseRetriever):
    """Random Retriever. Each in-context example of the test prompts is
    retrieved in a random way.

    **WARNING**: This class has not been tested thoroughly. Please use it with
    caution.
    """

    def __init__(self,
                 dataset,
                 ice_separator: Optional[str] = '\n',
                 ice_eos_token: Optional[str] = '\n',
                 ice_num: Optional[int] = 1,
                 seed: Optional[int] = 43) -> None:
        super().__init__(dataset, ice_separator, ice_eos_token, ice_num)
        self.seed = seed

    def retrieve(self):
        np.random.seed(self.seed)
        num_idx = len(self.index_ds)
        rtr_idx_list = []
        logger.info('Retrieving data for test set...')
        for _ in trange(len(self.test_ds), disable=not self.is_main_process):
            idx_list = np.random.choice(num_idx, self.ice_num,
                                        replace=False).tolist()
            rtr_idx_list.append(idx_list)
        return rtr_idx_list