Fb_improved_zeroshot
Zero-Shot Model designed to classify academic search logs in German and English. Developed by students at ETH Zรผrich.
This model was trained using the bart-large-mnli checkpoint provided by Meta on Huggingface. It was then fine-tuned to suit the needs of this project.
NLI-based Zero-Shot Text Classification
This method is based on Natural Language Inference (NLI), see Yin et al.. The following tutorials are taken from the model card of bart-large-mnli.
With the zero-shot classification pipeline
The model can be loaded with the zero-shot-classification
pipeline like so:
from transformers import pipeline
classifier = pipeline("zero-shot-classification",
model="oigele/Fb_improved_zeroshot")
You can then use this pipeline to classify sequences into any of the class names you specify.
sequence_to_classify = "natural language processing"
candidate_labels = ['Location & Address', 'Employment', 'Organizational', 'Name', 'Service', 'Studies', 'Science']
classifier(sequence_to_classify, candidate_labels)
If more than one candidate label can be correct, pass multi_class=True
to calculate each class independently:
candidate_labels = ['Location & Address', 'Employment', 'Organizational', 'Name', 'Service', 'Studies', 'Science']
classifier(sequence_to_classify, candidate_labels, multi_class=True)
With manual PyTorch
# pose sequence as a NLI premise and label as a hypothesis
from transformers import AutoModelForSequenceClassification, AutoTokenizer
nli_model = AutoModelForSequenceClassification.from_pretrained('oigele/Fb_improved_zeroshot/')
tokenizer = AutoTokenizer.from_pretrained('facebook/bart-large-mnli')
premise = sequence
hypothesis = f'This is {label}.'
# run through model pre-trained on MNLI
x = tokenizer.encode(premise, hypothesis, return_tensors='pt',
truncation_strategy='only_first')
logits = nli_model(x.to(device))[0]
# we throw away "neutral" (dim 1) and take the probability of
# "entailment" (2) as the probability of the label being true
entail_contradiction_logits = logits[:,[0,2]]
probs = entail_contradiction_logits.softmax(dim=1)
prob_label_is_true = probs[:,1]
- Downloads last month
- 103
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social
visibility and check back later, or deploy to Inference Endpoints (dedicated)
instead.