Spaces:
Runtime error
Runtime error
from typing import Union | |
from colbert import Searcher | |
from colbert.data import Queries | |
from colbert.infra.config import ColBERTConfig | |
TextQueries = Union[str, 'list[str]', 'dict[int, str]', Queries] | |
class HopSearcher(Searcher): | |
def __init__(self, *args, config=None, interaction='flipr', **kw_args): | |
defaults = ColBERTConfig(query_maxlen=64, interaction=interaction) | |
config = ColBERTConfig.from_existing(defaults, config) | |
super().__init__(*args, config=config, **kw_args) | |
def encode(self, text: TextQueries, context: TextQueries): | |
queries = text if type(text) is list else [text] | |
context = context if context is None or type(context) is list else [context] | |
bsize = 128 if len(queries) > 128 else None | |
self.checkpoint.query_tokenizer.query_maxlen = self.config.query_maxlen | |
Q = self.checkpoint.queryFromText(queries, context=context, bsize=bsize, to_cpu=True) | |
return Q | |
def search(self, text: str, context: str, k=10): | |
return self.dense_search(self.encode(text, context), k) | |
def search_all(self, queries: TextQueries, context: TextQueries, k=10): | |
queries = Queries.cast(queries) | |
context = Queries.cast(context) if context is not None else context | |
queries_ = list(queries.values()) | |
context_ = list(context.values()) if context is not None else context | |
Q = self.encode(queries_, context_) | |
return self._search_all_Q(queries, Q, k) | |