huynhdoo commited on
Commit
2349468
1 Parent(s): 4fb7ad9

Upload with huggingface_hub

Browse files
Files changed (2) hide show
  1. requirements.txt +3 -0
  2. sentence_camembert_base.py +39 -0
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ optimum[onnxruntime]
2
+ mkl-include
3
+ mkl
sentence_camembert_base.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from optimum.onnxruntime import ORTModelForFeatureExtraction
3
+ from transformers import AutoTokenizer
4
+ import torch.nn.functional as F
5
+ import torch
6
+
7
+ # copied from the model card
8
+ def mean_pooling(model_output, attention_mask):
9
+ token_embeddings = model_output[0] #First element of model_output contains all token embeddings
10
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
11
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
12
+
13
+
14
+ class sentence_embeddings(path = '.'):
15
+ def __init__(self, path):
16
+ # load the optimized model
17
+ self.model = ORTModelForFeatureExtraction.from_pretrained(path, file_name="model_quantized.onnx")
18
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
19
+
20
+ def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
21
+ """
22
+ Args:
23
+ data (:obj:):
24
+ includes the input data and the parameters for the inference.
25
+ Return:
26
+ A :obj:`list`:. The list contains the embeddings of the inference inputs
27
+ """
28
+ inputs = data.get("inputs", data)
29
+
30
+ # tokenize the input
31
+ encoded_inputs = self.tokenizer(inputs, padding=True, truncation=True, return_tensors='pt')
32
+ # run the model
33
+ outputs = self.model(**encoded_inputs)
34
+ # Perform pooling
35
+ embeddings = mean_pooling(outputs, encoded_inputs['attention_mask'])
36
+ # Normalize embeddings
37
+ embeddings = F.normalize(embeddings, p=2, dim=1)
38
+ # postprocess the prediction
39
+ return {'embeddings': embeddings.tolist()}