|
--- |
|
license: mit |
|
datasets: |
|
- AmelieSchreiber/interaction_pairs |
|
language: |
|
- en |
|
library_name: transformers |
|
tags: |
|
- ESM-2 |
|
- biology |
|
- protein language model |
|
--- |
|
|
|
# ESM-2 for Interacting Proteins |
|
|
|
This model was finetuned on concatenated pairs of interacting proteins in much the same way as [PepMLM](https://huggingface.co/spaces/TianlaiChen/PepMLM). |
|
It is meant to generate interaction partners for proteins using the masked language modeling capabilities of ESM-2. The model is not |
|
well tested, so use with caution. This is just a preliminary experiment. |
|
|
|
## Using the Model |
|
|
|
To use the model, try running: |
|
|
|
```python |
|
from transformers import AutoTokenizer, EsmForMaskedLM |
|
import torch |
|
import pandas as pd |
|
import numpy as np |
|
from torch.distributions import Categorical |
|
|
|
def compute_pseudo_perplexity(model, tokenizer, protein_seq, binder_seq): |
|
sequence = protein_seq + binder_seq |
|
tensor_input = tokenizer.encode(sequence, return_tensors='pt').to(model.device) |
|
|
|
# Create a mask for the binder sequence |
|
binder_mask = torch.zeros(tensor_input.shape).to(model.device) |
|
binder_mask[0, -len(binder_seq)-1:-1] = 1 |
|
|
|
# Mask the binder sequence in the input and create labels |
|
masked_input = tensor_input.clone().masked_fill_(binder_mask.bool(), tokenizer.mask_token_id) |
|
labels = tensor_input.clone().masked_fill_(~binder_mask.bool(), -100) |
|
|
|
with torch.no_grad(): |
|
loss = model(masked_input, labels=labels).loss |
|
return np.exp(loss.item()) |
|
|
|
|
|
def generate_peptide_for_single_sequence(protein_seq, peptide_length = 15, top_k = 3, num_binders = 4): |
|
|
|
peptide_length = int(peptide_length) |
|
top_k = int(top_k) |
|
num_binders = int(num_binders) |
|
|
|
binders_with_ppl = [] |
|
|
|
for _ in range(num_binders): |
|
# Generate binder |
|
masked_peptide = '<mask>' * peptide_length |
|
input_sequence = protein_seq + masked_peptide |
|
inputs = tokenizer(input_sequence, return_tensors="pt").to(model.device) |
|
|
|
with torch.no_grad(): |
|
logits = model(**inputs).logits |
|
mask_token_indices = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)[1] |
|
logits_at_masks = logits[0, mask_token_indices] |
|
|
|
# Apply top-k sampling |
|
top_k_logits, top_k_indices = logits_at_masks.topk(top_k, dim=-1) |
|
probabilities = torch.nn.functional.softmax(top_k_logits, dim=-1) |
|
predicted_indices = Categorical(probabilities).sample() |
|
predicted_token_ids = top_k_indices.gather(-1, predicted_indices.unsqueeze(-1)).squeeze(-1) |
|
|
|
generated_binder = tokenizer.decode(predicted_token_ids, skip_special_tokens=True).replace(' ', '') |
|
|
|
# Compute PPL for the generated binder |
|
ppl_value = compute_pseudo_perplexity(model, tokenizer, protein_seq, generated_binder) |
|
|
|
# Add the generated binder and its PPL to the results list |
|
binders_with_ppl.append([generated_binder, ppl_value]) |
|
|
|
return binders_with_ppl |
|
|
|
def generate_peptide(input_seqs, peptide_length=15, top_k=3, num_binders=4): |
|
if isinstance(input_seqs, str): # Single sequence |
|
binders = generate_peptide_for_single_sequence(input_seqs, peptide_length, top_k, num_binders) |
|
return pd.DataFrame(binders, columns=['Binder', 'Pseudo Perplexity']) |
|
|
|
elif isinstance(input_seqs, list): # List of sequences |
|
results = [] |
|
for seq in input_seqs: |
|
binders = generate_peptide_for_single_sequence(seq, peptide_length, top_k, num_binders) |
|
for binder, ppl in binders: |
|
results.append([seq, binder, ppl]) |
|
return pd.DataFrame(results, columns=['Input Sequence', 'Binder', 'Pseudo Perplexity']) |
|
|
|
model = EsmForMaskedLM.from_pretrained("AmelieSchreiber/esm_interact") |
|
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t30_150M_UR50D") |
|
|
|
protein_seq = "MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKVLPPPVRRIIGDLSNREKVLIGLDLLYEEIGDQAEDDLGLE" |
|
|
|
results_df = generate_peptide(protein_seq, peptide_length=15, top_k=3, num_binders=5) |
|
print(results_df) |
|
``` |