Spaces:
Sleeping
Sleeping
File size: 4,282 Bytes
b212cb1 fba8f5e b212cb1 fba8f5e b212cb1 fba8f5e b212cb1 fba8f5e b212cb1 fba8f5e b212cb1 fba8f5e b212cb1 fba8f5e b212cb1 fba8f5e b212cb1 fba8f5e b212cb1 fba8f5e b212cb1 fba8f5e b212cb1 fba8f5e b212cb1 fba8f5e b212cb1 fba8f5e b212cb1 fba8f5e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
from huggingface_hub import HfApi, ModelFilter
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
from transformers.tokenization_utils_base import BatchEncoding
from transformers.modeling_outputs import MaskedLMOutput
# Function to fetch suitable ESM models from HuggingFace Hub
def get_models() -> list[None|str]:
"""Fetch suitable ESM models from HuggingFace Hub."""
if not any(
out := [
m.modelId for m in HfApi().list_models(
filter=ModelFilter(
author="facebook", model_name="esm", task="fill-mask"
),
sort="lastModified",
direction=-1
)
]
):
raise RuntimeError("Error while retrieving models from HuggingFace Hub")
return out
# Class to wrap ESM models
class Model:
"""Wrapper for ESM models."""
def __init__(self, model_name: str = ""):
"""Load selected model and tokenizer."""
self.model_name = model_name
if model_name:
self.model = AutoModelForMaskedLM.from_pretrained(model_name)
self.batch_converter = AutoTokenizer.from_pretrained(model_name)
self.alphabet = self.batch_converter.get_vocab()
# Check if CUDA is available and if so, use it
if torch.cuda.is_available():
self.model = self.model.cuda()
def tokenise(self, input: str) -> BatchEncoding:
"""Convert input string to batch of tokens."""
return self.batch_converter(input, return_tensors="pt")
def __call__(self, batch_tokens: torch.Tensor, **kwargs) -> MaskedLMOutput:
"""Run model on batch of tokens."""
return self.model(batch_tokens, **kwargs)
def __getitem__(self, key: str) -> int:
"""Get token ID from character."""
return self.alphabet[key]
def run_model(self, data):
"""Run model on data."""
def label_row(row, token_probs):
"""Label row with score."""
# Extract wild type, index and mutant type from the row
wt, idx, mt = row[0], int(row[1:-1])-1, row[-1]
# Calculate the score as the difference between the token probabilities of the mutant type and the wild type
score = token_probs[0, 1+idx, self[mt]] - token_probs[0, 1+idx, self[wt]]
return score.item()
# Tokenise the sequence data
batch_tokens = self.tokenise(data.seq).input_ids
# Calculate the token probabilities without updating the model parameters
with torch.no_grad():
token_probs = torch.log_softmax(self(batch_tokens).logits, dim=-1)
# Store the token probabilities in the data
data.token_probs = token_probs.cpu().numpy()
# If the scoring strategy starts with "masked-marginals"
if data.scoring_strategy.startswith("masked-marginals"):
all_token_probs = []
# For each token in the batch
for i in range(batch_tokens.size()[1]):
# If the token is in the list of residues
if i in data.resi:
# Clone the batch tokens and mask the current token
batch_tokens_masked = batch_tokens.clone()
batch_tokens_masked[0, i] = self['<mask>']
# Calculate the masked token probabilities
with torch.no_grad():
masked_token_probs = torch.log_softmax(
self(batch_tokens_masked).logits, dim=-1
)
else:
# If the token is not in the list of residues, use the original token probabilities
masked_token_probs = token_probs
# Append the token probabilities to the list
all_token_probs.append(masked_token_probs[:, i])
# Concatenate all token probabilities
token_probs = torch.cat(all_token_probs, dim=0).unsqueeze(0)
# Apply the label_row function to each row of the substitutions dataframe
data.out[self.model_name] = data.sub.apply(
lambda row: label_row(
row['0'],
token_probs,
),
axis=1,
)
|