vkoppaka commited on
Commit
3ac99d5
1 Parent(s): 67ba14f

First Version

Browse files
Files changed (5) hide show
  1. app.py +73 -0
  2. bert.onnx +3 -0
  3. distilbert.onnx +3 -0
  4. muril.onnx +3 -0
  5. roberta.onnx +3 -0
app.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import onnxruntime as ort
2
+ from transformers import AutoTokenizer
3
+ import gradio as gr
4
+
5
+ # Define available models with their ONNX file paths and tokenizer names
6
+ models = {
7
+ "DistilBERT": {
8
+ "onnx_model_path": "distilbert.onnx",
9
+ "tokenizer_name": "distilbert-base-multilingual-cased",
10
+ },
11
+ "BERT": {
12
+ "onnx_model_path": "bert.onnx",
13
+ "tokenizer_name": "bert-base-multilingual-cased",
14
+ },
15
+ "MuRIL": {
16
+ "onnx_model_path": "muril.onnx",
17
+ "tokenizer_name": "google/muril-base-cased",
18
+ },
19
+ "RoBERTa": {
20
+ "onnx_model_path": "roberta.onnx",
21
+ "tokenizer_name": "cardiffnlp/twitter-roberta-base-emotion",
22
+ },
23
+ }
24
+
25
+ # Load models and tokenizers into memory
26
+ model_sessions = {}
27
+ tokenizers = {}
28
+
29
+ for model_name, config in models.items():
30
+ print(f"Loading {model_name}...")
31
+ model_sessions[model_name] = ort.InferenceSession(config["onnx_model_path"])
32
+ tokenizers[model_name] = AutoTokenizer.from_pretrained(config["tokenizer_name"])
33
+
34
+ print("All models loaded!")
35
+
36
+ # Prediction function
37
+ def predict_with_model(text, model_name):
38
+ # Select the appropriate ONNX session and tokenizer
39
+ ort_session = model_sessions[model_name]
40
+ tokenizer = tokenizers[model_name]
41
+
42
+ # Tokenize the input text
43
+ inputs = tokenizer(text, return_tensors="np", padding=True, truncation=True)
44
+
45
+ # Run ONNX inference
46
+ outputs = ort_session.run(None, {
47
+ "input_ids": inputs["input_ids"],
48
+ "attention_mask": inputs["attention_mask"],
49
+ })
50
+
51
+ # Post-process the output
52
+ logits = outputs[0]
53
+ label = "Hate Speech" if logits[0][1] > logits[0][0] else "Not Hate Speech"
54
+ return label
55
+
56
+ # Define Gradio interface
57
+ interface = gr.Interface(
58
+ fn=predict_with_model,
59
+ inputs=[
60
+ gr.Textbox(label="Enter text to classify"),
61
+ gr.Dropdown(
62
+ choices=list(models.keys()),
63
+ label="Select a model",
64
+ ),
65
+ ],
66
+ outputs="text",
67
+ title="Multi-Model Hate Speech Detection",
68
+ description="Choose a model and enter text to classify whether it's hate speech.",
69
+ )
70
+
71
+ # Launch the app
72
+ if __name__ == "__main__":
73
+ interface.launch()
bert.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4d38eb2aeab1422656bed7ceb8a0979ea43cd65f5d3f80cdc3d73f1f02482cb1
3
+ size 711692681
distilbert.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6cfada1fc91fbf304085d4e72d3fc8c47ad3059196adca64f9dad4341c8b8f82
3
+ size 541440517
muril.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1142234505ed706b0d9f9a9eb5b2d5b9079647b52ea2b918bb96884758c8395f
3
+ size 950503817
roberta.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e5aa28d6d11a3e37527d7f9428069afb92464cb0c89384ba6e1b31d497466f21
3
+ size 498870601