Spaces:
Sleeping
Sleeping
ubuntu
commited on
Commit
•
3091067
1
Parent(s):
bfa4bdd
MVP
Browse files- app.py +15 -4
- constants.py +13 -0
- load_model.py +23 -0
- pretrained_acc935/config.json +38 -0
- pretrained_acc935/pytorch_model.bin +3 -0
- pretrained_acc935/training_args.bin +3 -0
app.py
CHANGED
@@ -1,7 +1,18 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
iface = gr.Interface(fn=greet, inputs="text", outputs="text")
|
7 |
-
iface.launch()
|
|
|
1 |
import gradio as gr
|
2 |
+
from load_model import load_model, predict_probs
|
3 |
+
from constants import title, description, examples
|
4 |
|
5 |
+
model = load_model()
|
6 |
+
|
7 |
+
demo = gr.Interface(
|
8 |
+
fn=lambda text: predict_probs(model, text),
|
9 |
+
inputs=gr.Textbox(label='News article title and description'),
|
10 |
+
outputs=gr.Label(num_top_classes=4),
|
11 |
+
examples=examples,
|
12 |
+
allow_flagging='never',
|
13 |
+
title=title,
|
14 |
+
description=description
|
15 |
+
)
|
16 |
+
|
17 |
+
demo.launch()
|
18 |
|
|
|
|
constants.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
title = 'Demo of News Classifier'
|
2 |
+
description='This is the demo of News Classifier. You can submit your news title and description and NN will classify it into 4 classes: World, Sports, Business and Sci/Tech'
|
3 |
+
|
4 |
+
examples = [
|
5 |
+
'Yandex School of Data Analysis is cool!',
|
6 |
+
"Five Killed in Al Qaeda Jailbreak in Kabul KABUL (Reuters) - Three Afghan prison guards and two prisoners were killed in a jail break attempt by al Qaeda inmates Friday and a shoot-out was going on between police and another two, the chief of Kabul's Pul-i-Charki prison told Reuters.",
|
7 |
+
'Olympic history for India, UAE An Indian army major shot his way to his country #39;s first ever individual Olympic silver medal on Tuesday, while in the same event an member of Dubai #39;s ruling family became the first ever medallist from the United Arab Emirates.',
|
8 |
+
"Fears for T N pension after talks Unions representing workers at Turner Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul.'"
|
9 |
+
'Monster Mashes Attract Masses Kaiju Big Battel -- a multimedia event in which costumed combatants spew toxic ooze on audience members -- is growing in popularity. There are already dedicated websites and a DVD series. Coming next: a book and TV pilot. By Xeni Jardin.',
|
10 |
+
'<script> alert("I love ML"); </script>',
|
11 |
+
]
|
12 |
+
|
13 |
+
__all__ = ['title', 'description', 'examples']
|
load_model.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import transformers
|
2 |
+
import torch
|
3 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
4 |
+
|
5 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
6 |
+
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased")
|
7 |
+
|
8 |
+
def predict_probs(model, text,
|
9 |
+
labels=['World', 'Sports', 'Business', 'Sci/Tech']):
|
10 |
+
with torch.no_grad():
|
11 |
+
tokens = tokenizer(text, padding="max_length", truncation=True, return_tensors='pt').to(device)
|
12 |
+
logits = model(**tokens).logits
|
13 |
+
probs = torch.nn.functional.softmax(logits)[0]
|
14 |
+
|
15 |
+
return {labels[i]: float(probs[i]) for i in range(min(len(probs), len(labels)))}
|
16 |
+
|
17 |
+
|
18 |
+
def load_model(labels_count=4):
|
19 |
+
model = AutoModelForSequenceClassification.from_pretrained("pretrained_acc935/", num_labels=labels_count).to(device)
|
20 |
+
return model
|
21 |
+
|
22 |
+
|
23 |
+
__all__ = ['predict_probs', 'load_model']
|
pretrained_acc935/config.json
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "/content/drive/MyDrive/checkpoint-2000",
|
3 |
+
"activation": "gelu",
|
4 |
+
"architectures": [
|
5 |
+
"DistilBertForSequenceClassification"
|
6 |
+
],
|
7 |
+
"attention_dropout": 0.1,
|
8 |
+
"dim": 768,
|
9 |
+
"dropout": 0.1,
|
10 |
+
"hidden_dim": 3072,
|
11 |
+
"id2label": {
|
12 |
+
"0": "LABEL_0",
|
13 |
+
"1": "LABEL_1",
|
14 |
+
"2": "LABEL_2",
|
15 |
+
"3": "LABEL_3"
|
16 |
+
},
|
17 |
+
"initializer_range": 0.02,
|
18 |
+
"label2id": {
|
19 |
+
"LABEL_0": 0,
|
20 |
+
"LABEL_1": 1,
|
21 |
+
"LABEL_2": 2,
|
22 |
+
"LABEL_3": 3
|
23 |
+
},
|
24 |
+
"max_position_embeddings": 512,
|
25 |
+
"model_type": "distilbert",
|
26 |
+
"n_heads": 12,
|
27 |
+
"n_layers": 6,
|
28 |
+
"output_past": true,
|
29 |
+
"pad_token_id": 0,
|
30 |
+
"problem_type": "single_label_classification",
|
31 |
+
"qa_dropout": 0.1,
|
32 |
+
"seq_classif_dropout": 0.2,
|
33 |
+
"sinusoidal_pos_embds": false,
|
34 |
+
"tie_weights_": true,
|
35 |
+
"torch_dtype": "float32",
|
36 |
+
"transformers_version": "4.28.1",
|
37 |
+
"vocab_size": 28996
|
38 |
+
}
|
pretrained_acc935/pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:176060d9274dd8947e999474862dc04e8eae26ce55c465d56c73b1b0f23d41d3
|
3 |
+
size 263173805
|
pretrained_acc935/training_args.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:197660446f201b2c7ee748857bb121f05f9102745742686f5f1a77d71930bb92
|
3 |
+
size 3579
|