# -*- coding: utf-8 -*- from __future__ import annotations from dataclasses import dataclass import pickle import os from typing import Iterable, Callable, List, Dict, Optional, Type, TypeVar from nlp4web_codebase.ir.data_loaders.dm import Document from collections import Counter import tqdm import re import nltk nltk.download("stopwords", quiet=True) from nltk.corpus import stopwords as nltk_stopwords LANGUAGE = "english" word_splitter = re.compile(r"(?u)\b\w\w+\b").findall stopwords = set(nltk_stopwords.words(LANGUAGE)) def word_splitting(text: str) -> List[str]: return word_splitter(text.lower()) def lemmatization(words: List[str]) -> List[str]: return words # We ignore lemmatization here for simplicity def simple_tokenize(text: str) -> List[str]: words = word_splitting(text) tokenized = list(filter(lambda w: w not in stopwords, words)) tokenized = lemmatization(tokenized) return tokenized T = TypeVar("T", bound="InvertedIndex") @dataclass class PostingList: term: str # The term docid_postings: List[int] # docid_postings[i] means the docid (int) of the i-th associated posting tweight_postings: List[float] # tweight_postings[i] means the term weight (float) of the i-th associated posting @dataclass class InvertedIndex: posting_lists: List[PostingList] # docid -> posting_list vocab: Dict[str, int] cid2docid: Dict[str, int] # collection_id -> docid collection_ids: List[str] # docid -> collection_id doc_texts: Optional[List[str]] = None # docid -> document text def save(self, output_dir: str) -> None: os.makedirs(output_dir, exist_ok=True) with open(os.path.join(output_dir, "index.pkl"), "wb") as f: pickle.dump(self, f) @classmethod def from_saved(cls: Type[T], saved_dir: str) -> T: index = cls( posting_lists=[], vocab={}, cid2docid={}, collection_ids=[], doc_texts=None ) with open(os.path.join(saved_dir, "index.pkl"), "rb") as f: index = pickle.load(f) return index # The output of the counting function: @dataclass class Counting: posting_lists: List[PostingList] vocab: Dict[str, int] cid2docid: Dict[str, int] collection_ids: List[str] dfs: List[int] # tid -> df dls: List[int] # docid -> doc length avgdl: float nterms: int doc_texts: Optional[List[str]] = None def run_counting( documents: Iterable[Document], tokenize_fn: Callable[[str], List[str]] = simple_tokenize, store_raw: bool = True, # store the document text in doc_texts ndocs: Optional[int] = None, show_progress_bar: bool = True, ) -> Counting: """Counting TFs, DFs, doc_lengths, etc.""" posting_lists: List[PostingList] = [] vocab: Dict[str, int] = {} cid2docid: Dict[str, int] = {} collection_ids: List[str] = [] dfs: List[int] = [] # tid -> df dls: List[int] = [] # docid -> doc length nterms: int = 0 doc_texts: Optional[List[str]] = [] for doc in tqdm.tqdm( documents, desc="Counting", total=ndocs, disable=not show_progress_bar, ): if doc.collection_id in cid2docid: continue collection_ids.append(doc.collection_id) docid = cid2docid.setdefault(doc.collection_id, len(cid2docid)) toks = tokenize_fn(doc.text) tok2tf = Counter(toks) dls.append(sum(tok2tf.values())) for tok, tf in tok2tf.items(): nterms += tf tid = vocab.get(tok, None) if tid is None: posting_lists.append( PostingList(term=tok, docid_postings=[], tweight_postings=[]) ) tid = vocab.setdefault(tok, len(vocab)) posting_lists[tid].docid_postings.append(docid) posting_lists[tid].tweight_postings.append(tf) if tid < len(dfs): dfs[tid] += 1 else: dfs.append(0) if store_raw: doc_texts.append(doc.text) else: doc_texts = None return Counting( posting_lists=posting_lists, vocab=vocab, cid2docid=cid2docid, collection_ids=collection_ids, dfs=dfs, dls=dls, avgdl=sum(dls) / len(dls), nterms=nterms, doc_texts=doc_texts, ) from nlp4web_codebase.ir.data_loaders.sciq import load_sciq sciq = load_sciq() counting = run_counting(documents=iter(sciq.corpus), ndocs=len(sciq.corpus)) """### BM25 Index""" from dataclasses import asdict, dataclass import math import os from typing import Iterable, List, Optional, Type import tqdm from nlp4web_codebase.ir.data_loaders.dm import Document @dataclass class BM25Index(InvertedIndex): @staticmethod def tokenize(text: str) -> List[str]: return simple_tokenize(text) @staticmethod def cache_term_weights( posting_lists: List[PostingList], total_docs: int, avgdl: float, dfs: List[int], dls: List[int], k1: float, b: float, ) -> None: """Compute term weights and caching""" N = total_docs for tid, posting_list in enumerate( tqdm.tqdm(posting_lists, desc="Regularizing TFs") ): idf = BM25Index.calc_idf(df=dfs[tid], N=N) for i in range(len(posting_list.docid_postings)): docid = posting_list.docid_postings[i] tf = posting_list.tweight_postings[i] dl = dls[docid] regularized_tf = BM25Index.calc_regularized_tf( tf=tf, dl=dl, avgdl=avgdl, k1=k1, b=b ) posting_list.tweight_postings[i] = regularized_tf * idf @staticmethod def calc_regularized_tf( tf: int, dl: float, avgdl: float, k1: float, b: float ) -> float: return tf / (tf + k1 * (1 - b + b * dl / avgdl)) @staticmethod def calc_idf(df: int, N: int): return math.log(1 + (N - df + 0.5) / (df + 0.5)) @classmethod def build_from_documents( cls: Type[BM25Index], documents: Iterable[Document], store_raw: bool = True, output_dir: Optional[str] = None, ndocs: Optional[int] = None, show_progress_bar: bool = True, k1: float = 0.9, b: float = 0.4, ) -> BM25Index: # Counting TFs, DFs, doc_lengths, etc.: counting = run_counting( documents=documents, tokenize_fn=BM25Index.tokenize, store_raw=store_raw, ndocs=ndocs, show_progress_bar=show_progress_bar, ) # Compute term weights and caching: posting_lists = counting.posting_lists total_docs = len(counting.cid2docid) BM25Index.cache_term_weights( posting_lists=posting_lists, total_docs=total_docs, avgdl=counting.avgdl, dfs=counting.dfs, dls=counting.dls, k1=k1, b=b, ) # Assembly and save: index = BM25Index( posting_lists=posting_lists, vocab=counting.vocab, cid2docid=counting.cid2docid, collection_ids=counting.collection_ids, doc_texts=counting.doc_texts, ) return index bm25_index = BM25Index.build_from_documents( documents=iter(sciq.corpus), ndocs=12160, show_progress_bar=True, ) bm25_index.save("output/bm25_index") """### BM25 Retriever""" from nlp4web_codebase.ir.models import BaseRetriever from typing import Type from abc import abstractmethod class BaseInvertedIndexRetriever(BaseRetriever): @property @abstractmethod def index_class(self) -> Type[InvertedIndex]: pass def __init__(self, index_dir: str) -> None: self.index = self.index_class.from_saved(index_dir) def get_term_weights(self, query: str, cid: str) -> Dict[str, float]: toks = self.index.tokenize(query) target_docid = self.index.cid2docid[cid] term_weights = {} for tok in toks: if tok not in self.index.vocab: continue tid = self.index.vocab[tok] posting_list = self.index.posting_lists[tid] for docid, tweight in zip( posting_list.docid_postings, posting_list.tweight_postings ): if docid == target_docid: term_weights[tok] = tweight break return term_weights def score(self, query: str, cid: str) -> float: return sum(self.get_term_weights(query=query, cid=cid).values()) def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]: toks = self.index.tokenize(query) docid2score: Dict[int, float] = {} for tok in toks: if tok not in self.index.vocab: continue tid = self.index.vocab[tok] posting_list = self.index.posting_lists[tid] for docid, tweight in zip( posting_list.docid_postings, posting_list.tweight_postings ): docid2score.setdefault(docid, 0) docid2score[docid] += tweight docid2score = dict( sorted(docid2score.items(), key=lambda pair: pair[1], reverse=True)[:topk] ) return { self.index.collection_ids[docid]: score for docid, score in docid2score.items() } class BM25Retriever(BaseInvertedIndexRetriever): @property def index_class(self) -> Type[BM25Index]: return BM25Index bm25_retriever = BM25Retriever(index_dir="output/bm25_index") bm25_retriever.retrieve("What type of diseases occur when the immune system attacks normal body cells?") """# TASK1: tune b and k1 (4 points) Tune b and k1 on the **dev** split of SciQ using the metric MAP@10. The evaluation function (`evalaute_map`) is provided. Record the values in `plots_k1` and `plots_b`. Do it in a greedy manner: as the influence from b is larger, please first tune b (with k1 fixed to the default value 0.9) and use the best value of b to further tune k1. $${\displaystyle {\text{score}}(D,Q)=\sum _{i=1}^{n}{\text{IDF}}(q_{i})\cdot {\frac {f(q_{i},D)\cdot (k_{1}+1)}{f(q_{i},D)+k_{1}\cdot \left(1-b+b\cdot {\frac {|D|}{\text{avgdl}}}\right)}}}$$ """ from nlp4web_codebase.ir.data_loaders import Split import pytrec_eval import numpy as np def evaluate_map(rankings: Dict[str, Dict[str, float]], split=Split.dev) -> float: metric = "map_cut_10" qrels = sciq.get_qrels_dict(split) evaluator = pytrec_eval.RelevanceEvaluator(sciq.get_qrels_dict(split), (metric,)) qps = evaluator.evaluate(rankings) return float(np.mean([qp[metric] for qp in qps.values()])) """Example of using the pre-requisite code:""" # Loading dataset: from nlp4web_codebase.ir.data_loaders.sciq import load_sciq sciq = load_sciq() counting = run_counting(documents=iter(sciq.corpus), ndocs=len(sciq.corpus)) # Building BM25 index and save: bm25_index = BM25Index.build_from_documents( documents=iter(sciq.corpus), ndocs=12160, show_progress_bar=True ) bm25_index.save("output/bm25_index") # Loading index and use BM25 retriever to retrieve: bm25_retriever = BM25Retriever(index_dir="output/bm25_index") """# TASK2: CSC matrix and `CSCBM25Index` (12 points) Recall that we use Python lists to implement posting lists, mapping term IDs to the documents in which they appear. This is inefficient due to its naive design. Actually [Compressed Sparse Column matrix](https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csc_matrix.html) is very suitable for storing the posting lists and can boost the efficiency. ## TASK2.1: learn about `scipy.sparse.csc_matrix` (2 point) Convert the matrix \begin{bmatrix} 0 & 1 & 0 & 3 \\ 10 & 2 & 1 & 0 \\ 0 & 0 & 0 & 9 \end{bmatrix} to a `csc_matrix` by specifying `data`, `indices`, `indptr` and `shape`. """ from scipy.sparse._csc import csc_matrix input_matrix = [[0, 1, 0, 3], [10, 2, 1, 0], [0, 0, 0, 9]] data = None indices = None indptr = None shape = None ## YOUR_CODE_STARTS_HERE # Please assign the values to data, indices, indptr and shape # One can just do it in a hard-coded manner input_matrix_np = np.array(input_matrix).T non_z_idx = input_matrix_np.nonzero() data = input_matrix_np[non_z_idx] indices = non_z_idx[0] indptr = np.zeros(input_matrix_np.shape[1] + 1, dtype=int) np.add.at(indptr, non_z_idx[1] + 1, 1) indptr = np.cumsum(indptr) shape = input_matrix_np.shape ## YOUR_CODE_ENDS_HERE output_matrix = csc_matrix((data, indices, indptr), shape=shape) ## TEST_CASES (should be 3 and 11) print((output_matrix.indices + output_matrix.data).tolist()[2]) print((output_matrix.indices + output_matrix.data).tolist()[-1]) ## RESULT_CHECKING_POINT print((output_matrix.indices + output_matrix.data).tolist()) """## TASK2.2: implement `CSCBM25Index` (4 points) Implement `CSCBM25Index` by completing the missing code. Note that `CSCInvertedIndex` is similar to `InvertedIndex` which we talked about during the class. The main difference is posting lists are represented by a CSC sparse matrix. """ @dataclass class CSCInvertedIndex: posting_lists_matrix: csc_matrix # docid -> posting_list vocab: Dict[str, int] cid2docid: Dict[str, int] # collection_id -> docid collection_ids: List[str] # docid -> collection_id doc_texts: Optional[List[str]] = None # docid -> document text def save(self, output_dir: str) -> None: os.makedirs(output_dir, exist_ok=True) with open(os.path.join(output_dir, "index.pkl"), "wb") as f: pickle.dump(self, f) @classmethod def from_saved(cls: Type[T], saved_dir: str) -> T: index = cls( posting_lists_matrix=None, vocab={}, cid2docid={}, collection_ids=[], doc_texts=None ) with open(os.path.join(saved_dir, "index.pkl"), "rb") as f: index = pickle.load(f) return index @dataclass class CSCBM25Index(CSCInvertedIndex): @staticmethod def tokenize(text: str) -> List[str]: return simple_tokenize(text) @staticmethod def cache_term_weights( posting_lists: List[PostingList], total_docs: int, avgdl: float, dfs: List[int], dls: List[int], k1: float, b: float, ) -> csc_matrix: """Compute term weights and caching""" ## YOUR_CODE_STARTS_HERE N = total_docs data = [] indices = [] indptr = [0] shape = (N, len(posting_lists)) for tid, posting_list in enumerate( tqdm.tqdm(posting_lists, desc="Regularizing TFs") ): idf = BM25Index.calc_idf(df=dfs[tid], N=N) for i in range(len(posting_list.docid_postings)): docid = posting_list.docid_postings[i] tf = posting_list.tweight_postings[i] dl = dls[docid] regularized_tf = BM25Index.calc_regularized_tf( tf=tf, dl=dl, avgdl=avgdl, k1=k1, b=b ) data.append(regularized_tf * idf) indices.append(docid) indptr.append(len(indices)) return csc_matrix((np.array(data, dtype=np.float32), indices, indptr), shape=shape) ## YOUR_CODE_ENDS_HERE @staticmethod def calc_regularized_tf( tf: int, dl: float, avgdl: float, k1: float, b: float ) -> float: return tf / (tf + k1 * (1 - b + b * dl / avgdl)) @staticmethod def calc_idf(df: int, N: int): return math.log(1 + (N - df + 0.5) / (df + 0.5)) @classmethod def build_from_documents( cls: Type[CSCBM25Index], documents: Iterable[Document], store_raw: bool = True, output_dir: Optional[str] = None, ndocs: Optional[int] = None, show_progress_bar: bool = True, k1: float = 0.9, b: float = 0.4, ) -> CSCBM25Index: # Counting TFs, DFs, doc_lengths, etc.: counting = run_counting( documents=documents, tokenize_fn=CSCBM25Index.tokenize, store_raw=store_raw, ndocs=ndocs, show_progress_bar=show_progress_bar, ) # Compute term weights and caching: posting_lists = counting.posting_lists total_docs = len(counting.cid2docid) posting_lists_matrix = CSCBM25Index.cache_term_weights( posting_lists=posting_lists, total_docs=total_docs, avgdl=counting.avgdl, dfs=counting.dfs, dls=counting.dls, k1=k1, b=b, ) # Assembly and save: index = CSCBM25Index( posting_lists_matrix=posting_lists_matrix, vocab=counting.vocab, cid2docid=counting.cid2docid, collection_ids=counting.collection_ids, doc_texts=counting.doc_texts, ) return index class BaseCSCInvertedIndexRetriever(BaseRetriever): @property @abstractmethod def index_class(self) -> Type[CSCInvertedIndex]: pass def __init__(self, index_dir: str) -> None: self.index = self.index_class.from_saved(index_dir) def get_term_weights(self, query: str, cid: str) -> Dict[str, float]: ## YOUR_CODE_STARTS_HERE toks = self.index.tokenize(query) target_docid = self.index.cid2docid[cid] w_matrix = self.index.posting_lists_matrix term_weights = {} for tok in toks: if tok not in self.index.vocab: continue tid = self.index.vocab[tok] w = w_matrix[target_docid, tid] if w > 0: term_weights[tok] = w return term_weights ## YOUR_CODE_ENDS_HERE def score(self, query: str, cid: str) -> float: return sum(self.get_term_weights(query=query, cid=cid).values()) def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]: ## YOUR_CODE_STARTS_HERE toks = self.index.tokenize(query) docid2score: Dict[int, float] = {} for tok in toks: if tok not in self.index.vocab: continue tid = self.index.vocab[tok] posting_list = self.index.posting_lists_matrix[:, tid] for i, w in enumerate(posting_list): docid2score.setdefault(i, 0) docid2score[i] += w docid2score = dict( sorted(docid2score.items(), key=lambda pair: pair[1], reverse=True)[:topk] ) return { self.index.collection_ids[docid]: score.data[0] for docid, score in docid2score.items() if len(score.data) == 1 } ## YOUR_CODE_ENDS_HERE class CSCBM25Retriever(BaseCSCInvertedIndexRetriever): @property def index_class(self) -> Type[CSCBM25Index]: return CSCBM25Index import gradio as gr from typing import TypedDict class Hit(TypedDict): cid: str score: float text: str demo: Optional[gr.Interface] = None # Assign your gradio demo to this variable return_type = List[Hit] ## YOUR_CODE_STARTS_HERE csc_bm25_index = BM25Index.build_from_documents( documents=iter(sciq.corpus), ndocs=12160, show_progress_bar=False ) csc_bm25_index.save("output/csc_bm25_index") csc_bm25_retriever = BM25Retriever(index_dir="output/csc_bm25_index") def BM25_search(term): result = csc_bm25_retriever.retrieve(term) hits = [Hit(cid=k, score=v, text=sciq.corpus[bm25_index.cid2docid[k]].text) for k, v in result.items()] return hits demo = gr.Interface(fn=BM25_search, inputs=gr.Textbox(), outputs=gr.Textbox()) ## YOUR_CODE_ENDS_HERE demo.launch()