TwT-6's picture
Upload 2667 files
256a159 verified
"""BM25 Retriever."""
from typing import List, Optional
import numpy as np
from nltk.tokenize import word_tokenize
from rank_bm25 import BM25Okapi
from tqdm import trange
from opencompass.openicl.icl_retriever import BaseRetriever
from opencompass.openicl.utils.logging import get_logger
from opencompass.registry import ICL_RETRIEVERS
logger = get_logger(__name__)
@ICL_RETRIEVERS.register_module()
class BM25Retriever(BaseRetriever):
"""BM25 Retriever. In information retrieval, Okapi BM25 (BM is an
abbreviation of best matching) is a ranking function used by search engines
to estimate the relevance of documents to a given search query. You can
find more details in https://en.wikipedia.org/wiki/Okapi_BM25. Each in-
context example of the test prompts is retrieved by the BM25 Algorithm.
Args:
dataset (`BaseDataset`): Any BaseDataset instances.
Attributes of ``reader``, ``train`` and ``test`` will be used.
ice_separator (`Optional[str]`): The separator between each in-context
example template when origin `PromptTemplate` is provided. Defaults
to '\n'.
ice_eos_token (`Optional[str]`): The end of sentence token for
in-context example template when origin `PromptTemplate` is
provided. Defaults to '\n'.
ice_num (`Optional[int]`): The number of in-context example template
when origin `PromptTemplate` is provided. Defaults to 1.
index_split (`Optional[str]`): The split of the dataset to retrieve the
in-context example index, used when `dataset_reader.dataset` is an
instance of `datasets.Dataset`. Defaults to 'train'.
test_split (`Optional[str]`): The split of the dataset to retrieve the
in-context example, used when `dataset_reader.dataset` is an
instance of `datasets.Dataset`. Defaults to 'test'.
"""
bm25 = None
index_corpus = None
test_corpus = None
def __init__(self,
dataset,
ice_separator: Optional[str] = '\n',
ice_eos_token: Optional[str] = '\n',
ice_num: Optional[int] = 1) -> None:
super().__init__(dataset, ice_separator, ice_eos_token, ice_num)
self.index_corpus = [
word_tokenize(data) for data in
self.dataset_reader.generate_input_field_corpus(self.index_ds)
]
self.bm25 = BM25Okapi(self.index_corpus)
self.test_corpus = [
word_tokenize(data) for data in
self.dataset_reader.generate_input_field_corpus(self.test_ds)
]
def retrieve(self) -> List[List]:
"""Retrieve the in-context example index for each test example."""
rtr_idx_list = []
logger.info('Retrieving data for test set...')
for idx in trange(len(self.test_corpus),
disable=not self.is_main_process):
query = self.test_corpus[idx]
scores = self.bm25.get_scores(query)
near_ids = list(np.argsort(scores)[::-1][:self.ice_num])
near_ids = [int(a) for a in near_ids]
rtr_idx_list.append(near_ids)
return rtr_idx_list