succ1 / ProteinMPNN-main /protein_mpnn_run.py
jie1's picture
Update ProteinMPNN-main/protein_mpnn_run.py
72c7703
raw
history blame
31.6 kB
import argparse
import os.path
def p_m_r(ca_only, path_to_model_weights, model_name, seed, save_score, save_probs, score_only, conditional_probs_only,
conditional_probs_only_backbone
, unconditional_probs_only, backbone_noise, num_seq_per_target, batch_size, max_length, sampling_temp, out_folder,
pdb_path, pdb_path_chains, jsonl_path, chain_id_jsonl, fixed_positions_jsonl, omit_AAs, bias_AA_jsonl, bias_by_res_jsonl
, omit_AA_jsonl, pssm_jsonl, pssm_multi, pssm_threshold, pssm_log_odds_flag, pssm_bias_flag, tied_positions_jsonl):
seed = int(seed)
save_score = int(save_score)
save_probs = int(save_probs)
score_only = int(score_only)
conditional_probs_only = int(conditional_probs_only)
conditional_probs_only_backbone = int(conditional_probs_only_backbone)
unconditional_probs_only = int(unconditional_probs_only)
num_seq_per_target = int(num_seq_per_target)
batch_size = int(batch_size)
max_length = int(max_length)
pssm_log_odds_flag = int(pssm_log_odds_flag)
pssm_bias_flag = int(pssm_bias_flag)
import json, time, os, sys, glob
import shutil
import warnings
import numpy as np
import torch
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split, Subset
import copy
import torch.nn as nn
import torch.nn.functional as F
import random
import os.path
import subprocess
from protein_mpnn_utils import loss_nll, loss_smoothed, gather_edges, gather_nodes, gather_nodes_t, \
cat_neighbors_nodes, _scores, _S_to_seq, tied_featurize, parse_PDB
from protein_mpnn_utils import StructureDataset, StructureDatasetPDB, ProteinMPNN
if seed:
seed = seed
else:
seed = int(np.random.randint(0, high=999, size=1, dtype=int)[0])
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
hidden_dim = 128
num_layers = 3
if path_to_model_weights:
model_folder_path = path_to_model_weights
if model_folder_path[-1] != '/':
model_folder_path = model_folder_path + '/'
else:
file_path = os.path.realpath(__file__)
# 改
k = file_path.rfind("/")
if ca_only:
model_folder_path = file_path[:k] + '/ca_model_weights/'
else:
model_folder_path = file_path[:k] + '/vanilla_model_weights/'
checkpoint_path = model_folder_path + f'{model_name}.pt'
folder_for_outputs = out_folder
NUM_BATCHES = num_seq_per_target // batch_size
BATCH_COPIES = batch_size
temperatures = [float(item) for item in sampling_temp.split()]
omit_AAs_list = omit_AAs
alphabet = 'ACDEFGHIKLMNPQRSTVWYX'
omit_AAs_np = np.array([AA in omit_AAs_list for AA in alphabet]).astype(np.float32)
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
# os.path.isfile():判断某一对象(需提供绝对路径)是否为文件
# 改
if chain_id_jsonl:
if os.path.isfile(chain_id_jsonl.name):
with open(chain_id_jsonl.name, 'r') as json_file:
json_list = list(json_file)
for json_str in json_list:
chain_id_dict = json.loads(json_str)
else:
chain_id_dict = None
print(40 * '-')
print('chain_id_jsonl is NOT loaded')
if fixed_positions_jsonl:
if os.path.isfile(fixed_positions_jsonl.name):
with open(fixed_positions_jsonl.name, 'r') as json_file:
json_list = list(json_file)
for json_str in json_list:
fixed_positions_dict = json.loads(json_str)
else:
print(40 * '-')
print('fixed_positions_jsonl is NOT loaded')
fixed_positions_dict = None
if os.path.isfile(pssm_jsonl):
with open(pssm_jsonl, 'r') as json_file:
json_list = list(json_file)
pssm_dict = {}
for json_str in json_list:
pssm_dict.update(json.loads(json_str))
else:
print(40 * '-')
print('pssm_jsonl is NOT loaded')
pssm_dict = None
if os.path.isfile(omit_AA_jsonl):
with open(omit_AA_jsonl, 'r') as json_file:
json_list = list(json_file)
for json_str in json_list:
omit_AA_dict = json.loads(json_str)
else:
print(40 * '-')
print('omit_AA_jsonl is NOT loaded')
omit_AA_dict = None
if bias_AA_jsonl:
if os.path.isfile(bias_AA_jsonl.name):
with open(bias_AA_jsonl.name, 'r') as json_file:
json_list = list(json_file)
for json_str in json_list:
bias_AA_dict = json.loads(json_str)
else:
print(40 * '-')
print('bias_AA_jsonl is NOT loaded')
bias_AA_dict = None
if tied_positions_jsonl:
if os.path.isfile(tied_positions_jsonl.name):
with open(tied_positions_jsonl.name, 'r') as json_file:
json_list = list(json_file)
for json_str in json_list:
tied_positions_dict = json.loads(json_str)
else:
print(40 * '-')
print('tied_positions_jsonl is NOT loaded')
tied_positions_dict = None
if os.path.isfile(bias_by_res_jsonl):
with open(bias_by_res_jsonl, 'r') as json_file:
json_list = list(json_file)
for json_str in json_list:
bias_by_res_dict = json.loads(json_str)
print('bias by residue dictionary is loaded')
else:
print(40 * '-')
print('bias by residue dictionary is not loaded, or not provided')
bias_by_res_dict = None
print(40 * '-')
bias_AAs_np = np.zeros(len(alphabet))
if bias_AA_dict:
for n, AA in enumerate(alphabet):
if AA in list(bias_AA_dict.keys()):
bias_AAs_np[n] = bias_AA_dict[AA]
# 改
if pdb_path:
pdb_dict_list = parse_PDB(pdb_path.name, ca_only=ca_only)
dataset_valid = StructureDatasetPDB(pdb_dict_list, truncate=None, max_length=max_length)
all_chain_list = [item[-1:] for item in list(pdb_dict_list[0]) if item[:9] == 'seq_chain'] # ['A','B', 'C',...]
if pdb_path_chains:
designed_chain_list = [str(item) for item in pdb_path_chains.split()]
else:
designed_chain_list = all_chain_list
fixed_chain_list = [letter for letter in all_chain_list if letter not in designed_chain_list]
chain_id_dict = {}
chain_id_dict[pdb_dict_list[0]['name']] = (designed_chain_list, fixed_chain_list)
else:
dataset_valid = StructureDataset(jsonl_path.name, truncate=None, max_length=max_length)
print(40 * '-')
checkpoint = torch.load(checkpoint_path, map_location=device)
print('Number of edges:', checkpoint['num_edges'])
noise_level_print = checkpoint['noise_level']
print(f'Training noise level: {noise_level_print}A')
model = ProteinMPNN(ca_only=ca_only, num_letters=21, node_features=hidden_dim, edge_features=hidden_dim,
hidden_dim=hidden_dim, num_encoder_layers=num_layers, num_decoder_layers=num_layers,
augment_eps=backbone_noise, k_neighbors=checkpoint['num_edges'])
model.to(device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
# Build paths for experiment
base_folder = folder_for_outputs
if base_folder[-1] != '/':
base_folder = base_folder + '/'
if not os.path.exists(base_folder):
os.makedirs(base_folder)
if not os.path.exists(base_folder + 'seqs'):
os.makedirs(base_folder + 'seqs')
if save_score:
if not os.path.exists(base_folder + 'scores'):
os.makedirs(base_folder + 'scores')
if score_only:
if not os.path.exists(base_folder + 'score_only'):
os.makedirs(base_folder + 'score_only')
if conditional_probs_only:
if not os.path.exists(base_folder + 'conditional_probs_only'):
os.makedirs(base_folder + 'conditional_probs_only')
if unconditional_probs_only:
if not os.path.exists(base_folder + 'unconditional_probs_only'):
os.makedirs(base_folder + 'unconditional_probs_only')
if save_probs:
if not os.path.exists(base_folder + 'probs'):
os.makedirs(base_folder + 'probs')
# Timing
start_time = time.time()
total_residues = 0
protein_list = []
total_step = 0
# Validation epoch
with torch.no_grad():
test_sum, test_weights = 0., 0.
# print('Generating sequences...')
# 改
results = []
# enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,一般用在 for 循环当中。
for ix, protein in enumerate(dataset_valid):
score_list = []
global_score_list = []
all_probs_list = []
all_log_probs_list = []
S_sample_list = []
# deepcopy复制
batch_clones = [copy.deepcopy(protein) for i in range(BATCH_COPIES)]
X, S, mask, lengths, chain_M, chain_encoding_all, chain_list_list, visible_list_list, masked_list_list, masked_chain_length_list_list, chain_M_pos, omit_AA_mask, residue_idx, dihedral_mask, tied_pos_list_of_lists_list, pssm_coef, pssm_bias, pssm_log_odds_all, bias_by_res_all, tied_beta = tied_featurize(
batch_clones, device, chain_id_dict, fixed_positions_dict, omit_AA_dict, tied_positions_dict, pssm_dict,
bias_by_res_dict, ca_only=ca_only)
pssm_log_odds_mask = (pssm_log_odds_all > pssm_threshold).float() # 1.0 for true, 0.0 for false
name_ = batch_clones[0]['name']
if score_only:
structure_sequence_score_file = base_folder + '/score_only/' + batch_clones[0]['name'] + '.npz'
native_score_list = []
global_native_score_list = []
for j in range(NUM_BATCHES):
randn_1 = torch.randn(chain_M.shape, device=X.device)
log_probs = model(X, S, mask, chain_M * chain_M_pos, residue_idx, chain_encoding_all, randn_1)
mask_for_loss = mask * chain_M * chain_M_pos
scores = _scores(S, log_probs, mask_for_loss)
native_score = scores.cpu().data.numpy()
native_score_list.append(native_score)
global_scores = _scores(S, log_probs, mask)
global_native_score = global_scores.cpu().data.numpy()
global_native_score_list.append(global_native_score)
native_score = np.concatenate(native_score_list, 0)
global_native_score = np.concatenate(global_native_score_list, 0)
ns_mean = native_score.mean()
ns_mean_print = np.format_float_positional(np.float32(ns_mean), unique=False, precision=4)
ns_std = native_score.std()
ns_std_print = np.format_float_positional(np.float32(ns_std), unique=False, precision=4)
global_ns_mean = global_native_score.mean()
global_ns_mean_print = np.format_float_positional(np.float32(global_ns_mean), unique=False, precision=4)
global_ns_std = global_native_score.std()
global_ns_std_print = np.format_float_positional(np.float32(global_ns_std), unique=False, precision=4)
ns_sample_size = native_score.shape[0]
np.savez(structure_sequence_score_file, score=native_score, global_score=global_native_score)
print(
f'Score for {name_}, mean: {ns_mean_print}, std: {ns_std_print}, sample size: {ns_sample_size}, Global Score for {name_}, mean: {global_ns_mean_print}, std: {global_ns_std_print}, sample size: {ns_sample_size}')
results.append(structure_sequence_score_file)
elif conditional_probs_only:
print(f'Calculating conditional probabilities for {name_}')
conditional_probs_only_file = base_folder + '/conditional_probs_only/' + batch_clones[0]['name']
log_conditional_probs_list = []
for j in range(NUM_BATCHES):
randn_1 = torch.randn(chain_M.shape, device=X.device)
log_conditional_probs = model.conditional_probs(X, S, mask, chain_M * chain_M_pos, residue_idx,
chain_encoding_all, randn_1,
conditional_probs_only_backbone)
log_conditional_probs_list.append(log_conditional_probs.cpu().numpy())
concat_log_p = np.concatenate(log_conditional_probs_list, 0) # [B, L, 21]
mask_out = (chain_M * chain_M_pos * mask)[0,].cpu().numpy()
np.savez(conditional_probs_only_file, log_p=concat_log_p, S=S[0,].cpu().numpy(),
mask=mask[0,].cpu().numpy(), design_mask=mask_out)
elif unconditional_probs_only:
print(f'Calculating sequence unconditional probabilities for {name_}')
# 改
unconditional_probs_only_file = base_folder + '/unconditional_probs_only/' + batch_clones[0]['name'] + '.npz'
log_unconditional_probs_list = []
for j in range(NUM_BATCHES):
log_unconditional_probs = model.unconditional_probs(X, mask, residue_idx, chain_encoding_all)
log_unconditional_probs_list.append(log_unconditional_probs.cpu().numpy())
concat_log_p = np.concatenate(log_unconditional_probs_list, 0) # [B, L, 21]
mask_out = (chain_M * chain_M_pos * mask)[0,].cpu().numpy()
np.savez(unconditional_probs_only_file, log_p=concat_log_p, S=S[0,].cpu().numpy(),
mask=mask[0,].cpu().numpy(), design_mask=mask_out)
results.append(unconditional_probs_only_file)
else:
randn_1 = torch.randn(chain_M.shape, device=X.device)
log_probs = model(X, S, mask, chain_M * chain_M_pos, residue_idx, chain_encoding_all, randn_1)
mask_for_loss = mask * chain_M * chain_M_pos
scores = _scores(S, log_probs, mask_for_loss) # score only the redesigned part
native_score = scores.cpu().data.numpy()
global_scores = _scores(S, log_probs, mask) # score the whole structure-sequence
global_native_score = global_scores.cpu().data.numpy()
# Generate some sequences
ali_file = base_folder + '/seqs/' + batch_clones[0]['name'] + '.fa'
score_file = base_folder + '/scores/' + batch_clones[0]['name'] + '.npz'
probs_file = base_folder + '/probs/' + batch_clones[0]['name'] + '.npz'
print(f'Generating sequences for: {name_}')
t0 = time.time()
with open(ali_file, 'w') as f:
for temp in temperatures:
for j in range(NUM_BATCHES):
randn_2 = torch.randn(chain_M.shape, device=X.device)
if tied_positions_dict == None:
sample_dict = model.sample(X, randn_2, S, chain_M, chain_encoding_all, residue_idx,
mask=mask, temperature=temp, omit_AAs_np=omit_AAs_np,
bias_AAs_np=bias_AAs_np, chain_M_pos=chain_M_pos,
omit_AA_mask=omit_AA_mask, pssm_coef=pssm_coef,
pssm_bias=pssm_bias, pssm_multi=pssm_multi,
pssm_log_odds_flag=bool(pssm_log_odds_flag),
pssm_log_odds_mask=pssm_log_odds_mask,
pssm_bias_flag=bool(pssm_bias_flag),
bias_by_res=bias_by_res_all)
S_sample = sample_dict["S"]
else:
sample_dict = model.tied_sample(X, randn_2, S, chain_M, chain_encoding_all, residue_idx,
mask=mask, temperature=temp, omit_AAs_np=omit_AAs_np,
bias_AAs_np=bias_AAs_np, chain_M_pos=chain_M_pos,
omit_AA_mask=omit_AA_mask, pssm_coef=pssm_coef,
pssm_bias=pssm_bias, pssm_multi=pssm_multi,
pssm_log_odds_flag=bool(pssm_log_odds_flag),
pssm_log_odds_mask=pssm_log_odds_mask,
pssm_bias_flag=bool(pssm_bias_flag),
tied_pos=tied_pos_list_of_lists_list[0],
tied_beta=tied_beta, bias_by_res=bias_by_res_all)
# Compute scores
S_sample = sample_dict["S"]
log_probs = model(X, S_sample, mask, chain_M * chain_M_pos, residue_idx, chain_encoding_all,
randn_2, use_input_decoding_order=True,
decoding_order=sample_dict["decoding_order"])
mask_for_loss = mask * chain_M * chain_M_pos
scores = _scores(S_sample, log_probs, mask_for_loss)
scores = scores.cpu().data.numpy()
global_scores = _scores(S_sample, log_probs, mask) # score the whole structure-sequence
global_scores = global_scores.cpu().data.numpy()
all_probs_list.append(sample_dict["probs"].cpu().data.numpy())
all_log_probs_list.append(log_probs.cpu().data.numpy())
S_sample_list.append(S_sample.cpu().data.numpy())
for b_ix in range(BATCH_COPIES):
masked_chain_length_list = masked_chain_length_list_list[b_ix]
masked_list = masked_list_list[b_ix]
seq_recovery_rate = torch.sum(torch.sum(
torch.nn.functional.one_hot(S[b_ix], 21) * torch.nn.functional.one_hot(
S_sample[b_ix], 21), axis=-1) * mask_for_loss[b_ix]) / torch.sum(
mask_for_loss[b_ix])
seq = _S_to_seq(S_sample[b_ix], chain_M[b_ix])
score = scores[b_ix]
score_list.append(score)
global_score = global_scores[b_ix]
global_score_list.append(global_score)
native_seq = _S_to_seq(S[b_ix], chain_M[b_ix])
if b_ix == 0 and j == 0 and temp == temperatures[0]:
start = 0
end = 0
list_of_AAs = []
for mask_l in masked_chain_length_list:
end += mask_l
list_of_AAs.append(native_seq[start:end])
start = end
native_seq = "".join(list(np.array(list_of_AAs)[np.argsort(masked_list)]))
l0 = 0
for mc_length in list(np.array(masked_chain_length_list)[np.argsort(masked_list)])[
:-1]:
l0 += mc_length
native_seq = native_seq[:l0] + '/' + native_seq[l0:]
l0 += 1
sorted_masked_chain_letters = np.argsort(masked_list_list[0])
print_masked_chains = [masked_list_list[0][i] for i in sorted_masked_chain_letters]
sorted_visible_chain_letters = np.argsort(visible_list_list[0])
print_visible_chains = [visible_list_list[0][i] for i in
sorted_visible_chain_letters]
native_score_print = np.format_float_positional(np.float32(native_score.mean()),
unique=False, precision=4)
global_native_score_print = np.format_float_positional(
np.float32(global_native_score.mean()), unique=False, precision=4)
script_dir = os.path.dirname(os.path.realpath(__file__))
try:
commit_str = subprocess.check_output(
f'git --git-dir {script_dir}/.git rev-parse HEAD',
shell=True).decode().strip()
except subprocess.CalledProcessError:
commit_str = 'unknown'
if ca_only:
print_model_name = 'CA_model_name'
else:
print_model_name = 'model_name'
f.write(
'>{}, score={}, global_score={}, fixed_chains={}, designed_chains={}, {}={}, git_hash={}, seed={}\n{}\n'.format(
name_, native_score_print, global_native_score_print, print_visible_chains,
print_masked_chains, print_model_name, model_name, commit_str, seed,
native_seq)) # write the native sequence
start = 0
end = 0
list_of_AAs = []
for mask_l in masked_chain_length_list:
end += mask_l
list_of_AAs.append(seq[start:end])
start = end
seq = "".join(list(np.array(list_of_AAs)[np.argsort(masked_list)]))
l0 = 0
for mc_length in list(np.array(masked_chain_length_list)[np.argsort(masked_list)])[:-1]:
l0 += mc_length
seq = seq[:l0] + '/' + seq[l0:]
l0 += 1
score_print = np.format_float_positional(np.float32(score), unique=False, precision=4)
global_score_print = np.format_float_positional(np.float32(global_score), unique=False,
precision=4)
seq_rec_print = np.format_float_positional(
np.float32(seq_recovery_rate.detach().cpu().numpy()), unique=False, precision=4)
sample_number = j * BATCH_COPIES + b_ix + 1
f.write(
'>T={}, sample={}, score={}, global_score={}, seq_recovery={}\n{}\n'.format(temp,
sample_number,
score_print,
global_score_print,
seq_rec_print,
seq)) # write generated sequence
results.append(ali_file)
if save_score:
np.savez(score_file, score=np.array(score_list, np.float32),
global_score=np.array(global_score_list, np.float32))
if save_probs:
all_probs_concat = np.concatenate(all_probs_list)
all_log_probs_concat = np.concatenate(all_log_probs_list)
S_sample_concat = np.concatenate(S_sample_list)
np.savez(probs_file, probs=np.array(all_probs_concat, np.float32),
log_probs=np.array(all_log_probs_concat, np.float32),
S=np.array(S_sample_concat, np.int32), mask=mask_for_loss.cpu().data.numpy(),
chain_order=chain_list_list)
t1 = time.time()
dt = round(float(t1 - t0), 4)
num_seqs = len(temperatures) * NUM_BATCHES * BATCH_COPIES
total_length = X.shape[1]
print(f'{num_seqs} sequences of length {total_length} generated in {dt} seconds')
return results
# if __name__ == "__main__":
# argparser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
#
# argparser.add_argument("--ca_only", action="store_true", default=False,
# help="Parse CA-only structures and use CA-only models (default: false)")
# argparser.add_argument("--path_to_model_weights", type=str, default="", help="Path to model weights folder;")
# argparser.add_argument("--model_name", type=str, default="v_48_020",
# help="ProteinMPNN model name: v_48_002, v_48_010, v_48_020, v_48_030; v_48_010=version with 48 edges 0.10A noise")
#
# argparser.add_argument("--seed", type=int, default=0, help="If set to 0 then a random seed will be picked;")
#
# argparser.add_argument("--save_score", type=int, default=0,
# help="0 for False, 1 for True; save score=-log_prob to npy files")
# argparser.add_argument("--save_probs", type=int, default=0,
# help="0 for False, 1 for True; save MPNN predicted probabilites per position")
#
# argparser.add_argument("--score_only", type=int, default=0,
# help="0 for False, 1 for True; score input backbone-sequence pairs")
#
# argparser.add_argument("--conditional_probs_only", type=int, default=0,
# help="0 for False, 1 for True; output conditional probabilities p(s_i given the rest of the sequence and backbone)")
# argparser.add_argument("--conditional_probs_only_backbone", type=int, default=0,
# help="0 for False, 1 for True; if true output conditional probabilities p(s_i given backbone)")
# argparser.add_argument("--unconditional_probs_only", type=int, default=0,
# help="0 for False, 1 for True; output unconditional probabilities p(s_i given backbone) in one forward pass")
#
# argparser.add_argument("--backbone_noise", type=float, default=0.00,
# help="Standard deviation of Gaussian noise to add to backbone atoms")
# argparser.add_argument("--num_seq_per_target", type=int, default=1,
# help="Number of sequences to generate per target")
# argparser.add_argument("--batch_size", type=int, default=1,
# help="Batch size; can set higher for titan, quadro GPUs, reduce this if running out of GPU memory")
# argparser.add_argument("--max_length", type=int, default=200000, help="Max sequence length")
# argparser.add_argument("--sampling_temp", type=str, default="0.1",
# help="A string of temperatures, 0.2 0.25 0.5. Sampling temperature for amino acids. Suggested values 0.1, 0.15, 0.2, 0.25, 0.3. Higher values will lead to more diversity.")
#
# argparser.add_argument("--out_folder", type=str, help="Path to a folder to output sequences, e.g. /home/out/")
# argparser.add_argument("--pdb_path", type=str, default='', help="Path to a single PDB to be designed")
# argparser.add_argument("--pdb_path_chains", type=str, default='',
# help="Define which chains need to be designed for a single PDB ")
# argparser.add_argument("--jsonl_path", type=str, help="Path to a folder with parsed pdb into jsonl")
# argparser.add_argument("--chain_id_jsonl", type=str, default='',
# help="Path to a dictionary specifying which chains need to be designed and which ones are fixed, if not specied all chains will be designed.")
# argparser.add_argument("--fixed_positions_jsonl", type=str, default='',
# help="Path to a dictionary with fixed positions")
# argparser.add_argument("--omit_AAs", type=list, default='X',
# help="Specify which amino acids should be omitted in the generated sequence, e.g. 'AC' would omit alanine and cystine.")
# argparser.add_argument("--bias_AA_jsonl", type=str, default='',
# help="Path to a dictionary which specifies AA composion bias if neededi, e.g. {A: -1.1, F: 0.7} would make A less likely and F more likely.")
#
# argparser.add_argument("--bias_by_res_jsonl", default='', help="Path to dictionary with per position bias.")
# argparser.add_argument("--omit_AA_jsonl", type=str, default='',
# help="Path to a dictionary which specifies which amino acids need to be omited from design at specific chain indices")
# argparser.add_argument("--pssm_jsonl", type=str, default='', help="Path to a dictionary with pssm")
# argparser.add_argument("--pssm_multi", type=float, default=0.0,
# help="A value between [0.0, 1.0], 0.0 means do not use pssm, 1.0 ignore MPNN predictions")
# argparser.add_argument("--pssm_threshold", type=float, default=0.0,
# help="A value between -inf + inf to restric per position AAs")
# argparser.add_argument("--pssm_log_odds_flag", type=int, default=0, help="0 for False, 1 for True")
# argparser.add_argument("--pssm_bias_flag", type=int, default=0, help="0 for False, 1 for True")
#
# argparser.add_argument("--tied_positions_jsonl", type=str, default='',
# help="Path to a dictionary with tied positions")
#
# args = argparser.parse_args()
# main(args)