lbourdois's picture
Add multilingual to the language tag
e061483
|
raw
history blame
1.68 kB
metadata
language:
  - ar
  - bg
  - de
  - el
  - en
  - es
  - fr
  - ru
  - th
  - tr
  - ur
  - vn
  - zh
  - multilingual
tags:
  - zero-shot-classification
datasets:
  - SNLI
  - MNLI
  - ANLI
  - XNLI

A cross-attention NLI model trained for zero-shot and few-shot text classification.

The base model is xlm-roberta-base, trained with the code from here; on SNLI, MNLI, ANLI and XNLI.

Usage:

from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
import numpy as np

model = AutoModelForSequenceClassification.from_pretrained("symanto/xlm-roberta-base-snli-mnli-anli-xnli")
tokenizer = AutoTokenizer.from_pretrained("symanto/xlm-roberta-base-snli-mnli-anli-xnli")

input_pairs = [
               ("I like this pizza.", "The sentence is positive."),
               ("I like this pizza.", "The sentence is negative."),
               ("I mag diese Pizza.", "Der Satz ist positiv."),
               ("I mag diese Pizza.", "Der Satz ist negativ."),
               ("Me gusta esta pizza.", "Esta frase es positivo."),
               ("Me gusta esta pizza.", "Esta frase es negativo."),
]
inputs = tokenizer(input_pairs, truncation="only_first", return_tensors="pt", padding=True)
logits = model(**inputs).logits
probs = torch.softmax(logits, dim=1)
probs = probs[..., [0]].tolist()
print("probs", probs)
np.testing.assert_almost_equal(probs, [[0.83], [0.04], [1.00], [0.00], [1.00], [0.00]], decimal=2)