Websocket for GBERT
Hi!
Is it possible to use a websocket in python to obtain word embeddings through the Inference API for the GBERT-base model?
I have been able to collect word embeddings through individual HTTP requests for every sentence. However, that is very time consuming. I've been struggling with implementing a websocket (as in https://huggingface.co/docs/api-inference/parallelism) to obtain word embeddings – the websocket only works for multiple sentences with a [MASK] token (wss://api-inference.huggingface.co/bulk/stream/cpu/deepset/gbert-base). The websocket wss for GBERT-base does not seem to work, though (wss://api-inference.huggingface.co/pipeline/feature-extraction/deepset/gbert-base).
Any ideas? Thank you for your help in advance! (Response in German would also be fine with me)
Hi @lucazug ,
Just tested with the following python code using websockets and worked fine
But I'd recommend you to look into our Inference Endpoints Solution
https://ui.endpoints.huggingface.co/new?repository=deepset/gbert-base
Here HUGGING_FACE_HUB_TOKEN
is your access token
HUGGING_FACE_HUB_TOKEN=hf_XXXXXXXXXX python app.py
import os
import asyncio
import json
import uuid
import websockets
MODEL_ID = "deepset/gbert-base"
COMPUTE_TYPE = "cpu" # or "gpu"
API_TOKEN = os.environ["HUGGING_FACE_HUB_TOKEN"]
async def send(websocket, payloads):
# You need to login with a first message as headers are not forwarded
# for websockets
await websocket.send(f"Bearer {API_TOKEN}".encode("utf-8"))
for payload in payloads:
await websocket.send(json.dumps(payload).encode("utf-8"))
print("Sent ")
async def recv(websocket, last_id):
outputs = []
while True:
data = await websocket.recv()
payload = json.loads(data)
if payload["type"] == "results":
# {"type": "results", "outputs": JSONFormatted results, "id": the id we sent}
print(payload["outputs"])
outputs.append(payload["outputs"])
if payload["id"] == last_id:
return outputs
else:
# {"type": "status", "message": "Some information about the queue"}
print(f"< {payload['message']}")
pass
async def main():
uri = f"wss://api-inference.huggingface.co/bulk/stream/{COMPUTE_TYPE}/{MODEL_ID}"
print(f"Connecting to {uri}")
async with websockets.connect(uri) as websocket:
# inputs and parameters are classic, "id" is a way to track that query
payloads = [
{
"id": str(uuid.uuid4()),
"inputs": "Das Ziel des Lebens ist [MASK].",
}
for i in range(10)
]
last_id = payloads[-1]["id"]
future = send(websocket, payloads)
future_r = recv(websocket, last_id)
_, outputs = await asyncio.gather(future, future_r)
results = [out["labels"][0] for out in outputs]
return results
loop = asyncio.get_event_loop()
if loop.is_running():
# When running in notebooks
import nest_asyncio
nest_asyncio.apply()
results = loop.run_until_complete(main())
[
{'score': 0.25408321619033813, 'token': 3531, 'token_str': 'erreicht', 'sequence': 'Das Ziel des Lebens ist erreicht.'
},
{'score': 0.04424438253045082, 'token': 7939, 'token_str': 'erfüllt', 'sequence': 'Das Ziel des Lebens ist erfüllt.'
},
{'score': 0.031005313619971275, 'token': 199, 'token_str': 'das', 'sequence': 'Das Ziel des Lebens ist das.'
},
{'score': 0.029783131554722786, 'token': 288, 'token_str': 'es', 'sequence': 'Das Ziel des Lebens ist es.'
},
{'score': 0.021744653582572937, 'token': 2458, 'token_str': 'klar', 'sequence': 'Das Ziel des Lebens ist klar.'
}
]
Hi Radames,
thank you for your quick reply. The websocket solution works fine for the input with a sentence and a masked token and the output you inserted above for me too. However, I want to use a websocket to 'upload' a list of sentences (without masked tokens) and have the Inference API and the GBERT model return word embeddings (list of list of 768 elements). I have successfully implemented this using HTTP requests like this:
for text in X_train:
response = requests.post(f"https://api-inference.huggingface.co/pipeline/feature-extraction/deepset/gbert-base", headers=headers, json={"inputs": text, "options":{"wait_for_model":True}})
embeddings = response.json()[0]
inputs_train.append(embeddings)
Sadly, this is slow and expensive. Now I'm looking for a websocket solution for this problem and have been unable to find one online. Is there a GBERT-base URI for a websocket application that would let me call word embeddings trough the Inference API?
Sorry for not being more specific in the first place!
hi
@lucazug
, sorry I didn't understand your request first. Indeed you can't override the the task via the websocket endpoint like you're doing via HTTP.
The recommended way is to clone the model and set a new pipeline_tag: feature-extraction
on the README.md
metadata headers.
I've tested this approach with the code above and I got the embeddings.
---
language: de
license: mit
tags:
- feature-extraction
pipeline_tag: feature-extraction
datasets:
- wikipedia
- OPUS
- OpenLegalData
---