|
--- |
|
license: mit |
|
language: |
|
- en |
|
base_model: |
|
- distilbert/distilbert-base-uncased |
|
pipeline_tag: text-classification |
|
--- |
|
|
|
# Topic Classifier |
|
|
|
This repository contains the Topic Classifier model developed by DAXA.AI. The Topic Classifier is a machine learning model designed to categorize text documents across various domains, such as corporate documents, financial texts, harmful content, and medical documents. |
|
|
|
## Model Details |
|
|
|
### Model Description |
|
|
|
The Topic Classifier is a BERT-based model, fine-tuned from the `distilbert-base-uncased` model. It is intended for categorizing text into specific topics, including "CORPORATE_DOCUMENTS," "FINANCIAL," "HARMFUL," and "MEDICAL." This model streamlines text classification tasks across multiple sectors, making it suitable for various business use cases. |
|
|
|
- **Developed by:** DAXA.AI |
|
- **Funded by:** Open Source |
|
- **Model type:** Text classification |
|
- **Language(s):** English |
|
- **License:** MIT |
|
- **Fine-tuned from:** `distilbert-base-uncased` |
|
|
|
### Model Sources |
|
|
|
- **Repository:** [https://huggingface.co/daxa-ai/topic-classifier](https://huggingface.co/daxa-ai/Topic-Classifier-2) |
|
- **Demo:** [https://huggingface.co/spaces/daxa-ai/Topic-Classifier-2](https://huggingface.co/spaces/daxa-ai/Topic-Classifier-2) |
|
|
|
## Usage |
|
|
|
### How to Get Started with the Model |
|
|
|
To use the Topic Classifier in your Python project, you can follow the steps below: |
|
|
|
```python |
|
# Import necessary libraries |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
import torch |
|
import joblib |
|
from huggingface_hub import hf_hub_url, cached_download |
|
|
|
# Load the tokenizer and model |
|
tokenizer = AutoTokenizer.from_pretrained("daxa-ai/topic-classifier") |
|
model = AutoModelForSequenceClassification.from_pretrained("daxa-ai/topic-classifier") |
|
|
|
# Example text |
|
text = "Please enter your text here." |
|
encoded_input = tokenizer(text, return_tensors='pt') |
|
output = model(**encoded_input) |
|
|
|
# Apply softmax to the logits |
|
probabilities = torch.nn.functional.softmax(output.logits, dim=-1) |
|
|
|
# Get the predicted label |
|
predicted_label = torch.argmax(probabilities, dim=-1) |
|
|
|
# URL of your Hugging Face model repository |
|
REPO_NAME = "daxa-ai/topic-classifier" |
|
|
|
# Path to the label encoder file in the repository |
|
LABEL_ENCODER_FILE = "label_encoder.joblib" |
|
|
|
# Construct the URL to the label encoder file |
|
url = hf_hub_url(REPO_NAME, filename=LABEL_ENCODER_FILE) |
|
|
|
# Download and cache the label encoder file |
|
filename = cached_download(url) |
|
|
|
# Load the label encoder |
|
label_encoder = joblib.load(filename) |
|
|
|
# Decode the predicted label |
|
decoded_label = label_encoder.inverse_transform(predicted_label.numpy()) |
|
|
|
print(decoded_label) |
|
``` |
|
|
|
## Training Details |
|
|
|
### Training Data |
|
|
|
The training dataset consists of 29,286 entries, categorized into four distinct labels. The distribution of these labels is presented below: |
|
|
|
| Document Type | Instances | |
|
| ------------------- | --------- | |
|
| CORPORATE_DOCUMENTS | 17,649 | |
|
| FINANCIAL | 3,385 | |
|
| HARMFUL | 2,388 | |
|
| MEDICAL | 5,864 | |
|
|
|
### Evaluation |
|
|
|
#### Testing Data & Metrics |
|
|
|
The model was evaluated on a dataset consisting of 4,565 entries. The distribution of labels in the evaluation set is shown below: |
|
|
|
| Document Type | Instances | |
|
| ------------------- | --------- | |
|
| CORPORATE_DOCUMENTS | 3,051 | |
|
| FINANCIAL | 409 | |
|
| HARMFUL | 246 | |
|
| MEDICAL | 859 | |
|
|
|
The evaluation metrics include precision, recall, and F1-score, calculated for each label: |
|
|
|
| Document Type | Precision | Recall | F1-Score | Support | |
|
| ------------------- | --------- | ------ | -------- | ------- | |
|
| CORPORATE_DOCUMENTS | 1.00 | 1.00 | 1.00 | 3,051 | |
|
| FINANCIAL | 0.95 | 0.96 | 0.96 | 409 | |
|
| HARMFUL | 0.95 | 0.95 | 0.95 | 246 | |
|
| MEDICAL | 0.99 | 1.00 | 0.99 | 859 | |
|
| Accuracy | | | 0.99 | 4,565 | |
|
| Macro Avg | 0.97 | 0.98 | 0.97 | 4,565 | |
|
| Weighted Avg | 0.99 | 0.99 | 0.99 | 4,565 | |
|
|
|
#### Test Data Evaluation Results |
|
|
|
The model's evaluation results are as follows: |
|
|
|
- **Evaluation Loss:** 0.0233 |
|
- **Accuracy:** 0.9908 |
|
- **Precision:** 0.9909 |
|
- **Recall:** 0.9908 |
|
- **F1-Score:** 0.9908 |
|
- **Evaluation Runtime:** 30.1149 seconds |
|
- **Evaluation Samples Per Second:** 151.586 |
|
- **Evaluation Steps Per Second:** 2.391 |
|
|
|
#### Inference Code |
|
|
|
```python |
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline |
|
|
|
|
|
def model_fn(model_dir): |
|
""" |
|
Load the model and tokenizer from the specified paths |
|
:param model_dir: |
|
:return: |
|
""" |
|
tokenizer = AutoTokenizer.from_pretrained(model_dir) |
|
model = AutoModelForSequenceClassification.from_pretrained(model_dir) |
|
return model, tokenizer |
|
|
|
|
|
def predict_fn(data, model_and_tokenizer): |
|
# destruct model and tokenizer |
|
model, tokenizer = model_and_tokenizer |
|
|
|
bert_pipe = pipeline("text-classification", model=model, tokenizer=tokenizer, |
|
truncation=True, max_length=512, return_all_scores=True) |
|
# Tokenize the input, pick up first 512 tokens before passing it further |
|
tokens = tokenizer.encode(data['inputs'], add_special_tokens=False, max_length=512, truncation=True) |
|
input_data = tokenizer.decode(tokens) |
|
return bert_pipe(input_data) |
|
|
|
``` |
|
|
|
## Conclusion |
|
|
|
The Topic Classifier achieves high accuracy, precision, recall, and F1-score, making it a reliable model for categorizing text across the domains of corporate documents, financial content, harmful content, and medical texts. The model is optimized for immediate deployment and works efficiently in real-world applications. |
|
|
|
For more information or to try the model yourself, check out the public space [here](https://huggingface.co/spaces/daxa-ai/Topic-Classifier-2). |