|
import gradio as gr |
|
import pandas as pd |
|
from pathlib import Path |
|
from Bio import SeqIO |
|
from dscript.pretrained import get_pretrained |
|
from dscript.language_model import lm_embed |
|
from tqdm.auto import tqdm |
|
from uuid import uuid4 |
|
|
|
model_map = { |
|
"D-SCRIPT": "human_v1", |
|
"Topsy-Turvy": "human_v2" |
|
} |
|
|
|
def predict(model, sequence_file, pairs_file): |
|
|
|
run_id = uuid4() |
|
|
|
gr.Info("Loading model...") |
|
_ = lm_embed("M") |
|
|
|
model = get_pretrained(model_map[model]) |
|
|
|
gr.Info("Loading files...") |
|
seqs = SeqIO.to_dict(SeqIO.parse(sequence_file.name, "fasta")) |
|
if Path(pairs_file.name).suffix == ".csv": |
|
pairs = pd.read_csv(pairs_file.name) |
|
elif Path(pairs_file.name).suffix == ".tsv": |
|
pairs = pd.read_csv(pairs_file.name, sep="\t") |
|
pairs.columns = ["protein1", "protein2"] |
|
|
|
gr.Info("Predicting...") |
|
results = [] |
|
progress = gr.Progress(track_tqdm=True) |
|
for i, r in tqdm(pairs.iterrows(), total=len(pairs)): |
|
gr.Info(f"[{i+1}/{len(pairs)}]") |
|
prot1 = r["protein1"] |
|
prot2 = r["protein2"] |
|
seq1 = str(seqs[prot1].seq) |
|
seq2 = str(seqs[prot2].seq) |
|
lm1 = lm_embed(seq1) |
|
lm2 = lm_embed(seq2) |
|
interaction = model.predict(lm1, lm2).item() |
|
results.append([prot1, prot2, interaction]) |
|
progress((i, len(pairs))) |
|
|
|
results = pd.DataFrame(results, columns = ["Protein 1", "Protein 2", "Interaction"]) |
|
|
|
file_path = f"/tmp/{run_id}.tsv" |
|
with open(file_path, "w") as f: |
|
results.to_csv(f, sep="\t", index=False, header = True) |
|
|
|
return results, file_path |
|
|
|
demo = gr.Interface( |
|
fn=predict, |
|
inputs = [ |
|
gr.Dropdown(label="Model", choices = ["D-SCRIPT", "Topsy-Turvy"], value = "Topsy-Turvy"), |
|
gr.File(label="Sequences (.fasta)", file_types = [".fasta"]), |
|
gr.File(label="Pairs (.csv/.tsv)", file_types = [".csv", ".tsv"]) |
|
], |
|
outputs = [ |
|
gr.DataFrame(label='Results', headers=['Protein 1', 'Protein 2', 'Interaction']), |
|
gr.File(label="Download results", type="file") |
|
] |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.queue(max_size=20) |
|
demo.launch() |