shuttie commited on
Commit
7176493
1 Parent(s): 8f3adae

set proper input names and types during export

Browse files
Files changed (2) hide show
  1. convert.py +10 -3
  2. pytorch_model.onnx +2 -2
convert.py CHANGED
@@ -1,11 +1,18 @@
1
  from transformers import AutoTokenizer, AutoModel
2
  import torch
3
 
 
 
4
  model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
5
  model.eval()
6
 
7
- sample = torch.randint(low=0, high=1, size=(1,128))
8
- input = (sample, sample, sample)
 
 
 
9
 
10
- torch.onnx.export(model, input, 'pytorch_model.onnx', export_params=True)
 
 
11
 
 
1
  from transformers import AutoTokenizer, AutoModel
2
  import torch
3
 
4
+ max_seq_length=128
5
+
6
  model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
7
  model.eval()
8
 
9
+ inputs = {"input_ids": torch.ones(1, max_seq_length, dtype=torch.int64),
10
+ "attention_mask": torch.ones(1, max_seq_length, dtype=torch.int64),
11
+ "token_type_ids": torch.ones(1, max_seq_length, dtype=torch.int64)}
12
+
13
+ symbolic_names = {0: 'batch_size', 1: 'max_seq_len'}
14
 
15
+ torch.onnx.export(model, args=tuple(inputs.values()), f='pytorch_model.onnx', export_params=True,
16
+ input_names=['input_ids', 'attention_mask', 'token_type_ids'], output_names=['output'],
17
+ dynamic_axes={'input_ids': symbolic_names, 'attention_mask': symbolic_names, 'token_type_ids': symbolic_names})
18
 
pytorch_model.onnx CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:5ec02f3d61fd002cc3c4220ba2f7851fad67abf4004312753f866a9c37a5693c
3
- size 90933591
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:79c87071eeceb3a40ef53b96aa8b3f9bfd4d3152023a5ae25ed134eed59c724e
3
+ size 90984197