Spaces:
Sleeping
Sleeping
import logging | |
import os | |
import sys | |
sys.path.append("../") | |
import torch | |
import torch.nn as nn | |
from torch.nn import functional as F | |
from torch.cuda.amp import autocast | |
from torch.nn import Module | |
from tqdm import tqdm | |
from torch.nn.utils.weight_norm import weight_norm | |
from torch.utils.data import Dataset | |
LOGGER = logging.getLogger(__name__) | |
class FusionDTI(nn.Module): | |
def __init__(self, prot_out_dim, disease_out_dim, args): | |
super(FusionDTI, self).__init__() | |
self.fusion = args.fusion | |
self.drug_reg = nn.Linear(disease_out_dim, 512) | |
self.prot_reg = nn.Linear(prot_out_dim, 512) | |
if self.fusion == "CAN": | |
self.can_layer = CAN_Layer(hidden_dim=512, num_heads=8, args=args) | |
self.mlp_classifier = MlPdecoder_CAN(input_dim=1024) | |
elif self.fusion == "BAN": | |
self.ban_layer = weight_norm(BANLayer(512, 512, 256, 2), name='h_mat', dim=None) | |
self.mlp_classifier = MlPdecoder_CAN(input_dim=256) | |
elif self.fusion == "Nan": | |
self.mlp_classifier_nan = MlPdecoder_CAN(input_dim=1214) | |
def forward(self, prot_embed, drug_embed, prot_mask, drug_mask): | |
# print("drug_embed", drug_embed.shape) | |
if self.fusion == "Nan": | |
prot_embed = prot_embed.mean(1) # query : [batch_size, hidden] | |
drug_embed = drug_embed.mean(1) # query : [batch_size, hidden] | |
joint_embed = torch.cat([prot_embed, drug_embed], dim=1) | |
score = self.mlp_classifier_nan(joint_embed) | |
else: | |
prot_embed = self.prot_reg(prot_embed) | |
drug_embed = self.drug_reg(drug_embed) | |
if self.fusion == "CAN": | |
joint_embed, att = self.can_layer(prot_embed, drug_embed, prot_mask, drug_mask) | |
elif self.fusion == "BAN": | |
joint_embed, att = self.ban_layer(prot_embed, drug_embed) | |
score = self.mlp_classifier(joint_embed) | |
return score, att | |
class Pre_encoded(nn.Module): | |
def __init__( | |
self, prot_encoder, drug_encoder, args | |
): | |
"""Constructor for the model. | |
Args: | |
prot_encoder (_type_): Protein sturcture-aware sequence encoder. | |
drug_encoder (_type_): Drug SFLFIES encoder. | |
args (_type_): _description_ | |
""" | |
super(Pre_encoded, self).__init__() | |
self.prot_encoder = prot_encoder | |
self.drug_encoder = drug_encoder | |
def encoding(self, prot_input_ids, prot_attention_mask, drug_input_ids, drug_attention_mask): | |
# Process inputs through encoders | |
prot_embed = self.prot_encoder( | |
input_ids=prot_input_ids, attention_mask=prot_attention_mask, return_dict=True | |
).logits | |
# prot_embed = self.prot_reg(prot_embed) | |
drug_embed = self.drug_encoder( | |
input_ids=drug_input_ids, attention_mask=drug_attention_mask, return_dict=True | |
).last_hidden_state # .last_hidden_state | |
# print("drug_embed", drug_embed.shape) | |
return prot_embed, drug_embed | |
class CAN_Layer(nn.Module): | |
def __init__(self, hidden_dim, num_heads, args): | |
super(CAN_Layer, self).__init__() | |
self.agg_mode = args.agg_mode | |
self.group_size = args.group_size # Control Fusion Scale | |
self.hidden_dim = hidden_dim | |
self.num_heads = num_heads | |
self.head_size = hidden_dim // num_heads | |
self.query_p = nn.Linear(hidden_dim, hidden_dim, bias=False) | |
self.key_p = nn.Linear(hidden_dim, hidden_dim, bias=False) | |
self.value_p = nn.Linear(hidden_dim, hidden_dim, bias=False) | |
self.query_d = nn.Linear(hidden_dim, hidden_dim, bias=False) | |
self.key_d = nn.Linear(hidden_dim, hidden_dim, bias=False) | |
self.value_d = nn.Linear(hidden_dim, hidden_dim, bias=False) | |
def alpha_logits(self, logits, mask_row, mask_col, inf=1e6): | |
N, L1, L2, H = logits.shape | |
mask_row = mask_row.view(N, L1, 1).repeat(1, 1, H) | |
mask_col = mask_col.view(N, L2, 1).repeat(1, 1, H) | |
mask_pair = torch.einsum('blh, bkh->blkh', mask_row, mask_col) | |
logits = torch.where(mask_pair, logits, logits - inf) | |
alpha = torch.softmax(logits, dim=2) | |
mask_row = mask_row.view(N, L1, 1, H).repeat(1, 1, L2, 1) | |
alpha = torch.where(mask_row, alpha, torch.zeros_like(alpha)) | |
return alpha | |
def apply_heads(self, x, n_heads, n_ch): | |
s = list(x.size())[:-1] + [n_heads, n_ch] | |
return x.view(*s) | |
def group_embeddings(self, x, mask, group_size): | |
N, L, D = x.shape | |
groups = L // group_size | |
x_grouped = x.view(N, groups, group_size, D).mean(dim=2) | |
mask_grouped = mask.view(N, groups, group_size).any(dim=2) | |
return x_grouped, mask_grouped | |
def forward(self, protein, drug, mask_prot, mask_drug): | |
# Group embeddings before applying multi-head attention | |
protein_grouped, mask_prot_grouped = self.group_embeddings(protein, mask_prot, self.group_size) | |
drug_grouped, mask_drug_grouped = self.group_embeddings(drug, mask_drug, self.group_size) | |
# print("protein_grouped:", protein_grouped.shape) | |
# print("mask_prot_grouped:", mask_prot_grouped.shape) | |
# Compute queries, keys, values for both protein and drug after grouping | |
query_prot = self.apply_heads(self.query_p(protein_grouped), self.num_heads, self.head_size) | |
key_prot = self.apply_heads(self.key_p(protein_grouped), self.num_heads, self.head_size) | |
value_prot = self.apply_heads(self.value_p(protein_grouped), self.num_heads, self.head_size) | |
query_drug = self.apply_heads(self.query_d(drug_grouped), self.num_heads, self.head_size) | |
key_drug = self.apply_heads(self.key_d(drug_grouped), self.num_heads, self.head_size) | |
value_drug = self.apply_heads(self.value_d(drug_grouped), self.num_heads, self.head_size) | |
# Compute attention scores | |
logits_pp = torch.einsum('blhd, bkhd->blkh', query_prot, key_prot) | |
logits_pd = torch.einsum('blhd, bkhd->blkh', query_prot, key_drug) | |
logits_dp = torch.einsum('blhd, bkhd->blkh', query_drug, key_prot) | |
logits_dd = torch.einsum('blhd, bkhd->blkh', query_drug, key_drug) | |
# print("logits_pp:", logits_pp.shape) | |
alpha_pp = self.alpha_logits(logits_pp, mask_prot_grouped, mask_prot_grouped) | |
alpha_pd = self.alpha_logits(logits_pd, mask_prot_grouped, mask_drug_grouped) | |
alpha_dp = self.alpha_logits(logits_dp, mask_drug_grouped, mask_prot_grouped) | |
alpha_dd = self.alpha_logits(logits_dd, mask_drug_grouped, mask_drug_grouped) | |
prot_embedding = (torch.einsum('blkh, bkhd->blhd', alpha_pp, value_prot).flatten(-2) + | |
torch.einsum('blkh, bkhd->blhd', alpha_pd, value_drug).flatten(-2)) / 2 | |
drug_embedding = (torch.einsum('blkh, bkhd->blhd', alpha_dp, value_prot).flatten(-2) + | |
torch.einsum('blkh, bkhd->blhd', alpha_dd, value_drug).flatten(-2)) / 2 | |
# print("prot_embedding:", prot_embedding.shape) | |
# Continue as usual with the aggregation mode | |
if self.agg_mode == "cls": | |
prot_embed = prot_embedding[:, 0] # query : [batch_size, hidden] | |
drug_embed = drug_embedding[:, 0] # query : [batch_size, hidden] | |
elif self.agg_mode == "mean_all_tok": | |
prot_embed = prot_embedding.mean(1) # query : [batch_size, hidden] | |
drug_embed = drug_embedding.mean(1) # query : [batch_size, hidden] | |
elif self.agg_mode == "mean": | |
prot_embed = (prot_embedding * mask_prot_grouped.unsqueeze(-1)).sum(1) / mask_prot_grouped.sum(-1).unsqueeze(-1) | |
drug_embed = (drug_embedding * mask_drug_grouped.unsqueeze(-1)).sum(1) / mask_drug_grouped.sum(-1).unsqueeze(-1) | |
else: | |
raise NotImplementedError() | |
# print("prot_embed:", prot_embed.shape) | |
query_embed = torch.cat([prot_embed, drug_embed], dim=1) | |
att = torch.zeros(1, 1, 1024, 1024) | |
att[:, :, :512, :512] = alpha_pp.mean(dim=-1) # Protein to Protein | |
att[:, :, :512, 512:] = alpha_pd.mean(dim=-1) # Protein to Drug | |
att[:, :, 512:, :512] = alpha_dp.mean(dim=-1) # Drug to Protein | |
att[:, :, 512:, 512:] = alpha_dd.mean(dim=-1) # Drug to Drug | |
# print("query_embed:", query_embed.shape) | |
return query_embed, att | |
class MlPdecoder_CAN(nn.Module): | |
def __init__(self, input_dim): | |
super(MlPdecoder_CAN, self).__init__() | |
self.fc1 = nn.Linear(input_dim, input_dim) | |
self.bn1 = nn.BatchNorm1d(input_dim) | |
self.fc2 = nn.Linear(input_dim, input_dim // 2) | |
self.bn2 = nn.BatchNorm1d(input_dim // 2) | |
self.fc3 = nn.Linear(input_dim // 2, input_dim // 4) | |
self.bn3 = nn.BatchNorm1d(input_dim // 4) | |
self.output = nn.Linear(input_dim // 4, 1) | |
def forward(self, x): | |
x = self.bn1(torch.relu(self.fc1(x))) | |
x = self.bn2(torch.relu(self.fc2(x))) | |
x = self.bn3(torch.relu(self.fc3(x))) | |
x = torch.sigmoid(self.output(x)) | |
return x | |
class MLPdecoder_BAN(nn.Module): | |
def __init__(self, in_dim, hidden_dim, out_dim, binary=1): | |
super(MLPdecoder_BAN, self).__init__() | |
self.fc1 = nn.Linear(in_dim, hidden_dim) | |
self.bn1 = nn.BatchNorm1d(hidden_dim) | |
self.fc2 = nn.Linear(hidden_dim, hidden_dim) | |
self.bn2 = nn.BatchNorm1d(hidden_dim) | |
self.fc3 = nn.Linear(hidden_dim, out_dim) | |
self.bn3 = nn.BatchNorm1d(out_dim) | |
self.fc4 = nn.Linear(out_dim, binary) | |
def forward(self, x): | |
x = self.bn1(F.relu(self.fc1(x))) | |
x = self.bn2(F.relu(self.fc2(x))) | |
x = self.bn3(F.relu(self.fc3(x))) | |
# x = self.fc4(x) | |
x = torch.sigmoid(self.fc4(x)) | |
return x | |
class BANLayer(nn.Module): | |
""" Bilinear attention network | |
Modified from https://github.com/peizhenbai/DrugBAN/blob/main/ban.py | |
""" | |
def __init__(self, v_dim, q_dim, h_dim, h_out, act='ReLU', dropout=0.2, k=3): | |
super(BANLayer, self).__init__() | |
self.c = 32 | |
self.k = k | |
self.v_dim = v_dim | |
self.q_dim = q_dim | |
self.h_dim = h_dim | |
self.h_out = h_out | |
self.v_net = FCNet([v_dim, h_dim * self.k], act=act, dropout=dropout) | |
self.q_net = FCNet([q_dim, h_dim * self.k], act=act, dropout=dropout) | |
# self.dropout = nn.Dropout(dropout[1]) | |
if 1 < k: | |
self.p_net = nn.AvgPool1d(self.k, stride=self.k) | |
if h_out <= self.c: | |
self.h_mat = nn.Parameter(torch.Tensor(1, h_out, 1, h_dim * self.k).normal_()) | |
self.h_bias = nn.Parameter(torch.Tensor(1, h_out, 1, 1).normal_()) | |
else: | |
self.h_net = weight_norm(nn.Linear(h_dim * self.k, h_out), dim=None) | |
self.bn = nn.BatchNorm1d(h_dim) | |
def attention_pooling(self, v, q, att_map): | |
fusion_logits = torch.einsum('bvk,bvq,bqk->bk', (v, att_map, q)) | |
if 1 < self.k: | |
fusion_logits = fusion_logits.unsqueeze(1) # b x 1 x d | |
fusion_logits = self.p_net(fusion_logits).squeeze(1) * self.k # sum-pooling | |
return fusion_logits | |
def forward(self, v, q, softmax=False): | |
v_num = v.size(1) | |
q_num = q.size(1) | |
# print("v_num", v_num) | |
# print("v_num ", v_num) | |
if self.h_out <= self.c: | |
v_ = self.v_net(v) | |
q_ = self.q_net(q) | |
# print("v_", v_.shape) | |
# print("q_ ", q_.shape) | |
att_maps = torch.einsum('xhyk,bvk,bqk->bhvq', (self.h_mat, v_, q_)) + self.h_bias | |
# print("Attention map_1",att_maps.shape) | |
else: | |
v_ = self.v_net(v).transpose(1, 2).unsqueeze(3) | |
q_ = self.q_net(q).transpose(1, 2).unsqueeze(2) | |
d_ = torch.matmul(v_, q_) # b x h_dim x v x q | |
att_maps = self.h_net(d_.transpose(1, 2).transpose(2, 3)) # b x v x q x h_out | |
att_maps = att_maps.transpose(2, 3).transpose(1, 2) # b x h_out x v x q | |
# print("Attention map_2",att_maps.shape) | |
if softmax: | |
p = nn.functional.softmax(att_maps.view(-1, self.h_out, v_num * q_num), 2) | |
att_maps = p.view(-1, self.h_out, v_num, q_num) | |
# print("Attention map_softmax", att_maps.shape) | |
logits = self.attention_pooling(v_, q_, att_maps[:, 0, :, :]) | |
for i in range(1, self.h_out): | |
logits_i = self.attention_pooling(v_, q_, att_maps[:, i, :, :]) | |
logits += logits_i | |
logits = self.bn(logits) | |
return logits, att_maps | |
class FCNet(nn.Module): | |
"""Simple class for non-linear fully connect network | |
Modified from https://github.com/jnhwkim/ban-vqa/blob/master/fc.py | |
""" | |
def __init__(self, dims, act='ReLU', dropout=0): | |
super(FCNet, self).__init__() | |
layers = [] | |
for i in range(len(dims) - 2): | |
in_dim = dims[i] | |
out_dim = dims[i + 1] | |
if 0 < dropout: | |
layers.append(nn.Dropout(dropout)) | |
layers.append(weight_norm(nn.Linear(in_dim, out_dim), dim=None)) | |
if '' != act: | |
layers.append(getattr(nn, act)()) | |
if 0 < dropout: | |
layers.append(nn.Dropout(dropout)) | |
layers.append(weight_norm(nn.Linear(dims[-2], dims[-1]), dim=None)) | |
if '' != act: | |
layers.append(getattr(nn, act)()) | |
self.main = nn.Sequential(*layers) | |
def forward(self, x): | |
return self.main(x) | |
class BatchFileDataset_Case(Dataset): | |
def __init__(self, file_list): | |
self.file_list = file_list | |
def __len__(self): | |
return len(self.file_list) | |
def __getitem__(self, idx): | |
batch_file = self.file_list[idx] | |
data = torch.load(batch_file) | |
return data['prot'], data['drug'], data['prot_ids'], data['drug_ids'], data['prot_mask'], data['drug_mask'], data['y'] |