File size: 3,388 Bytes
4bfe854
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
712be5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f519a4
712be5f
4f519a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
712be5f
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
---
tags:
- transformers
- xlm-roberta
library_name: transformers
license: cc-by-nc-4.0
language:
  - multilingual
  - af
  - am
  - ar
  - as
  - az
  - be
  - bg
  - bn
  - br
  - bs
  - ca
  - cs
  - cy
  - da
  - de
  - el
  - en
  - eo
  - es
  - et
  - eu
  - fa
  - fi
  - fr
  - fy
  - ga
  - gd
  - gl
  - gu
  - ha
  - he
  - hi
  - hr
  - hu
  - hy
  - id
  - is
  - it
  - ja
  - jv
  - ka
  - kk
  - km
  - kn
  - ko
  - ku
  - ky
  - la
  - lo
  - lt
  - lv
  - mg
  - mk
  - ml
  - mn
  - mr
  - ms
  - my
  - ne
  - nl
  - 'no'
  - om
  - or
  - pa
  - pl
  - ps
  - pt
  - ro
  - ru
  - sa
  - sd
  - si
  - sk
  - sl
  - so
  - sq
  - sr
  - su
  - sv
  - sw
  - ta
  - te
  - th
  - tl
  - tr
  - ug
  - uk
  - ur
  - uz
  - vi
  - xh
  - yi
  - zh
---

Modified version of [xlm-roberta-flash-implementation](https://huggingface.co/jinaai/xlm-roberta-flash-implementation) for the onnx conversion

## Brief Summary of Challenges and Modifications:
### Dynamic Matrix Calculation in RoPE
The original RoPE implementation did not compute the entire rotation matrix at the start. Instead, it calculated the matrix only for the required sequence length, cached it, and recalculated if a longer sequence came as input. This approach isn't compatible with ONNX, which requires a fixed graph during inference. To solve this, I now calculate the entire rotation matrix in advance.

### Custom Backward Functions for RoPE
We have custom forward and backward functions for RoPE. ONNX does not support custom backward functions, but since we only need forward passes for inference with ONNX, I removed the backward function completely.

### ONNX Model Size Limitation
ONNX stores the model in a protobuf format, which has a maximum size limit of 2GB. Our model was too large to fit this limit, so I had to store the model's parameters as external data files.

### Lack of Support for the `unique()` Function
We used the `unique()` function to identify unique task types in a batch, which is important when there are multiple task types. However, ONNX does not support the unique() function. For inference, having multiple task types in a batch is not important. Therefore, I modified the code to use the `task_id` argument—an integer that works for every text in a batch—instead of the `adapter_mask`, which was a tensor specifying an independent task ID for each text in the batch.



## Code

```python
import torch
from transformers import AutoModel, AutoTokenizer
import torch.onnx


model = AutoModel.from_pretrained('/home/admin/saba/jina-embeddings-v3', trust_remote_code=True, use_flash_attn=False)
model.eval()

onnx_path =  "/home/admin/saba/jina-embeddings-v3/onnx/model.onnx"

tokenizer = AutoTokenizer.from_pretrained('/home/admin/saba/jina-embeddings-v3')
inputs = tokenizer(["jina", 'ai'], return_tensors="pt", padding='longest')
inps = inputs['input_ids']
mask = inputs['attention_mask']
task_id = 2


torch.onnx.export(
    model,
    (inps, mask, task_id),
    onnx_path,
    export_params=True,
    do_constant_folding=True,
    input_names = ['input_ids', 'attention_mask', 'task_id'],
    output_names = ['text_embeds'],
    opset_version=16,
    dynamic_axes={
        'input_ids' : {0 : 'batch_size', 1: 'sequence_length'},
        'attention_mask' : {0 : 'batch_size', 1: 'sequence_length'},
        'text_embeds' : {0 : 'batch_size'}
    },
)
```