File size: 1,491 Bytes
235b9c1
93e5f33
235b9c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
"""
Based on transformers python API.
This script turn list of string into embeddings.
"""
from transformers import AutoTokenizer, TFAutoModel
import tensorflow as tf


class Embed(object):
    def __init__(self):
        self.tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
        self.model = TFAutoModel.from_pretrained("sentence-transformers/paraphrase-multilingual-mpnet-base-v2")

    @staticmethod
    # Mean Pooling - Take attention mask into account for correct averaging
    def mean_pooling(model_output, attention_mask):
        token_embeddings = model_output.last_hidden_state
        input_mask_expanded = tf.cast(tf.tile(tf.expand_dims(attention_mask, -1), [1, 1, token_embeddings.shape[-1]]),
                                      tf.float32)
        return tf.math.reduce_sum(token_embeddings * input_mask_expanded, 1) / tf.math.maximum(
            tf.math.reduce_sum(input_mask_expanded, 1), 1e-9)

    # Encode text
    def encode(self, texts):
        # Tokenize sentences
        encoded_input = self.tokenizer(texts, padding=True, truncation=True, return_tensors='tf')

        # Compute token embeddings
        model_output = self.model(**encoded_input, return_dict=True)

        # Perform pooling
        embeddings = Embed.mean_pooling(model_output, encoded_input['attention_mask'])

        # Normalize embeddings
        embeddings = tf.math.l2_normalize(embeddings, axis=1)

        return embeddings