import json, time, os, sys, glob
import gradio as gr
sys.path.append("/home/user/app/ProteinMPNN/vanilla_proteinmpnn")
import matplotlib.pyplot as plt
import shutil
import warnings
import numpy as np
import torch
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split, Subset
import copy
import torch.nn as nn
import torch.nn.functional as F
import random
import os
import os.path
from protein_mpnn_utils import (
loss_nll,
loss_smoothed,
gather_edges,
gather_nodes,
gather_nodes_t,
cat_neighbors_nodes,
_scores,
_S_to_seq,
tied_featurize,
parse_PDB,
)
from protein_mpnn_utils import StructureDataset, StructureDatasetPDB, ProteinMPNN
import plotly.express as px
import urllib
import jax.numpy as jnp
import tensorflow as tf
if "/home/user/app/af_backprop" not in sys.path:
sys.path.append("/home/user/app/af_backprop")
from utils import *
# import libraries
import colabfold as cf
from alphafold.common import protein
from alphafold.data import pipeline
from alphafold.model import data, config, model
from alphafold.common import residue_constants
import plotly.graph_objects as go
import ray
import re
import numpy as np
import jax
tf.config.set_visible_devices([], "GPU")
def chain_break(idx_res, Ls, length=200):
# Minkyung's code
# add big enough number to residue index to indicate chain breaks
L_prev = 0
for L_i in Ls[:-1]:
idx_res[L_prev + L_i :] += length
L_prev += L_i
return idx_res
def setup_model(seq, model_name="model_1_ptm"):
# setup model
cfg = config.model_config("model_1_ptm")
cfg.model.num_recycle = 0
cfg.data.common.num_recycle = 0
cfg.data.eval.max_msa_clusters = 1
cfg.data.common.max_extra_msa = 1
cfg.data.eval.masked_msa_replace_fraction = 0
cfg.model.global_config.subbatch_size = None
model_params = data.get_model_haiku_params(model_name=model_name, data_dir=".")
model_runner = model.RunModel(cfg, model_params, is_training=False)
Ls = [len(s) for s in seq.split("/")]
seq = re.sub("[^A-Z]", "", seq.upper())
length = len(seq)
feature_dict = {
**pipeline.make_sequence_features(
sequence=seq, description="none", num_res=length
),
**pipeline.make_msa_features(msas=[[seq]], deletion_matrices=[[[0] * length]]),
}
feature_dict["residue_index"] = chain_break(feature_dict["residue_index"], Ls)
inputs = model_runner.process_features(feature_dict, random_seed=0)
def runner(seq, opt):
# update sequence
inputs = opt["inputs"]
inputs.update(opt["prev"])
update_seq(seq, inputs)
update_aatype(inputs["target_feat"][..., 1:], inputs)
# mask prediction
mask = seq.sum(-1)
inputs["seq_mask"] = inputs["seq_mask"].at[:].set(mask)
inputs["msa_mask"] = inputs["msa_mask"].at[:].set(mask)
inputs["residue_index"] = jnp.where(mask == 1, inputs["residue_index"], 0)
# get prediction
key = jax.random.PRNGKey(0)
outputs = model_runner.apply(opt["params"], key, inputs)
prev = {
"init_msa_first_row": outputs["representations"]["msa_first_row"][None],
"init_pair": outputs["representations"]["pair"][None],
"init_pos": outputs["structure_module"]["final_atom_positions"][None],
}
aux = {
"final_atom_positions": outputs["structure_module"]["final_atom_positions"],
"final_atom_mask": outputs["structure_module"]["final_atom_mask"],
"plddt": get_plddt(outputs),
"pae": get_pae(outputs),
"inputs": inputs,
"prev": prev,
}
return aux
return jax.jit(runner), {"inputs": inputs, "params": model_params}
def make_tied_positions_for_homomers(pdb_dict_list):
my_dict = {}
for result in pdb_dict_list:
all_chain_list = sorted(
[item[-1:] for item in list(result) if item[:9] == "seq_chain"]
) # A, B, C, ...
tied_positions_list = []
chain_length = len(result[f"seq_chain_{all_chain_list[0]}"])
for i in range(1, chain_length + 1):
temp_dict = {}
for j, chain in enumerate(all_chain_list):
temp_dict[chain] = [i] # needs to be a list
tied_positions_list.append(temp_dict)
my_dict[result["name"]] = tied_positions_list
return my_dict
def renumber(struc):
"""Renumber residues consecutively and remove all hetero residues"""
resid = 0
residue_to_remove = []
chain_to_remove = []
for model in struc:
for chain in model:
for i, residue in enumerate(chain.get_residues()):
res_id = list(residue.id)
res_id[1] = resid
resid += 1
residue.id = tuple(res_id)
if residue.id[0] != " ":
residue_to_remove.append((chain.id, residue.id))
if len(chain) == 0:
chain_to_remove.append(chain.id)
for residue in residue_to_remove:
struc[0][residue[0]].detach_child(residue[1])
for chain in chain_to_remove:
model.detach_child(chain)
return struc
def align_structures(pdb1, pdb2, lenRes):
"""Take two structure and superimpose pdb1 on pdb2"""
import Bio.PDB
# We use all residues
atoms_to_be_aligned = range(0, lenRes)
pdb_parser = Bio.PDB.PDBParser(QUIET=True)
# Get the structures
ref_structure = pdb_parser.get_structure("samle", pdb2)
sample_structure = renumber(pdb_parser.get_structure("reference", pdb1))
# Use the first model in the pdb-files for alignment
ref_model = ref_structure[0]
sample_model = sample_structure[0]
# Make a list of the atoms (in the structures) to align.
ref_atoms = []
sample_atoms = []
# Iterate of all chains in the model in order to find all residues
for ref_chain in ref_model:
# Iterate of all residues in each model in order to find proper atoms
for ref_res in ref_chain:
# Check if residue number ( .get_id() ) is in the list
if ref_res.get_id()[1] in atoms_to_be_aligned:
# Append CA atom to list
ref_atoms.append(ref_res["CA"])
for sample_chain in sample_model:
for sample_res in sample_chain:
if sample_res.get_id()[1] in atoms_to_be_aligned:
sample_atoms.append(sample_res["CA"])
# Now we initiate the superimposer:
super_imposer = Bio.PDB.Superimposer()
super_imposer.set_atoms(ref_atoms, sample_atoms)
super_imposer.apply(sample_model.get_atoms())
io = Bio.PDB.PDBIO()
io.set_structure(sample_structure)
io.save(f"{pdb1}_aligned.pdb")
return super_imposer.rms, f"{pdb1}_aligned.pdb"
def save_pdb(outs, filename, LEN):
"""save pdb coordinates"""
p = {
"residue_index": outs["inputs"]["residue_index"][0][:LEN],
"aatype": outs["inputs"]["aatype"].argmax(-1)[0][:LEN],
"atom_positions": outs["final_atom_positions"][:LEN],
"atom_mask": outs["final_atom_mask"][:LEN],
}
b_factors = 100.0 * outs["plddt"][:LEN, None] * p["atom_mask"]
p = protein.Protein(**p, b_factors=b_factors)
pdb_lines = protein.to_pdb(p)
with open(filename, "w") as f:
f.write(pdb_lines)
print(os.listdir(), os.getcwd())
@ray.remote(num_gpus=1, max_calls=1)
def run_alphafold(sequence):
recycles = 3
RUNNER, OPT = setup_model(sequence)
SEQ = re.sub("[^A-Z]", "", sequence.upper())
MAX_LEN = len(SEQ)
LEN = len(SEQ)
x = np.array([residue_constants.restype_order.get(aa, -1) for aa in SEQ])
x = np.pad(x, [0, MAX_LEN - LEN], constant_values=-1)
x = jax.nn.one_hot(x, 20)
OPT["prev"] = {
"init_msa_first_row": np.zeros([1, MAX_LEN, 256]),
"init_pair": np.zeros([1, MAX_LEN, MAX_LEN, 128]),
"init_pos": np.zeros([1, MAX_LEN, 37, 3]),
}
positions = []
plddts = []
for r in range(recycles + 1):
outs = RUNNER(x, OPT)
outs = jax.tree_map(lambda x: np.asarray(x), outs)
positions.append(outs["prev"]["init_pos"][0, :LEN])
plddts.append(outs["plddt"][:LEN])
OPT["prev"] = outs["prev"]
if recycles > 0:
print(r, plddts[-1].mean())
save_pdb(outs, "out.pdb", LEN)
num_res = int(outs["inputs"]["aatype"][0].sum())
return outs["plddt"], outs["pae"], num_res
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
model_name = "v_48_020" # ProteinMPNN model name: v_48_002, v_48_010, v_48_020, v_48_030, v_32_002, v_32_010; v_32_020, v_32_030; v_48_010=version with 48 edges 0.10A noise
backbone_noise = 0.00 # Standard deviation of Gaussian noise to add to backbone atoms
path_to_model_weights = (
"/home/user/app/ProteinMPNN/vanilla_proteinmpnn/vanilla_model_weights"
)
hidden_dim = 128
num_layers = 3
model_folder_path = path_to_model_weights
if model_folder_path[-1] != "/":
model_folder_path = model_folder_path + "/"
checkpoint_path = model_folder_path + f"{model_name}.pt"
checkpoint = torch.load(checkpoint_path, map_location=device)
noise_level_print = checkpoint["noise_level"]
model = ProteinMPNN(
num_letters=21,
node_features=hidden_dim,
edge_features=hidden_dim,
hidden_dim=hidden_dim,
num_encoder_layers=num_layers,
num_decoder_layers=num_layers,
augment_eps=backbone_noise,
k_neighbors=checkpoint["num_edges"],
)
model.to(device)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
def get_pdb(pdb_code="", filepath=""):
if pdb_code is None or pdb_code == "":
return filepath.name
else:
os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb")
return f"{pdb_code}.pdb"
def update(inp, file, designed_chain, fixed_chain, homomer, num_seqs, sampling_temp):
pdb_path = get_pdb(pdb_code=inp, filepath=file)
if designed_chain == "":
designed_chain_list = []
else:
designed_chain_list = re.sub("[^A-Za-z]+", ",", designed_chain).split(",")
if fixed_chain == "":
fixed_chain_list = []
else:
fixed_chain_list = re.sub("[^A-Za-z]+", ",", fixed_chain).split(",")
chain_list = list(set(designed_chain_list + fixed_chain_list))
num_seq_per_target = num_seqs
save_score = 0 # 0 for False, 1 for True; save score=-log_prob to npy files
save_probs = (
0 # 0 for False, 1 for True; save MPNN predicted probabilites per position
)
score_only = 0 # 0 for False, 1 for True; score input backbone-sequence pairs
conditional_probs_only = 0 # 0 for False, 1 for True; output conditional probabilities p(s_i given the rest of the sequence and backbone)
conditional_probs_only_backbone = 0 # 0 for False, 1 for True; if true output conditional probabilities p(s_i given backbone)
batch_size = 1 # Batch size; can set higher for titan, quadro GPUs, reduce this if running out of GPU memory
max_length = 20000 # Max sequence length
out_folder = "." # Path to a folder to output sequences, e.g. /home/out/
jsonl_path = "" # Path to a folder with parsed pdb into jsonl
omit_AAs = "X" # Specify which amino acids should be omitted in the generated sequence, e.g. 'AC' would omit alanine and cystine.
pssm_multi = 0.0 # A value between [0.0, 1.0], 0.0 means do not use pssm, 1.0 ignore MPNN predictions
pssm_threshold = 0.0 # A value between -inf + inf to restric per position AAs
pssm_log_odds_flag = 0 # 0 for False, 1 for True
pssm_bias_flag = 0 # 0 for False, 1 for True
folder_for_outputs = out_folder
NUM_BATCHES = num_seq_per_target // batch_size
BATCH_COPIES = batch_size
temperatures = [sampling_temp]
omit_AAs_list = omit_AAs
alphabet = "ACDEFGHIKLMNPQRSTVWYX"
omit_AAs_np = np.array([AA in omit_AAs_list for AA in alphabet]).astype(np.float32)
chain_id_dict = None
fixed_positions_dict = None
pssm_dict = None
omit_AA_dict = None
bias_AA_dict = None
bias_by_res_dict = None
bias_AAs_np = np.zeros(len(alphabet))
###############################################################
pdb_dict_list = parse_PDB(pdb_path, input_chain_list=chain_list)
dataset_valid = StructureDatasetPDB(
pdb_dict_list, truncate=None, max_length=max_length
)
if homomer:
tied_positions_dict = make_tied_positions_for_homomers(pdb_dict_list)
else:
tied_positions_dict = None
chain_id_dict = {}
chain_id_dict[pdb_dict_list[0]["name"]] = (designed_chain_list, fixed_chain_list)
with torch.no_grad():
for ix, protein in enumerate(dataset_valid):
score_list = []
all_probs_list = []
all_log_probs_list = []
S_sample_list = []
batch_clones = [copy.deepcopy(protein) for i in range(BATCH_COPIES)]
(
X,
S,
mask,
lengths,
chain_M,
chain_encoding_all,
chain_list_list,
visible_list_list,
masked_list_list,
masked_chain_length_list_list,
chain_M_pos,
omit_AA_mask,
residue_idx,
dihedral_mask,
tied_pos_list_of_lists_list,
pssm_coef,
pssm_bias,
pssm_log_odds_all,
bias_by_res_all,
tied_beta,
) = tied_featurize(
batch_clones,
device,
chain_id_dict,
fixed_positions_dict,
omit_AA_dict,
tied_positions_dict,
pssm_dict,
bias_by_res_dict,
)
pssm_log_odds_mask = (
pssm_log_odds_all > pssm_threshold
).float() # 1.0 for true, 0.0 for false
name_ = batch_clones[0]["name"]
randn_1 = torch.randn(chain_M.shape, device=X.device)
log_probs = model(
X,
S,
mask,
chain_M * chain_M_pos,
residue_idx,
chain_encoding_all,
randn_1,
)
mask_for_loss = mask * chain_M * chain_M_pos
scores = _scores(S, log_probs, mask_for_loss)
native_score = scores.cpu().data.numpy()
message = ""
for temp in temperatures:
for j in range(NUM_BATCHES):
randn_2 = torch.randn(chain_M.shape, device=X.device)
if tied_positions_dict == None:
sample_dict = model.sample(
X,
randn_2,
S,
chain_M,
chain_encoding_all,
residue_idx,
mask=mask,
temperature=temp,
omit_AAs_np=omit_AAs_np,
bias_AAs_np=bias_AAs_np,
chain_M_pos=chain_M_pos,
omit_AA_mask=omit_AA_mask,
pssm_coef=pssm_coef,
pssm_bias=pssm_bias,
pssm_multi=pssm_multi,
pssm_log_odds_flag=bool(pssm_log_odds_flag),
pssm_log_odds_mask=pssm_log_odds_mask,
pssm_bias_flag=bool(pssm_bias_flag),
bias_by_res=bias_by_res_all,
)
S_sample = sample_dict["S"]
else:
sample_dict = model.tied_sample(
X,
randn_2,
S,
chain_M,
chain_encoding_all,
residue_idx,
mask=mask,
temperature=temp,
omit_AAs_np=omit_AAs_np,
bias_AAs_np=bias_AAs_np,
chain_M_pos=chain_M_pos,
omit_AA_mask=omit_AA_mask,
pssm_coef=pssm_coef,
pssm_bias=pssm_bias,
pssm_multi=pssm_multi,
pssm_log_odds_flag=bool(pssm_log_odds_flag),
pssm_log_odds_mask=pssm_log_odds_mask,
pssm_bias_flag=bool(pssm_bias_flag),
tied_pos=tied_pos_list_of_lists_list[0],
tied_beta=tied_beta,
bias_by_res=bias_by_res_all,
)
# Compute scores
S_sample = sample_dict["S"]
log_probs = model(
X,
S_sample,
mask,
chain_M * chain_M_pos,
residue_idx,
chain_encoding_all,
randn_2,
use_input_decoding_order=True,
decoding_order=sample_dict["decoding_order"],
)
mask_for_loss = mask * chain_M * chain_M_pos
scores = _scores(S_sample, log_probs, mask_for_loss)
scores = scores.cpu().data.numpy()
all_probs_list.append(sample_dict["probs"].cpu().data.numpy())
all_log_probs_list.append(log_probs.cpu().data.numpy())
S_sample_list.append(S_sample.cpu().data.numpy())
for b_ix in range(BATCH_COPIES):
masked_chain_length_list = masked_chain_length_list_list[b_ix]
masked_list = masked_list_list[b_ix]
seq_recovery_rate = torch.sum(
torch.sum(
torch.nn.functional.one_hot(S[b_ix], 21)
* torch.nn.functional.one_hot(S_sample[b_ix], 21),
axis=-1,
)
* mask_for_loss[b_ix]
) / torch.sum(mask_for_loss[b_ix])
seq = _S_to_seq(S_sample[b_ix], chain_M[b_ix])
score = scores[b_ix]
score_list.append(score)
native_seq = _S_to_seq(S[b_ix], chain_M[b_ix])
if b_ix == 0 and j == 0 and temp == temperatures[0]:
start = 0
end = 0
list_of_AAs = []
for mask_l in masked_chain_length_list:
end += mask_l
list_of_AAs.append(native_seq[start:end])
start = end
native_seq = "".join(
list(np.array(list_of_AAs)[np.argsort(masked_list)])
)
l0 = 0
for mc_length in list(
np.array(masked_chain_length_list)[
np.argsort(masked_list)
]
)[:-1]:
l0 += mc_length
native_seq = native_seq[:l0] + "/" + native_seq[l0:]
l0 += 1
sorted_masked_chain_letters = np.argsort(
masked_list_list[0]
)
print_masked_chains = [
masked_list_list[0][i]
for i in sorted_masked_chain_letters
]
sorted_visible_chain_letters = np.argsort(
visible_list_list[0]
)
print_visible_chains = [
visible_list_list[0][i]
for i in sorted_visible_chain_letters
]
native_score_print = np.format_float_positional(
np.float32(native_score.mean()),
unique=False,
precision=4,
)
line = ">{}, score={}, fixed_chains={}, designed_chains={}, model_name={}\n{}\n".format(
name_,
native_score_print,
print_visible_chains,
print_masked_chains,
model_name,
native_seq,
)
message += f"{line}\n"
start = 0
end = 0
list_of_AAs = []
for mask_l in masked_chain_length_list:
end += mask_l
list_of_AAs.append(seq[start:end])
start = end
seq = "".join(
list(np.array(list_of_AAs)[np.argsort(masked_list)])
)
l0 = 0
for mc_length in list(
np.array(masked_chain_length_list)[np.argsort(masked_list)]
)[:-1]:
l0 += mc_length
seq = seq[:l0] + "/" + seq[l0:]
l0 += 1
score_print = np.format_float_positional(
np.float32(score), unique=False, precision=4
)
seq_rec_print = np.format_float_positional(
np.float32(seq_recovery_rate.detach().cpu().numpy()),
unique=False,
precision=4,
)
line = (
">T={}, sample={}, score={}, seq_recovery={}\n{}\n".format(
temp, b_ix, score_print, seq_rec_print, seq
)
)
message += f"{line}\n"
all_probs_concat = np.concatenate(all_probs_list)
all_log_probs_concat = np.concatenate(all_log_probs_list)
np.savetxt("all_probs_concat.csv", all_probs_concat.mean(0).T, delimiter=",")
np.savetxt(
"all_log_probs_concat.csv",
np.exp(all_log_probs_concat).mean(0).T,
delimiter=",",
)
S_sample_concat = np.concatenate(S_sample_list)
fig = px.imshow(
np.exp(all_log_probs_concat).mean(0).T,
labels=dict(x="positions", y="amino acids", color="probability"),
y=list(alphabet),
template="simple_white",
)
fig.update_xaxes(side="top")
fig_tadjusted = px.imshow(
all_probs_concat.mean(0).T,
labels=dict(x="positions", y="amino acids", color="probability"),
y=list(alphabet),
template="simple_white",
)
fig_tadjusted.update_xaxes(side="top")
return (
message,
fig,
fig_tadjusted,
gr.File.update(value="all_log_probs_concat.csv", visible=True),
gr.File.update(value="all_probs_concat.csv", visible=True),
pdb_path,
)
def update_AF(startsequence, pdb):
# # run alphafold using ray
plddts, pae, num_res = ray.get(run_alphafold.remote(startsequence))
x = np.arange(10)
plotAF_plddt = go.Figure(
data=go.Scatter(
x=np.arange(len(plddts)),
y=plddts,
hovertemplate="pLDDT: %{y:.2f}
Residue index: %{x}",
)
)
plotAF_plddt.update_layout(
title="pLDDT",
xaxis_title="Residue index",
yaxis_title="pLDDT",
height=500,
template="simple_white",
)
plotAF_pae = px.imshow(
pae,
labels=dict(x="Scored residue", y="Aligned residue", color=""),
)
plotAF_pae.update_layout(title="Predicted Aligned Error", template="simple_white")
return molecule(pdb, "af_backprop/out.pdb", num_res), plotAF_plddt, plotAF_pae
def read_mol(molpath):
with open(molpath, "r") as fp:
lines = fp.readlines()
mol = ""
for l in lines:
mol += l
return mol
def molecule(pdb, afpdb, num_res):
rms, aligned_pdb = align_structures(pdb, afpdb, num_res)
mol = read_mol(pdb)
pred_mol = read_mol(aligned_pdb)
x = (
"""