|
--- |
|
license: mit |
|
language: |
|
- en |
|
library_name: peft |
|
tags: |
|
- ESM-2 |
|
- QLoRA |
|
- Binding Sites |
|
- biology |
|
--- |
|
|
|
# ESM-2 QLoRA |
|
|
|
These are the checkpoints for the first ever QLoRA for ESM-2! They haven't been checked for overfitting yet, so use with caution! |
|
You can load and use them similarly to the LoRA models. This is the smallest `esm2_t6_8M_UR50D` model, so the metrics aren't great. |
|
Scaling to larger models for better metrics is in progress. These checkpoints were trained using |
|
[the 600K dataset](https://huggingface.co/datasets/AmelieSchreiber/600K_data). To replicate the training of QLoRA for ESM-2 models, |
|
you can use the `conda-environment.yml` file. However, for the next week or two (28/09/2023) you will need to uninstall transformers |
|
and use this instead: |
|
|
|
``` |
|
pip install --upgrade git+https://github.com/huggingface/transformers.git |
|
``` |
|
|
|
Once the transformers library is updated, you should be able to simply use the latest version of transformers and gradient checkpointing |
|
will be fully enabled, and QLoRA compatibility should be fully integrated into ESM-2 models. |
|
|
|
## QLoRA Info |
|
|
|
Note, we are only training 0.58% of the parameters, using only the query, key, and value weight matrices. |
|
|
|
``` |
|
trainable params: 23682 || all params: 4075265 || trainable%: 0.5811155838945443 |
|
``` |
|
|
|
## Testing for Overfitting |
|
|
|
### Checkpoint 1 |
|
|
|
### Checkpoint 2 |
|
|
|
### Checkpoint 3 |
|
|
|
### Checkpoint 4 |
|
|
|
```python |
|
Train metrics: |
|
{'eval_loss': 0.24070295691490173, |
|
'eval_accuracy': 0.9018779246397052, |
|
'eval_precision': 0.16624103834249204, |
|
'eval_recall': 0.8651772818812425, |
|
'eval_f1': 0.27889357183237473, |
|
'eval_auc': 0.8839390799308487, |
|
'eval_mcc': 0.3536803490333407} |
|
|
|
Test metrics: |
|
{'eval_loss': 0.26776671409606934, |
|
'eval_accuracy': 0.8902711124906878, |
|
'eval_precision': 0.13008662855482372, |
|
'eval_recall': 0.7084623832213568, |
|
'eval_f1': 0.219811797752809, |
|
'eval_auc': 0.8013943890942485, |
|
'eval_mcc': 0.2721459410994918} |
|
``` |
|
|