FusionDTI / utils /metric_learning_models_att_maps.py
ZhaohanM
Initial commit
0312a01
import logging
import os
import sys
sys.path.append("../")
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.cuda.amp import autocast
from torch.nn import Module
from tqdm import tqdm
from torch.nn.utils.weight_norm import weight_norm
from torch.utils.data import Dataset
LOGGER = logging.getLogger(__name__)
class FusionDTI(nn.Module):
def __init__(self, prot_out_dim, disease_out_dim, args):
super(FusionDTI, self).__init__()
self.fusion = args.fusion
self.drug_reg = nn.Linear(disease_out_dim, 512)
self.prot_reg = nn.Linear(prot_out_dim, 512)
if self.fusion == "CAN":
self.can_layer = CAN_Layer(hidden_dim=512, num_heads=8, args=args)
self.mlp_classifier = MlPdecoder_CAN(input_dim=1024)
elif self.fusion == "BAN":
self.ban_layer = weight_norm(BANLayer(512, 512, 256, 2), name='h_mat', dim=None)
self.mlp_classifier = MlPdecoder_CAN(input_dim=256)
elif self.fusion == "Nan":
self.mlp_classifier_nan = MlPdecoder_CAN(input_dim=1214)
def forward(self, prot_embed, drug_embed, prot_mask, drug_mask):
# print("drug_embed", drug_embed.shape)
if self.fusion == "Nan":
prot_embed = prot_embed.mean(1) # query : [batch_size, hidden]
drug_embed = drug_embed.mean(1) # query : [batch_size, hidden]
joint_embed = torch.cat([prot_embed, drug_embed], dim=1)
score = self.mlp_classifier_nan(joint_embed)
else:
prot_embed = self.prot_reg(prot_embed)
drug_embed = self.drug_reg(drug_embed)
if self.fusion == "CAN":
joint_embed, att = self.can_layer(prot_embed, drug_embed, prot_mask, drug_mask)
elif self.fusion == "BAN":
joint_embed, att = self.ban_layer(prot_embed, drug_embed)
score = self.mlp_classifier(joint_embed)
return score, att
class Pre_encoded(nn.Module):
def __init__(
self, prot_encoder, drug_encoder, args
):
"""Constructor for the model.
Args:
prot_encoder (_type_): Protein sturcture-aware sequence encoder.
drug_encoder (_type_): Drug SFLFIES encoder.
args (_type_): _description_
"""
super(Pre_encoded, self).__init__()
self.prot_encoder = prot_encoder
self.drug_encoder = drug_encoder
def encoding(self, prot_input_ids, prot_attention_mask, drug_input_ids, drug_attention_mask):
# Process inputs through encoders
prot_embed = self.prot_encoder(
input_ids=prot_input_ids, attention_mask=prot_attention_mask, return_dict=True
).logits
# prot_embed = self.prot_reg(prot_embed)
drug_embed = self.drug_encoder(
input_ids=drug_input_ids, attention_mask=drug_attention_mask, return_dict=True
).last_hidden_state # .last_hidden_state
# print("drug_embed", drug_embed.shape)
return prot_embed, drug_embed
class CAN_Layer(nn.Module):
def __init__(self, hidden_dim, num_heads, args):
super(CAN_Layer, self).__init__()
self.agg_mode = args.agg_mode
self.group_size = args.group_size # Control Fusion Scale
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.head_size = hidden_dim // num_heads
self.query_p = nn.Linear(hidden_dim, hidden_dim, bias=False)
self.key_p = nn.Linear(hidden_dim, hidden_dim, bias=False)
self.value_p = nn.Linear(hidden_dim, hidden_dim, bias=False)
self.query_d = nn.Linear(hidden_dim, hidden_dim, bias=False)
self.key_d = nn.Linear(hidden_dim, hidden_dim, bias=False)
self.value_d = nn.Linear(hidden_dim, hidden_dim, bias=False)
def alpha_logits(self, logits, mask_row, mask_col, inf=1e6):
N, L1, L2, H = logits.shape
mask_row = mask_row.view(N, L1, 1).repeat(1, 1, H)
mask_col = mask_col.view(N, L2, 1).repeat(1, 1, H)
mask_pair = torch.einsum('blh, bkh->blkh', mask_row, mask_col)
logits = torch.where(mask_pair, logits, logits - inf)
alpha = torch.softmax(logits, dim=2)
mask_row = mask_row.view(N, L1, 1, H).repeat(1, 1, L2, 1)
alpha = torch.where(mask_row, alpha, torch.zeros_like(alpha))
return alpha
def apply_heads(self, x, n_heads, n_ch):
s = list(x.size())[:-1] + [n_heads, n_ch]
return x.view(*s)
def group_embeddings(self, x, mask, group_size):
N, L, D = x.shape
groups = L // group_size
x_grouped = x.view(N, groups, group_size, D).mean(dim=2)
mask_grouped = mask.view(N, groups, group_size).any(dim=2)
return x_grouped, mask_grouped
def forward(self, protein, drug, mask_prot, mask_drug):
# Group embeddings before applying multi-head attention
protein_grouped, mask_prot_grouped = self.group_embeddings(protein, mask_prot, self.group_size)
drug_grouped, mask_drug_grouped = self.group_embeddings(drug, mask_drug, self.group_size)
# print("protein_grouped:", protein_grouped.shape)
# print("mask_prot_grouped:", mask_prot_grouped.shape)
# Compute queries, keys, values for both protein and drug after grouping
query_prot = self.apply_heads(self.query_p(protein_grouped), self.num_heads, self.head_size)
key_prot = self.apply_heads(self.key_p(protein_grouped), self.num_heads, self.head_size)
value_prot = self.apply_heads(self.value_p(protein_grouped), self.num_heads, self.head_size)
query_drug = self.apply_heads(self.query_d(drug_grouped), self.num_heads, self.head_size)
key_drug = self.apply_heads(self.key_d(drug_grouped), self.num_heads, self.head_size)
value_drug = self.apply_heads(self.value_d(drug_grouped), self.num_heads, self.head_size)
# Compute attention scores
logits_pp = torch.einsum('blhd, bkhd->blkh', query_prot, key_prot)
logits_pd = torch.einsum('blhd, bkhd->blkh', query_prot, key_drug)
logits_dp = torch.einsum('blhd, bkhd->blkh', query_drug, key_prot)
logits_dd = torch.einsum('blhd, bkhd->blkh', query_drug, key_drug)
# print("logits_pp:", logits_pp.shape)
alpha_pp = self.alpha_logits(logits_pp, mask_prot_grouped, mask_prot_grouped)
alpha_pd = self.alpha_logits(logits_pd, mask_prot_grouped, mask_drug_grouped)
alpha_dp = self.alpha_logits(logits_dp, mask_drug_grouped, mask_prot_grouped)
alpha_dd = self.alpha_logits(logits_dd, mask_drug_grouped, mask_drug_grouped)
prot_embedding = (torch.einsum('blkh, bkhd->blhd', alpha_pp, value_prot).flatten(-2) +
torch.einsum('blkh, bkhd->blhd', alpha_pd, value_drug).flatten(-2)) / 2
drug_embedding = (torch.einsum('blkh, bkhd->blhd', alpha_dp, value_prot).flatten(-2) +
torch.einsum('blkh, bkhd->blhd', alpha_dd, value_drug).flatten(-2)) / 2
# print("prot_embedding:", prot_embedding.shape)
# Continue as usual with the aggregation mode
if self.agg_mode == "cls":
prot_embed = prot_embedding[:, 0] # query : [batch_size, hidden]
drug_embed = drug_embedding[:, 0] # query : [batch_size, hidden]
elif self.agg_mode == "mean_all_tok":
prot_embed = prot_embedding.mean(1) # query : [batch_size, hidden]
drug_embed = drug_embedding.mean(1) # query : [batch_size, hidden]
elif self.agg_mode == "mean":
prot_embed = (prot_embedding * mask_prot_grouped.unsqueeze(-1)).sum(1) / mask_prot_grouped.sum(-1).unsqueeze(-1)
drug_embed = (drug_embedding * mask_drug_grouped.unsqueeze(-1)).sum(1) / mask_drug_grouped.sum(-1).unsqueeze(-1)
else:
raise NotImplementedError()
# print("prot_embed:", prot_embed.shape)
query_embed = torch.cat([prot_embed, drug_embed], dim=1)
att = torch.zeros(1, 1, 1024, 1024)
att[:, :, :512, :512] = alpha_pp.mean(dim=-1) # Protein to Protein
att[:, :, :512, 512:] = alpha_pd.mean(dim=-1) # Protein to Drug
att[:, :, 512:, :512] = alpha_dp.mean(dim=-1) # Drug to Protein
att[:, :, 512:, 512:] = alpha_dd.mean(dim=-1) # Drug to Drug
# print("query_embed:", query_embed.shape)
return query_embed, att
class MlPdecoder_CAN(nn.Module):
def __init__(self, input_dim):
super(MlPdecoder_CAN, self).__init__()
self.fc1 = nn.Linear(input_dim, input_dim)
self.bn1 = nn.BatchNorm1d(input_dim)
self.fc2 = nn.Linear(input_dim, input_dim // 2)
self.bn2 = nn.BatchNorm1d(input_dim // 2)
self.fc3 = nn.Linear(input_dim // 2, input_dim // 4)
self.bn3 = nn.BatchNorm1d(input_dim // 4)
self.output = nn.Linear(input_dim // 4, 1)
def forward(self, x):
x = self.bn1(torch.relu(self.fc1(x)))
x = self.bn2(torch.relu(self.fc2(x)))
x = self.bn3(torch.relu(self.fc3(x)))
x = torch.sigmoid(self.output(x))
return x
class MLPdecoder_BAN(nn.Module):
def __init__(self, in_dim, hidden_dim, out_dim, binary=1):
super(MLPdecoder_BAN, self).__init__()
self.fc1 = nn.Linear(in_dim, hidden_dim)
self.bn1 = nn.BatchNorm1d(hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.bn2 = nn.BatchNorm1d(hidden_dim)
self.fc3 = nn.Linear(hidden_dim, out_dim)
self.bn3 = nn.BatchNorm1d(out_dim)
self.fc4 = nn.Linear(out_dim, binary)
def forward(self, x):
x = self.bn1(F.relu(self.fc1(x)))
x = self.bn2(F.relu(self.fc2(x)))
x = self.bn3(F.relu(self.fc3(x)))
# x = self.fc4(x)
x = torch.sigmoid(self.fc4(x))
return x
class BANLayer(nn.Module):
""" Bilinear attention network
Modified from https://github.com/peizhenbai/DrugBAN/blob/main/ban.py
"""
def __init__(self, v_dim, q_dim, h_dim, h_out, act='ReLU', dropout=0.2, k=3):
super(BANLayer, self).__init__()
self.c = 32
self.k = k
self.v_dim = v_dim
self.q_dim = q_dim
self.h_dim = h_dim
self.h_out = h_out
self.v_net = FCNet([v_dim, h_dim * self.k], act=act, dropout=dropout)
self.q_net = FCNet([q_dim, h_dim * self.k], act=act, dropout=dropout)
# self.dropout = nn.Dropout(dropout[1])
if 1 < k:
self.p_net = nn.AvgPool1d(self.k, stride=self.k)
if h_out <= self.c:
self.h_mat = nn.Parameter(torch.Tensor(1, h_out, 1, h_dim * self.k).normal_())
self.h_bias = nn.Parameter(torch.Tensor(1, h_out, 1, 1).normal_())
else:
self.h_net = weight_norm(nn.Linear(h_dim * self.k, h_out), dim=None)
self.bn = nn.BatchNorm1d(h_dim)
def attention_pooling(self, v, q, att_map):
fusion_logits = torch.einsum('bvk,bvq,bqk->bk', (v, att_map, q))
if 1 < self.k:
fusion_logits = fusion_logits.unsqueeze(1) # b x 1 x d
fusion_logits = self.p_net(fusion_logits).squeeze(1) * self.k # sum-pooling
return fusion_logits
def forward(self, v, q, softmax=False):
v_num = v.size(1)
q_num = q.size(1)
# print("v_num", v_num)
# print("v_num ", v_num)
if self.h_out <= self.c:
v_ = self.v_net(v)
q_ = self.q_net(q)
# print("v_", v_.shape)
# print("q_ ", q_.shape)
att_maps = torch.einsum('xhyk,bvk,bqk->bhvq', (self.h_mat, v_, q_)) + self.h_bias
# print("Attention map_1",att_maps.shape)
else:
v_ = self.v_net(v).transpose(1, 2).unsqueeze(3)
q_ = self.q_net(q).transpose(1, 2).unsqueeze(2)
d_ = torch.matmul(v_, q_) # b x h_dim x v x q
att_maps = self.h_net(d_.transpose(1, 2).transpose(2, 3)) # b x v x q x h_out
att_maps = att_maps.transpose(2, 3).transpose(1, 2) # b x h_out x v x q
# print("Attention map_2",att_maps.shape)
if softmax:
p = nn.functional.softmax(att_maps.view(-1, self.h_out, v_num * q_num), 2)
att_maps = p.view(-1, self.h_out, v_num, q_num)
# print("Attention map_softmax", att_maps.shape)
logits = self.attention_pooling(v_, q_, att_maps[:, 0, :, :])
for i in range(1, self.h_out):
logits_i = self.attention_pooling(v_, q_, att_maps[:, i, :, :])
logits += logits_i
logits = self.bn(logits)
return logits, att_maps
class FCNet(nn.Module):
"""Simple class for non-linear fully connect network
Modified from https://github.com/jnhwkim/ban-vqa/blob/master/fc.py
"""
def __init__(self, dims, act='ReLU', dropout=0):
super(FCNet, self).__init__()
layers = []
for i in range(len(dims) - 2):
in_dim = dims[i]
out_dim = dims[i + 1]
if 0 < dropout:
layers.append(nn.Dropout(dropout))
layers.append(weight_norm(nn.Linear(in_dim, out_dim), dim=None))
if '' != act:
layers.append(getattr(nn, act)())
if 0 < dropout:
layers.append(nn.Dropout(dropout))
layers.append(weight_norm(nn.Linear(dims[-2], dims[-1]), dim=None))
if '' != act:
layers.append(getattr(nn, act)())
self.main = nn.Sequential(*layers)
def forward(self, x):
return self.main(x)
class BatchFileDataset_Case(Dataset):
def __init__(self, file_list):
self.file_list = file_list
def __len__(self):
return len(self.file_list)
def __getitem__(self, idx):
batch_file = self.file_list[idx]
data = torch.load(batch_file)
return data['prot'], data['drug'], data['prot_ids'], data['drug_ids'], data['prot_mask'], data['drug_mask'], data['y']