What are the input variables required for the NLLB model in the inference process
I don't know what the input variables are when I use onnxruntime inference after exporting the onnx model
among
inputs = tokenizer(text, return_tensors="pt", padding=True)
target_lang_id = np.array([[tokenizer.convert_tokens_to_ids('zho_Hans')]])
decoder_input_ids = target_lang_id
ort_input = {
"input_ids": inputs['input_ids'].numpy(),
"attention_mask": inputs['attention_mask'].numpy(),
'decoder_input_ids':decoder_input_ids
}
But it cannot be accurately translated
The complete code is as follows:
import onnxruntime as ort
import numpy as np
from transformers import , AutoTokenizer
ort_session = ort.InferenceSession('nllb-200-distilled-600M-ONNX/model_quantized.onnx')
text = "hello"
tokenizer = AutoTokenizer.from_pretrained('nllb-200-distilled-600M-ONNX', use_auth_token=True, src_lang='eng_Latn')
inputs = tokenizer(text, return_tensors="pt", padding=True)
target_lang_id = np.array([[tokenizer.convert_tokens_to_ids('zho_Hans')]])
decoder_input_ids = target_lang_id
print(inputs)
import onnxruntime as ort
inputs = ort_session.get_inputs()
for input in inputs:
print(f"{input}")
outputs = ort_session.run(None, {
"input_ids": inputs['input_ids'].numpy(),
"attention_mask": inputs['attention_mask'].numpy(),
'decoder_input_ids':decoder_input_ids
})
print(outputs)
logits = outputs[0]
predicted_ids = np.argmax(logits, axis=-1)
print(predicted_ids)
output_text = tokenizer.decode(predicted_ids[0], skip_special_tokens=True)
print("output: ", output_text)
What are the input variables and how can I use the onnx model to obtain accurate inference
Here are the running results:
inputs: {'input_ids': tensor([[256047, 1537, 17606, 248, 524, 300, 53, 43804, 248079,
30, 5076, 5057, 12620, 2]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
outputs: [array([[[ 0.07672643, 0.2473299 , 10.727799 , ..., 0.43490347,
-0.05795102, 0.4599975 ]]], dtype=float32)]
predicted_ids: [[248506]]
output_text: 的
decoder_input_ids:[[2, target_language_ids]]
can solve the above problems