Spaces:
Sleeping
Sleeping
Witold Wydmański
commited on
Commit
•
8ee7dbf
1
Parent(s):
137a7d5
feat: Add get_esm2_embeddings function
Browse files
app.py
CHANGED
@@ -38,6 +38,22 @@ def fold_prot_locally(sequence):
|
|
38 |
pdb = convert_outputs_to_pdb(output)
|
39 |
return pdb
|
40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
def get_esmfold_embeddings(sequence):
|
42 |
logger.info("Getting embeddings for: " + sequence)
|
43 |
tokenized_input = tokenizer([sequence], return_tensors="pt", add_special_tokens=False)['input_ids'].cuda()
|
@@ -165,11 +181,12 @@ with gr.Blocks() as demo:
|
|
165 |
with gr.Row(visible=False):
|
166 |
with gr.Column():
|
167 |
gr.Markdown("## Embeddings")
|
168 |
-
embs = gr.JSON(label="Embeddings"
|
169 |
|
170 |
name.change(fn=suggest, inputs=name, outputs=inp)
|
171 |
btn.click(fold_prot_locally, inputs=[inp], outputs=[out], api_name="pdb")
|
172 |
btn.click(get_esmfold_embeddings, inputs=[inp], outputs=[embs], api_name="embeddings")
|
|
|
173 |
out.change(fn=molecule, inputs=[out], outputs=[out_mol], api_name="3d_fold")
|
174 |
|
175 |
demo.launch()
|
|
|
38 |
pdb = convert_outputs_to_pdb(output)
|
39 |
return pdb
|
40 |
|
41 |
+
def get_esm2_embeddings(sequence):
|
42 |
+
logger.info("Getting embeddings for: " + sequence)
|
43 |
+
tokenized_input = tokenizer([sequence], return_tensors="pt", add_special_tokens=False)['input_ids'].cuda()
|
44 |
+
|
45 |
+
with torch.no_grad():
|
46 |
+
aa = tokenized_input
|
47 |
+
L = aa.shape[1]
|
48 |
+
device = tokenized_input.device
|
49 |
+
attention_mask = torch.ones_like(aa, device=device)
|
50 |
+
|
51 |
+
# === ESM ===
|
52 |
+
esmaa = model.af2_idx_to_esm_idx(aa, attention_mask)
|
53 |
+
esm_s = model.compute_language_model_representations(esmaa)
|
54 |
+
|
55 |
+
return {"res": esm_s.cpu().tolist()}
|
56 |
+
|
57 |
def get_esmfold_embeddings(sequence):
|
58 |
logger.info("Getting embeddings for: " + sequence)
|
59 |
tokenized_input = tokenizer([sequence], return_tensors="pt", add_special_tokens=False)['input_ids'].cuda()
|
|
|
181 |
with gr.Row(visible=False):
|
182 |
with gr.Column():
|
183 |
gr.Markdown("## Embeddings")
|
184 |
+
embs = gr.JSON(label="Embeddings")
|
185 |
|
186 |
name.change(fn=suggest, inputs=name, outputs=inp)
|
187 |
btn.click(fold_prot_locally, inputs=[inp], outputs=[out], api_name="pdb")
|
188 |
btn.click(get_esmfold_embeddings, inputs=[inp], outputs=[embs], api_name="embeddings")
|
189 |
+
btn.click(get_esm2_embeddings, inputs=[inp], outputs=[embs], api_name="esm2_embeddings")
|
190 |
out.change(fn=molecule, inputs=[out], outputs=[out_mol], api_name="3d_fold")
|
191 |
|
192 |
demo.launch()
|
client.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#%%
|
2 |
+
from gradio_client import Client
|
3 |
+
|
4 |
+
#%%
|
5 |
+
# client = Client("https://wwydmanski-esmfold.hf.space/")
|
6 |
+
client = Client("http://localhost:7860")
|
7 |
+
|
8 |
+
# %%
|
9 |
+
result = client.predict(
|
10 |
+
"MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN", # str in 'sequence' Textbox component
|
11 |
+
api_name="/esm2_embeddings")
|
12 |
+
|
13 |
+
# %%
|
14 |
+
result
|