|
from transformers import AutoModel |
|
import torch |
|
|
|
max_seq_length = 384 |
|
|
|
model = AutoModel.from_pretrained("sentence-transformers/all-mpnet-base-v2") |
|
model.eval() |
|
|
|
inputs = { |
|
"input_ids": torch.ones(1, max_seq_length, dtype=torch.int64), |
|
"attention_mask": torch.ones(1, max_seq_length, dtype=torch.int64), |
|
} |
|
|
|
symbolic_names = {0: 'batch_size', 1: 'max_seq_len'} |
|
|
|
torch.onnx.export( |
|
model,args=tuple(inputs.values()), |
|
f="model.onnx", |
|
export_params=True, |
|
input_names=["input_ids", "attention_mask"], output_names=["last_hidden_state"], |
|
dynamic_axes={"input_ids": symbolic_names, "attention_mask": symbolic_names} |
|
) |