jpohhhh commited on
Commit
2b69a61
1 Parent(s): 6f39dac

Upload philschmid/all-MiniLM-L6-v2-optimum-embeddings

Browse files
README.md CHANGED
@@ -1,3 +1,12 @@
1
  ---
2
  license: mit
 
 
 
 
 
3
  ---
 
 
 
 
 
1
  ---
2
  license: mit
3
+ tags:
4
+ - sentence-embeddings
5
+ - endpoints-template
6
+ - optimum
7
+ library_name: generic
8
  ---
9
+
10
+ This repository is a fork of philschmid/all-MiniLM-L6-v2-optimum-embeddings.
11
+ My own ONNX conversion seems to be about 4x slower, no discernable reason why: the quantized models seem roughly the same.
12
+ The idea here is by forking we can ex. upgrade the Optimum lib used as well.
config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "sentence-transformers/all-MiniLM-L6-v2",
3
+ "architectures": [
4
+ "BertModel"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "classifier_dropout": null,
8
+ "gradient_checkpointing": false,
9
+ "hidden_act": "gelu",
10
+ "hidden_dropout_prob": 0.1,
11
+ "hidden_size": 384,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 1536,
14
+ "layer_norm_eps": 1e-12,
15
+ "max_position_embeddings": 512,
16
+ "num_attention_heads": 12,
17
+ "num_hidden_layers": 6,
18
+ "pad_token_id": 0,
19
+ "position_embedding_type": "absolute",
20
+ "transformers_version": "4.20.1",
21
+ "type_vocab_size": 2,
22
+ "use_cache": false,
23
+ "vocab_size": 30522
24
+ }
convert.ipynb ADDED
@@ -0,0 +1,711 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Convert & Optimize model with Optimum \n",
8
+ "\n",
9
+ "\n",
10
+ "Steps:\n",
11
+ "1. Convert model to ONNX\n",
12
+ "2. Optimize & quantize model with Optimum\n",
13
+ "3. Create Custom Handler for Inference Endpoints\n",
14
+ "\n",
15
+ "Helpful links:\n",
16
+ "* [Accelerate Sentence Transformers with Hugging Face Optimum](https://www.philschmid.de/optimize-sentence-transformers)\n",
17
+ "* [Create Custom Handler Endpoints](https://link-to-docs)"
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "markdown",
22
+ "metadata": {},
23
+ "source": [
24
+ "## Setup & Installation"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": 2,
30
+ "metadata": {},
31
+ "outputs": [
32
+ {
33
+ "name": "stdout",
34
+ "output_type": "stream",
35
+ "text": [
36
+ "Writing requirements.txt\n"
37
+ ]
38
+ }
39
+ ],
40
+ "source": [
41
+ "%%writefile requirements.txt\n",
42
+ "optimum[onnxruntime]==1.3.0\n",
43
+ "mkl-include\n",
44
+ "mkl"
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "code",
49
+ "execution_count": null,
50
+ "metadata": {},
51
+ "outputs": [],
52
+ "source": [
53
+ "!pip install -r requirements.txt"
54
+ ]
55
+ },
56
+ {
57
+ "cell_type": "markdown",
58
+ "metadata": {},
59
+ "source": [
60
+ "## 1. Convert model to ONNX"
61
+ ]
62
+ },
63
+ {
64
+ "cell_type": "code",
65
+ "execution_count": 6,
66
+ "metadata": {},
67
+ "outputs": [
68
+ {
69
+ "data": {
70
+ "application/vnd.jupyter.widget-view+json": {
71
+ "model_id": "2920b55a58bb41b78436f64d24b31d27",
72
+ "version_major": 2,
73
+ "version_minor": 0
74
+ },
75
+ "text/plain": [
76
+ "Downloading: 0%| | 0.00/612 [00:00<?, ?B/s]"
77
+ ]
78
+ },
79
+ "metadata": {},
80
+ "output_type": "display_data"
81
+ },
82
+ {
83
+ "data": {
84
+ "text/plain": [
85
+ "('./tokenizer_config.json',\n",
86
+ " './special_tokens_map.json',\n",
87
+ " './vocab.txt',\n",
88
+ " './added_tokens.json',\n",
89
+ " './tokenizer.json')"
90
+ ]
91
+ },
92
+ "execution_count": 6,
93
+ "metadata": {},
94
+ "output_type": "execute_result"
95
+ }
96
+ ],
97
+ "source": [
98
+ "from optimum.onnxruntime import ORTModelForFeatureExtraction\n",
99
+ "from transformers import AutoTokenizer\n",
100
+ "from pathlib import Path\n",
101
+ "\n",
102
+ "\n",
103
+ "model_id=\"sentence-transformers/all-MiniLM-L6-v2\"\n",
104
+ "onnx_path = Path(\".\")\n",
105
+ "\n",
106
+ "# load vanilla transformers and convert to onnx\n",
107
+ "model = ORTModelForFeatureExtraction.from_pretrained(model_id, from_transformers=True)\n",
108
+ "tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
109
+ "\n",
110
+ "# save onnx checkpoint and tokenizer\n",
111
+ "model.save_pretrained(onnx_path)\n",
112
+ "tokenizer.save_pretrained(onnx_path)"
113
+ ]
114
+ },
115
+ {
116
+ "cell_type": "markdown",
117
+ "metadata": {},
118
+ "source": [
119
+ "## 2. Optimize & quantize model with Optimum"
120
+ ]
121
+ },
122
+ {
123
+ "cell_type": "code",
124
+ "execution_count": 7,
125
+ "metadata": {},
126
+ "outputs": [
127
+ {
128
+ "name": "stderr",
129
+ "output_type": "stream",
130
+ "text": [
131
+ "2022-08-31 19:22:18.331832429 [W:onnxruntime:, inference_session.cc:1488 Initialize] Serializing optimized model with Graph Optimization level greater than ORT_ENABLE_EXTENDED and the NchwcTransformer enabled. The generated model may contain hardware specific optimizations, and should only be used in the same environment the model was optimized in.\n",
132
+ "WARNING:fusion_skiplayernorm:symbolic shape infer failed. it's safe to ignore this message if there is no issue with optimized model\n",
133
+ "WARNING:fusion_skiplayernorm:symbolic shape infer failed. it's safe to ignore this message if there is no issue with optimized model\n",
134
+ "WARNING:fusion_skiplayernorm:symbolic shape infer failed. it's safe to ignore this message if there is no issue with optimized model\n",
135
+ "WARNING:fusion_skiplayernorm:symbolic shape infer failed. it's safe to ignore this message if there is no issue with optimized model\n",
136
+ "WARNING:fusion_skiplayernorm:symbolic shape infer failed. it's safe to ignore this message if there is no issue with optimized model\n",
137
+ "WARNING:fusion_skiplayernorm:symbolic shape infer failed. it's safe to ignore this message if there is no issue with optimized model\n",
138
+ "WARNING:fusion_skiplayernorm:symbolic shape infer failed. it's safe to ignore this message if there is no issue with optimized model\n",
139
+ "WARNING:fusion_skiplayernorm:symbolic shape infer failed. it's safe to ignore this message if there is no issue with optimized model\n",
140
+ "WARNING:fusion_skiplayernorm:symbolic shape infer failed. it's safe to ignore this message if there is no issue with optimized model\n",
141
+ "WARNING:fusion_skiplayernorm:symbolic shape infer failed. it's safe to ignore this message if there is no issue with optimized model\n",
142
+ "WARNING:fusion_skiplayernorm:symbolic shape infer failed. it's safe to ignore this message if there is no issue with optimized model\n",
143
+ "WARNING:fusion_skiplayernorm:symbolic shape infer failed. it's safe to ignore this message if there is no issue with optimized model\n",
144
+ "WARNING:fusion_skiplayernorm:symbolic shape infer failed. it's safe to ignore this message if there is no issue with optimized model\n"
145
+ ]
146
+ }
147
+ ],
148
+ "source": [
149
+ "from optimum.onnxruntime import ORTOptimizer, ORTQuantizer\n",
150
+ "from optimum.onnxruntime.configuration import OptimizationConfig, AutoQuantizationConfig\n",
151
+ "\n",
152
+ "# create ORTOptimizer and define optimization configuration\n",
153
+ "optimizer = ORTOptimizer.from_pretrained(model_id, feature=model.pipeline_task)\n",
154
+ "optimization_config = OptimizationConfig(optimization_level=99) # enable all optimizations\n",
155
+ "\n",
156
+ "# apply the optimization configuration to the model\n",
157
+ "optimizer.export(\n",
158
+ " onnx_model_path=onnx_path / \"model.onnx\",\n",
159
+ " onnx_optimized_model_output_path=onnx_path / \"model-optimized.onnx\",\n",
160
+ " optimization_config=optimization_config,\n",
161
+ ")\n",
162
+ "\n",
163
+ "\n",
164
+ "# create ORTQuantizer and define quantization configuration\n",
165
+ "dynamic_quantizer = ORTQuantizer.from_pretrained(model_id, feature=model.pipeline_task)\n",
166
+ "dqconfig = AutoQuantizationConfig.avx512_vnni(is_static=False, per_channel=False)\n",
167
+ "\n",
168
+ "# apply the quantization configuration to the model\n",
169
+ "model_quantized_path = dynamic_quantizer.export(\n",
170
+ " onnx_model_path=onnx_path / \"model-optimized.onnx\",\n",
171
+ " onnx_quantized_model_output_path=onnx_path / \"model-quantized.onnx\",\n",
172
+ " quantization_config=dqconfig,\n",
173
+ ")\n",
174
+ "\n"
175
+ ]
176
+ },
177
+ {
178
+ "cell_type": "markdown",
179
+ "metadata": {},
180
+ "source": [
181
+ "## 3. Create Custom Handler for Inference Endpoints\n"
182
+ ]
183
+ },
184
+ {
185
+ "cell_type": "code",
186
+ "execution_count": 2,
187
+ "metadata": {},
188
+ "outputs": [
189
+ {
190
+ "name": "stdout",
191
+ "output_type": "stream",
192
+ "text": [
193
+ "Overwriting pipeline.py\n"
194
+ ]
195
+ }
196
+ ],
197
+ "source": [
198
+ "%%writefile pipeline.py\n",
199
+ "from typing import Dict, List, Any\n",
200
+ "from optimum.onnxruntime import ORTModelForFeatureExtraction\n",
201
+ "from transformers import AutoTokenizer\n",
202
+ "import torch.nn.functional as F\n",
203
+ "import torch\n",
204
+ "\n",
205
+ "# copied from the model card\n",
206
+ "def mean_pooling(model_output, attention_mask):\n",
207
+ " token_embeddings = model_output[0] #First element of model_output contains all token embeddings\n",
208
+ " input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()\n",
209
+ " return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)\n",
210
+ "\n",
211
+ "\n",
212
+ "class PreTrainedPipeline():\n",
213
+ " def __init__(self, path=\"\"):\n",
214
+ " # load the optimized model\n",
215
+ " self.model = ORTModelForFeatureExtraction.from_pretrained(path, file_name=\"model-quantized.onnx\")\n",
216
+ " self.tokenizer = AutoTokenizer.from_pretrained(path)\n",
217
+ "\n",
218
+ " def __call__(self, data: Any) -> List[List[Dict[str, float]]]:\n",
219
+ " \"\"\"\n",
220
+ " Args:\n",
221
+ " data (:obj:):\n",
222
+ " includes the input data and the parameters for the inference.\n",
223
+ " Return:\n",
224
+ " A :obj:`list`:. The list contains the embeddings of the inference inputs\n",
225
+ " \"\"\"\n",
226
+ " inputs = data.get(\"inputs\", data)\n",
227
+ "\n",
228
+ " # tokenize the input\n",
229
+ " encoded_inputs = self.tokenizer(inputs, padding=True, truncation=True, return_tensors='pt')\n",
230
+ " # run the model\n",
231
+ " outputs = self.model(**encoded_inputs)\n",
232
+ " # Perform pooling\n",
233
+ " sentence_embeddings = mean_pooling(outputs, encoded_inputs['attention_mask'])\n",
234
+ " # Normalize embeddings\n",
235
+ " sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)\n",
236
+ " # postprocess the prediction\n",
237
+ " return {\"embeddings\": sentence_embeddings.tolist()}"
238
+ ]
239
+ },
240
+ {
241
+ "cell_type": "markdown",
242
+ "metadata": {},
243
+ "source": [
244
+ "test custom pipeline"
245
+ ]
246
+ },
247
+ {
248
+ "cell_type": "code",
249
+ "execution_count": 1,
250
+ "metadata": {},
251
+ "outputs": [
252
+ {
253
+ "name": "stdout",
254
+ "output_type": "stream",
255
+ "text": [
256
+ "1.55 ms ± 2.04 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n"
257
+ ]
258
+ }
259
+ ],
260
+ "source": [
261
+ "from pipeline import PreTrainedPipeline\n",
262
+ "\n",
263
+ "# init handler\n",
264
+ "my_handler = PreTrainedPipeline(path=\".\")\n",
265
+ "\n",
266
+ "# prepare sample payload\n",
267
+ "request = {\"inputs\": \"I am quite excited how this will turn out\"}\n",
268
+ "\n",
269
+ "# test the handler\n",
270
+ "%timeit my_handler(request)\n"
271
+ ]
272
+ },
273
+ {
274
+ "cell_type": "code",
275
+ "execution_count": 2,
276
+ "metadata": {},
277
+ "outputs": [
278
+ {
279
+ "data": {
280
+ "text/plain": [
281
+ "{'embeddings': [[-0.021580450236797333,\n",
282
+ " 0.021715054288506508,\n",
283
+ " 0.00979710929095745,\n",
284
+ " -0.0005379787762649357,\n",
285
+ " 0.04682469740509987,\n",
286
+ " -0.013600599952042103,\n",
287
+ " -0.003064213553443551,\n",
288
+ " 0.007061154581606388,\n",
289
+ " 0.026638098061084747,\n",
290
+ " -0.011613409034907818,\n",
291
+ " -0.06916121393442154,\n",
292
+ " 0.061429575085639954,\n",
293
+ " 0.013463253155350685,\n",
294
+ " -0.022426923736929893,\n",
295
+ " 0.04116947948932648,\n",
296
+ " 0.03925771266222,\n",
297
+ " 0.014005577191710472,\n",
298
+ " -0.07909698039293289,\n",
299
+ " -0.028196798637509346,\n",
300
+ " -0.003196786157786846,\n",
301
+ " 0.013688302598893642,\n",
302
+ " -0.044537559151649475,\n",
303
+ " -0.04594269394874573,\n",
304
+ " -0.04054776579141617,\n",
305
+ " -0.038281939923763275,\n",
306
+ " 0.06411226838827133,\n",
307
+ " -0.013305696658790112,\n",
308
+ " -0.02935652621090412,\n",
309
+ " -0.0150306923314929,\n",
310
+ " -0.0434146448969841,\n",
311
+ " 0.03218410909175873,\n",
312
+ " 0.018695568665862083,\n",
313
+ " -0.012916717678308487,\n",
314
+ " 0.009855723939836025,\n",
315
+ " -0.022609280422329903,\n",
316
+ " -0.08628173172473907,\n",
317
+ " 0.03853229060769081,\n",
318
+ " -0.03584187850356102,\n",
319
+ " 0.05425931513309479,\n",
320
+ " -0.002929823938757181,\n",
321
+ " -0.011379950679838657,\n",
322
+ " -0.15505683422088623,\n",
323
+ " 0.01120749581605196,\n",
324
+ " 0.03100745379924774,\n",
325
+ " 0.043684810400009155,\n",
326
+ " 0.008617725223302841,\n",
327
+ " 0.00824501272290945,\n",
328
+ " -0.01545825693756342,\n",
329
+ " -0.001658946624957025,\n",
330
+ " 0.027067873626947403,\n",
331
+ " -0.019667934626340866,\n",
332
+ " -0.09459519386291504,\n",
333
+ " 0.048974718898534775,\n",
334
+ " -0.02965048886835575,\n",
335
+ " -0.08003880828619003,\n",
336
+ " 0.045240651816129684,\n",
337
+ " -0.012594419531524181,\n",
338
+ " -0.05546975135803223,\n",
339
+ " 0.05608676001429558,\n",
340
+ " -0.04186442866921425,\n",
341
+ " -0.02615668624639511,\n",
342
+ " 0.02160278894007206,\n",
343
+ " 0.03741847351193428,\n",
344
+ " 0.0008759248885326087,\n",
345
+ " 0.03592744097113609,\n",
346
+ " -0.12200205773115158,\n",
347
+ " 0.06229585036635399,\n",
348
+ " 0.01601075753569603,\n",
349
+ " 0.040825288742780685,\n",
350
+ " -0.08544802665710449,\n",
351
+ " -0.029977118596434593,\n",
352
+ " 0.03295058012008667,\n",
353
+ " 0.05928152799606323,\n",
354
+ " -0.052630465477705,\n",
355
+ " 0.020404687151312828,\n",
356
+ " 0.00725224195048213,\n",
357
+ " 0.0009453881066292524,\n",
358
+ " 0.04398864880204201,\n",
359
+ " 0.071522556245327,\n",
360
+ " 0.032034359872341156,\n",
361
+ " 0.038474190980196,\n",
362
+ " -0.032708171755075455,\n",
363
+ " -0.011295354925096035,\n",
364
+ " -0.045965589582920074,\n",
365
+ " -0.041425369679927826,\n",
366
+ " 0.0482286661863327,\n",
367
+ " 0.008450332097709179,\n",
368
+ " -0.03801262006163597,\n",
369
+ " -0.0420663058757782,\n",
370
+ " 0.05417492985725403,\n",
371
+ " -0.09063713997602463,\n",
372
+ " -0.007592180278152227,\n",
373
+ " -0.009322550147771835,\n",
374
+ " -0.02063363790512085,\n",
375
+ " -0.03594734147191048,\n",
376
+ " 0.07223387807607651,\n",
377
+ " -0.03899461403489113,\n",
378
+ " -0.0934303030371666,\n",
379
+ " -0.03475493937730789,\n",
380
+ " 0.09417884796857834,\n",
381
+ " -0.03771593049168587,\n",
382
+ " 0.0638294667005539,\n",
383
+ " 0.032066185027360916,\n",
384
+ " -0.08843936026096344,\n",
385
+ " 0.012369371019303799,\n",
386
+ " -0.03089563362300396,\n",
387
+ " -0.005824724677950144,\n",
388
+ " 0.08723752945661545,\n",
389
+ " 0.02237764000892639,\n",
390
+ " -0.03896152228116989,\n",
391
+ " 0.025661000981926918,\n",
392
+ " -0.005460284650325775,\n",
393
+ " 0.05766639858484268,\n",
394
+ " 0.025396188721060753,\n",
395
+ " -0.03150532767176628,\n",
396
+ " 0.09431672841310501,\n",
397
+ " 0.035403359681367874,\n",
398
+ " 0.09509390592575073,\n",
399
+ " -0.015979617834091187,\n",
400
+ " 0.04350188001990318,\n",
401
+ " 0.046271294355392456,\n",
402
+ " 0.009891007095575333,\n",
403
+ " -0.0044189076870679855,\n",
404
+ " -0.017476193606853485,\n",
405
+ " 0.015222891233861446,\n",
406
+ " 0.009962008334696293,\n",
407
+ " -0.05670330300927162,\n",
408
+ " -1.8742520903182187e-33,\n",
409
+ " 0.017962634563446045,\n",
410
+ " 0.023281103000044823,\n",
411
+ " -0.013410707004368305,\n",
412
+ " 0.10924902558326721,\n",
413
+ " 0.036854133009910583,\n",
414
+ " -0.039277151226997375,\n",
415
+ " 0.026224950328469276,\n",
416
+ " -0.04877658933401108,\n",
417
+ " -0.0805993378162384,\n",
418
+ " -0.0030330857262015343,\n",
419
+ " -0.0028494936414062977,\n",
420
+ " 0.018921272829174995,\n",
421
+ " -0.01530009601265192,\n",
422
+ " 0.1219208613038063,\n",
423
+ " -0.07319916784763336,\n",
424
+ " -0.10112590342760086,\n",
425
+ " 0.006891624070703983,\n",
426
+ " -0.002260996960103512,\n",
427
+ " -0.007901495322585106,\n",
428
+ " 0.017701659351587296,\n",
429
+ " -0.08319021016359329,\n",
430
+ " 0.048608407378196716,\n",
431
+ " -0.05502907559275627,\n",
432
+ " -0.03751670941710472,\n",
433
+ " -0.004041539039462805,\n",
434
+ " 0.07481412589550018,\n",
435
+ " 0.0022187645081430674,\n",
436
+ " -0.03369564935564995,\n",
437
+ " -0.11100229620933533,\n",
438
+ " 0.01231460366398096,\n",
439
+ " -0.03582797944545746,\n",
440
+ " 0.026462607085704803,\n",
441
+ " -0.03912581503391266,\n",
442
+ " -0.011205351911485195,\n",
443
+ " -0.03137337043881416,\n",
444
+ " 0.0059767672792077065,\n",
445
+ " -0.1009056344628334,\n",
446
+ " -0.06049555912613869,\n",
447
+ " 0.021796569228172302,\n",
448
+ " -0.014793958514928818,\n",
449
+ " 0.03098255582153797,\n",
450
+ " -0.00538264773786068,\n",
451
+ " -0.04653438180685043,\n",
452
+ " -0.02799016609787941,\n",
453
+ " 0.023156380280852318,\n",
454
+ " 0.07959774136543274,\n",
455
+ " 0.043343499302864075,\n",
456
+ " 0.02526552602648735,\n",
457
+ " 0.05564416944980621,\n",
458
+ " -0.0895266905426979,\n",
459
+ " 0.02035175822675228,\n",
460
+ " 0.00761762959882617,\n",
461
+ " -0.01012750156223774,\n",
462
+ " 0.10514233261346817,\n",
463
+ " -0.00832043495029211,\n",
464
+ " -0.018016740679740906,\n",
465
+ " 0.01773231290280819,\n",
466
+ " -0.13199643790721893,\n",
467
+ " 0.11118609458208084,\n",
468
+ " 0.0027006398886442184,\n",
469
+ " 0.035123299807310104,\n",
470
+ " 0.017120877280831337,\n",
471
+ " -0.08685944974422455,\n",
472
+ " 0.014364459551870823,\n",
473
+ " -0.0697159692645073,\n",
474
+ " 0.03414931520819664,\n",
475
+ " 0.051882319152355194,\n",
476
+ " -0.049169816076755524,\n",
477
+ " -0.07678680121898651,\n",
478
+ " 0.03500046953558922,\n",
479
+ " -0.027233436703681946,\n",
480
+ " 0.019955039024353027,\n",
481
+ " -0.035047441720962524,\n",
482
+ " -0.03964361920952797,\n",
483
+ " -0.01907966658473015,\n",
484
+ " 0.05322276055812836,\n",
485
+ " -0.03573837876319885,\n",
486
+ " -0.02035624347627163,\n",
487
+ " 0.03240324929356575,\n",
488
+ " 0.023124489933252335,\n",
489
+ " 0.04587593674659729,\n",
490
+ " 0.006914089433848858,\n",
491
+ " 0.02254929207265377,\n",
492
+ " -0.048369478434324265,\n",
493
+ " 0.07502789050340652,\n",
494
+ " -0.04454338923096657,\n",
495
+ " 0.009581719525158405,\n",
496
+ " -0.08176697790622711,\n",
497
+ " -0.026596812531352043,\n",
498
+ " 0.05699768289923668,\n",
499
+ " 0.03196358308196068,\n",
500
+ " -0.0818556547164917,\n",
501
+ " 0.04586222395300865,\n",
502
+ " 0.026800116524100304,\n",
503
+ " 0.053372107446193695,\n",
504
+ " 4.116422800348778e-34,\n",
505
+ " 0.04144074022769928,\n",
506
+ " -0.00046204423415474594,\n",
507
+ " -0.05304589495062828,\n",
508
+ " 0.006641748361289501,\n",
509
+ " -0.05266479030251503,\n",
510
+ " -0.02192983590066433,\n",
511
+ " 0.010295987129211426,\n",
512
+ " 0.1503780037164688,\n",
513
+ " 0.06841202080249786,\n",
514
+ " 0.012436892837285995,\n",
515
+ " 0.02130315639078617,\n",
516
+ " 0.05735220015048981,\n",
517
+ " 0.020133396610617638,\n",
518
+ " -0.019417081028223038,\n",
519
+ " 0.018597068265080452,\n",
520
+ " -0.060950521379709244,\n",
521
+ " 0.14569053053855896,\n",
522
+ " 0.046135421842336655,\n",
523
+ " 0.014004155062139034,\n",
524
+ " 0.06448501348495483,\n",
525
+ " -0.03540049120783806,\n",
526
+ " 0.05386977270245552,\n",
527
+ " -0.04851151257753372,\n",
528
+ " 0.04860413447022438,\n",
529
+ " 0.003418552689254284,\n",
530
+ " 0.026858657598495483,\n",
531
+ " 0.08443755656480789,\n",
532
+ " 0.0688081830739975,\n",
533
+ " -0.027870699763298035,\n",
534
+ " -0.02680159918963909,\n",
535
+ " -0.10730879008769989,\n",
536
+ " -0.09660787880420685,\n",
537
+ " -0.010721202939748764,\n",
538
+ " 0.03249472752213478,\n",
539
+ " -0.010227357968688011,\n",
540
+ " -0.005592911038547754,\n",
541
+ " -0.02233457751572132,\n",
542
+ " 0.003959502559155226,\n",
543
+ " -0.0025461087934672832,\n",
544
+ " -0.07056054472923279,\n",
545
+ " -0.01288093812763691,\n",
546
+ " 0.03734854981303215,\n",
547
+ " -0.0930633544921875,\n",
548
+ " 0.06263089179992676,\n",
549
+ " -0.022451557219028473,\n",
550
+ " 0.011584922671318054,\n",
551
+ " 0.07056082785129547,\n",
552
+ " 0.07839607447385788,\n",
553
+ " -0.03750450536608696,\n",
554
+ " 0.08674977719783783,\n",
555
+ " -0.0174140315502882,\n",
556
+ " 0.037801019847393036,\n",
557
+ " -0.04431292042136192,\n",
558
+ " -0.003121826099231839,\n",
559
+ " -0.04473913460969925,\n",
560
+ " -0.009062718600034714,\n",
561
+ " 0.06917019933462143,\n",
562
+ " -0.07210793346166611,\n",
563
+ " 0.02439814619719982,\n",
564
+ " 0.06415946036577225,\n",
565
+ " -0.11128300428390503,\n",
566
+ " 0.07395494729280472,\n",
567
+ " -0.019613103941082954,\n",
568
+ " -0.0576956532895565,\n",
569
+ " 0.03607752546668053,\n",
570
+ " -0.049007922410964966,\n",
571
+ " -0.00931280292570591,\n",
572
+ " 0.02782956324517727,\n",
573
+ " -0.016698531806468964,\n",
574
+ " 0.04213561490178108,\n",
575
+ " 0.02651999704539776,\n",
576
+ " -0.021170292049646378,\n",
577
+ " -0.10422325879335403,\n",
578
+ " 0.02582547254860401,\n",
579
+ " 0.07547233253717422,\n",
580
+ " -0.07150454074144363,\n",
581
+ " 0.10658326745033264,\n",
582
+ " -0.08328848332166672,\n",
583
+ " -0.006845302879810333,\n",
584
+ " -0.018662545830011368,\n",
585
+ " -0.009805584326386452,\n",
586
+ " 0.035663068294525146,\n",
587
+ " 0.0027744239196181297,\n",
588
+ " -0.03721313178539276,\n",
589
+ " 0.06117653474211693,\n",
590
+ " 0.03830438479781151,\n",
591
+ " -0.01618945226073265,\n",
592
+ " -0.02423257753252983,\n",
593
+ " -0.0009939797455444932,\n",
594
+ " -0.003057157387956977,\n",
595
+ " -0.07808902114629745,\n",
596
+ " 0.057173147797584534,\n",
597
+ " 0.015869930386543274,\n",
598
+ " 0.01918310485780239,\n",
599
+ " 0.08144430071115494,\n",
600
+ " -2.1998719290650115e-08,\n",
601
+ " -0.025966359302401543,\n",
602
+ " -0.024850135669112206,\n",
603
+ " 0.02227822132408619,\n",
604
+ " 0.0793970599770546,\n",
605
+ " 0.044460248202085495,\n",
606
+ " 0.03317498043179512,\n",
607
+ " 0.03564529865980148,\n",
608
+ " 0.013410663232207298,\n",
609
+ " -0.05888325348496437,\n",
610
+ " -0.0570887066423893,\n",
611
+ " 0.02409365586936474,\n",
612
+ " -0.0031824831385165453,\n",
613
+ " 0.07432717829942703,\n",
614
+ " 0.00491950660943985,\n",
615
+ " 0.037177130579948425,\n",
616
+ " 0.1214393675327301,\n",
617
+ " -0.02980734035372734,\n",
618
+ " 0.08316365629434586,\n",
619
+ " -0.03441021963953972,\n",
620
+ " -0.05670581012964249,\n",
621
+ " -0.08702761679887772,\n",
622
+ " -0.033726878464221954,\n",
623
+ " 0.09084504842758179,\n",
624
+ " 0.030235234647989273,\n",
625
+ " 0.014355660416185856,\n",
626
+ " 0.008767222985625267,\n",
627
+ " -0.0827459916472435,\n",
628
+ " 0.08210321515798569,\n",
629
+ " -0.061066679656505585,\n",
630
+ " 0.03521161153912544,\n",
631
+ " -0.04115701839327812,\n",
632
+ " 0.014578152447938919,\n",
633
+ " -0.05554644390940666,\n",
634
+ " 0.031068438664078712,\n",
635
+ " -0.08362201601266861,\n",
636
+ " -0.023382432758808136,\n",
637
+ " -0.09858708828687668,\n",
638
+ " 0.017514051869511604,\n",
639
+ " 0.10520247370004654,\n",
640
+ " -0.04585810378193855,\n",
641
+ " -0.03088274411857128,\n",
642
+ " -0.06560547649860382,\n",
643
+ " -0.07936973869800568,\n",
644
+ " 0.038559265434741974,\n",
645
+ " -0.086161307990551,\n",
646
+ " -0.07989706099033356,\n",
647
+ " 0.06426848471164703,\n",
648
+ " -0.04678329452872276,\n",
649
+ " -0.005842810496687889,\n",
650
+ " -9.329108434030786e-05,\n",
651
+ " 0.005526330322027206,\n",
652
+ " -0.060696180909872055,\n",
653
+ " 0.045042477548122406,\n",
654
+ " 0.020842568948864937,\n",
655
+ " 0.10796718299388885,\n",
656
+ " 0.016674820333719254,\n",
657
+ " -0.03490869328379631,\n",
658
+ " 0.050079092383384705,\n",
659
+ " 0.046036623418331146,\n",
660
+ " 0.1225607842206955,\n",
661
+ " 0.03865363076329231,\n",
662
+ " -0.06910006701946259,\n",
663
+ " 0.03865937888622284,\n",
664
+ " 4.1704730392666534e-05]]}"
665
+ ]
666
+ },
667
+ "execution_count": 2,
668
+ "metadata": {},
669
+ "output_type": "execute_result"
670
+ }
671
+ ],
672
+ "source": [
673
+ "my_handler(request)"
674
+ ]
675
+ },
676
+ {
677
+ "cell_type": "code",
678
+ "execution_count": null,
679
+ "metadata": {},
680
+ "outputs": [],
681
+ "source": []
682
+ }
683
+ ],
684
+ "metadata": {
685
+ "kernelspec": {
686
+ "display_name": "Python 3.9.12 ('base')",
687
+ "language": "python",
688
+ "name": "python3"
689
+ },
690
+ "language_info": {
691
+ "codemirror_mode": {
692
+ "name": "ipython",
693
+ "version": 3
694
+ },
695
+ "file_extension": ".py",
696
+ "mimetype": "text/x-python",
697
+ "name": "python",
698
+ "nbconvert_exporter": "python",
699
+ "pygments_lexer": "ipython3",
700
+ "version": "3.9.12"
701
+ },
702
+ "orig_nbformat": 4,
703
+ "vscode": {
704
+ "interpreter": {
705
+ "hash": "7a2c4b191d1ae843dde5cb5f4d1f62fa892f6b79b0f9392a84691e890e33c5a4"
706
+ }
707
+ }
708
+ },
709
+ "nbformat": 4,
710
+ "nbformat_minor": 2
711
+ }
handler.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from optimum.onnxruntime import ORTModelForFeatureExtraction
3
+ from transformers import AutoTokenizer
4
+ import torch.nn.functional as F
5
+ import torch
6
+
7
+ # copied from the model card
8
+ def mean_pooling(model_output, attention_mask):
9
+ token_embeddings = model_output[0] #First element of model_output contains all token embeddings
10
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
11
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
12
+
13
+
14
+ class EndpointHandler():
15
+ def __init__(self, path=""):
16
+ # load the optimized model
17
+ self.model = ORTModelForFeatureExtraction.from_pretrained(path, file_name="model-quantized.onnx")
18
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
19
+
20
+ def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
21
+ """
22
+ Args:
23
+ data (:obj:):
24
+ includes the input data and the parameters for the inference.
25
+ Return:
26
+ A :obj:`list`:. The list contains the embeddings of the inference inputs
27
+ """
28
+ inputs = data.get("inputs", data)
29
+
30
+ # tokenize the input
31
+ encoded_inputs = self.tokenizer(inputs, padding=True, truncation=True, return_tensors='pt')
32
+ # run the model
33
+ outputs = self.model(**encoded_inputs)
34
+ # Perform pooling
35
+ sentence_embeddings = mean_pooling(outputs, encoded_inputs['attention_mask'])
36
+ # Normalize embeddings
37
+ sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
38
+ # postprocess the prediction
39
+ return {"embeddings": sentence_embeddings.tolist()}
model-optimized.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8d86e5aac5aaf9b1ba7d91401ccceb7c6a014e05161b71b92a7252099d19f6b7
3
+ size 90868852
model-quantized.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7d1835268a3fdee3b431eb86f49aa4e7a4fe584ad98f4cd76fe3c1adc4076f14
3
+ size 66553074
model.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e947acf87027bfa67f0f79e083a7ffabf2728c60de2ec7b60f5b26a3b4df6325
3
+ size 90908097
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+
2
+ optimum[onnxruntime]==1.3.0
3
+ mkl-include
4
+ mkl
special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "do_basic_tokenize": true,
4
+ "do_lower_case": true,
5
+ "mask_token": "[MASK]",
6
+ "model_max_length": 512,
7
+ "name_or_path": "sentence-transformers/all-MiniLM-L6-v2",
8
+ "never_split": null,
9
+ "pad_token": "[PAD]",
10
+ "sep_token": "[SEP]",
11
+ "special_tokens_map_file": "/home/ubuntu/.cache/huggingface/transformers/828163b9cc16a2e7d13324e55d0bc0433dab54d1ae271e02d2e3cb1387e1135b.dd8bd9bfd3664b530ea4e645105f557769387b3da9f79bdb55ed556bdd80611d",
12
+ "strip_accents": null,
13
+ "tokenize_chinese_chars": true,
14
+ "tokenizer_class": "BertTokenizer",
15
+ "unk_token": "[UNK]"
16
+ }
vocab.txt ADDED
The diff for this file is too large to render. See raw diff