from huggingface_hub import HfApi, ModelFilter import torch from transformers import AutoTokenizer, AutoModelForMaskedLM from transformers.tokenization_utils_base import BatchEncoding from transformers.modeling_outputs import MaskedLMOutput # Function to fetch suitable ESM models from HuggingFace Hub def get_models() -> list[None|str]: """Fetch suitable ESM models from HuggingFace Hub.""" if not any( out := [ m.modelId for m in HfApi().list_models( filter=ModelFilter( author="facebook", model_name="esm", task="fill-mask" ), sort="lastModified", direction=-1 ) ] ): raise RuntimeError("Error while retrieving models from HuggingFace Hub") return out # Class to wrap ESM models class Model: """Wrapper for ESM models.""" def __init__(self, model_name: str = ""): """Load selected model and tokenizer.""" self.model_name = model_name if model_name: self.model = AutoModelForMaskedLM.from_pretrained(model_name) self.batch_converter = AutoTokenizer.from_pretrained(model_name) self.alphabet = self.batch_converter.get_vocab() # Check if CUDA is available and if so, use it if torch.cuda.is_available(): self.model = self.model.cuda() def tokenise(self, input: str) -> BatchEncoding: """Convert input string to batch of tokens.""" return self.batch_converter(input, return_tensors="pt") def __call__(self, batch_tokens: torch.Tensor, **kwargs) -> MaskedLMOutput: """Run model on batch of tokens.""" return self.model(batch_tokens, **kwargs) def __getitem__(self, key: str) -> int: """Get token ID from character.""" return self.alphabet[key] def run_model(self, data): """Run model on data.""" def label_row(row, token_probs): """Label row with score.""" # Extract wild type, index and mutant type from the row wt, idx, mt = row[0], int(row[1:-1])-1, row[-1] # Calculate the score as the difference between the token probabilities of the mutant type and the wild type score = token_probs[0, 1+idx, self[mt]] - token_probs[0, 1+idx, self[wt]] return score.item() # Tokenise the sequence data batch_tokens = self.tokenise(data.seq).input_ids # Calculate the token probabilities without updating the model parameters with torch.no_grad(): token_probs = torch.log_softmax(self(batch_tokens).logits, dim=-1) # Store the token probabilities in the data data.token_probs = token_probs.cpu().numpy() # If the scoring strategy starts with "masked-marginals" if data.scoring_strategy.startswith("masked-marginals"): all_token_probs = [] # For each token in the batch for i in range(batch_tokens.size()[1]): # If the token is in the list of residues if i in data.resi: # Clone the batch tokens and mask the current token batch_tokens_masked = batch_tokens.clone() batch_tokens_masked[0, i] = self[''] # Calculate the masked token probabilities with torch.no_grad(): masked_token_probs = torch.log_softmax( self(batch_tokens_masked).logits, dim=-1 ) else: # If the token is not in the list of residues, use the original token probabilities masked_token_probs = token_probs # Append the token probabilities to the list all_token_probs.append(masked_token_probs[:, i]) # Concatenate all token probabilities token_probs = torch.cat(all_token_probs, dim=0).unsqueeze(0) # Apply the label_row function to each row of the substitutions dataframe data.out[self.model_name] = data.sub.apply( lambda row: label_row( row['0'], token_probs, ), axis=1, )