--- license: mit --- # ESM-2 QLoRA for Predicting Binding Sites This model is the ESM-2 model [esm2_t12_35M_UR50D](https://huggingface.co/facebook/esm2_t12_35M_UR50D) finetuned with QLoRA on [this dataset](https://huggingface.co/datasets/AmelieSchreiber/2600K_binding_sites) of 2.6M protein sequences with binding and active site annotations from UniProt. The model and dataset size were scaled in a one-to-one way (following the Chinchilla paper) up from the smaller QLoRA adaptations of the `esm2_t6_8M_UR50D` models which were trained on 600K proteins. Since this model is 4.375 times larger, a dataset approximately 4.375 times larger is needed if Chinchilla scaling laws hold for QLoRA finetuning of protein language models. Determining if such scaling laws also hold is part of this project, so checking for improvements in performance metrics over a period of 3 epochs, as well as checking for signs of overfitting for each epoch are underway. ## QLoRA Info ``` trainable params: 71046 || all params: 17246053 || trainable%: 0.41195512967517844 ``` ```python 'eval_loss': 0.6011912822723389, 'eval_accuracy': 0.9297529150299436, 'eval_precision': 0.22835223718675476, 'eval_recall': 0.697386656717114, 'eval_f1': 0.3440490710592986, 'eval_auc': 0.8167222019799886, 'eval_mcc': 0.3730152153022164 ``` To use this model, run: ```python from transformers import AutoModelForTokenClassification, AutoTokenizer from peft import PeftModel import torch # Path to the saved LoRA model model_path = "AmelieSchreiber/esm2_t12_35M_qlora_binding_2600K_cp1" # ESM2 base model base_model_path = "facebook/esm2_t12_35M_UR50D" # Load the model base_model = AutoModelForTokenClassification.from_pretrained(base_model_path) loaded_model = PeftModel.from_pretrained(base_model, model_path) # Ensure the model is in evaluation mode loaded_model.eval() # Load the tokenizer loaded_tokenizer = AutoTokenizer.from_pretrained(base_model_path) # Protein sequence for inference protein_sequence = "MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT" # Replace with your actual sequence # Tokenize the sequence inputs = loaded_tokenizer(protein_sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length') # Run the model with torch.no_grad(): logits = loaded_model(**inputs).logits # Get predictions tokens = loaded_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens predictions = torch.argmax(logits, dim=2) # Define labels id2label = { 0: "No binding site", 1: "Binding site" } # Print the predicted labels for each token for token, prediction in zip(tokens, predictions[0].numpy()): if token not in ['', '', '']: print((token, id2label[prediction])) ```