TAAS / utils.py
zy414775's picture
Update utils.py
54a9a32
raw
history blame
4.84 kB
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import operator
import pickle
import numpy as np
import pandas as pd
s2_label_dict = {
'0': 0,
'1': 1,
'2': 2,
'3': 3,
'4': 4,
'5': 5,
'6': 6,
'7': 7,
'8': 8,
'9': 9,
'a': 10,
'b': 11,
'c': 12,
'd': 13,
'e': 14,
'f': 15
}
s2_label_decode_dict = {v: k for k, v in s2_label_dict.items()}
s2_weights = [0.025, 0.025, 0.025,
0.025, 0.025, 0.025,
0.025, 0.025, 0.025,
0.0325, 0.0325, 0.0325,
0.035, 0.035, 0.035,
0.0375, 0.0375, 0.0375,
0.04, 0.04, 0.04,
0.0425, 0.0425, 0.0425,
0.045, 0.045, 0.0475,
0.025, 0.025, 0.025,
0.0, 0.0, 0.0]
def generate_s2_index(s2_label):
result = [0 for _ in range(33)]
for i, char_ in enumerate(s2_label):
result[i] = s2_label_dict[char_]
return result
def decode_s2(x):
result = []
for i in x:
result.append(s2_label_decode_dict[i])
return ''.join(result)
def sample_csv2pkl(csv_path, pkl_path):
# df = pd.read_csv('/Users/liujianlin/odps_clt_release_64/bin/addr6node_small1.csv', sep='^', encoding="utf_8_sig")
df = pd.read_csv(csv_path, sep='^', encoding="utf_8_sig")
# print(df)
data = []
for index, row in df.iterrows():
node_s = []
label = []
node1 = [row['node_t1'], row['poi_address_mask1'], row['node1'], generate_s2_index(row['node1'])]
node2 = [row['node_t2'], row['poi_address_mask2'], row['node2'], generate_s2_index(row['node2'])]
node3 = [row['node_t3'], row['poi_address_mask3'], row['node3'], generate_s2_index(row['node3'])]
node4 = [row['node_t4'], row['poi_address_mask4'], row['node4'], generate_s2_index(row['node4'])]
node5 = [row['node_t5'], row['poi_address_mask5'], row['node5'], generate_s2_index(row['node5'])]
node6 = [row['node_t6'], row['poi_address_mask6'], row['node6'], generate_s2_index(row['node6'])]
label.extend(node1[3])
label.extend(node2[3])
label.extend(node3[3])
label.extend(node4[3])
label.extend(node5[3])
label.extend(node6[3])
node1.append(label)
node2.append(label)
node3.append(label)
node4.append(label)
node5.append(label)
node6.append(label)
node_s.append(node1)
node_s.append(node2)
node_s.append(node3)
node_s.append(node4)
node_s.append(node5)
node_s.append(node6)
data.append(node_s)
# print(data)
with open(pkl_path,'wb') as f:
pickle.dump(data,f)
def calculate_multi_s2_acc(predicted_s2, y):
acc_cnt = np.array([0, 0, 0, 0, 0, 0, 0])
y = y.view(-1, 33).tolist()
predicted = predicted_s2.view(-1, 33).tolist()
# print(y.shape, predicted.shape)
for index, s2 in enumerate(y):
for c, i in enumerate(range(12, 33, 3)):
y_l10 = y[index][12:i+3]
p_l10 = predicted[index][12:i+3]
# print(y_l10, p_l10, operator.eq(y_l10, p_l10))
if operator.eq(y_l10, p_l10):
acc_cnt[c] += 1
# print('==='*20)
# print(acc_cnt)
return acc_cnt
def calculate_multi_s2_acc_batch(predicted_s2, y, sequence_len = 6):
acc_cnt = np.array([0, 0, 0, 0, 0, 0, 0])
y = y.view(-1, sequence_len, 33).tolist()
predicted = predicted_s2.view(-1, sequence_len, 33).tolist()
# print(y.shape, predicted.shape)
batch_size = len(y)
for batch_i in range(batch_size):
for index, s2 in enumerate(y[batch_i]):
for c, i in enumerate(range(12, 33, 3)):
y_l10 = y[batch_i][index][12:i+3]
p_l10 = predicted[batch_i][index][12:i+3]
# print(y_l10, p_l10, operator.eq(y_l10, p_l10))
if operator.eq(y_l10, p_l10):
acc_cnt[c] += 1
# print('==='*20)
# print(acc_cnt)
return acc_cnt
def calculate_alias_acc(predicted, y):
tp, fp, fn, tn = 0, 0, 0, 0
acc = 0
for index, label in enumerate(y):
if int(label) == int(predicted[index]):
acc += 1
if int(label) == 1:
fn += 1
if int(predicted[index]) == 1:
tp += 1
if fn == 0:
precision = 0
else:
precision = tp / fn * 100
return tp, fn, acc
def calculate_aoi_acc(predicted, y):
tp, fp, fn, tn = 0, 0, 0, 0
acc = 0
for index, label in enumerate(y):
if int(label) == int(predicted[index]):
acc += 1
if int(label) == 0:
fn += 1
if int(predicted[index]) == 0:
tp += 1
if fn == 0:
precision = 0
else:
precision = tp / fn * 100
return tp, fn, acc