File size: 1,924 Bytes
a049c3f
 
c20ef98
 
 
 
 
 
 
 
a049c3f
c20ef98
45e45d6
 
c20ef98
 
508e6f9
 
bbe0e74
 
 
 
 
 
 
 
 
c20ef98
45e45d6
 
 
 
 
 
 
0b6b91e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
---
license: mit
language:
- en
library_name: peft
tags:
- ESM-2
- QLoRA
- Binding Sites
- biology
---

# ESM-2 QLoRA

These are the checkpoints for the first ever QLoRA for ESM-2! They haven't been checked for overfitting yet, so use with caution!
You can load and use them similarly to the LoRA models. This is the smallest `esm2_t6_8M_UR50D` model, so the metrics aren't great. 
Scaling to larger models for better metrics is in progress. These checkpoints were trained using 
[the 600K dataset](https://huggingface.co/datasets/AmelieSchreiber/600K_data). To replicate the training of QLoRA for ESM-2 models, 
you can use the `conda-environment.yml` file. However, for the next week or two (28/09/2023) you will need to uninstall transformers
and use this instead:

```
pip install --upgrade git+https://github.com/huggingface/transformers.git
```

Once the transformers library is updated, you should be able to simply use the latest version of transformers and gradient checkpointing 
will be fully enabled, and QLoRA compatibility should be fully integrated into ESM-2 models. 

## QLoRA Info

Note, we are only training 0.58% of the parameters, using only the query, key, and value weight matrices. 

```
trainable params: 23682 || all params: 4075265 || trainable%: 0.5811155838945443
```

## Testing for Overfitting

### Checkpoint 1

### Checkpoint 2

### Checkpoint 3

### Checkpoint 4

```python
Train metrics:
{'eval_loss': 0.24070295691490173,
'eval_accuracy': 0.9018779246397052,
'eval_precision': 0.16624103834249204,
'eval_recall': 0.8651772818812425,
'eval_f1': 0.27889357183237473,
'eval_auc': 0.8839390799308487,
'eval_mcc': 0.3536803490333407}

Test metrics:
{'eval_loss': 0.26776671409606934,
'eval_accuracy': 0.8902711124906878,
'eval_precision': 0.13008662855482372,
'eval_recall': 0.7084623832213568,
'eval_f1': 0.219811797752809,
'eval_auc': 0.8013943890942485,
'eval_mcc': 0.2721459410994918}
```