File size: 6,145 Bytes
9c0f93c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
#! python3
# -*- encoding: utf-8 -*-
import torch
import torch.nn.functional as F
import pandas as pd
import sys
import os
from transformers.utils.hub import cached_file
resolved_module_file = cached_file(
'JunhongLou/G2PTL',
'htc_mask_dict.pkl',
)
htc_weights = [0.067, 0.133, 0.2, 0.267, 0.333]
htc_mask_dict = pd.read_pickle(resolved_module_file)
import numpy as np
import operator
def calculate_multi_htc_acc_batch(predicted_htc, y, sequence_len = 6):
acc_cnt = np.array([0, 0, 0, 0, 0])
y = y.view(-1, sequence_len, 5).tolist()
predicted = np.array(predicted_htc).reshape(-1, sequence_len, 5).tolist()
batch_size = len(y)
total_cnt = np.array([0, 0, 0, 0, 0])
for batch_i in range(batch_size):
for index, s2 in enumerate(y[batch_i]):
for c, i in enumerate(range(5)):
y_l10 = y[batch_i][index][:i+1]
p_l10 = predicted[batch_i][index][:i+1]
if -100 in y_l10:
break
if operator.eq(y_l10, p_l10):
acc_cnt[c] += 1
total_cnt[c] += 1
return acc_cnt, total_cnt
class HTCLoss(torch.nn.Module):
def __init__(self, device, reduction='mean', using_htc = True):
super(HTCLoss, self).__init__()
self.reduction = reduction
self.htc_weights = htc_weights
self.device = device
self.using_htc = using_htc
self.htc_mask_dict = htc_mask_dict
for key, value in self.htc_mask_dict.items():
self.htc_mask_dict[key] = torch.tensor(value).clone().detach().to(self.device)
def forward(self, logits, target):
target = target.reshape(-1, 1)
target_mask = target != -100
target_mask = target_mask.squeeze()
target_mask_idx = torch.where(target == -100)
target_new = target.clone()
target_new[target_mask_idx] = 0
predict_res = []
if not self.using_htc:
log_pro = -1.0 * F.log_softmax(logits, dim=1)
else:
logits_reshaped = logits.clone()
logits_reshaped = logits_reshaped.reshape(-1, 5, 100)
_, aa_predicted = torch.max(logits_reshaped[:,0,1:32], 1)
aa_predicted += 1
logits_new = -5 * torch.ones_like(logits_reshaped).to(self.device)
logits_new[:,0,1:32] = logits_reshaped[:,0,1:32]
for sample_idx, aa in enumerate(aa_predicted):
# Using mask_dict to get candidates for the next hierarchical
bb_idx = htc_mask_dict['{:02d}'.format(aa)]
_, bb_idy = torch.max(logits_reshaped[sample_idx,1,bb_idx], 0)
bb = bb_idx[bb_idy]
logits_new[sample_idx,1,bb_idx] = logits_reshaped[sample_idx,1,bb_idx]
cc_idx = htc_mask_dict['{:02d}{:02d}'.format(aa, bb)]
_, cc_idy = torch.max(logits_reshaped[sample_idx,2,cc_idx], 0)
logits_new[sample_idx,2,cc_idx] = logits_reshaped[sample_idx,2,cc_idx]
cc = cc_idx[cc_idy]
d_idx = htc_mask_dict['{:02d}{:02d}{:02d}'.format(aa, bb, cc)]
_, d_idy = torch.max(logits_reshaped[sample_idx,3,d_idx], 0)
logits_new[sample_idx,3,d_idx] = logits_reshaped[sample_idx,3,d_idx]
d = d_idx[d_idy]
ee_idx = htc_mask_dict['{:02d}{:02d}{:02d}{:01d}'.format(aa, bb, cc, d)]
_, ee_idy = torch.max(logits_reshaped[sample_idx,4,ee_idx], 0)
logits_new[sample_idx,4,ee_idx] = logits_reshaped[sample_idx,4,ee_idx]
ee = ee_idx[ee_idy]
predict_res.extend([aa.item(), bb.item(), cc.item(), d.item(), ee.item()])
logits_new = logits_new.reshape(-1, 100)
log_pro = -1.0 * F.log_softmax(logits_new, dim=1)
logits = logits.contiguous().view(-1, 100)
one_hot = torch.zeros(logits.shape[0], logits.shape[1]).to(self.device) # .cuda()
one_hot = one_hot.scatter_(1, target_new, 1)
loss = torch.mul(log_pro, one_hot).sum(dim=1)
loss = loss*target_mask
bs = int(loss.shape[0] / 5)
w_loss = []
for i in range(bs):
w_loss.extend(self.htc_weights)
w_loss = torch.FloatTensor(w_loss).to(self.device)
loss = loss.mul(w_loss) * 5
if self.reduction == 'mean':
loss = loss[torch.where(loss>0)].mean()
elif self.reduction == 'sum':
loss = loss[torch.where(loss>0)].sum()
return loss, predict_res
def get_htc_code(self, logits):
logits_reshaped = logits.clone()
logits_reshaped = logits_reshaped.reshape(-1, 5, 100)
_, aa_predicted = torch.max(logits_reshaped[:,0,1:32], 1)
aa_predicted += 1
logits_new = -5 * torch.ones_like(logits_reshaped).to(self.device)
logits_new[:,0,1:32] = logits_reshaped[:,0,1:32]
predict_res = []
for sample_idx, aa in enumerate(aa_predicted):
bb_idx = htc_mask_dict['{:02d}'.format(aa)]
_, bb_idy = torch.max(logits_reshaped[sample_idx,1,bb_idx], 0)
bb = bb_idx[bb_idy]
logits_new[sample_idx,1,bb_idx] = logits_reshaped[sample_idx,1,bb_idx]
cc_idx = htc_mask_dict['{:02d}{:02d}'.format(aa, bb)]
_, cc_idy = torch.max(logits_reshaped[sample_idx,2,cc_idx], 0)
logits_new[sample_idx,2,cc_idx] = logits_reshaped[sample_idx,2,cc_idx]
cc = cc_idx[cc_idy]
d_idx = htc_mask_dict['{:02d}{:02d}{:02d}'.format(aa, bb, cc)]
_, d_idy = torch.max(logits_reshaped[sample_idx,3,d_idx], 0)
logits_new[sample_idx,3,d_idx] = logits_reshaped[sample_idx,3,d_idx]
d = d_idx[d_idy]
ee_idx = htc_mask_dict['{:02d}{:02d}{:02d}{:01d}'.format(aa, bb, cc, d)]
_, ee_idy = torch.max(logits_reshaped[sample_idx,4,ee_idx], 0)
logits_new[sample_idx,4,ee_idx] = logits_reshaped[sample_idx,4,ee_idx]
ee = ee_idx[ee_idy]
predict_res.extend([aa.item(), bb.item(), cc.item(), d.item(), ee.item()])
return predict_res
|