Websocket for GBERT

#2
by lucazug - opened

Hi!
Is it possible to use a websocket in python to obtain word embeddings through the Inference API for the GBERT-base model?

I have been able to collect word embeddings through individual HTTP requests for every sentence. However, that is very time consuming. I've been struggling with implementing a websocket (as in https://huggingface.co/docs/api-inference/parallelism) to obtain word embeddings – the websocket only works for multiple sentences with a [MASK] token (wss://api-inference.huggingface.co/bulk/stream/cpu/deepset/gbert-base). The websocket wss for GBERT-base does not seem to work, though (wss://api-inference.huggingface.co/pipeline/feature-extraction/deepset/gbert-base).

Any ideas? Thank you for your help in advance! (Response in German would also be fine with me)

Hi @lucazug ,

Just tested with the following python code using websockets and worked fine
But I'd recommend you to look into our Inference Endpoints Solution
https://ui.endpoints.huggingface.co/new?repository=deepset/gbert-base

Here HUGGING_FACE_HUB_TOKEN is your access token

HUGGING_FACE_HUB_TOKEN=hf_XXXXXXXXXX python app.py
import os
import asyncio
import json
import uuid
import websockets
MODEL_ID = "deepset/gbert-base"
COMPUTE_TYPE = "cpu"  # or "gpu"
API_TOKEN = os.environ["HUGGING_FACE_HUB_TOKEN"]

async def send(websocket, payloads):
    # You need to login with a first message as headers are not forwarded
    # for websockets
    await websocket.send(f"Bearer {API_TOKEN}".encode("utf-8"))
    for payload in payloads:
        await websocket.send(json.dumps(payload).encode("utf-8"))
        print("Sent ")
async def recv(websocket, last_id):
    outputs = []
    while True:
        data = await websocket.recv()
        payload = json.loads(data)
        if payload["type"] == "results":
            # {"type": "results", "outputs": JSONFormatted results, "id": the id we sent}
            print(payload["outputs"])
            outputs.append(payload["outputs"])
            if payload["id"] == last_id:
                return outputs
        else:
            # {"type": "status", "message": "Some information about the queue"}
            print(f"< {payload['message']}")
            pass
async def main():
    uri = f"wss://api-inference.huggingface.co/bulk/stream/{COMPUTE_TYPE}/{MODEL_ID}"
    print(f"Connecting to {uri}")
    async with websockets.connect(uri) as websocket:
        # inputs and parameters are classic, "id" is a way to track that query
        payloads = [
            {
                "id": str(uuid.uuid4()),
                "inputs": "Das Ziel des Lebens ist [MASK].",
            }
            for i in range(10)
        ]
        last_id = payloads[-1]["id"]
        future = send(websocket, payloads)
        future_r = recv(websocket, last_id)
        _, outputs = await asyncio.gather(future, future_r)
    results = [out["labels"][0] for out in outputs]
    return results
loop = asyncio.get_event_loop()
if loop.is_running():
    # When running in notebooks
    import nest_asyncio
    nest_asyncio.apply()
results = loop.run_until_complete(main())
[
    {'score': 0.25408321619033813, 'token': 3531, 'token_str': 'erreicht', 'sequence': 'Das Ziel des Lebens ist erreicht.'
    },
    {'score': 0.04424438253045082, 'token': 7939, 'token_str': 'erfüllt', 'sequence': 'Das Ziel des Lebens ist erfüllt.'
    },
    {'score': 0.031005313619971275, 'token': 199, 'token_str': 'das', 'sequence': 'Das Ziel des Lebens ist das.'
    },
    {'score': 0.029783131554722786, 'token': 288, 'token_str': 'es', 'sequence': 'Das Ziel des Lebens ist es.'
    },
    {'score': 0.021744653582572937, 'token': 2458, 'token_str': 'klar', 'sequence': 'Das Ziel des Lebens ist klar.'
    }
]

Hi Radames,
thank you for your quick reply. The websocket solution works fine for the input with a sentence and a masked token and the output you inserted above for me too. However, I want to use a websocket to 'upload' a list of sentences (without masked tokens) and have the Inference API and the GBERT model return word embeddings (list of list of 768 elements). I have successfully implemented this using HTTP requests like this:

for text in X_train:
response = requests.post(f"https://api-inference.huggingface.co/pipeline/feature-extraction/deepset/gbert-base", headers=headers, json={"inputs": text, "options":{"wait_for_model":True}})
embeddings = response.json()[0]
inputs_train.append(embeddings)

Sadly, this is slow and expensive. Now I'm looking for a websocket solution for this problem and have been unable to find one online. Is there a GBERT-base URI for a websocket application that would let me call word embeddings trough the Inference API?

Sorry for not being more specific in the first place!

hi @lucazug , sorry I didn't understand your request first. Indeed you can't override the the task via the websocket endpoint like you're doing via HTTP.
The recommended way is to clone the model and set a new pipeline_tag: feature-extraction on the README.md metadata headers.
I've tested this approach with the code above and I got the embeddings.

---
language: de
license: mit
tags:
  - feature-extraction
pipeline_tag: feature-extraction
datasets:
- wikipedia
- OPUS
- OpenLegalData
--- 

Sign up or log in to comment