D-SCRIPT / app.py
samsl's picture
Add error checking for fasta file loading
d43f920
raw
history blame
No virus
2.25 kB
import gradio as gr
import pandas as pd
from pathlib import Path
from Bio import SeqIO
from dscript.pretrained import get_pretrained
from dscript.language_model import lm_embed
from tqdm.auto import tqdm
from uuid import uuid4
model_map = {
"D-SCRIPT": "human_v1",
"Topsy-Turvy": "human_v2"
}
def predict(model, sequence_file, pairs_file):
run_id = uuid4()
gr.Info("Loading model...")
_ = lm_embed("M")
model = get_pretrained(model_map[model])
gr.Info("Loading files...")
try:
seqs = SeqIO.to_dict(SeqIO.parse(sequence_file.name, "fasta"))
except ValueError as e:
gr.Error("Invalid FASTA file - duplicate entry")
if Path(pairs_file.name).suffix == ".csv":
pairs = pd.read_csv(pairs_file.name)
elif Path(pairs_file.name).suffix == ".tsv":
pairs = pd.read_csv(pairs_file.name, sep="\t")
pairs.columns = ["protein1", "protein2"]
gr.Info("Predicting...")
results = []
progress = gr.Progress(track_tqdm=True)
for i, r in tqdm(pairs.iterrows(), total=len(pairs)):
gr.Info(f"[{i+1}/{len(pairs)}]")
prot1 = r["protein1"]
prot2 = r["protein2"]
seq1 = str(seqs[prot1].seq)
seq2 = str(seqs[prot2].seq)
lm1 = lm_embed(seq1)
lm2 = lm_embed(seq2)
interaction = model.predict(lm1, lm2).item()
results.append([prot1, prot2, interaction])
progress((i, len(pairs)))
results = pd.DataFrame(results, columns = ["Protein 1", "Protein 2", "Interaction"])
file_path = f"/tmp/{run_id}.tsv"
with open(file_path, "w") as f:
results.to_csv(f, sep="\t", index=False, header = True)
return results, file_path
demo = gr.Interface(
fn=predict,
inputs = [
gr.Dropdown(label="Model", choices = ["D-SCRIPT", "Topsy-Turvy"], value = "Topsy-Turvy"),
gr.File(label="Sequences (.fasta)", file_types = [".fasta"]),
gr.File(label="Pairs (.csv/.tsv)", file_types = [".csv", ".tsv"])
],
outputs = [
gr.DataFrame(label='Results', headers=['Protein 1', 'Protein 2', 'Interaction']),
gr.File(label="Download results", type="file")
]
)
if __name__ == "__main__":
demo.queue(max_size=20)
demo.launch()