|
from typing import Any, Dict, List |
|
|
|
from colbert.infra import ColBERTConfig |
|
from colbert.modeling.checkpoint import Checkpoint |
|
import torch |
|
import logging |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
MODEL = "fdurant/colbert-xm-for-inference-api" |
|
|
|
class EndpointHandler(): |
|
|
|
def __init__(self, path=""): |
|
self._config = ColBERTConfig( |
|
|
|
doc_maxlen=512, |
|
nbits=2, |
|
kmeans_niters=4, |
|
nranks=-1, |
|
checkpoint=MODEL, |
|
) |
|
self._checkpoint = Checkpoint(self._config.checkpoint, colbert_config=self._config, verbose=3) |
|
|
|
def __call__(self, data: Any) -> List[Dict[str, Any]]: |
|
inputs = data["inputs"] |
|
texts = [] |
|
if isinstance(inputs, str): |
|
texts = [inputs] |
|
elif isinstance(inputs, list) and all(isinstance(text, str) for text in inputs): |
|
texts = inputs |
|
else: |
|
raise ValueError("Invalid input data format") |
|
with torch.inference_mode(): |
|
|
|
if len(texts) == 1: |
|
|
|
logger.info(f"Query: {texts}") |
|
embedding = self._checkpoint.queryFromText( |
|
queries=texts, |
|
full_length_search=False, |
|
) |
|
logger.info(f"Query embedding shape: {embedding.shape}") |
|
return [ |
|
{"input": inputs, "query_embedding": embedding.tolist()[0]} |
|
] |
|
elif len(texts) > 1: |
|
|
|
logger.info(f"Batch of chunks: {texts}") |
|
embeddings, token_counts = self._checkpoint.docFromText( |
|
docs=texts, |
|
bsize=self._config.bsize, |
|
keep_dims=True, |
|
return_tokens=True, |
|
) |
|
for text, embedding, token_count in zip(texts, embeddings, token_counts): |
|
logger.info(f"Chunk: {text}") |
|
logger.info(f"Chunk embedding shape: {embedding.shape}") |
|
logger.info(f"Chunk count: {token_count}") |
|
return [ |
|
{"input": _input, "chunk_embedding": embedding.tolist(), "token_count": token_count.tolist()} |
|
for _input, embedding, token_count in zip(texts, embeddings, token_counts) |
|
] |
|
else: |
|
raise ValueError("No data to process") |
|
|