How do I got token_weights from onnx inference?

#9
by bash99 - opened

When I try onnx-community/gte-multilingual-base, I can found there are two outputs
['token_embeddings', 'sentence_embedding']

        outputs_meta = self._session.get_outputs()
        output =[node.name for node in outputs_meta]
        print('Outputs: ', output)

After inference, outputs_onnx has two members, outputs_onnx[1] is the sentence_embedding.

But outputs_onnx[0] are 6 list[float] like below:

token_embeddings out puts:                                                                                                                                                                                 [[[-1.5273438   1.9277344  -0.77685547 ...  0.6074219   0.28857422                                                                                                                                          -1.5166016 ]                                                                                                                                                                                             [-1.6230469   1.2304688  -0.79003906 ...  0.5473633   0.35253906                                                                                                                                           -0.6826172 ]                                                                                                                                                                                             [-1.5107422   1.1708984  -0.74609375 ...  0.49047852  0.1928711
   -0.5253906 ]                                                                                                                                                                                             [-1.5         1.0712891  -0.79248047 ...  0.3005371   0.0668335                                                                                                                                            -0.29370117]                                                                                                                                                                                             [-1.4746094   1.171875   -0.8046875  ...  0.6591797   0.20800781
   -0.8466797 ]                                                                                                                                                                                             [-1.4042969   1.2041016  -0.6269531  ...  0.29174805  0.3125
   -1.0458984 ]]]

Which compare to gte_embeder.py,after

            token_weights = torch.relu(model_out.logits).squeeze(-1)

token_weights is 6 floats weight as below:

tensor([[0.9023, 0.8047, 2.7324, 3.8789, 0.9346, 1.5498]], device='cuda:0',
       dtype=torch.float16)

Both are encode string 'Where is Munich?'

So how do I covert the 6 List[float] to those 6 real weights?

Sign up or log in to comment