AmelieSchreiber commited on
Commit
c4d5de5
1 Parent(s): 82214f5

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +100 -0
README.md CHANGED
@@ -1,3 +1,103 @@
1
  ---
2
  license: mit
 
 
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: mit
3
+ datasets:
4
+ - AmelieSchreiber/interaction_pairs
5
+ language:
6
+ - en
7
+ library_name: transformers
8
+ tags:
9
+ - ESM-2
10
+ - biology
11
+ - protein language model
12
  ---
13
+
14
+ # ESM-2 for Interacting Proteins
15
+
16
+ This model was finetuned on concatenated pairs of interacting proteins in much the same way as [PepMLM](https://huggingface.co/spaces/TianlaiChen/PepMLM).
17
+ It is meant to generate an interaction partners for proteins using the masked language modeling capabilities of ESM-2. The model is not
18
+ well tested, so use with caution. This is just a preliminary experiment.
19
+
20
+ ## Using the Model
21
+
22
+ To use the model, try running:
23
+
24
+ ```python
25
+ from transformers import AutoTokenizer, EsmForMaskedLM
26
+ import torch
27
+ import pandas as pd
28
+ import numpy as np
29
+ from torch.distributions import Categorical
30
+
31
+ def compute_pseudo_perplexity(model, tokenizer, protein_seq, binder_seq):
32
+ sequence = protein_seq + binder_seq
33
+ tensor_input = tokenizer.encode(sequence, return_tensors='pt').to(model.device)
34
+
35
+ # Create a mask for the binder sequence
36
+ binder_mask = torch.zeros(tensor_input.shape).to(model.device)
37
+ binder_mask[0, -len(binder_seq)-1:-1] = 1
38
+
39
+ # Mask the binder sequence in the input and create labels
40
+ masked_input = tensor_input.clone().masked_fill_(binder_mask.bool(), tokenizer.mask_token_id)
41
+ labels = tensor_input.clone().masked_fill_(~binder_mask.bool(), -100)
42
+
43
+ with torch.no_grad():
44
+ loss = model(masked_input, labels=labels).loss
45
+ return np.exp(loss.item())
46
+
47
+
48
+ def generate_peptide_for_single_sequence(protein_seq, peptide_length = 15, top_k = 3, num_binders = 4):
49
+
50
+ peptide_length = int(peptide_length)
51
+ top_k = int(top_k)
52
+ num_binders = int(num_binders)
53
+
54
+ binders_with_ppl = []
55
+
56
+ for _ in range(num_binders):
57
+ # Generate binder
58
+ masked_peptide = '<mask>' * peptide_length
59
+ input_sequence = protein_seq + masked_peptide
60
+ inputs = tokenizer(input_sequence, return_tensors="pt").to(model.device)
61
+
62
+ with torch.no_grad():
63
+ logits = model(**inputs).logits
64
+ mask_token_indices = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)[1]
65
+ logits_at_masks = logits[0, mask_token_indices]
66
+
67
+ # Apply top-k sampling
68
+ top_k_logits, top_k_indices = logits_at_masks.topk(top_k, dim=-1)
69
+ probabilities = torch.nn.functional.softmax(top_k_logits, dim=-1)
70
+ predicted_indices = Categorical(probabilities).sample()
71
+ predicted_token_ids = top_k_indices.gather(-1, predicted_indices.unsqueeze(-1)).squeeze(-1)
72
+
73
+ generated_binder = tokenizer.decode(predicted_token_ids, skip_special_tokens=True).replace(' ', '')
74
+
75
+ # Compute PPL for the generated binder
76
+ ppl_value = compute_pseudo_perplexity(model, tokenizer, protein_seq, generated_binder)
77
+
78
+ # Add the generated binder and its PPL to the results list
79
+ binders_with_ppl.append([generated_binder, ppl_value])
80
+
81
+ return binders_with_ppl
82
+
83
+ def generate_peptide(input_seqs, peptide_length=15, top_k=3, num_binders=4):
84
+ if isinstance(input_seqs, str): # Single sequence
85
+ binders = generate_peptide_for_single_sequence(input_seqs, peptide_length, top_k, num_binders)
86
+ return pd.DataFrame(binders, columns=['Binder', 'Pseudo Perplexity'])
87
+
88
+ elif isinstance(input_seqs, list): # List of sequences
89
+ results = []
90
+ for seq in input_seqs:
91
+ binders = generate_peptide_for_single_sequence(seq, peptide_length, top_k, num_binders)
92
+ for binder, ppl in binders:
93
+ results.append([seq, binder, ppl])
94
+ return pd.DataFrame(results, columns=['Input Sequence', 'Binder', 'Pseudo Perplexity'])
95
+
96
+ model = EsmForMaskedLM.from_pretrained("AmelieSchreiber/esm_interact")
97
+ tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t30_150M_UR50D")
98
+
99
+ protein_seq = "MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKVLPPPVRRIIGDLSNREKVLIGLDLLYEEIGDQAEDDLGLE"
100
+
101
+ results_df = generate_peptide(protein_seq, peptide_length=15, top_k=3, num_binders=5)
102
+ print(results_df)
103
+ ```