|
|
|
|
|
""" |
|
# @File : utils.py |
|
# @Author : 刘建林(霜旻) |
|
# @Email : [email protected] |
|
# @Time : 2022/10/27 下午8:52 |
|
""" |
|
import operator |
|
import pickle |
|
import numpy as np |
|
import pandas as pd |
|
|
|
s2_label_dict = { |
|
'0': 0, |
|
'1': 1, |
|
'2': 2, |
|
'3': 3, |
|
'4': 4, |
|
'5': 5, |
|
'6': 6, |
|
'7': 7, |
|
'8': 8, |
|
'9': 9, |
|
'a': 10, |
|
'b': 11, |
|
'c': 12, |
|
'd': 13, |
|
'e': 14, |
|
'f': 15 |
|
} |
|
s2_label_decode_dict = {v: k for k, v in s2_label_dict.items()} |
|
|
|
s2_weights = [0.025, 0.025, 0.025, |
|
0.025, 0.025, 0.025, |
|
0.025, 0.025, 0.025, |
|
0.0325, 0.0325, 0.0325, |
|
0.035, 0.035, 0.035, |
|
0.0375, 0.0375, 0.0375, |
|
0.04, 0.04, 0.04, |
|
0.0425, 0.0425, 0.0425, |
|
0.045, 0.045, 0.0475, |
|
0.025, 0.025, 0.025, |
|
0.0, 0.0, 0.0] |
|
|
|
def generate_s2_index(s2_label): |
|
result = [0 for _ in range(33)] |
|
for i, char_ in enumerate(s2_label): |
|
result[i] = s2_label_dict[char_] |
|
return result |
|
|
|
|
|
def decode_s2(x): |
|
result = [] |
|
for i in x: |
|
result.append(s2_label_decode_dict[i]) |
|
return ''.join(result) |
|
|
|
|
|
def sample_csv2pkl(csv_path, pkl_path): |
|
|
|
df = pd.read_csv(csv_path, sep='^', encoding="utf_8_sig") |
|
|
|
data = [] |
|
for index, row in df.iterrows(): |
|
node_s = [] |
|
label = [] |
|
node1 = [row['node_t1'], row['poi_address_mask1'], row['node1'], generate_s2_index(row['node1'])] |
|
node2 = [row['node_t2'], row['poi_address_mask2'], row['node2'], generate_s2_index(row['node2'])] |
|
node3 = [row['node_t3'], row['poi_address_mask3'], row['node3'], generate_s2_index(row['node3'])] |
|
node4 = [row['node_t4'], row['poi_address_mask4'], row['node4'], generate_s2_index(row['node4'])] |
|
node5 = [row['node_t5'], row['poi_address_mask5'], row['node5'], generate_s2_index(row['node5'])] |
|
node6 = [row['node_t6'], row['poi_address_mask6'], row['node6'], generate_s2_index(row['node6'])] |
|
label.extend(node1[3]) |
|
label.extend(node2[3]) |
|
label.extend(node3[3]) |
|
label.extend(node4[3]) |
|
label.extend(node5[3]) |
|
label.extend(node6[3]) |
|
node1.append(label) |
|
node2.append(label) |
|
node3.append(label) |
|
node4.append(label) |
|
node5.append(label) |
|
node6.append(label) |
|
node_s.append(node1) |
|
node_s.append(node2) |
|
node_s.append(node3) |
|
node_s.append(node4) |
|
node_s.append(node5) |
|
node_s.append(node6) |
|
data.append(node_s) |
|
|
|
|
|
with open(pkl_path,'wb') as f: |
|
pickle.dump(data,f) |
|
|
|
|
|
def calculate_multi_s2_acc(predicted_s2, y): |
|
acc_cnt = np.array([0, 0, 0, 0, 0, 0, 0]) |
|
y = y.view(-1, 33).tolist() |
|
predicted = predicted_s2.view(-1, 33).tolist() |
|
|
|
for index, s2 in enumerate(y): |
|
for c, i in enumerate(range(12, 33, 3)): |
|
y_l10 = y[index][12:i+3] |
|
p_l10 = predicted[index][12:i+3] |
|
|
|
if operator.eq(y_l10, p_l10): |
|
acc_cnt[c] += 1 |
|
|
|
|
|
return acc_cnt |
|
|
|
def calculate_multi_s2_acc_batch(predicted_s2, y, sequence_len = 6): |
|
acc_cnt = np.array([0, 0, 0, 0, 0, 0, 0]) |
|
y = y.view(-1, sequence_len, 33).tolist() |
|
predicted = predicted_s2.view(-1, sequence_len, 33).tolist() |
|
|
|
batch_size = len(y) |
|
for batch_i in range(batch_size): |
|
for index, s2 in enumerate(y[batch_i]): |
|
for c, i in enumerate(range(12, 33, 3)): |
|
y_l10 = y[batch_i][index][12:i+3] |
|
p_l10 = predicted[batch_i][index][12:i+3] |
|
|
|
if operator.eq(y_l10, p_l10): |
|
acc_cnt[c] += 1 |
|
|
|
|
|
return acc_cnt |
|
|
|
|
|
|
|
def calculate_alias_acc(predicted, y): |
|
tp, fp, fn, tn = 0, 0, 0, 0 |
|
acc = 0 |
|
for index, label in enumerate(y): |
|
if int(label) == int(predicted[index]): |
|
acc += 1 |
|
if int(label) == 1: |
|
fn += 1 |
|
if int(predicted[index]) == 1: |
|
tp += 1 |
|
if fn == 0: |
|
precision = 0 |
|
else: |
|
precision = tp / fn * 100 |
|
return tp, fn, acc |
|
|
|
|
|
def calculate_aoi_acc(predicted, y): |
|
tp, fp, fn, tn = 0, 0, 0, 0 |
|
acc = 0 |
|
for index, label in enumerate(y): |
|
if int(label) == int(predicted[index]): |
|
acc += 1 |
|
if int(label) == 0: |
|
fn += 1 |
|
if int(predicted[index]) == 0: |
|
tp += 1 |
|
if fn == 0: |
|
precision = 0 |
|
else: |
|
precision = tp / fn * 100 |
|
return tp, fn, acc |
|
|