|
--- |
|
license: apache-2.0 |
|
datasets: |
|
- KDAI-NLP/traffy-fondue-type-only |
|
language: |
|
- th |
|
metrics: |
|
- f1 |
|
tags: |
|
- roberta |
|
widget: |
|
- text: "แยกอโศกฝนตกน้ำท่วมหนักมากครับ ต้นไม้ก็ล้มขวางทางรถติดชห" |
|
--- |
|
|
|
# Traffy Complaint Classification |
|
|
|
This multi-label model is trained to automatically classify various types of traffic complaints expressed in Thai text, |
|
with the goal of minimizing the need for manual classification. Please note that the example inference provided by Hugging Face (Right-side UI) |
|
does not yet support multi-label classification. If you require multi-label classification, please use the code provided below. |
|
|
|
### Model Details |
|
|
|
Model Name: KDAI-NLP/wangchanberta-traffy-multi |
|
Tokenizer: airesearch/wangchanberta-base-att-spm-uncased |
|
License: Apache License 2.0 |
|
|
|
### How to Use |
|
|
|
```python |
|
|
|
!pip install sentencepiece |
|
|
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
from torch.nn.functional import sigmoid |
|
import json |
|
|
|
# Target lists |
|
target_list = [ |
|
'ความสะอาด', 'สายไฟ', 'สะพาน', 'ถนน', 'น้ำท่วม', |
|
'ร้องเรียน', 'ท่อระบายน้ำ', 'ความปลอดภัย', 'คลอง', 'แสงสว่าง', |
|
'ทางเท้า', 'จราจร', 'กีดขวาง', 'การเดินทาง', 'เสียงรบกวน', |
|
'ต้นไม้', 'สัตว์จรจัด', 'เสนอแนะ', 'คนจรจัด', 'ห้องน้ำ', |
|
'ป้ายจราจร', 'สอบถาม', 'ป้าย', 'PM2.5' |
|
] |
|
|
|
# Load tokenizer and model |
|
tokenizer = AutoTokenizer.from_pretrained("airesearch/wangchanberta-base-att-spm-uncased") |
|
model = AutoModelForSequenceClassification.from_pretrained("KDAI-NLP/wangchanberta-traffy-multi") |
|
|
|
# Example text to classify |
|
text = "ช่วยด้วยครับถนนน้ำท่วมอีกแล้ว ต้นไม้ก็ล้มขวางทาง กลับบ้านไม่ได้" |
|
|
|
# Encode the text using the tokenizer |
|
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=256) |
|
|
|
# Get model predictions (logits) |
|
with torch.no_grad(): |
|
logits = model(**inputs).logits |
|
|
|
# Apply sigmoid function to convert logits to probabilities |
|
probabilities = sigmoid(logits) |
|
|
|
# Map probabilities to corresponding labels |
|
probabilities = probabilities.squeeze().tolist() |
|
label_probabilities = zip(target_list, probabilities) |
|
|
|
# Print labels with probabilities |
|
for label, probability in label_probabilities: |
|
print(f"{label}: {probability:.4f}") |
|
|
|
# Or JSON |
|
# Create a dictionary for labels and probabilities |
|
results_dict = {label: probability for label, probability in label_probabilities} |
|
|
|
# Convert dictionary to JSON string |
|
results_json = json.dumps(results_dict, ensure_ascii=False, indent=4) |
|
|
|
# Print the JSON string |
|
print(results_json) |
|
``` |
|
|
|
## Training Details |
|
|
|
The model was trained on traffic complaint data API (included stopwords) using the airesearch/wangchanberta-base-att-spm-uncased base model. This is a multi-label classification task with a total of 24 classes. |
|
|
|
## Training Scores |
|
|
|
| Model | Stopword | Epoch | Training Loss | Validation Loss | F1 | Accuracy | |
|
| ---------------------------------- | -------- | ----- | ------------- | --------------- | ------- | -------- | |
|
| wangchanberta-base-att-spm-uncased | Included | 0 | 0.0322 | 0.034822 | 0.7015 | 0.7569 | |
|
| wangchanberta-base-att-spm-uncased | Included | 2 | 0.0207 | 0.026364 | 0.8405 | 0.7821 | |
|
| wangchanberta-base-att-spm-uncased | Included | 4 | 0.0165 | 0.025142 | 0.8458 | 0.7934 | |
|
|
|
|
|
Feel free to customize the README further if needed. |