Add handler.py, start_emulator.sh and test scripts
Browse files- embed_single_query.sh +9 -0
- embed_two_chunks.sh +9 -0
- handler.py +65 -0
- start_emulator.sh +4 -0
- test_endpoint.py +67 -0
embed_single_query.sh
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/sh
|
2 |
+
set -x
|
3 |
+
|
4 |
+
curl \
|
5 |
+
--request POST \
|
6 |
+
--url http://localhost:4999 \
|
7 |
+
--header 'Content-Type: application/json' \
|
8 |
+
--data '{"inputs": "Please embed me"}' \
|
9 |
+
-w "\n"
|
embed_two_chunks.sh
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/sh
|
2 |
+
set -x
|
3 |
+
|
4 |
+
curl \
|
5 |
+
--request POST \
|
6 |
+
--url http://localhost:4999 \
|
7 |
+
--header 'Content-Type: application/json' \
|
8 |
+
--data '{"inputs": ["Please embed me", "And me too, please!"]}' \
|
9 |
+
-w "\n"
|
handler.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, List
|
2 |
+
|
3 |
+
from colbert.infra import ColBERTConfig
|
4 |
+
from colbert.modeling.checkpoint import Checkpoint
|
5 |
+
import torch
|
6 |
+
import logging
|
7 |
+
|
8 |
+
logger = logging.getLogger(__name__)
|
9 |
+
|
10 |
+
MODEL = "fdurant/colbert-xm-for-inference-api"
|
11 |
+
|
12 |
+
class EndpointHandler():
|
13 |
+
|
14 |
+
def __init__(self, path=""):
|
15 |
+
self._config = ColBERTConfig(
|
16 |
+
# Defaults copied from https://github.com/datastax/ragstack-ai/blob/main/libs/colbert/ragstack_colbert/colbert_embedding_model.py
|
17 |
+
doc_maxlen=512, # Maximum number of tokens for document chunks. Should equal the chunk_size.
|
18 |
+
nbits=2, # The number bits that each dimension encodes to.
|
19 |
+
kmeans_niters=4, # Number of iterations for k-means clustering during quantization.
|
20 |
+
nranks=-1, # Number of ranks (processors) to use for distributed computing; -1 uses all available CPUs/GPUs.
|
21 |
+
checkpoint=MODEL,
|
22 |
+
)
|
23 |
+
self._checkpoint = Checkpoint(self._config.checkpoint, colbert_config=self._config, verbose=3)
|
24 |
+
|
25 |
+
def __call__(self, data: Any) -> List[Dict[str, Any]]:
|
26 |
+
inputs = data["inputs"]
|
27 |
+
texts = []
|
28 |
+
if isinstance(inputs, str):
|
29 |
+
texts = [inputs]
|
30 |
+
elif isinstance(inputs, list) and all(isinstance(text, str) for text in inputs):
|
31 |
+
texts = inputs
|
32 |
+
else:
|
33 |
+
raise ValueError("Invalid input data format")
|
34 |
+
with torch.inference_mode():
|
35 |
+
|
36 |
+
if len(texts) == 1:
|
37 |
+
# It's a query
|
38 |
+
logger.info(f"Query: {texts}")
|
39 |
+
embedding = self._checkpoint.queryFromText(
|
40 |
+
queries=texts,
|
41 |
+
full_length_search=False, # Indicates whether to encode the query for a full-length search.
|
42 |
+
)
|
43 |
+
logger.info(f"Query embedding shape: {embedding.shape}")
|
44 |
+
return [
|
45 |
+
{"input": inputs, "query_embedding": embedding.tolist()[0]}
|
46 |
+
]
|
47 |
+
elif len(texts) > 1:
|
48 |
+
# It's a batch of chunks
|
49 |
+
logger.info(f"Batch of chunks: {texts}")
|
50 |
+
embeddings, token_counts = self._checkpoint.docFromText(
|
51 |
+
docs=texts,
|
52 |
+
bsize=self._config.bsize, # Batch size
|
53 |
+
keep_dims=True, # Do NOT flatten the embeddings
|
54 |
+
return_tokens=True, # Return the tokens as well
|
55 |
+
)
|
56 |
+
for text, embedding, token_count in zip(texts, embeddings, token_counts):
|
57 |
+
logger.info(f"Chunk: {text}")
|
58 |
+
logger.info(f"Chunk embedding shape: {embedding.shape}")
|
59 |
+
logger.info(f"Chunk count: {token_count}")
|
60 |
+
return [
|
61 |
+
{"input": _input, "chunk_embedding": embedding.tolist(), "token_count": token_count.tolist()}
|
62 |
+
for _input, embedding, token_count in zip(texts, embeddings, token_counts)
|
63 |
+
]
|
64 |
+
else:
|
65 |
+
raise ValueError("No data to process")
|
start_emulator.sh
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash -e
|
2 |
+
export SHELL=/bin/bash
|
3 |
+
|
4 |
+
hf-endpoints-emulator "$@"
|
test_endpoint.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pytest
|
3 |
+
import requests
|
4 |
+
|
5 |
+
URL = "http://localhost:4999/"
|
6 |
+
HEADERS = {"Content-Type": "application/json"}
|
7 |
+
|
8 |
+
def test_returns_200():
|
9 |
+
payload = {"inputs": "try me"}
|
10 |
+
|
11 |
+
response = requests.request("POST", URL, json=payload, headers=HEADERS)
|
12 |
+
|
13 |
+
assert response.status_code == 200
|
14 |
+
|
15 |
+
def test_query_returns_expected_result():
|
16 |
+
query = "try me"
|
17 |
+
payload = {"inputs": query}
|
18 |
+
|
19 |
+
response = requests.request("POST", URL, json=payload, headers=HEADERS)
|
20 |
+
response_data = response.json()
|
21 |
+
|
22 |
+
# print(response_data)
|
23 |
+
|
24 |
+
# Check structure and input
|
25 |
+
assert isinstance(response_data, list)
|
26 |
+
assert len(response_data) == 1
|
27 |
+
assert isinstance(response_data[0], dict)
|
28 |
+
assert response_data[0].get("input") == query
|
29 |
+
|
30 |
+
# Check query embedding (actually a list of embeddings, one per token in the query)
|
31 |
+
query_embedding = response_data[0].get("query_embedding")
|
32 |
+
assert isinstance(query_embedding, list)
|
33 |
+
assert len(query_embedding) == 32
|
34 |
+
|
35 |
+
# Check first of the token embeddings
|
36 |
+
first_token_embedding = query_embedding[0]
|
37 |
+
assert isinstance(first_token_embedding, list)
|
38 |
+
assert len(first_token_embedding) == 128
|
39 |
+
assert all(isinstance(value, float) for value in first_token_embedding)
|
40 |
+
|
41 |
+
def test_batch_returns_expected_result():
|
42 |
+
chunks = ["try me", "try me again and again and again"]
|
43 |
+
expected_token_counts = [11, 11] # Including start and stop tokens, I presume. Not exactly clear!
|
44 |
+
payload = {"inputs": chunks}
|
45 |
+
|
46 |
+
response = requests.request("POST", URL, json=payload, headers=HEADERS)
|
47 |
+
response_data = response.json()
|
48 |
+
|
49 |
+
# Check structure
|
50 |
+
assert isinstance(response_data, list)
|
51 |
+
assert len(response_data) == len(chunks)
|
52 |
+
|
53 |
+
for i, response_chunk in enumerate(response_data):
|
54 |
+
# Check input
|
55 |
+
assert response_chunk.get("input") == chunks[i]
|
56 |
+
|
57 |
+
# Check chunk embedding (actually a list of embeddings, one per token in the chunk)
|
58 |
+
chunk_embedding = response_chunk.get("chunk_embedding")
|
59 |
+
token_count = response_chunk.get("token_count")
|
60 |
+
assert isinstance(chunk_embedding, list)
|
61 |
+
assert len(chunk_embedding) == len(token_count)
|
62 |
+
assert len(token_count) == expected_token_counts[i]
|
63 |
+
|
64 |
+
# Check first of the token embeddings
|
65 |
+
first_token_embedding = chunk_embedding[0]
|
66 |
+
assert len(first_token_embedding) == 128
|
67 |
+
assert all(isinstance(value, float) for value in first_token_embedding)
|