File size: 1,295 Bytes
256a159 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
"""Random Retriever."""
from typing import Optional
import numpy as np
from tqdm import trange
from opencompass.openicl.icl_retriever import BaseRetriever
from opencompass.openicl.utils.logging import get_logger
logger = get_logger(__name__)
class RandomRetriever(BaseRetriever):
"""Random Retriever. Each in-context example of the test prompts is
retrieved in a random way.
**WARNING**: This class has not been tested thoroughly. Please use it with
caution.
"""
def __init__(self,
dataset,
ice_separator: Optional[str] = '\n',
ice_eos_token: Optional[str] = '\n',
ice_num: Optional[int] = 1,
seed: Optional[int] = 43) -> None:
super().__init__(dataset, ice_separator, ice_eos_token, ice_num)
self.seed = seed
def retrieve(self):
np.random.seed(self.seed)
num_idx = len(self.index_ds)
rtr_idx_list = []
logger.info('Retrieving data for test set...')
for _ in trange(len(self.test_ds), disable=not self.is_main_process):
idx_list = np.random.choice(num_idx, self.ice_num,
replace=False).tolist()
rtr_idx_list.append(idx_list)
return rtr_idx_list
|