File size: 639 Bytes
9ae1ebe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from transformers import AutoModel
import torch

max_seq_length = 384

model = AutoModel.from_pretrained("sentence-transformers/all-mpnet-base-v2")
model.eval()

inputs = {
    "input_ids": torch.ones(1, max_seq_length, dtype=torch.int64),
    "attention_mask": torch.ones(1, max_seq_length, dtype=torch.int64),
}

symbolic_names = {0: 'batch_size', 1: 'max_seq_len'}

torch.onnx.export(
    model,args=tuple(inputs.values()),
    f="model.onnx",
    export_params=True, 
    input_names=["input_ids", "attention_mask"], output_names=["last_hidden_state"],
    dynamic_axes={"input_ids": symbolic_names, "attention_mask": symbolic_names}
)