Upload 2 files
Browse files- FAPM_inference.py +86 -0
- README.md +75 -0
FAPM_inference.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import pandas as pd
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from lavis.models.protein_models.protein_function_opt import Blip2ProteinMistral
|
6 |
+
# from lavis.models.base_model import FAPMConfig
|
7 |
+
# from lavis.models.blip2_models.blip2_opt import Blip2ProteinOPT
|
8 |
+
import random
|
9 |
+
from lavis.models.base_model import FAPMConfig
|
10 |
+
import argparse
|
11 |
+
|
12 |
+
prop = True
|
13 |
+
|
14 |
+
if __name__ == '__main__':
|
15 |
+
parser = argparse.ArgumentParser(description='FAPM')
|
16 |
+
parser.add_argument('--model_path', type=str, help='Dataset path')
|
17 |
+
parser.add_argument('--example_path', type=str, help='Example protein path')
|
18 |
+
parser.add_argument('--device', type=str, default='cuda', help='Which gpu to use if any (default: cuda)')
|
19 |
+
parser.add_argument('--prompt', type=str, default='none', help='Input prompt for protein function prediction')
|
20 |
+
parser.add_argument('--ground_truth', type=str, default='none', help='ground truth function')
|
21 |
+
args = parser.parse_args()
|
22 |
+
test_sdf_paths = args.model_path
|
23 |
+
|
24 |
+
# model = Blip2ProteinOPT(config=FAPMConfig(), esm_size='3b')
|
25 |
+
# model.load_checkpoint('/cluster/home/wenkai/LAVIS/lavis/output/BLIP2/Pretrain_stage2/20240327081/checkpoint_2.pth')
|
26 |
+
model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
|
27 |
+
model.load_checkpoint(args.model_path)
|
28 |
+
model.to(args.device)
|
29 |
+
|
30 |
+
# esm_emb = torch.load('/cluster/home/wenkai/LAVIS/data/pretrain/ipr_domain_emb_esm2_3b/Gp49.pt')['representations'][36]
|
31 |
+
esm_emb = torch.load(args.example_path)['representations'][36]
|
32 |
+
esm_emb = F.pad(esm_emb.t(), (0, 1024 - len(esm_emb))).t().to('cuda')
|
33 |
+
samples = {'name': ['P18281'],
|
34 |
+
'image': torch.unsqueeze(esm_emb, dim=0),
|
35 |
+
'text_input': [args.ground_truth],
|
36 |
+
'prompt': [args.prompt]}
|
37 |
+
prediction = model.generate(samples, length_penalty=0., num_beams=15, num_captions=10, temperature=1., repetition_penalty=1.0)
|
38 |
+
print(f"Text Prediction: {prediction}")
|
39 |
+
|
40 |
+
|
41 |
+
if prop == True:
|
42 |
+
from data.evaluate_data.utils import Ontology
|
43 |
+
import difflib
|
44 |
+
import re
|
45 |
+
|
46 |
+
# godb = Ontology(f'/cluster/home/wenkai/LAVIS/data/go1.4-basic.obo', with_rels=True)
|
47 |
+
godb = Ontology(f'data/go1.4-basic.obo', with_rels=True)
|
48 |
+
|
49 |
+
go_des = pd.read_csv('data/go_descriptions1.4.txt', sep='|', header=None)
|
50 |
+
go_des.columns = ['id', 'text']
|
51 |
+
go_des = go_des.dropna()
|
52 |
+
go_des['id'] = go_des['id'].apply(lambda x: re.sub('_', ':', x))
|
53 |
+
go_obo_set = set(go_des['id'].tolist())
|
54 |
+
go_des['text'] = go_des['text'].apply(lambda x: x.lower())
|
55 |
+
GO_dict = dict(zip(go_des['text'], go_des['id']))
|
56 |
+
Func_dict = dict(zip(go_des['id'], go_des['text']))
|
57 |
+
|
58 |
+
# terms_mf = pd.read_pickle('/cluster/home/wenkai/deepgo2/data/mf/terms.pkl')
|
59 |
+
terms_mf = pd.read_pickle('data/terms/mf_terms.pkl')
|
60 |
+
choices_mf = [Func_dict[i] for i in list(set(terms_mf['gos']))]
|
61 |
+
choices = {x.lower(): x for x in choices_mf}
|
62 |
+
|
63 |
+
pred_terms_list = []
|
64 |
+
pred_go_list = []
|
65 |
+
prop_annotations = []
|
66 |
+
for x in prediction:
|
67 |
+
x = [eval(i) for i in x.split('; ')]
|
68 |
+
pred_terms = []
|
69 |
+
pred_go = []
|
70 |
+
annot_set = set()
|
71 |
+
for i in x:
|
72 |
+
txt = i[0]
|
73 |
+
prob = i[1]
|
74 |
+
sim_list = difflib.get_close_matches(txt.lower(), choices, n=1, cutoff=0.9)
|
75 |
+
if len(sim_list) > 0:
|
76 |
+
pred_terms.append((sim_list[0], prob))
|
77 |
+
pred_go.append((GO_dict[sim_list[0]], prob))
|
78 |
+
annot_set |= godb.get_anchestors(GO_dict[sim_list[0]])
|
79 |
+
pred_terms_list.append(pred_terms)
|
80 |
+
pred_go_list.append(pred_go)
|
81 |
+
annots = list(annot_set)
|
82 |
+
prop_annotations.append(annots)
|
83 |
+
|
84 |
+
print(f"Predictions of GO terms: \n{pred_terms_list} \nPredictions of GO id: \n{pred_go_list} \nPredictions of GO id propgated: \n{prop_annotations}")
|
85 |
+
|
86 |
+
|
README.md
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Introduction
|
2 |
+
<p align="center">
|
3 |
+
<br>
|
4 |
+
<img src="assets/FAPM.png"/>
|
5 |
+
<br>
|
6 |
+
<p>
|
7 |
+
|
8 |
+
## Installation
|
9 |
+
|
10 |
+
1. (Optional) Creating conda environment
|
11 |
+
|
12 |
+
```bash
|
13 |
+
conda create -n lavis python=3.8
|
14 |
+
conda activate lavis
|
15 |
+
```
|
16 |
+
|
17 |
+
2. for development, you may build from source
|
18 |
+
|
19 |
+
```bash
|
20 |
+
git clone https://github.com/xiangwenkai/FAPM.git
|
21 |
+
cd FAPM
|
22 |
+
pip install -e .
|
23 |
+
|
24 |
+
pip install Biopython
|
25 |
+
pip install fair-esm
|
26 |
+
```
|
27 |
+
|
28 |
+
### Datasets
|
29 |
+
#### 1.raw dataset
|
30 |
+
Raw data are avaliable at *https://ftp.uniprot.org/pub/databases/uniprot/previous_releases/release-2023_04/knowledgebase/*, this file is very large and need to be processed to get its name, sequence, GO label, function description and prompt.
|
31 |
+
The domain level protein dataset we used are avaliable at *https://ftp.ebi.ac.uk/pub/databases/interpro/releases/95.0/protein2ipr.dat.gz*
|
32 |
+
In this respository, We provide the experimental train/val/test sets of Swiss-Prot, which are avaliable at data/swissprot_exp
|
33 |
+
#### 2.ESM2 embeddings
|
34 |
+
Source code for ESM2 embeddings generation: *https://github.com/facebookresearch/esm*
|
35 |
+
The generation command:
|
36 |
+
```bash
|
37 |
+
python esm_scripts/extract.py esm2_t33_3B_UR50D you_path/protein.fasta you_path_to_save_embedding_files --repr_layers 36 --truncation_seq_length 1024 --include per_tok
|
38 |
+
```
|
39 |
+
The default path to save embedding files in this respository is **data/emb_esm2_3b**
|
40 |
+
|
41 |
+
## Pretraining language models
|
42 |
+
Source: *https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B*
|
43 |
+
|
44 |
+
## Training
|
45 |
+
data config: lavis/configs/datasets/protein/GO_defaults_cap.yaml
|
46 |
+
stage1 config: lavis/projects/blip2/train/protein_pretrain_stage1.yaml
|
47 |
+
stage1 training command: run_scripts/blip2/train/protein_pretrain_domain_stage1.sh
|
48 |
+
stage2 config: lavis/projects/blip2/train/protein_pretrain_stage2.yaml
|
49 |
+
stage2 training/finetuning command: run_scripts/blip2/train/protein_pretrain_domain_stage2.sh
|
50 |
+
|
51 |
+
## Trained models
|
52 |
+
You can download our trained models from drive: *https://drive.google.com/drive/folders/1aA0eSYxNw3DvrU5GU1Cu-4q2kIxxAGSE?usp=drive_link*
|
53 |
+
|
54 |
+
## Testing
|
55 |
+
config: lavis/projects/blip2/eval/caption_protein_eval.yaml
|
56 |
+
command: run_scripts/blip2/eval/eval_cap_protein.sh
|
57 |
+
|
58 |
+
## Inference example
|
59 |
+
```
|
60 |
+
python FAPM_inference.py \
|
61 |
+
--model_path model/checkpoint_mf2.pth \
|
62 |
+
--example_path data/emb_esm2_3b/P18281.pt \
|
63 |
+
--device cuda \
|
64 |
+
--prompt Acanthamoeba
|
65 |
+
```
|
66 |
+
|
67 |
+
|
68 |
+
|
69 |
+
|
70 |
+
|
71 |
+
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
|