|
from transformers import pipeline |
|
import gradio as gr |
|
|
|
classifier = pipeline("zero-shot-classification", model="DeepPavlov/xlm-roberta-large-en-ru-mnli") |
|
|
|
def wrap_classifier(text, labels, template): |
|
labels = labels.split(",") |
|
outputs = classifier(text, labels, hypothesis_template=template) |
|
return outputs["labels"][0] |
|
|
|
gr.Interface( |
|
fn=wrap_classifier, |
|
title="Zero-shot Classification", |
|
inputs=[ |
|
gr.inputs.Textbox( |
|
lines=3, |
|
label="Text to classify", |
|
default="Sneaky Credit Card Tactics Keep an eye on your credit card issuers -- they may be about to raise your rates." |
|
), |
|
gr.inputs.Textbox( |
|
lines=1, |
|
label="Candidate labels separated with commas (no spaces)", |
|
default="World,Sports,Business,Sci/Tech", |
|
placeholder="World,Sports,Business,Sci/Tech", |
|
), |
|
gr.inputs.Textbox(lines=1, label="Template", default="The topic of this text is {}.", placeholder="The topic of this text is {}.") |
|
], |
|
outputs=[ |
|
gr.outputs.Textbox(label="Predicted label") |
|
], |
|
enable_queue=True, |
|
allow_screenshot=False, |
|
allow_flagging=False, |
|
examples=[ |
|
["Indian state rolls out wireless broadband Government in South Indian state of Kerala sets up wireless kiosks as part of initiative to bridge digital divide."] |
|
] |
|
).launch(debug=True) |