philipchung
commited on
Commit
•
04fab7f
1
Parent(s):
3d742af
Upload export_onnx.py with huggingface_hub
Browse files- export_onnx.py +209 -0
export_onnx.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This script exports BGEM3 to ONNX format which can be run using ONNX Runtime.
|
2 |
+
By default, the script does not apply any optimization to the ONNX model.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import copy
|
6 |
+
import os
|
7 |
+
import shutil
|
8 |
+
from collections import OrderedDict
|
9 |
+
from pathlib import Path
|
10 |
+
from typing import Annotated
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import typer
|
14 |
+
from huggingface_hub import snapshot_download
|
15 |
+
from optimum.exporters.onnx import onnx_export_from_model
|
16 |
+
from optimum.exporters.onnx.model_configs import XLMRobertaOnnxConfig
|
17 |
+
from optimum.onnxruntime import ORTModelForCustomTasks
|
18 |
+
from torch import Tensor
|
19 |
+
from transformers import (
|
20 |
+
AutoConfig,
|
21 |
+
AutoModel,
|
22 |
+
PretrainedConfig,
|
23 |
+
PreTrainedModel,
|
24 |
+
XLMRobertaConfig,
|
25 |
+
)
|
26 |
+
|
27 |
+
|
28 |
+
class BGEM3InferenceModel(PreTrainedModel):
|
29 |
+
"""Based on:
|
30 |
+
1. https://github.com/FlagOpen/FlagEmbedding/blob/master/FlagEmbedding/BGE_M3/modeling.py
|
31 |
+
2. https://huggingface.co/aapot/bge-m3-onnx/blob/main/export_onnx.py
|
32 |
+
|
33 |
+
The main changes here is that we are inheriting from `PreTrainedModel` which has the
|
34 |
+
methods .from_pretrained and .push_to_hub. This allows us to easily convert the model
|
35 |
+
"""
|
36 |
+
|
37 |
+
config_class = XLMRobertaConfig
|
38 |
+
base_model_prefix = "BGEM3InferenceModel"
|
39 |
+
model_tags = ["BAAI/bge-m3"]
|
40 |
+
|
41 |
+
def __init__(
|
42 |
+
self,
|
43 |
+
model_name: str = "BAAI/bge-m3",
|
44 |
+
colbert_dim: int = -1,
|
45 |
+
) -> None:
|
46 |
+
super().__init__(config=PretrainedConfig())
|
47 |
+
|
48 |
+
model_name = snapshot_download(repo_id=model_name)
|
49 |
+
self.config = AutoConfig.from_pretrained(model_name)
|
50 |
+
self.model = AutoModel.from_pretrained(model_name)
|
51 |
+
self.colbert_linear = torch.nn.Linear(
|
52 |
+
in_features=self.model.config.hidden_size,
|
53 |
+
out_features=(
|
54 |
+
self.model.config.hidden_size if colbert_dim == -1 else colbert_dim
|
55 |
+
),
|
56 |
+
)
|
57 |
+
self.sparse_linear = torch.nn.Linear(
|
58 |
+
in_features=self.model.config.hidden_size, out_features=1
|
59 |
+
)
|
60 |
+
colbert_state_dict = torch.load(
|
61 |
+
os.path.join(model_name, "colbert_linear.pt"), map_location="cpu"
|
62 |
+
)
|
63 |
+
sparse_state_dict = torch.load(
|
64 |
+
os.path.join(model_name, "sparse_linear.pt"), map_location="cpu"
|
65 |
+
)
|
66 |
+
self.colbert_linear.load_state_dict(colbert_state_dict)
|
67 |
+
self.sparse_linear.load_state_dict(sparse_state_dict)
|
68 |
+
|
69 |
+
def dense_embedding(self, last_hidden_state: Tensor) -> Tensor:
|
70 |
+
return last_hidden_state[:, 0]
|
71 |
+
|
72 |
+
def sparse_embedding(self, last_hidden_state: Tensor) -> Tensor:
|
73 |
+
with torch.no_grad():
|
74 |
+
return torch.relu(self.sparse_linear(last_hidden_state))
|
75 |
+
|
76 |
+
def colbert_embedding(
|
77 |
+
self, last_hidden_state: Tensor, attention_mask: Tensor
|
78 |
+
) -> Tensor:
|
79 |
+
with torch.no_grad():
|
80 |
+
colbert_vecs = self.colbert_linear(last_hidden_state[:, 1:])
|
81 |
+
colbert_vecs = colbert_vecs * attention_mask[:, 1:][:, :, None].float()
|
82 |
+
return colbert_vecs
|
83 |
+
|
84 |
+
def forward(
|
85 |
+
self, input_ids: Tensor, attention_mask: Tensor
|
86 |
+
) -> dict[str, dict[str, Tensor]]:
|
87 |
+
"""Forward pass of the model with custom output dict with dense, sparse, and
|
88 |
+
colbert embeddings. Dense and colbert embeddings are normalized."""
|
89 |
+
with torch.no_grad():
|
90 |
+
last_hidden_state = self.model(
|
91 |
+
input_ids=input_ids, attention_mask=attention_mask, return_dict=True
|
92 |
+
).last_hidden_state
|
93 |
+
|
94 |
+
output = {}
|
95 |
+
dense_vecs = self.dense_embedding(last_hidden_state)
|
96 |
+
output["dense_vecs"] = torch.nn.functional.normalize(dense_vecs, dim=-1)
|
97 |
+
|
98 |
+
sparse_vecs = self.sparse_embedding(last_hidden_state)
|
99 |
+
output["sparse_vecs"] = sparse_vecs
|
100 |
+
|
101 |
+
colbert_vecs = self.colbert_embedding(last_hidden_state, attention_mask)
|
102 |
+
output["colbert_vecs"] = torch.nn.functional.normalize(colbert_vecs, dim=-1)
|
103 |
+
|
104 |
+
return output
|
105 |
+
|
106 |
+
|
107 |
+
class BGEM3OnnxConfig(XLMRobertaOnnxConfig):
|
108 |
+
"""Modify RobertaOnnxConfig to include the additional outputs of the model
|
109 |
+
(dense_vecs, sparse_vecs, colbert_vecs)."""
|
110 |
+
|
111 |
+
@property
|
112 |
+
def outputs(self) -> dict[str, dict[int, str]]:
|
113 |
+
"""
|
114 |
+
Dict containing the axis definition of the output tensors to provide to the model.
|
115 |
+
|
116 |
+
Returns:
|
117 |
+
`Dict[str, Dict[int, str]]`: A mapping of each output name to a mapping of axis
|
118 |
+
position to the axes symbolic name.
|
119 |
+
"""
|
120 |
+
return copy.deepcopy(
|
121 |
+
OrderedDict(
|
122 |
+
{
|
123 |
+
"dense_vecs": {0: "batch_size", 1: "embedding"},
|
124 |
+
"sparse_vecs": {0: "batch_size", 1: "token", 2: "weight"},
|
125 |
+
"colbert_vecs": {0: "batch_size", 1: "token", 2: "embedding"},
|
126 |
+
}
|
127 |
+
)
|
128 |
+
)
|
129 |
+
|
130 |
+
|
131 |
+
def main(
|
132 |
+
output: Annotated[
|
133 |
+
str, typer.Option(help="Path to directory generated ONNX model is stored.")
|
134 |
+
] = "./onnx",
|
135 |
+
opset: Annotated[int, typer.Option(help="ONNX opset version number.")] = 17,
|
136 |
+
device: Annotated[
|
137 |
+
str, typer.Option(help="Device used to perform the export 'cpu' or 'cuda'.")
|
138 |
+
] = "cpu",
|
139 |
+
optimize: Annotated[
|
140 |
+
str,
|
141 |
+
typer.Option(
|
142 |
+
help=(
|
143 |
+
"Allows to run ONNX Runtime optimizations directly during the export. "
|
144 |
+
"Some of these optimizations are specific to ONNX Runtime, and "
|
145 |
+
"the resulting ONNX will not be usable with other runtime as OpenVINO or TensorRT. "
|
146 |
+
"Possible options:\n"
|
147 |
+
" - None: No optimization\n"
|
148 |
+
" - O1: Basic general optimizations\n"
|
149 |
+
" - O2: Basic and extended general optimizations, transformers-specific fusions\n"
|
150 |
+
" - O3: Same as O2 with GELU approximation\n"
|
151 |
+
" - O4: Same as O3 with mixed precision (fp16, GPU-only, requires `--device cuda`)"
|
152 |
+
),
|
153 |
+
),
|
154 |
+
] = None,
|
155 |
+
atol: Annotated[
|
156 |
+
str,
|
157 |
+
typer.Option(
|
158 |
+
help=(
|
159 |
+
"If specified, the absolute difference tolerance when validating the model. "
|
160 |
+
"Otherwise, the default atol for the model will be used."
|
161 |
+
)
|
162 |
+
),
|
163 |
+
] = None,
|
164 |
+
push_to_hub_repo_id: Annotated[
|
165 |
+
str,
|
166 |
+
typer.Option(
|
167 |
+
help="Huggingface Hub repo id in `namespace/model_name` format."
|
168 |
+
"If None, then model will not be pushed to Huggingface Hub."
|
169 |
+
),
|
170 |
+
] = None,
|
171 |
+
) -> None:
|
172 |
+
model = BGEM3InferenceModel(model_name="BAAI/bge-m3")
|
173 |
+
# tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-m3")
|
174 |
+
onnx_config = BGEM3OnnxConfig(config=model.config)
|
175 |
+
onnx_export_from_model(
|
176 |
+
model,
|
177 |
+
output=output,
|
178 |
+
task="feature-extraction",
|
179 |
+
custom_onnx_configs={"model": onnx_config},
|
180 |
+
opset=opset,
|
181 |
+
optimize=optimize,
|
182 |
+
atol=atol,
|
183 |
+
device=device,
|
184 |
+
)
|
185 |
+
|
186 |
+
# Copy this script and model card to export directory so it gets uploaded to Hub
|
187 |
+
try:
|
188 |
+
shutil.copy(__file__, output)
|
189 |
+
except Exception as ex:
|
190 |
+
print(f"Error copying script to export directory: {ex}")
|
191 |
+
try:
|
192 |
+
shutil.copy(str(Path(__file__).parent / "model_card.md"), output)
|
193 |
+
shutil.move(f"{output}/model_card.md", f"{output}/README.md")
|
194 |
+
except Exception as ex:
|
195 |
+
print(f"Error copying model card to export directory: {ex}")
|
196 |
+
|
197 |
+
# Optionally Push ONNX model to Hub
|
198 |
+
if push_to_hub_repo_id:
|
199 |
+
local_onnx_model = ORTModelForCustomTasks.from_pretrained(output)
|
200 |
+
local_onnx_model.push_to_hub(
|
201 |
+
save_directory=output,
|
202 |
+
repository_id=push_to_hub_repo_id,
|
203 |
+
use_auth_token=True,
|
204 |
+
)
|
205 |
+
|
206 |
+
|
207 |
+
|
208 |
+
if __name__ == "__main__":
|
209 |
+
typer.run(main)
|