Spaces:
Running
on
T4
Running
on
T4
import gradio as gr | |
import re | |
import urllib | |
import tempfile | |
from output_helpers import viewer_html, output_html, load_js, get_js | |
import json | |
import os | |
import shlex | |
import subprocess | |
from datetime import datetime | |
from einops import repeat | |
import torch | |
from core import data | |
from core import utils | |
import models | |
import sampling | |
# from draw_samples import draw_and_save_samples, parse_resample_idx_string | |
print("working directory", os.getcwd()) | |
def draw_and_save_samples( | |
model, | |
samples_per_len=8, | |
lengths=range(50, 512), | |
save_dir="./", | |
mode="backbone", | |
**sampling_kwargs, | |
): | |
device = model.device | |
sample_files = [] | |
if mode == "backbone": | |
total_sampling_time = 0 | |
for l in lengths: | |
prot_lens = torch.ones(samples_per_len).long() * l | |
seq_mask = model.make_seq_mask_for_sampling(prot_lens=prot_lens) | |
aux = sampling.draw_backbone_samples( | |
model, | |
seq_mask=seq_mask, | |
pdb_save_path=f"{save_dir}/len{format(l, '03d')}_samp", | |
return_aux=True, | |
return_sampling_runtime=True, | |
**sampling_kwargs, | |
) | |
total_sampling_time += aux["runtime"] | |
sample_files+= [f"{save_dir}/len{format(l, '03d')}_samp{i}.pdb" for i in range(samples_per_len)] | |
return sample_files | |
elif mode == "allatom": | |
total_sampling_time = 0 | |
for l in lengths: | |
prot_lens = torch.ones(samples_per_len).long() * l | |
seq_mask = model.make_seq_mask_for_sampling(prot_lens=prot_lens) | |
aux = sampling.draw_allatom_samples( | |
model, | |
seq_mask=seq_mask, | |
pdb_save_path=f"{save_dir}/len{format(l, '03d')}", | |
return_aux=True, | |
**sampling_kwargs, | |
) | |
total_sampling_time += aux["runtime"] | |
sample_files+= [f"{save_dir}/len{format(l, '03d')}_samp{i}.pdb" for i in range(samples_per_len)] | |
return sample_files | |
def parse_idx_string(idx_str): | |
spans = idx_str.split(",") | |
idxs = [] | |
for s in spans: | |
if "-" in s: | |
start, stop = s.split("-") | |
idxs.extend(list(range(int(start), int(stop)))) | |
else: | |
idxs.append(int(s)) | |
return idxs | |
def changemode(m): | |
if (m == "unconditional"): | |
return gr.update(visible=True), gr.update(visible=False),gr.update(visible=True), gr.update(visible=False), gr.update(visible=True) | |
else: | |
return gr.update(visible=False), gr.update(visible=True),gr.update(visible=False), gr.update(visible=True), gr.update(visible=False) | |
def fileselection(val): | |
if (val == "upload"): | |
return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False) | |
else: | |
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=True) | |
def update_structuresel(pdb, radio_val): | |
pdb_file = tempfile.NamedTemporaryFile(delete=False, suffix=".pdb") | |
representations = [{ | |
"model": 0, | |
"chain": "", | |
"resname": "", | |
"style": "cartoon", | |
"color": "whiteCarbon", | |
"residue_range": "", | |
"around": 0, | |
"byres": False, | |
"visible": False, | |
}] | |
if (radio_val == "PDB"): | |
if (len(pdb) != 4): | |
return gr.update(open=True),gr.update(), gr.update(value="",visible=False) | |
else: | |
urllib.request.urlretrieve( | |
f"http://files.rcsb.org/download/{pdb.lower()}.pdb1", | |
pdb_file.name, | |
) | |
return gr.update(open=False),gr.update(value=pdb_file.name), gr.update(value=f"""<iframe style="width: 100%; height: 930px" name="result" allow="midi; geolocation; microphone; camera; | |
display-capture; encrypted-media;" sandbox="allow-modals allow-forms | |
allow-scripts allow-same-origin allow-popups | |
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen="" | |
allowpaymentrequest="" frameborder="0" srcdoc='{viewer_html(pdb_file.name, representations=representations)}'></iframe>""",visible=True) | |
elif (radio_val == "AFDB2"): | |
if (re.match("[OPQ][0-9][A-Z0-9]{3}[0-9]|[A-NR-Z][0-9]([A-Z][A-Z0-9]{2}[0-9]){1,2}",pdb) != None): | |
urllib.request.urlretrieve( | |
f"https://alphafold.ebi.ac.uk/files/AF-{pdb}-F1-model_v2.pdb", | |
pdb_file.name | |
) | |
return gr.update(open=False),gr.update(value=pdb_file.name), gr.update(value=f"""<iframe style="width: 100%; height: 930px" name="result" allow="midi; geolocation; microphone; camera; | |
display-capture; encrypted-media;" sandbox="allow-modals allow-forms | |
allow-scripts allow-same-origin allow-popups | |
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen="" | |
allowpaymentrequest="" frameborder="0" srcdoc='{viewer_html(pdb_file.name, representations=representations)}'></iframe>""",visible=True) | |
else: | |
return gr.update(open=True), gr.update(value="regex not matched",visible=True) | |
else: | |
return gr.update(open=False),gr.update(value=f"{pdb.name}"), gr.update(value=f"""<iframe style="width: 100%; height: 930px" name="result" allow="midi; geolocation; microphone; camera; | |
display-capture; encrypted-media;" sandbox="allow-modals allow-forms | |
allow-scripts allow-same-origin allow-popups | |
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen="" | |
allowpaymentrequest="" frameborder="0" srcdoc='{viewer_html(pdb.name, representations=representations)}'></iframe>""",visible=True) | |
from Bio.PDB import PDBParser, cealign | |
from Bio.PDB.PDBIO import PDBIO | |
class dotdict(dict): | |
"""dot.notation access to dictionary attributes""" | |
__getattr__ = dict.get | |
__setattr__ = dict.__setitem__ | |
__delattr__ = dict.__delitem__ | |
def protpardelle(path_to_file, m, resample_idx, modeltype, minlen, maxlen, steplen, perlen): | |
# Set up params, arguments, sampling config | |
#################### | |
args = {} | |
args["model_checkpoint"] = "checkpoints" #Path to denoiser model weights and config", | |
args["mpnnpath"] = "checkpoints/minimpnn_state_dict.pth" #"Path to minimpnn model weights", | |
args["modeldir"] = None #"Model base directory, ex 'training_logs/other/lemon-shape-51'", | |
args["modelepoch"] = None #"Model epoch, ex 1000") | |
args["type"]=modeltype # "Type of model" | |
if m == "conditional": | |
args["param"] = None #"Which sampling param to vary" | |
args["paramval"]=None #"Which param val to use" | |
args["parampath"]= None # Path to json file with params, either use param/paramval or parampath, not both", | |
args["perlen"] = int(perlen) #How many samples per sequence length" | |
args["minlen"] = None #"Minimum sequence length" | |
args["maxlen"] = None #Maximum sequence length, not inclusive", | |
args["steplen"] = int(steplen) #"How frequently to select sequence length, for steplen 2, would be 50, 52, 54, etc", | |
args["num_lens"] = None #"If steplen not provided, how many random lengths to sample at", | |
args["targetdir"] = "." #"Directory to save results" | |
args["input_pdb"] = path_to_file # "PDB file to condition on" | |
args["resample_idxs"] = resample_idx[1:-1] # "Indices from PDB file to resample. Zero-indexed, comma-delimited, can use dashes, eg 0,2-5,7" | |
else: | |
args["param"] = None #"Which sampling param to vary" | |
args["paramval"]=None #"Which param val to use" | |
args["parampath"]= None # Path to json file with params, either use param/paramval or parampath, not both", | |
args["perlen"] = int(perlen) #How many samples per sequence length" | |
args["minlen"] = int(minlen) #"Minimum sequence length" | |
args["maxlen"] = int(maxlen)+1 #Maximum sequence length | |
args["steplen"] = int(steplen) #"How frequently to select sequence length, for steplen 2, would be 50, 52, 54, etc", | |
args["num_lens"] = None #"If steplen not provided, how many random lengths to sample at", | |
args["targetdir"] = "." #"Directory to save results" | |
args["resample_idxs"] = None | |
args = dotdict(args) | |
is_test_run = False | |
seed = 0 | |
samples_per_len = args.perlen | |
min_len = args.minlen | |
max_len = args.maxlen | |
len_step_size = args.steplen | |
device = "cuda:0" | |
# setting default sampling config | |
if args.type == "backbone": | |
sampling_config = sampling.default_backbone_sampling_config() | |
elif args.type == "allatom": | |
sampling_config = sampling.default_allatom_sampling_config() | |
sampling_kwargs = vars(sampling_config) | |
# Parse conditioning inputs | |
input_pdb_len = None | |
if args.input_pdb: | |
input_feats = utils.load_feats_from_pdb(args.input_pdb, protein_only=True) | |
input_pdb_len = input_feats["aatype"].shape[0] | |
if args.resample_idxs: | |
print( | |
f"Warning: when sampling conditionally, the input pdb length ({input_pdb_len} residues) is used automatically for the sampling lengths." | |
) | |
resample_idxs = parse_idx_string(args.resample_idxs) | |
else: | |
resample_idxs = list(range(input_pdb_len)) | |
cond_idxs = [i for i in range(input_pdb_len) if i not in resample_idxs] | |
to_batch_size = lambda x: repeat(x, "... -> b ...", b=samples_per_len).to( | |
device | |
) | |
# For unconditional model, center coords on whole structure | |
centered_coords = data.apply_random_se3( | |
input_feats["atom_positions"], | |
atom_mask=input_feats["atom_mask"], | |
translation_scale=0.0, | |
) | |
cond_kwargs = {} | |
cond_kwargs["gt_coords"] = to_batch_size(centered_coords) | |
cond_kwargs["gt_cond_atom_mask"] = to_batch_size(input_feats["atom_mask"]) | |
cond_kwargs["gt_cond_atom_mask"][:, resample_idxs] = 0 | |
cond_kwargs["gt_aatype"] = to_batch_size(input_feats["aatype"]) | |
cond_kwargs["gt_cond_seq_mask"] = torch.zeros_like(cond_kwargs["gt_aatype"]) | |
cond_kwargs["gt_cond_seq_mask"][:, cond_idxs] = 1 | |
sampling_kwargs.update(cond_kwargs) | |
print("input_pdb_len", input_pdb_len) | |
# Determine lengths to sample at | |
if min_len is not None and max_len is not None: | |
if len_step_size is not None: | |
sampling_lengths = range(min_len, max_len, len_step_size) | |
else: | |
sampling_lengths = list( | |
torch.randint(min_len, max_len, size=(args.num_lens,)) | |
) | |
elif input_pdb_len is not None: | |
sampling_lengths = [input_pdb_len] | |
else: | |
raise Exception("Need to provide a set of protein lengths or an input pdb.") | |
total_num_samples = len(list(sampling_lengths)) * samples_per_len | |
model_directory = args.modeldir | |
epoch = args.modelepoch | |
base_dir = args.targetdir | |
date_string = datetime.now().strftime("%y-%m-%d-%H-%M-%S") | |
if is_test_run: | |
date_string = f"test-{date_string}" | |
# Update sampling config with arguments | |
if args.param: | |
var_param = args.param | |
var_value = args.paramval | |
sampling_kwargs[var_param] = ( | |
None | |
if var_value == "None" | |
else int(var_value) | |
if var_param == "n_steps" | |
else float(var_value) | |
) | |
elif args.parampath: | |
with open(args.parampath) as f: | |
var_params = json.loads(f.read()) | |
sampling_kwargs.update(var_params) | |
# this is only used for the readme, keep s_min and s_max as params instead of struct_noise_schedule | |
sampling_kwargs_readme = list(sampling_kwargs.items()) | |
print("Base directory:", base_dir) | |
save_dir = f"{base_dir}/samples/{date_string}" | |
save_init_dir = f"{base_dir}/samples_inits/{date_string}" | |
# make dirs if do not exist | |
if not os.path.exists(save_dir): | |
subprocess.run(shlex.split(f"mkdir -p {save_dir}")) | |
if not os.path.exists(save_init_dir): | |
subprocess.run(shlex.split(f"mkdir -p {save_init_dir}")) | |
print("Samples saved to:", save_dir) | |
torch.manual_seed(seed) | |
# Load model | |
if args.type == "backbone": | |
if args.model_checkpoint: | |
checkpoint = f"{args.model_checkpoint}/backbone_state_dict.pth" | |
cfg_path = f"{args.model_checkpoint}/backbone_pretrained.yml" | |
else: | |
checkpoint = ( | |
f"{model_directory}/checkpoints/epoch{epoch}_training_state.pth" | |
) | |
cfg_path = f"{model_directory}/configs/backbone.yml" | |
config = utils.load_config(cfg_path) | |
weights = torch.load(checkpoint, map_location=device)["model_state_dict"] | |
model = models.Protpardelle(config, device=device) | |
model.load_state_dict(weights) | |
model.to(device) | |
model.eval() | |
model.device = device | |
elif args.type == "allatom": | |
if args.model_checkpoint: | |
checkpoint = f"{args.model_checkpoint}/allatom_state_dict.pth" | |
cfg_path = f"{args.model_checkpoint}/allatom_pretrained.yml" | |
else: | |
checkpoint = ( | |
f"{model_directory}/checkpoints/epoch{epoch}_training_state.pth" | |
) | |
cfg_path = f"{model_directory}/configs/allatom.yml" | |
config = utils.load_config(cfg_path) | |
weights = torch.load(checkpoint, map_location=device)["model_state_dict"] | |
model = models.Protpardelle(config, device=device) | |
model.load_state_dict(weights) | |
model.load_minimpnn(args.mpnnpath) | |
model.to(device) | |
model.eval() | |
model.device = device | |
if config.train.home_dir == '': | |
config.train.home_dir = os.getcwd() | |
with open(save_dir + "/run_parameters.txt", "w") as f: | |
f.write(f"Sampling run for {date_string}\n") | |
f.write(f"Random seed {seed}\n") | |
f.write(f"Model checkpoint: {checkpoint}\n") | |
f.write( | |
f"{samples_per_len} samples per length from {min_len}:{max_len}:{len_step_size}\n" | |
) | |
f.write("Sampling params:\n") | |
for k, v in sampling_kwargs_readme: | |
f.write(f"{k}\t{v}\n") | |
print(f"Model loaded from {checkpoint}") | |
print(f"Beginning sampling for {date_string}...") | |
# Draw samples | |
output_files = draw_and_save_samples( | |
model, | |
samples_per_len=samples_per_len, | |
lengths=sampling_lengths, | |
save_dir=save_dir, | |
mode=args.type, | |
**sampling_kwargs, | |
) | |
return output_files | |
def api_predict(pdb_content,m, resample_idx, modeltype, minlen, maxlen, steplen, perlen): | |
if (m == "conditional"): | |
tempPDB = tempfile.NamedTemporaryFile(delete=False, suffix=".pdb") | |
tempPDB.write(pdb_content.encode()) | |
tempPDB.close() | |
path_to_file = tempPDB.name | |
else: | |
path_to_file = None | |
try: | |
designs = protpardelle(path_to_file, m, resample_idx, modeltype, minlen, maxlen, steplen, perlen) | |
except Exception as e: | |
print(e) | |
raise gr.Error(e) | |
# load each design as string | |
design_str = [] | |
for d in designs: | |
with open(d, "r") as f: | |
design_str.append(f.read()) | |
results = list(zip(designs, design_str)) | |
return json.dumps(results) | |
def predict(pdb_radio, path_to_file,m, resample_idx, modeltype, minlen, maxlen, steplen, perlen): | |
print("running predict") | |
try: | |
designs = protpardelle(path_to_file, m, resample_idx, modeltype, minlen, maxlen, steplen, perlen) | |
except Exception as e: | |
print(e) | |
raise gr.Error(e) | |
return gr.update(open=True), gr.update(value="something went wrong") | |
parser = PDBParser() | |
aligner = cealign.CEAligner() | |
io=PDBIO() | |
aligned_designs = [] | |
metrics = [] | |
if (m == "conditional"): | |
ref = parser.get_structure("ref", path_to_file) | |
aligner.set_reference(ref) | |
for d in designs: | |
design = parser.get_structure("design", d) | |
aligner.align(design) | |
metrics.append({"rms": f"{aligner.rms:.1f}", "len": len(list(design[0].get_residues()))}) | |
io.set_structure(design) | |
io.save(d.replace(".pdb", f"_al.pdb")) | |
aligned_designs.append(d.replace(".pdb", f"_al.pdb")) | |
else: | |
for d in designs: | |
design = parser.get_structure("design", d) | |
metrics.append({"len": len(list(design[0].get_residues()))}) | |
aligned_designs = designs | |
output_view = f"""<iframe style="width: 100%; height: 900px" name="result" allow="midi; geolocation; microphone; camera; | |
display-capture; encrypted-media;" sandbox="allow-modals allow-forms | |
allow-scripts allow-same-origin allow-popups | |
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen="" | |
allowpaymentrequest="" frameborder="0" srcdoc='{output_html(path_to_file, aligned_designs, metrics, resample_idx=resample_idx, mode=m)}'></iframe>""" | |
return gr.update(open=False), gr.update(value=output_view,visible=True) | |
protpardelleDemo = gr.Blocks() | |
with protpardelleDemo: | |
gr.Markdown("# Protpardelle") | |
gr.Markdown(""" An all-atom protein generative model | |
Alexander E. Chu, Lucy Cheng, Gina El Nesr, Minkai Xu, Po-Ssu Huang | |
doi: https://doi.org/10.1101/2023.05.24.542194""") | |
with gr.Accordion(label="Input options", open=True) as input_accordion: | |
model = gr.Dropdown(["backbone", "allatom"], value="allatom", label="What to sample?") | |
m = gr.Radio(['unconditional','conditional'],value="unconditional", label="Choose a Mode") | |
#unconditional | |
with gr.Group(visible=True) as uncond: | |
gr.Markdown("Unconditional Sampling") | |
# length = gr.Slider(minimum=0, maximum=200, step=1, value=50, label="length") | |
# param = gr.Dropdown(["length", "param"], value="length", label="Which sampling param to vary?") | |
# paramval = gr.Dropdown(["nsteps"], label="paramval", info="Which param val to use?") | |
#conditional | |
with gr.Group(visible=False) as cond: | |
with gr.Accordion(label="Structure to condition on", open=True) as input_accordion: | |
pdb_radio = gr.Radio(['PDB','AF2 EBI DB', 'upload'],value="PDB", label="source of the structure") | |
pdbcode = gr.Textbox(label="Uniprot code to be retrieved Alphafold2 Database", visible=True) | |
pdbfile = gr.File(label="PDB File", visible=False) | |
btn_load = gr.Button("Load PDB") | |
pdb_radio.change(fileselection, inputs=pdb_radio, outputs=[pdbcode, pdbfile, btn_load]) | |
pdb_html = gr.HTML("", visible=False) | |
path_to_file = gr.Textbox(label="Path to file", visible=False) | |
resample_idxs = gr.Textbox(label="Cond Idxs", interactive=False, info="Zero indexed list of indices to condition on, select in sequence viewer above") | |
btn_load.click(update_structuresel, inputs=[pdbcode, pdb_radio], outputs=[input_accordion,path_to_file,pdb_html]) | |
pdbfile.change(update_structuresel, inputs=[pdbfile,pdb_radio], outputs=[input_accordion,path_to_file,pdb_html]) | |
with gr.Accordion(label="Sizes", open=True) as size_uncond: | |
with gr.Row(): | |
minlen = gr.Slider(minimum=2, maximum=200,value=50, step=1, label="minlen", info="Minimum sequence length") | |
maxlen = gr.Slider(minimum=3, maximum=200,value=60, step=1, label="maxlen", info="Maximum sequence length") | |
steplen = gr.Slider(minimum=1, maximum=50, step=1, value=1, label="steplen", info="How frequently to select sequence length?" ) | |
perlen = gr.Slider(minimum=1, maximum=200, step=1, value=2, label="perlen", info="How many samples per sequence length?") | |
btn_conditional = gr.Button("Run conditional",visible=False) | |
btn_unconditional = gr.Button("Run unconditional") | |
m.change(changemode, inputs=m, outputs=[uncond, cond, btn_unconditional, btn_conditional, size_uncond]) | |
out = gr.HTML("", visible=True) | |
btn_unconditional.click(predict, inputs=[pdb_radio, path_to_file,m, resample_idxs, model, minlen, maxlen, steplen, perlen], outputs=[input_accordion, out]) | |
btn_conditional.click(fn=None, | |
inputs=[resample_idxs], | |
outputs=[resample_idxs], | |
_js=get_js | |
) # | |
out_text = gr.Textbox(label="Output", visible=False) | |
#hidden button for named api route | |
pdb_content = gr.Textbox(label="PDB Content", visible=False) | |
btn_api = gr.Button("Run API",visible=False) | |
btn_api.click(api_predict, inputs=[pdb_content,m, resample_idxs, model, minlen, maxlen, steplen, perlen], outputs=[out_text], api_name="protpardelle") | |
resample_idxs.change(predict, inputs=[pdb_radio, path_to_file,m, resample_idxs, model, minlen, maxlen, steplen, perlen], outputs=[input_accordion, out]) | |
protpardelleDemo.load(None, None, None, _js=load_js) | |
protpardelleDemo.queue() | |
protpardelleDemo.launch(allowed_paths=['samples']) | |