D-SCRIPT / app.py
samsl's picture
File download and multiple models
ff2b104
raw
history blame
2.15 kB
import gradio as gr
import pandas as pd
from pathlib import Path
from Bio import SeqIO
from dscript.pretrained import get_pretrained
from dscript.language_model import lm_embed
from tqdm.auto import tqdm
from uuid import uuid4
model_map = {
"D-SCRIPT": "human_v1",
"Topsy-Turvy": "human_v2"
}
def predict(model, sequence_file, pairs_file):
run_id = uuid4()
gr.Info("Loading model...")
_ = lm_embed("M")
model = get_pretrained(model_map[model])
gr.Info("Loading files...")
seqs = SeqIO.to_dict(SeqIO.parse(sequence_file.name, "fasta"))
if Path(pairs_file.name).suffix == ".csv":
pairs = pd.read_csv(pairs_file.name)
elif Path(pairs_file.name).suffix == ".tsv":
pairs = pd.read_csv(pairs_file.name, sep="\t")
pairs.columns = ["protein1", "protein2"]
gr.Info("Predicting...")
results = []
progress = gr.Progress(track_tqdm=True)
for i, r in tqdm(pairs.iterrows(), total=len(pairs)):
gr.Info(f"[{i+1}/{len(pairs)}]")
prot1 = r["protein1"]
prot2 = r["protein2"]
seq1 = str(seqs[prot1].seq)
seq2 = str(seqs[prot2].seq)
lm1 = lm_embed(seq1)
lm2 = lm_embed(seq2)
interaction = model.predict(lm1, lm2).item()
results.append([prot1, prot2, interaction])
progress((i, len(pairs)))
results = pd.DataFrame(results, columns = ["Protein 1", "Protein 2", "Interaction"])
file_path = f"/tmp/{run_id}.tsv"
with open(file_path, "w") as f:
results.to_csv(f, sep="\t", index=False, header = True)
return results, file_path
demo = gr.Interface(
fn=predict,
inputs = [
gr.Dropdown(label="Model", choices = ["D-SCRIPT", "Topsy-Turvy"], value = "Topsy-Turvy"),
gr.File(label="Sequences (.fasta)", file_types = [".fasta"]),
gr.File(label="Pairs (.csv/.tsv)", file_types = [".csv", ".tsv"])
],
outputs = [
gr.DataFrame(label='Results', headers=['Protein 1', 'Protein 2', 'Interaction']),
gr.File(label="Download results", type="file")
]
)
if __name__ == "__main__":
demo.queue(max_size=20)
demo.launch()