{ "cells": [ { "cell_type": "markdown", "id": "5d9aca72-957a-4ee2-862f-e011b9cd3a62", "metadata": {}, "source": [ "# Introduction\n", "## Goal\n", "I have a dataset I want to embed for semantic search (or QA, or RAG), I want the easiest way to do embed this and put it in a new dataset.\n", "\n", "## Approach\n", "Im using a dataset from my favorite subreddit [r/bestofredditorupdates](). Since it has such long entries, I will use the new [jinaai/jina-embeddings-v2-base-en](https://huggingface.co/jinaai/jina-embeddings-v2-base-en) since it has an 8k context length. Since Im GPU-poor I will deploy this using [Inference Endpoint](https://huggingface.co/inference-endpoints) to save money and time. To follow this you will need to add a payment method. To make it even easier, I'll make this fully API based." ] }, { "cell_type": "markdown", "id": "d2534669-003d-490c-9d7a-32607fa5f404", "metadata": {}, "source": [ "# Setup" ] }, { "cell_type": "markdown", "id": "b6f72042-173d-4a72-ade1-9304b43b528d", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": 1, "id": "e2beecdd-d033-4736-bd45-6754ec53b4ac", "metadata": { "tags": [] }, "outputs": [], "source": [ "import asyncio\n", "from getpass import getpass\n", "import json\n", "from pathlib import Path\n", "import time\n", "\n", "from aiohttp import ClientSession, ClientTimeout\n", "from datasets import load_dataset, Dataset, DatasetDict\n", "from huggingface_hub import notebook_login\n", "import pandas as pd\n", "import requests\n", "from tqdm.auto import tqdm" ] }, { "cell_type": "markdown", "id": "5eece903-64ce-435d-a2fd-096c0ff650bf", "metadata": {}, "source": [ "## Config\n", "You need to fill this in with your desired repos. Note I used 5 for the `MAX_WORKERS` since `jina-embeddings-v2` are quite memory hungry. " ] }, { "cell_type": "code", "execution_count": 2, "id": "dcd7daed-6aca-4fe7-85ce-534bdcd8bc87", "metadata": { "tags": [] }, "outputs": [], "source": [ "dataset_in = 'derek-thomas/dataset-creator-reddit-bestofredditorupdates'\n", "dataset_out = \"processed-bestofredditorupdates\"\n", "endpoint_name = \"boru-jina-embeddings-demo\"\n", "\n", "MAX_WORKERS = 5 " ] }, { "cell_type": "code", "execution_count": 3, "id": "88cdbd73-5923-4ae9-9940-b6be935f70fa", "metadata": { "tags": [] }, "outputs": [ { "name": "stdin", "output_type": "stream", "text": [ "What is your Hugging Face 馃 username? (with a credit card) 路路路路路路路路\n", "What is your Hugging Face 馃 token? 路路路路路路路路\n" ] } ], "source": [ "username = getpass(prompt=\"What is your Hugging Face 馃 username? (with an added payment method)\")\n", "hf_token = getpass(prompt='What is your Hugging Face 馃 token?')" ] }, { "cell_type": "markdown", "id": "b972a719-2aed-4d2e-a24f-fae7776d5fa4", "metadata": {}, "source": [ "## Get Dataset" ] }, { "cell_type": "code", "execution_count": 4, "id": "27835fa4-3a4f-44b1-a02a-5e31584a1bba", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "Dataset({\n", " features: ['date_utc', 'title', 'flair', 'content', 'poster', 'permalink', 'id', 'content_length', 'score'],\n", " num_rows: 9991\n", "})" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset = load_dataset(dataset_in, token=hf_token)\n", "dataset['train']" ] }, { "cell_type": "code", "execution_count": 5, "id": "8846087e-4d0d-4c0e-8aeb-ea95d9e97126", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "(9991,\n", " {'date_utc': Timestamp('2022-12-31 18:16:22'),\n", " 'title': 'To All BORU contributors, Thank you :)',\n", " 'flair': 'CONCLUDED',\n", " 'content': '[removed]',\n", " 'poster': 'IsItAcOnSeQuEnCe',\n", " 'permalink': '/r/BestofRedditorUpdates/comments/10004zw/to_all_boru_contributors_thank_you/',\n", " 'id': '10004zw',\n", " 'content_length': 9,\n", " 'score': 1})" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "documents = dataset['train'].to_pandas().to_dict('records')\n", "len(documents), documents[0]" ] }, { "cell_type": "markdown", "id": "93096cbc-81c6-4137-a283-6afb0f48fbb9", "metadata": {}, "source": [ "# Inference Endpoints\n", "## Create Inference Endpoint\n", "We are going to use the [API](https://huggingface.co/docs/inference-endpoints/api_reference) to create an [Inference Endpoint](https://huggingface.co/inference-endpoints). This should provide a few main benefits:\n", "- It's convenient (No clicking)\n", "- It's repeatable (We have the code to run it easily)\n", "- It's cheaper (No time spent waiting for it to load, and automatically shut it down)" ] }, { "cell_type": "code", "execution_count": 6, "id": "3a8f67b9-6ac6-4b5e-91ee-e48463191e1b", "metadata": { "tags": [] }, "outputs": [], "source": [ "headers = {\n", "\t\"Authorization\": f\"Bearer {hf_token}\",\n", "\t\"Content-Type\": \"application/json\"\n", "}\n", "base_url = f\"https://api.endpoints.huggingface.cloud/v2/endpoint/{username}\"\n", "endpoint_url = f\"https://api.endpoints.huggingface.cloud/v2/endpoint/{username}/{endpoint_name}\"" ] }, { "cell_type": "markdown", "id": "0f2c97dc-34e8-49e9-b60e-f5b7366294c0", "metadata": {}, "source": [ "There are a few design choices here:\n", "- I'm using the `g5.2xlarge` since it is big and `jina-embeddings-v2` are memory hungry (remember the 8k context length). \n", "- I didnt alter the default `MAX_BATCH_TOKENS` or `MAX_CONCURRENT_REQUESTS`\n", " - You should consider this if you are making this production ready\n", " - You will need to restrict these to match the HW you are running on\n", "- As mentioned before, I chose the repo and the corresponding revision\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "f1ea29cb-b69d-4340-859f-3646d650c68e", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "202\n" ] } ], "source": [ "data = {\n", " \"accountId\": None,\n", " \"compute\": {\n", " \"accelerator\": \"gpu\",\n", " \"instanceType\": \"g5.2xlarge\",\n", " \"instanceSize\": \"medium\",\n", " \"scaling\": {\n", " \"maxReplica\": 1,\n", " \"minReplica\": 1\n", " }\n", " },\n", " \"model\": {\n", " \"framework\": \"pytorch\",\n", " \"image\": {\n", " \"custom\": {\n", " \"url\": \"ghcr.io/huggingface/text-embeddings-inference:0.3.0\",\n", " \"health_route\": \"/health\",\n", " \"env\": {\n", " \"MAX_BATCH_TOKENS\": \"16384\",\n", " \"MAX_CONCURRENT_REQUESTS\": \"512\",\n", " \"MODEL_ID\": \"/repository\"\n", " }\n", " }\n", " },\n", " \"repository\": \"jinaai/jina-embeddings-v2-base-en\",\n", " \"revision\": \"8705ed9657208b2d5220fffad1c3a30980d279d0\",\n", " \"task\": \"sentence-embeddings\",\n", " },\n", " \"name\": endpoint_name,\n", " \"provider\": {\n", " \"region\": \"us-east-1\",\n", " \"vendor\": \"aws\"\n", " },\n", " \"type\": \"protected\"\n", "}\n", "\n", "response = requests.post(base_url, headers={**headers, 'accept': 'application/json'}, json=data)\n", "\n", "\n", "print(response.status_code)" ] }, { "cell_type": "markdown", "id": "96d173b2-8980-4554-9039-c62843d3fc7d", "metadata": {}, "source": [ "## Wait until its running\n", "Here we use `tqdm` as a pretty way of displaying our status. It took about ~30s for this model to get the Inference Endpoint running." ] }, { "cell_type": "code", "execution_count": 8, "id": "b8aa66a9-3c8a-4040-9465-382c744f36cf", "metadata": { "tags": [] }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a6f27d86f68b4000aa40e09ae079c6b0", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Waiting for status to change: 0s [00:00, ?s/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Status is 'running'.\n" ] } ], "source": [ "with tqdm(desc=\"Waiting for status to change\", unit=\"s\") as pbar:\n", " while True:\n", " response_json = requests.get(endpoint_url, headers=headers).json()\n", " current_status = response_json['status']['state']\n", "\n", " if current_status == 'running':\n", " print(\"Status is 'running'.\")\n", " break\n", "\n", " pbar.set_description(f\"Status: {current_status}\")\n", " time.sleep(2)\n", " pbar.update(1)\n", "\n", "embedding_url = response_json['status']['url']" ] }, { "cell_type": "markdown", "id": "063fa066-e4d0-4a65-a82d-cf17db4af8d8", "metadata": {}, "source": [ "I found that even though the status is running, I want to get a test message to run first before running our batch in parallel." ] }, { "cell_type": "code", "execution_count": 9, "id": "66e00960-1d3d-490d-bedc-3eaf1924db76", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "4e03e5a3d07a498ca6b3631605724b62", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Waiting for endpoint to accept requests: 0s [00:00, ?s/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Endpoint is accepting requests\n" ] } ], "source": [ "payload = {\"inputs\": \"This sound track was beautiful! It paints the senery in your mind so well I would recomend it even to people who hate vid. game music!\"}\n", "\n", "with tqdm(desc=\"Waiting for endpoint to accept requests\", unit=\"s\") as pbar:\n", " while True:\n", " try:\n", " response_json = requests.post(embedding_url, headers=headers, json=payload).json()\n", "\n", " # Assuming the successful response has a specific structure\n", " if len(response_json[0]) == 768:\n", " print(\"Endpoint is accepting requests\")\n", " break\n", "\n", " except requests.ConnectionError as e:\n", " pass\n", "\n", " # Delay between retries\n", " time.sleep(5)\n", " pbar.update(1)\n" ] }, { "cell_type": "markdown", "id": "f7186126-ef6a-47d0-b158-112810649cd9", "metadata": {}, "source": [ "# Get Embeddings" ] }, { "cell_type": "markdown", "id": "1dadfd68-6d46-4ce8-a165-bfeb43b1f114", "metadata": {}, "source": [ "Here I send a document, update it with the embedding, and return it. This happens in parallel with `MAX_WORKERS`." ] }, { "cell_type": "code", "execution_count": 10, "id": "ad3193fb-3def-42a8-968e-c63f2b864ca8", "metadata": { "tags": [] }, "outputs": [], "source": [ "async def request(document, semaphore):\n", " # Semaphore guard\n", " async with semaphore:\n", " payload = {\n", " \"inputs\": document['content'] or document['title'] or '[deleted]',\n", " \"truncate\": True\n", " }\n", " \n", " timeout = ClientTimeout(total=10) # Set a timeout for requests (10 seconds here)\n", "\n", " async with ClientSession(timeout=timeout, headers=headers) as session:\n", " async with session.post(embedding_url, json=payload) as resp:\n", " if resp.status != 200:\n", " raise RuntimeError(await resp.text())\n", " result = await resp.json()\n", " \n", " document['embedding'] = result[0] # Assuming the API's output can be directly assigned\n", " return document\n", "\n", "async def main(documents):\n", " # Semaphore to limit concurrent requests. Adjust the number as needed.\n", " semaphore = asyncio.BoundedSemaphore(MAX_WORKERS)\n", "\n", " # Creating a list of tasks\n", " tasks = [request(document, semaphore) for document in documents]\n", " \n", " # Using tqdm to show progress. It's been integrated into the async loop.\n", " for f in tqdm(asyncio.as_completed(tasks), total=len(documents)):\n", " await f" ] }, { "cell_type": "code", "execution_count": 11, "id": "ec4983af-65eb-4841-808a-3738fb4d682d", "metadata": { "tags": [] }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "cb73af52244e40d2aab8bdac3a55d443", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/9991 [00:00 `pd.DataFrame` -> `Dataset`" ] }, { "cell_type": "code", "execution_count": 13, "id": "9bb993f8-d624-4192-9626-8e9ed9888a1b", "metadata": { "tags": [] }, "outputs": [], "source": [ "df = pd.DataFrame(documents)\n", "dd = DatasetDict({'train': Dataset.from_pandas(df)})" ] }, { "cell_type": "code", "execution_count": 14, "id": "f48e7c55-d5b7-4ed6-8516-272ae38716b1", "metadata": { "tags": [] }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "84a481e0cf74494cb2eb9d9857701212", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Pushing dataset shards to the dataset hub: 0%| | 0/1 [00:00