kelingwang's picture
Update README.md
000eba8 verified
---
license: mit
language:
- en
metrics:
- accuracy
- mse
- f1
base_model:
- dmis-lab/biobert-base-cased-v1.2
- google-bert/bert-base-cased
pipeline_tag: text-classification
model-index:
- name: bert-causation-rating-dr2
results:
- task:
type: text-classification
dataset:
name: rating_dr2
type: dataset
metrics:
- name: off by 1 accuracy
type: accuracy
value: 74.78991596638656
- name: mean squared error for ordinal data
type: mse
value: 0.773109243697479
- name: weighted F1 score
type: f1
value: 0.76386248572931
- name: Kendall's tau coefficient
type: Kendall's tau
value: 0.8081294201575603
source:
name: Keling Wang
url: https://github.com/Keling-Wang
datasets:
- kelingwang/causation_strength_rating
---
# Model description
This `bert-causation-rating-dr2` model is a fine-tuned [biobert-base-cased-v1.2](https://huggingface.co/dmis-lab/biobert-base-cased-v1.2) model on a small set of manually annotated texts with causation labels. This model is tasked with classifying a sentence into different levels of strength of causation expressed in this sentence.
Before tuning on this dataset, the `biobert-base-cased-v1.2` model is fine-tuned on a dataset containing causation labels from a published paper. This model starts from pre-trained [`kelingwang/bert-causation-rating-pubmed`](https://huggingface.co/kelingwang/bert-causation-rating-pubmed). For more information please view the link and my [GitHub page](https://github.com/Keling-Wang/causation_rating).
The sentences in the dataset were rated independently by two researchers. This `dr2` version is tuned on the set of sentences with labels rated by Rater 2 and 3.
# Intended use and limitations
This model is primarily used to rate for the strength of expressed causation in a sentence extracted from a clinical guideline in the field of diabetes mellitus management.
This model predicts strength of causation (SoC) labels based on the text inputs as:
* -1: No correlation or variable relationships mentioned in the sentence.
* 0: There is correlational relationships but not causation in the sentence.
* 1: The sentence expresses weak causation.
* 2: The sentence expresses moderate causation.
* 3: The sentence expresses strong causation.
*NOTE:* The model output is five one-hot logits and will be 0-index based, and the labels will be 0 to 4. It is good to use [this `python` module](https://github.com/Keling-Wang/causation_rating/blob/main/tests/prediction_from_pretrained.py) if one wants to make predictions.
# Performance and hyperparameters
## Test metrics
This model achieves the following results on the test dataset. The test dataset is a 25% held-out stratified split of the entire dataset with `SEED=114514`.
* Loss: 18.2347
* Off-by-1 accuracy: 74.7899
* Off-by-2 accuracy: 91.5966
* MSE for ordinal data: 0.7731
* Weighted F1: 0.7639
* Kendall's Tau: 0.8081
## Hyperparameter tuning metrics
This model achieves the following averaged results during 4-fold cross-validation with best hyperparameters in hyperparameter tuning process:
* Loss: 0.519251
* Off-by-1 accuracy: 98.3803
* Off-by-2 accuracy: 99.8944
* MSE for ordinal data: 0.02359
* Weighted F1: 0.9837
* Kendall's Tau: 0.9901
This performance is achieved with the following hyperparameters:
* Learning rate: 7.96862e-05
* Weight decay: 0.148775
* Warmup ratio: 0.460611
* Power of polynomial learning rate scheduler: 1.129829
* Power to the distance measure used in the loss function \alpha: 3.0
# Training settings
The following training configurations apply:
* Pre-trained model: `kelingwang/bert-causation-rating-pubmed`
* `seed`: 114514
* `batch_size`: 128
* `epoch`: 8
* `max_length` in `torch.utils.data.Dataset`: 128
* Loss function: the [OLL loss](https://aclanthology.org/2022.coling-1.407/) with a tunable hyperparameter \alpha (Power to the distance measure used in the loss function).
* `lr`: 7.96862e-05
* `weight_decay`: 0.148775
* `warmup_ratio`: 0.460611
* `lr_scheduler_type`: polynomial
* `lr_scheduler_kwargs`: `{"power": 1.129829, "lr_end": 1e-8}`
* Power to the distance measure used in the loss function \alpha: 3.0
# Framework versions and devices
This model is run on a NVIDIA P100 CPU provided by Kaggle.
Framework versions are:
* python==3.10.14
* cuda==12.4
* NVIDIA-SMI==550.90.07
* torch=2.4.0
* transformers==4.45.1
* scikit-learn==1.2.2
* optuna==4.0.0
* nlpaug==1.1.11