Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoTokenizer, EsmForProteinFolding | |
from transformers.models.esm.openfold_utils.protein import to_pdb, Protein as OFProtein | |
from transformers.models.esm.openfold_utils.feats import atom14_to_atom37 | |
import torch | |
from logging import getLogger | |
logger = getLogger(__name__) | |
def convert_outputs_to_pdb(outputs): | |
final_atom_positions = atom14_to_atom37(outputs["positions"][-1], outputs) | |
outputs = {k: v.to("cpu").numpy() for k, v in outputs.items()} | |
final_atom_positions = final_atom_positions.cpu().numpy() | |
final_atom_mask = outputs["atom37_atom_exists"] | |
pdbs = [] | |
for i in range(outputs["aatype"].shape[0]): | |
aa = outputs["aatype"][i] | |
pred_pos = final_atom_positions[i] | |
mask = final_atom_mask[i] | |
resid = outputs["residue_index"][i] + 1 | |
pred = OFProtein( | |
aatype=aa, | |
atom_positions=pred_pos, | |
atom_mask=mask, | |
residue_index=resid, | |
b_factors=outputs["plddt"][i], | |
chain_index=outputs["chain_index"][i] if "chain_index" in outputs else None, | |
) | |
pdbs.append(to_pdb(pred)) | |
return pdbs[0] | |
def fold_prot_locally(sequence): | |
logger.info("Folding: " + sequence) | |
tokenized_input = tokenizer([sequence], return_tensors="pt", add_special_tokens=False)['input_ids'].cuda() | |
with torch.no_grad(): | |
output = model(tokenized_input) | |
pdb = convert_outputs_to_pdb(output) | |
return pdb | |
def get_esm2_embeddings(sequence): | |
logger.info("Getting embeddings for: " + sequence) | |
tokenized_input = tokenizer([sequence], return_tensors="pt", add_special_tokens=False)['input_ids'].cuda() | |
with torch.no_grad(): | |
aa = tokenized_input | |
L = aa.shape[1] | |
device = tokenized_input.device | |
attention_mask = torch.ones_like(aa, device=device) | |
# === ESM === | |
esmaa = model.af2_idx_to_esm_idx(aa, attention_mask) | |
esm_s = model.compute_language_model_representations(esmaa) | |
return {"res": esm_s.cpu().tolist()} | |
def get_esmfold_embeddings(sequence): | |
logger.info("Getting embeddings for: " + sequence) | |
tokenized_input = tokenizer([sequence], return_tensors="pt", add_special_tokens=False)['input_ids'].cuda() | |
with torch.no_grad(): | |
output = model(tokenized_input) | |
return {"res": output["s_s"].cpu().tolist()} | |
def suggest(option): | |
if option == "Plastic degradation protein": | |
suggestion = "MGSSHHHHHHSSGLVPRGSHMRGPNPTAASLEASAGPFTVRSFTVSRPSGYGAGTVYYPTNAGGTVGAIAIVPGYTARQSSIKWWGPRLASHGFVVITIDTNSTLDQPSSRSSQQMAALRQVASLNGTSSSPIYGKVDTARMGVMGWSMGGGGSLISAANNPSLKAAAPQAPWDSSTNFSSVTVPTLIFACENDSIAPVNSSALPIYDSMSRNAKQFLEINGGSHSCANSGNSNQALIGKKGVAWMKRFMDNDTRYSTFACENPNSTRVSDFRTANCSLEDPAANKARKEAELAAATAEQ" | |
elif option == "Antifreeze protein": | |
suggestion = "QCTGGADCTSCTGACTGCGNCPNAVTCTNSQHCVKANTCTGSTDCNTAQTCTNSKDCFEANTCTDSTNCYKATACTNSSGCPGH" | |
elif option == "AI Generated protein": | |
suggestion = "MSGMKKLYEYTVTTLDEFLEKLKEFILNTSKDKIYKLTITNPKLIKDIGKAIAKAAEIADVDPKEIEEMIKAVEENELTKLVITIEQTDDKYVIKVELENEDGLVHSFEIYFKNKEEMEKFLELLEKLISKLSGS" | |
elif option == "7-bladed propeller fold": | |
suggestion = "VKLAGNSSLCPINGWAVYSKDNSIRIGSKGDVFVIREPFISCSHLECRTFFLTQGALLNDKHSNGTVKDRSPHRTLMSCPVGEAPSPYNSRFESVAWSASACHDGTSWLTIGISGPDNGAVAVLKYNGIITDTIKSWRNNILRTQESECACVNGSCFTVMTDGPSNGQASYKIFKMEKGKVVKSVELDAPNYHYEECSCYPNAGEITCVCRDNWHGSNRPWVSFNQNLEYQIGYICSGVFGDNPRPNDGTGSCGPVSSNGAYGVKGFSFKYGNGVWIGRTKSTNSRSGFEMIWDPNGWTETDSSFSVKQDIVAITDWSGYSGSFVQHPELTGLDCIRPCFWVELIRGRPKESTIWTSGSSISFCGVNSDTVGWSWPDGAELPFTIDK" | |
else: | |
suggestion = "" | |
return suggestion | |
def molecule(mol): | |
x = ( | |
"""<!DOCTYPE html> | |
<html> | |
<head> | |
<meta http-equiv="content-type" content="text/html; charset=UTF-8" /> | |
<style> | |
body{ | |
font-family:sans-serif | |
} | |
.mol-container { | |
width: 100%; | |
height: 600px; | |
position: relative; | |
} | |
.mol-container select{ | |
background-image:None; | |
} | |
</style> | |
<script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js" integrity="sha512-STof4xm1wgkfm7heWqFJVn58Hm3EtS31XFaagaa8VMReCXAkQnJZ+jEy8PCC/iT18dFy95WcExNHFTqLyp72eQ==" crossorigin="anonymous" referrerpolicy="no-referrer"></script> | |
<script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script> | |
</head> | |
<body> | |
<div id="container" class="mol-container"></div> | |
<script> | |
let pdb = `""" | |
+ mol | |
+ """` | |
$(document).ready(function () { | |
let element = $("#container"); | |
let config = { backgroundColor: "white" }; | |
let viewer = $3Dmol.createViewer(element, config); | |
viewer.addModel(pdb, "pdb"); | |
viewer.getModel(0).setStyle({}, { cartoon: { colorscheme:"whiteCarbon" } }); | |
viewer.zoomTo(); | |
viewer.render(); | |
viewer.zoom(0.8, 2000); | |
}) | |
</script> | |
</body></html>""" | |
) | |
return f"""<iframe style="width: 100%; height: 600px" name="result" allow="midi; geolocation; microphone; camera; | |
display-capture; encrypted-media;" sandbox="allow-modals allow-forms | |
allow-scripts allow-same-origin allow-popups | |
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen="" | |
allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>""" | |
sample_code = """ | |
from gradio_client import Client | |
client = Client("https://wwydmanski-esmfold.hf.space/") | |
def fold_huggingface(sequence, fname=None): | |
result = client.predict( | |
sequence, # str in 'sequence' Textbox component | |
api_name="/pdb") | |
if fname is None: | |
with tempfile.NamedTemporaryFile("w", delete=False, suffix=".pdb", prefix="esmfold_") as fp: | |
fp.write(result) | |
fp.flush() | |
return fp.name | |
else: | |
with open(fname, "w") as fp: | |
fp.write(result) | |
fp.flush() | |
return fname | |
pdb_fname = fold_huggingface("MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN") | |
""" | |
tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1") | |
model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1", low_cpu_mem_usage=True).cuda() | |
model.esm = model.esm.half() | |
torch.backends.cuda.matmul.allow_tf32 = True | |
with gr.Blocks() as demo: | |
gr.Markdown("# ESMFold") | |
with gr.Row(): | |
with gr.Column(): | |
inp = gr.Textbox(lines=1, label="Sequence") | |
name = gr.Dropdown(label="Choose a Sample Protein", value="Plastic degradation protein", choices=["Antifreeze protein", "Plastic degradation protein", "AI Generated protein", "7-bladed propeller fold", "custom"]) | |
btn = gr.Button("🔬 Predict Structure ") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("## Sample code") | |
gr.Code(sample_code, label="Sample usage", language="python", interactive=False) | |
with gr.Row(): | |
gr.Markdown("## Output") | |
with gr.Row(): | |
with gr.Column(): | |
out = gr.Code(label="Output", interactive=False) | |
with gr.Column(): | |
out_mol = gr.HTML(label="3D Structure") | |
with gr.Row(visible=False): | |
with gr.Column(): | |
gr.Markdown("## Embeddings") | |
embs = gr.JSON(label="Embeddings") | |
name.change(fn=suggest, inputs=name, outputs=inp) | |
btn.click(fold_prot_locally, inputs=[inp], outputs=[out], api_name="pdb") | |
btn.click(get_esmfold_embeddings, inputs=[inp], outputs=[embs], api_name="embeddings") | |
btn.click(get_esm2_embeddings, inputs=[inp], outputs=[embs], api_name="esm2_embeddings") | |
out.change(fn=molecule, inputs=[out], outputs=[out_mol], api_name="3d_fold") | |
demo.launch() |