Spaces:
Sleeping
Sleeping
from huggingface_hub import HfApi, ModelFilter | |
import torch | |
from transformers import AutoTokenizer, AutoModelForMaskedLM | |
from transformers.tokenization_utils_base import BatchEncoding | |
from transformers.modeling_outputs import MaskedLMOutput | |
# Function to fetch suitable ESM models from HuggingFace Hub | |
def get_models() -> list[None|str]: | |
"""Fetch suitable ESM models from HuggingFace Hub.""" | |
if not any( | |
out := [ | |
m.modelId for m in HfApi().list_models( | |
filter=ModelFilter( | |
author="facebook", model_name="esm", task="fill-mask" | |
), | |
sort="lastModified", | |
direction=-1 | |
) | |
] | |
): | |
raise RuntimeError("Error while retrieving models from HuggingFace Hub") | |
return out | |
# Class to wrap ESM models | |
class Model: | |
"""Wrapper for ESM models.""" | |
def __init__(self, model_name: str = ""): | |
"""Load selected model and tokenizer.""" | |
self.model_name = model_name | |
if model_name: | |
self.model = AutoModelForMaskedLM.from_pretrained(model_name) | |
self.batch_converter = AutoTokenizer.from_pretrained(model_name) | |
self.alphabet = self.batch_converter.get_vocab() | |
# Check if CUDA is available and if so, use it | |
if torch.cuda.is_available(): | |
self.model = self.model.cuda() | |
def tokenise(self, input: str) -> BatchEncoding: | |
"""Convert input string to batch of tokens.""" | |
return self.batch_converter(input, return_tensors="pt") | |
def __call__(self, batch_tokens: torch.Tensor, **kwargs) -> MaskedLMOutput: | |
"""Run model on batch of tokens.""" | |
return self.model(batch_tokens, **kwargs) | |
def __getitem__(self, key: str) -> int: | |
"""Get token ID from character.""" | |
return self.alphabet[key] | |
def run_model(self, data): | |
"""Run model on data.""" | |
def label_row(row, token_probs): | |
"""Label row with score.""" | |
# Extract wild type, index and mutant type from the row | |
wt, idx, mt = row[0], int(row[1:-1])-1, row[-1] | |
# Calculate the score as the difference between the token probabilities of the mutant type and the wild type | |
score = token_probs[0, 1+idx, self[mt]] - token_probs[0, 1+idx, self[wt]] | |
return score.item() | |
# Tokenise the sequence data | |
batch_tokens = self.tokenise(data.seq).input_ids | |
# Calculate the token probabilities without updating the model parameters | |
with torch.no_grad(): | |
token_probs = torch.log_softmax(self(batch_tokens).logits, dim=-1) | |
# Store the token probabilities in the data | |
data.token_probs = token_probs.cpu().numpy() | |
# If the scoring strategy starts with "masked-marginals" | |
if data.scoring_strategy.startswith("masked-marginals"): | |
all_token_probs = [] | |
# For each token in the batch | |
for i in range(batch_tokens.size()[1]): | |
# If the token is in the list of residues | |
if i in data.resi: | |
# Clone the batch tokens and mask the current token | |
batch_tokens_masked = batch_tokens.clone() | |
batch_tokens_masked[0, i] = self['<mask>'] | |
# Calculate the masked token probabilities | |
with torch.no_grad(): | |
masked_token_probs = torch.log_softmax( | |
self(batch_tokens_masked).logits, dim=-1 | |
) | |
else: | |
# If the token is not in the list of residues, use the original token probabilities | |
masked_token_probs = token_probs | |
# Append the token probabilities to the list | |
all_token_probs.append(masked_token_probs[:, i]) | |
# Concatenate all token probabilities | |
token_probs = torch.cat(all_token_probs, dim=0).unsqueeze(0) | |
# Apply the label_row function to each row of the substitutions dataframe | |
data.out[self.model_name] = data.sub.apply( | |
lambda row: label_row( | |
row['0'], | |
token_probs, | |
), | |
axis=1, | |
) | |