--- license: mit datasets: - AmelieSchreiber/general_binding_sites language: - en metrics: - precision - recall - f1 library_name: transformers tags: - biology - esm - esm2 - ESM-2 - protein language model --- # ESM-2 for General Protein Binding Site Prediction This model is trained to predict general binding sites of proteins using on the sequence. This is a finetuned version of `esm2_t6_8M_UR50D`, trained on [this dataset](https://huggingface.co/datasets/AmelieSchreiber/general_binding_sites). The data is not filtered by family, and thus the model may be overfit to some degree. ## Training ``` epoch 3: Training Loss Validation Loss Precision Recall F1 Auc 0.031100 0.074720 0.684798 0.966856 0.801743 0.980853 ``` ``` wandb: lr: 0.0004977045729600779 wandb: lr_scheduler_type: cosine wandb: max_grad_norm: 0.5 wandb: num_train_epochs: 3 wandb: per_device_train_batch_size: 8 wandb: weight_decay: 0.025 ``` ## Using the Model Try pasting a protein sequence into the cell on the right and clicking on "Compute". For example, try ``` MNSVTVSHAPYTIAYHDDWEPVMSQLVEFYNEAASWLLRDETSPIPSKFNIQLKQPLRNKRVCVFGIDPYPKDGTGVPFESPNFTKKSIKEIASSISRLMGVIDYEGYNLNIIDGVIPWNYYLSCKLGETKSHAIYWDKISKLLLQHITKHVSVLYCLGKTDFSNIRAKLESPVTTIVGYHPSARDRQFEKDRSFEIINVLLELDNKVPLNWAQGFIY ``` To use the model, try running: ```python import torch from transformers import AutoModelForTokenClassification, AutoTokenizer def predict_binding_sites(model_path, protein_sequences): """ Predict binding sites for a collection of protein sequences. Parameters: - model_path (str): Path to the saved model. - protein_sequences (List[str]): List of protein sequences. Returns: - List[List[str]]: Predicted labels for each sequence. """ # Load tokenizer and model tokenizer = AutoTokenizer.from_pretrained(model_path) model = AutoModelForTokenClassification.from_pretrained(model_path) # Ensure model is in evaluation mode model.eval() # Tokenize sequences inputs = tokenizer(protein_sequences, return_tensors="pt", padding=True, truncation=True) # Move to the same device as model and obtain logits with torch.no_grad(): logits = model(**inputs).logits # Obtain predicted labels predicted_labels = torch.argmax(logits, dim=-1).cpu().numpy() # Convert label IDs to human-readable labels id2label = model.config.id2label human_readable_labels = [[id2label[label_id] for label_id in sequence] for sequence in predicted_labels] return human_readable_labels # Usage: model_path = "AmelieSchreiber/esm2_t6_8M_general_binding_sites_v2" # Replace with your model's path unseen_proteins = [ "MKVEEILEKALELVIPDEEEVRKGREAEEELRRRLDELGVEYVFVGSYARNTWLKGSLEIDVFLLFPEEFSKEELRERGLEIGKAVLDSYEIRYAEHPYVHGVVKGVEVDVVPCYKLKEPKNIKSAVDRTPFHHKWLEGRIKGKENEVRLLKGFLKANGIYGAEYKVRGFSGYLCELLIVFYGSFLETVKNARRWTRRTVIDVAKGEVRKGEEFFVVDPVDEKRNVAANLSLDNLARFVHLCREFMEAPSLGFFKPKHPLEIEPERLRKIVEERGTAVFAVKFRKPDIVDDNLYPQLERASRKIFEFLERENFMPLRSAFKASEEFCYLLFECQIKEISRVFRRMGPQFEDERNVKKFLSRNRAFRPFIENGRWWAFEMRKFTTPEEGVRSYASTHWHTLGKNVGESIREYFEIISGEKLFKEPVTAELCEMMGVKD", "MKVEEILEKALELVIPDEEEVRKGREAEEELRRRLDELGVEYVFVGSYARNTWLKGSLEIAVFLLFPEEFSKEELRERGLEIGKAVLDSYEIRYAEHPYVHGVVKGVEVDVVPCYKLKEPKNIKSAVDRTPFHHKWLEGRIKGKENEVRLLKGFLKANGIYGAEYKVRGFSGYLCELLIVFYGSFLETVKNARRWTRRTVIDVAKGEVRKGEEFFVVDPVDEKRNVAANLSLDNLARFVHLCREFMEAPSLGFFKVKHPLEIEPERLRKIVEERGTAVFAVKFRKPDIVDDNLYPQLERASRKIFEFLERENFMPLRSAFKASEEFCYLLFECQIKEISRVFRRMGPQFEDERNVKKFLSRNRAFRPFIENGRWWAFEMRKFTTPEEGVRSYASTHWHTLGKNVGESIREYFEIISGEKLFKEPVTAELCEMMGVKD", "MKVEEILEKALELVIPDEEEVRKGREAEEELRRRLDELGVEAVFVGSYARNTWLKGSLEIAVFLLFPEEFSKEELRERGLEIEKAVLDSYEIRYAEHPYVHGVVKGVEVDVVPCYKLKEPKNIKSAVDRTPFHHKELEGRIKGKENEVRLLKGFLKANGIYGAEYAVRGFSGYLCELLIVFYGSFLETVKNARRWTRRTVIDVAKGEVRKGEEFFVVDPVDEKRNVAANLSLDNLARFVHLCREFMEAPSLGFFKVKHPLEIEPERLRKIVEERGTAVFMVKFRKPDIVDDNLYPQLRRASRKIFEFLERNNFMPLRSAFKASEEFCYLLFECQIKEISDVFRRMGPLFEDERNVKKFLSRNRALRPFIENGRWWIFEMRKFTTPEEGVRSYASTHWHTLGKNVGESIREYFEIISGEKLFKEPVTAELCRMMGVKD", "MKVEEILEKALELVIPDEEEVRKGREAEEELRRRLDELGVEAVFVGSYARNTWLKGSLEIAVFLLFPEEFSKEELRERGLEIEKAVLDSYGIRYAEHPYVHGVVKGVELDVVPCYKLKEPKNIKSAVDRTPFHHKELEGRIKGKENEYRSLKGFLKANGIYGAEYAVRGFSGYLCELLIVFYGSFLETVKNARRWTRKTVIDVAKGEVRKGEEFFVVDPVDEKRNVAALLSLDNLARFVHLCREFMEAVSLGFFKVKHPLEIEPERLRKIVEERGTAVFMVKFRKPDIVDDNLYPQLRRASRKIFEFLERNNFMPLRRAFKASEEFCYLLFEQQIKEISDVFRRMGPLFEDERNVKKFLSRNRALRPFIENGRWWIFEMRKFTTPEEGVRSYASTHWHTLGKNVGESIREYFEIIEGEKLFKEPVTAELCRMMGVKD" ] # Replace with your unseen protein sequences predictions = predict_binding_sites(model_path, unseen_proteins) predictions ```