Tihsrah-CD commited on
Commit
a9a3816
1 Parent(s): 53ece0c

feat: Add inference code for the Topic Classifier model

Browse files

Added `model_fn` and `predict_fn` functions to load the model and run inference. Updated `README.md` to include the new inference instructions and usage example.

Files changed (2) hide show
  1. README.md +30 -0
  2. code/code_inference.py +24 -0
README.md CHANGED
@@ -129,6 +129,36 @@ The model's evaluation results are as follows:
129
  - **Evaluation Samples Per Second:** 151.586
130
  - **Evaluation Steps Per Second:** 2.391
131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  ## Conclusion
133
 
134
  The Topic Classifier achieves high accuracy, precision, recall, and F1-score, making it a reliable model for categorizing text across the domains of corporate documents, financial content, harmful content, and medical texts. The model is optimized for immediate deployment and works efficiently in real-world applications.
 
129
  - **Evaluation Samples Per Second:** 151.586
130
  - **Evaluation Steps Per Second:** 2.391
131
 
132
+ #### Inference Code
133
+
134
+ ```python
135
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
136
+
137
+
138
+ def model_fn(model_dir):
139
+ """
140
+ Load the model and tokenizer from the specified paths
141
+ :param model_dir:
142
+ :return:
143
+ """
144
+ tokenizer = AutoTokenizer.from_pretrained(model_dir)
145
+ model = AutoModelForSequenceClassification.from_pretrained(model_dir)
146
+ return model, tokenizer
147
+
148
+
149
+ def predict_fn(data, model_and_tokenizer):
150
+ # destruct model and tokenizer
151
+ model, tokenizer = model_and_tokenizer
152
+
153
+ bert_pipe = pipeline("text-classification", model=model, tokenizer=tokenizer,
154
+ truncation=True, max_length=512, return_all_scores=True)
155
+ # Tokenize the input, pick up first 512 tokens before passing it further
156
+ tokens = tokenizer.encode(data['inputs'], add_special_tokens=False, max_length=512, truncation=True)
157
+ input_data = tokenizer.decode(tokens)
158
+ return bert_pipe(input_data)
159
+
160
+ ```
161
+
162
  ## Conclusion
163
 
164
  The Topic Classifier achieves high accuracy, precision, recall, and F1-score, making it a reliable model for categorizing text across the domains of corporate documents, financial content, harmful content, and medical texts. The model is optimized for immediate deployment and works efficiently in real-world applications.
code/code_inference.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
2
+
3
+
4
+ def model_fn(model_dir):
5
+ """
6
+ Load the model and tokenizer from the specified paths
7
+ :param model_dir:
8
+ :return:
9
+ """
10
+ tokenizer = AutoTokenizer.from_pretrained(model_dir)
11
+ model = AutoModelForSequenceClassification.from_pretrained(model_dir)
12
+ return model, tokenizer
13
+
14
+
15
+ def predict_fn(data, model_and_tokenizer):
16
+ # destruct model and tokenizer
17
+ model, tokenizer = model_and_tokenizer
18
+
19
+ bert_pipe = pipeline("text-classification", model=model, tokenizer=tokenizer,
20
+ truncation=True, max_length=512, return_all_scores=True)
21
+ # Tokenize the input, pick up first 512 tokens before passing it further
22
+ tokens = tokenizer.encode(data['inputs'], add_special_tokens=False, max_length=512, truncation=True)
23
+ input_data = tokenizer.decode(tokens)
24
+ return bert_pipe(input_data)