|
|
|
|
|
""" |
|
Created on Fri Jun 16 14:27:44 2023 |
|
|
|
@author: mheinzinger |
|
""" |
|
|
|
import argparse |
|
import time |
|
from pathlib import Path |
|
|
|
from urllib import request |
|
import shutil |
|
|
|
import numpy as np |
|
import torch |
|
from torch import nn |
|
from transformers import T5EncoderModel, T5Tokenizer |
|
|
|
|
|
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') |
|
print("Using device: {}".format(device)) |
|
|
|
|
|
|
|
class CNN(nn.Module): |
|
def __init__( self ): |
|
super(CNN, self).__init__() |
|
|
|
self.classifier = nn.Sequential( |
|
nn.Conv2d(1024, 32, kernel_size=(7, 1), padding=(3, 0)), |
|
nn.ReLU(), |
|
nn.Dropout(0.0), |
|
nn.Conv2d(32, 20, kernel_size=(7, 1), padding=(3, 0)) |
|
) |
|
|
|
def forward(self, x): |
|
""" |
|
L = protein length |
|
B = batch-size |
|
F = number of features (1024 for embeddings) |
|
N = number of classes (20 for 3Di) |
|
""" |
|
x = x.permute(0, 2, 1).unsqueeze(dim=-1) |
|
Yhat = self.classifier(x) |
|
Yhat = Yhat.squeeze(dim=-1) |
|
return Yhat |
|
|
|
def one_hot_3di_sequence(sequence, vocab): |
|
foldseek_enc = torch.zeros( |
|
len(sequence), len(vocab), dtype=torch.float32 |
|
) |
|
for i, a in enumerate(sequence): |
|
assert a in vocab |
|
foldseek_enc[i, vocab[a]] = 1 |
|
return foldseek_enc.unsqueeze(0) |
|
|
|
|
|
def get_T5_model(model_dir): |
|
print("Loading T5 from: {}".format(model_dir)) |
|
model = T5EncoderModel.from_pretrained(model_dir).to(device) |
|
model = model.eval() |
|
vocab = T5Tokenizer.from_pretrained(model_dir, do_lower_case=False ) |
|
return model, vocab |
|
|
|
|
|
def read_fasta( fasta_path, split_char, id_field ): |
|
''' |
|
Reads in fasta file containing multiple sequences. |
|
Returns dictionary of holding multiple sequences or only single |
|
sequence, depending on input file. |
|
''' |
|
|
|
sequences = dict() |
|
with open( fasta_path, 'r' ) as fasta_f: |
|
for line in fasta_f: |
|
|
|
if line.startswith('>'): |
|
uniprot_id = line.replace('>', '').strip().split(split_char)[id_field] |
|
|
|
uniprot_id = uniprot_id.replace("/","_").replace(".","_") |
|
sequences[ uniprot_id ] = '' |
|
else: |
|
s = ''.join( line.split() ).replace("-","") |
|
|
|
if s.islower(): |
|
print("The input file was in lower-case which indicates 3Di-input." + |
|
"This predictor only operates on amino-acid-input (upper-case)." + |
|
"Exiting now ..." |
|
) |
|
return None |
|
else: |
|
sequences[ uniprot_id ] += s |
|
return sequences |
|
|
|
def write_predictions(predictions, out_path): |
|
ss_mapping = { |
|
0: "A", |
|
1: "C", |
|
2: "D", |
|
3: "E", |
|
4: "F", |
|
5: "G", |
|
6: "H", |
|
7: "I", |
|
8: "K", |
|
9: "L", |
|
10: "M", |
|
11: "N", |
|
12: "P", |
|
13: "Q", |
|
14: "R", |
|
15: "S", |
|
16: "T", |
|
17: "V", |
|
18: "W", |
|
19: "Y" |
|
} |
|
|
|
with open(out_path, 'w+') as out_f: |
|
out_f.write( '\n'.join( |
|
[ ">{}\n{}".format( |
|
seq_id, "".join(list(map(lambda yhat: ss_mapping[int(yhat)], yhats))) ) |
|
for seq_id, yhats in predictions.items() |
|
] |
|
) ) |
|
print(f"Finished writing results to {out_path}") |
|
return None |
|
|
|
def predictions_to_dict(predictions): |
|
ss_mapping = { |
|
0: "A", |
|
1: "C", |
|
2: "D", |
|
3: "E", |
|
4: "F", |
|
5: "G", |
|
6: "H", |
|
7: "I", |
|
8: "K", |
|
9: "L", |
|
10: "M", |
|
11: "N", |
|
12: "P", |
|
13: "Q", |
|
14: "R", |
|
15: "S", |
|
16: "T", |
|
17: "V", |
|
18: "W", |
|
19: "Y" |
|
} |
|
|
|
results = {seq_id: "".join(list(map(lambda yhat: ss_mapping[int(yhat)], yhats))) for seq_id, yhats in predictions.items()} |
|
return results |
|
|
|
def toCPU(tensor): |
|
if len(tensor.shape) > 1: |
|
return tensor.detach().cpu().squeeze(dim=-1).numpy() |
|
else: |
|
return tensor.detach().cpu().numpy() |
|
|
|
|
|
def download_file(url,local_path): |
|
if not local_path.parent.is_dir(): |
|
local_path.parent.mkdir() |
|
|
|
print("Downloading: {}".format(url)) |
|
req = request.Request(url, headers={ |
|
'User-Agent' : 'Mozilla/5.0 (Windows NT 6.1; Win64; x64)' |
|
}) |
|
|
|
with request.urlopen(req) as response, open(local_path, 'wb') as outfile: |
|
shutil.copyfileobj(response, outfile) |
|
return None |
|
|
|
|
|
def load_predictor( weights_link="https://rostlab.org/~deepppi/prostt5/cnn_chkpnt/model.pt" , device=torch.device("cpu")): |
|
model = CNN() |
|
checkpoint_p = Path.cwd() / "cnn_chkpnt" / "model.pt" |
|
|
|
if not checkpoint_p.exists(): |
|
download_file(weights_link, checkpoint_p) |
|
|
|
state = torch.load(checkpoint_p, map_location=device) |
|
|
|
model.load_state_dict(state["state_dict"]) |
|
|
|
model = model.eval() |
|
model = model.to(device) |
|
|
|
return model |
|
|
|
|
|
def get_3di_sequences( seq_dict, model_dir, device, |
|
max_residues=4000, max_seq_len=1000, max_batch=100,report_fn=print,error_fn=print,half_precision=False): |
|
|
|
predictions = dict() |
|
|
|
prefix = "<AA2fold>" |
|
|
|
model, vocab = get_T5_model(model_dir) |
|
predictor = load_predictor(device=device) |
|
|
|
if half_precision: |
|
model = model.half() |
|
predictor = predictor.half() |
|
|
|
report_fn('Total number of sequences: {}'.format(len(seq_dict))) |
|
|
|
avg_length = sum([ len(seq) for _, seq in seq_dict.items()]) / len(seq_dict) |
|
n_long = sum([ 1 for _, seq in seq_dict.items() if len(seq)>max_seq_len]) |
|
|
|
seq_dict = sorted( seq_dict.items(), key=lambda kv: len( seq_dict[kv[0]] ), reverse=True ) |
|
|
|
report_fn("Average sequence length: {}".format(avg_length)) |
|
report_fn("Number of sequences >{}: {}".format(max_seq_len, n_long)) |
|
|
|
start = time.time() |
|
batch = list() |
|
for seq_idx, (pdb_id, seq) in enumerate(seq_dict,1): |
|
|
|
seq = seq.replace('U','X').replace('Z','X').replace('O','X') |
|
seq_len = len(seq) |
|
seq = prefix + ' ' + ' '.join(list(seq)) |
|
batch.append((pdb_id,seq,seq_len)) |
|
|
|
|
|
|
|
n_res_batch = sum([ s_len for _, _, s_len in batch ]) + seq_len |
|
if len(batch) >= max_batch or n_res_batch>=max_residues or seq_idx==len(seq_dict) or seq_len>max_seq_len: |
|
pdb_ids, seqs, seq_lens = zip(*batch) |
|
batch = list() |
|
|
|
token_encoding = vocab.batch_encode_plus(seqs, |
|
add_special_tokens=True, |
|
padding="longest", |
|
return_tensors='pt' |
|
).to(device) |
|
try: |
|
with torch.no_grad(): |
|
embedding_repr = model(token_encoding.input_ids, |
|
attention_mask=token_encoding.attention_mask |
|
) |
|
except RuntimeError: |
|
error_fn("RuntimeError during embedding for {} (L={})".format( |
|
pdb_id, seq_len) |
|
) |
|
continue |
|
|
|
|
|
|
|
for idx, s_len in enumerate(seq_lens): |
|
token_encoding.attention_mask[idx,s_len+1] = 0 |
|
|
|
|
|
residue_embedding = embedding_repr.last_hidden_state.detach() |
|
|
|
residue_embedding = residue_embedding*token_encoding.attention_mask.unsqueeze(dim=-1) |
|
|
|
residue_embedding = residue_embedding[:,1:] |
|
|
|
prediction = predictor(residue_embedding) |
|
prediction = toCPU(torch.max( prediction, dim=1, keepdim=True )[1] ).astype(np.byte) |
|
|
|
|
|
|
|
for batch_idx, identifier in enumerate(pdb_ids): |
|
s_len = seq_lens[batch_idx] |
|
|
|
predictions[identifier] = prediction[batch_idx,:, 0:s_len].squeeze() |
|
assert s_len == len(predictions[identifier]), error_fn(f"Length mismatch for {identifier}: is:{len(predictions[identifier])} vs should:{s_len}") |
|
|
|
end = time.time() |
|
report_fn('Total number of predictions: {}'.format(len(predictions))) |
|
report_fn('Total time: {:.2f}[s]; time/prot: {:.4f}[s]; avg. len= {:.2f}'.format( |
|
end-start, (end-start)/len(predictions), avg_length)) |
|
|
|
return predictions |
|
|
|
|
|
def create_arg_parser(): |
|
""""Creates and returns the ArgumentParser object.""" |
|
|
|
|
|
parser = argparse.ArgumentParser(description=( |
|
'embed.py creates ProstT5-Encoder embeddings for a given text '+ |
|
' file containing sequence(s) in FASTA-format.' + |
|
'Example: python predict_3Di.py --input /path/to/some_AA_sequences.fasta --output /path/to/some_3Di_sequences.fasta --half 1' ) ) |
|
|
|
|
|
parser.add_argument( '-i', '--input', required=True, type=str, |
|
help='A path to a fasta-formatted text file containing protein sequence(s).') |
|
|
|
|
|
parser.add_argument( '-o', '--output', required=True, type=str, |
|
help='A path for saving the created embeddings as NumPy npz file.') |
|
|
|
|
|
|
|
parser.add_argument('--model', required=False, type=str, |
|
default="Rostlab/ProstT5", |
|
help='Either a path to a directory holding the checkpoint for a pre-trained model or a huggingface repository link.' ) |
|
|
|
|
|
parser.add_argument('--split_char', type=str, |
|
default='!', |
|
help='The character for splitting the FASTA header in order to retrieve ' + |
|
"the protein identifier. Should be used in conjunction with --id." + |
|
"Default: '!' ") |
|
|
|
|
|
parser.add_argument('--id', type=int, |
|
default=0, |
|
help='The index for the uniprot identifier field after splitting the ' + |
|
"FASTA header after each symbole in ['|', '#', ':', ' ']." + |
|
'Default: 0') |
|
|
|
parser.add_argument('--half', type=int, |
|
default=1, |
|
help="Whether to use half_precision or not. Default: 1 (half-precision)") |
|
|
|
return parser |
|
|
|
def main(): |
|
parser = create_arg_parser() |
|
args = parser.parse_args() |
|
|
|
seq_path = Path( args.input ) |
|
out_path = Path( args.output) |
|
model_dir = args.model |
|
|
|
if out_path.is_file(): |
|
print("Output file is already existing and will be overwritten ...") |
|
|
|
split_char = args.split_char |
|
id_field = args.id |
|
|
|
half_precision = False if int(args.half) == 0 else True |
|
assert not (half_precision and device=="cpu"), print("Running fp16 on CPU is not supported, yet") |
|
|
|
seq_dict = read_fasta( seq_path, split_char, id_field ) |
|
predictions = get_3di_sequences( |
|
seq_dict, |
|
model_dir, |
|
) |
|
|
|
print("Writing results now to disk ...") |
|
write_predictions(predictions,out_path) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |