|
from typing import Any, Dict, List |
|
|
|
from colbert.infra import ColBERTConfig |
|
from colbert.modeling.checkpoint import Checkpoint |
|
import torch |
|
import logging |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
MODEL = "fdurant/colbert-xm-for-inference-api" |
|
|
|
class EndpointHandler(): |
|
|
|
def __init__(self, path=""): |
|
self._config = ColBERTConfig( |
|
|
|
doc_maxlen=512, |
|
nbits=2, |
|
kmeans_niters=4, |
|
nranks=-1, |
|
checkpoint=MODEL, |
|
) |
|
self._checkpoint = Checkpoint(self._config.checkpoint, colbert_config=self._config, verbose=3) |
|
|
|
def __call__(self, data: Any) -> List[Dict[str, Any]]: |
|
""" |
|
data args: |
|
inputs (:obj: `str`) |
|
Return: |
|
A :obj:`list` : will be serialized and returned. |
|
When the input is a single query string, the returned list will contain a single dictionary with: |
|
- input (:obj: `str`) : The input query. |
|
- query_embedding (:obj: `list`) : The query embedding of shape (1, 32, 128). |
|
When the input is a batch (= list) of chunk strings, the returned list will contain a dictionary for each chunk: |
|
- input (:obj: `str`) : The input chunk. |
|
- chunk_embedding (:obj: `list`) : The chunk embedding of shape (1, num_tokens, 128) |
|
- token_ids (:obj: `list`) : The token ids. |
|
- token_list (:obj: `list`) : The token list. |
|
""" |
|
inputs = data["inputs"] |
|
texts = [] |
|
if isinstance(inputs, str): |
|
texts = [inputs] |
|
elif isinstance(inputs, list) and all(isinstance(text, str) for text in inputs): |
|
texts = inputs |
|
else: |
|
raise ValueError("Invalid input data format") |
|
with torch.inference_mode(): |
|
|
|
if len(texts) == 1: |
|
|
|
logger.info(f"Received query of 1 text with {len(texts[0])} characters and {len(texts[0].split())} words") |
|
embedding = self._checkpoint.queryFromText( |
|
queries=texts, |
|
full_length_search=False, |
|
) |
|
logger.info(f"Query embedding shape: {embedding.shape}") |
|
return [ |
|
{"input": inputs, "query_embedding": embedding.tolist()[0]} |
|
] |
|
elif len(texts) > 1: |
|
|
|
logger.info(f"Received batch of {len(texts)} chunks") |
|
for i, text in enumerate(texts): |
|
logger.info(f"Chunk {i} has {len(text)} characters and {len(text.split())} words") |
|
embeddings, token_id_lists = self._checkpoint.docFromText( |
|
docs=texts, |
|
bsize=self._config.bsize, |
|
keep_dims=True, |
|
return_tokens=True, |
|
) |
|
logger.info(f"Chunk embeddings shape: {embeddings.shape}") |
|
token_lists = [] |
|
for text, embedding, token_ids in zip(texts, embeddings, token_id_lists): |
|
logger.debug(f"Chunk: {text}") |
|
logger.debug(f"Chunk embedding shape: {embedding.shape}") |
|
logger.debug(f"Chunk token ids: {token_ids}") |
|
token_list = self._checkpoint.doc_tokenizer.tok.convert_ids_to_tokens(token_ids) |
|
token_lists.append(token_list) |
|
logger.debug(f"Chunk tokens: {token_list}") |
|
|
|
|
|
return [ |
|
{"input": _input, "chunk_embedding": embedding.tolist(), "token_ids": token_ids.tolist(), "token_list": token_list} |
|
for _input, embedding, token_ids, token_list in zip(texts, embeddings, token_id_lists, token_lists) |
|
] |
|
else: |
|
raise ValueError("No data to process") |
|
|