TwT-6's picture
Upload 2667 files
256a159 verified
raw
history blame
1.3 kB
"""Random Retriever."""
from typing import Optional
import numpy as np
from tqdm import trange
from opencompass.openicl.icl_retriever import BaseRetriever
from opencompass.openicl.utils.logging import get_logger
logger = get_logger(__name__)
class RandomRetriever(BaseRetriever):
"""Random Retriever. Each in-context example of the test prompts is
retrieved in a random way.
**WARNING**: This class has not been tested thoroughly. Please use it with
caution.
"""
def __init__(self,
dataset,
ice_separator: Optional[str] = '\n',
ice_eos_token: Optional[str] = '\n',
ice_num: Optional[int] = 1,
seed: Optional[int] = 43) -> None:
super().__init__(dataset, ice_separator, ice_eos_token, ice_num)
self.seed = seed
def retrieve(self):
np.random.seed(self.seed)
num_idx = len(self.index_ds)
rtr_idx_list = []
logger.info('Retrieving data for test set...')
for _ in trange(len(self.test_ds), disable=not self.is_main_process):
idx_list = np.random.choice(num_idx, self.ice_num,
replace=False).tolist()
rtr_idx_list.append(idx_list)
return rtr_idx_list