|
--- |
|
datasets: |
|
- snli |
|
- anli |
|
- multi_nli |
|
- multi_nli_mismatch |
|
- fever |
|
license: mit |
|
--- |
|
This is a strong pre-trained RoBERTa-Large NLI model. |
|
|
|
The training data is a combination of well-known NLI datasets: [`SNLI`](https://nlp.stanford.edu/projects/snli/), [`MNLI`](https://cims.nyu.edu/~sbowman/multinli/), [`FEVER-NLI`](https://github.com/easonnie/combine-FEVER-NSMN/blob/master/other_resources/nli_fever.md), [`ANLI (R1, R2, R3)`](https://github.com/facebookresearch/anli). |
|
Other pre-trained NLI models including `RoBERTa`, `ALBert`, `BART`, `ELECTRA`, `XLNet` are also available. |
|
|
|
Trained by [Yixin Nie](https://easonnie.github.io), [original source](https://github.com/facebookresearch/anli). |
|
|
|
Try the code snippet below. |
|
``` |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
import torch |
|
|
|
if __name__ == '__main__': |
|
max_length = 256 |
|
|
|
premise = "Two women are embracing while holding to go packages." |
|
hypothesis = "The men are fighting outside a deli." |
|
|
|
hg_model_hub_name = "ynie/roberta-large-snli_mnli_fever_anli_R1_R2_R3-nli" |
|
# hg_model_hub_name = "ynie/albert-xxlarge-v2-snli_mnli_fever_anli_R1_R2_R3-nli" |
|
# hg_model_hub_name = "ynie/bart-large-snli_mnli_fever_anli_R1_R2_R3-nli" |
|
# hg_model_hub_name = "ynie/electra-large-discriminator-snli_mnli_fever_anli_R1_R2_R3-nli" |
|
# hg_model_hub_name = "ynie/xlnet-large-cased-snli_mnli_fever_anli_R1_R2_R3-nli" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(hg_model_hub_name) |
|
model = AutoModelForSequenceClassification.from_pretrained(hg_model_hub_name) |
|
|
|
tokenized_input_seq_pair = tokenizer.encode_plus(premise, hypothesis, |
|
max_length=max_length, |
|
return_token_type_ids=True, truncation=True) |
|
|
|
input_ids = torch.Tensor(tokenized_input_seq_pair['input_ids']).long().unsqueeze(0) |
|
# remember bart doesn't have 'token_type_ids', remove the line below if you are using bart. |
|
token_type_ids = torch.Tensor(tokenized_input_seq_pair['token_type_ids']).long().unsqueeze(0) |
|
attention_mask = torch.Tensor(tokenized_input_seq_pair['attention_mask']).long().unsqueeze(0) |
|
|
|
outputs = model(input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
labels=None) |
|
# Note: |
|
# "id2label": { |
|
# "0": "entailment", |
|
# "1": "neutral", |
|
# "2": "contradiction" |
|
# }, |
|
|
|
predicted_probability = torch.softmax(outputs[0], dim=1)[0].tolist() # batch_size only one |
|
|
|
print("Premise:", premise) |
|
print("Hypothesis:", hypothesis) |
|
print("Entailment:", predicted_probability[0]) |
|
print("Neutral:", predicted_probability[1]) |
|
print("Contradiction:", predicted_probability[2]) |
|
``` |
|
|
|
More in [here](https://github.com/facebookresearch/anli/blob/master/src/hg_api/interactive_eval.py). |
|
|
|
Citation: |
|
``` |
|
@inproceedings{nie-etal-2020-adversarial, |
|
title = "Adversarial {NLI}: A New Benchmark for Natural Language Understanding", |
|
author = "Nie, Yixin and |
|
Williams, Adina and |
|
Dinan, Emily and |
|
Bansal, Mohit and |
|
Weston, Jason and |
|
Kiela, Douwe", |
|
booktitle = "Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics", |
|
year = "2020", |
|
publisher = "Association for Computational Linguistics", |
|
} |
|
``` |
|
|