Spaces:
Runtime error
Runtime error
lilingxi01
commited on
Commit
β’
68f573d
1
Parent(s):
cd78976
[General] Rebuild folders with a simpler layout.
Browse files- .gitignore +1 -1
- ercbcm/ERCBCM.py +14 -0
- ercbcm/__init__.py +33 -0
- {modules/prediction β ercbcm}/model_loader.py +1 -1
- {modules/prediction β model}/ERCBCM.py +1 -0
- {modules/prediction β model}/__init__.py +8 -6
- model/model_loader.py +38 -0
.gitignore
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
.DS_Store
|
2 |
-
|
3 |
venv/
|
4 |
__pycache__/
|
|
|
1 |
.DS_Store
|
2 |
+
.idea/
|
3 |
venv/
|
4 |
__pycache__/
|
ercbcm/ERCBCM.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
from transformers import BertForSequenceClassification
|
3 |
+
|
4 |
+
class ERCBCM(nn.Module):
|
5 |
+
|
6 |
+
def __init__(self):
|
7 |
+
super(ERCBCM, self).__init__()
|
8 |
+
print('>>> ERCBCM Init!')
|
9 |
+
|
10 |
+
self.bert_base = BertForSequenceClassification.from_pretrained('bert-base-uncased')
|
11 |
+
|
12 |
+
def forward(self, text, label):
|
13 |
+
loss, text_fea = self.bert_base(text, labels=label)[:2]
|
14 |
+
return loss, text_fea
|
ercbcm/__init__.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
|
3 |
+
myPath = os.path.dirname(os.path.abspath(__file__))
|
4 |
+
sys.path.insert(0, myPath + '/../')
|
5 |
+
|
6 |
+
# ==========
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from ercbcm.model_loader import load_checkpoint
|
11 |
+
from ercbcm.ERCBCM import ERCBCM
|
12 |
+
from modules.tokenizer import tokenizer, normalize_v2, PAD_TOKEN_ID
|
13 |
+
|
14 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
15 |
+
print('>>> GPU Available?', torch.cuda.is_available())
|
16 |
+
|
17 |
+
# ==========
|
18 |
+
|
19 |
+
model_for_predict = ERCBCM().to(device)
|
20 |
+
load_checkpoint('ercbcm/model.pt', model_for_predict, device)
|
21 |
+
|
22 |
+
def predict(sentence, name):
|
23 |
+
label = torch.tensor([0])
|
24 |
+
label = label.type(torch.LongTensor)
|
25 |
+
label = label.to(device)
|
26 |
+
text = tokenizer.encode(normalize_v2(sentence, name))
|
27 |
+
text += [PAD_TOKEN_ID] * (128 - len(text))
|
28 |
+
text = torch.tensor([text])
|
29 |
+
text = text.type(torch.LongTensor)
|
30 |
+
text = text.to(device)
|
31 |
+
_, output = model_for_predict(text, label)
|
32 |
+
pred = torch.argmax(output, 1).tolist()[0]
|
33 |
+
return 'CALLING' if pred == 1 else 'MENTIONING'
|
{modules/prediction β ercbcm}/model_loader.py
RENAMED
@@ -1,6 +1,6 @@
|
|
1 |
import torch
|
2 |
|
3 |
-
# Save and Load Functions
|
4 |
|
5 |
def save_checkpoint(save_path, model, valid_loss):
|
6 |
if save_path == None:
|
|
|
1 |
import torch
|
2 |
|
3 |
+
# Save and Load Functions.
|
4 |
|
5 |
def save_checkpoint(save_path, model, valid_loss):
|
6 |
if save_path == None:
|
{modules/prediction β model}/ERCBCM.py
RENAMED
@@ -5,6 +5,7 @@ class ERCBCM(nn.Module):
|
|
5 |
|
6 |
def __init__(self):
|
7 |
super(ERCBCM, self).__init__()
|
|
|
8 |
|
9 |
self.encoder = BertForSequenceClassification.from_pretrained('bert-base-uncased')
|
10 |
|
|
|
5 |
|
6 |
def __init__(self):
|
7 |
super(ERCBCM, self).__init__()
|
8 |
+
print('>>> ERCBCM Init!')
|
9 |
|
10 |
self.encoder = BertForSequenceClassification.from_pretrained('bert-base-uncased')
|
11 |
|
{modules/prediction β model}/__init__.py
RENAMED
@@ -14,13 +14,12 @@ from modules.tokenizer import tokenizer, normalize_v2, PAD_TOKEN_ID
|
|
14 |
erc_root_folder = './model'
|
15 |
|
16 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
17 |
|
18 |
# ==========
|
19 |
|
20 |
-
|
21 |
-
|
22 |
-
def prepare():
|
23 |
-
load_checkpoint(erc_root_folder + '/model.pt', model_for_evaluate, device)
|
24 |
|
25 |
def predict(sentence, name):
|
26 |
label = torch.tensor([0])
|
@@ -31,6 +30,9 @@ def predict(sentence, name):
|
|
31 |
text = torch.tensor([text])
|
32 |
text = text.type(torch.LongTensor)
|
33 |
text = text.to(device)
|
34 |
-
_, output =
|
35 |
pred = torch.argmax(output, 1).tolist()[0]
|
36 |
-
return 'CALLING' if pred == 1 else 'MENTIONING'
|
|
|
|
|
|
|
|
14 |
erc_root_folder = './model'
|
15 |
|
16 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
17 |
+
print('>>> GPU Available?', torch.cuda.is_available())
|
18 |
|
19 |
# ==========
|
20 |
|
21 |
+
model_for_predict = ERCBCM().to(device)
|
22 |
+
load_checkpoint(erc_root_folder + '/model.pt', model_for_predict, device)
|
|
|
|
|
23 |
|
24 |
def predict(sentence, name):
|
25 |
label = torch.tensor([0])
|
|
|
30 |
text = torch.tensor([text])
|
31 |
text = text.type(torch.LongTensor)
|
32 |
text = text.to(device)
|
33 |
+
_, output = model_for_predict(text, label)
|
34 |
pred = torch.argmax(output, 1).tolist()[0]
|
35 |
+
return 'CALLING' if pred == 1 else 'MENTIONING'
|
36 |
+
|
37 |
+
print(predict('are you okay, jimmy', 'jimmy'))
|
38 |
+
print(predict('jimmy is good', 'jimmy'))
|
model/model_loader.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
# Save and Load Functions.
|
6 |
+
|
7 |
+
def save_checkpoint(save_path, model, valid_loss):
|
8 |
+
if save_path == None:
|
9 |
+
return
|
10 |
+
state_dict = {'model_state_dict': model.state_dict(),
|
11 |
+
'valid_loss': valid_loss}
|
12 |
+
torch.save(state_dict, save_path)
|
13 |
+
print('[SAVE] Model has been saved successfully to \'{}\''.format(save_path))
|
14 |
+
|
15 |
+
def load_checkpoint(load_path, model, device):
|
16 |
+
if load_path == None:
|
17 |
+
return
|
18 |
+
state_dict = torch.load(load_path, map_location=device)
|
19 |
+
print('DICT:', state_dict)
|
20 |
+
print('[LOAD] Model has been loaded successfully from \'{}\''.format(load_path))
|
21 |
+
model.load_state_dict(state_dict['model_state_dict'])
|
22 |
+
return state_dict['valid_loss']
|
23 |
+
|
24 |
+
def save_metrics(save_path, train_loss_list, valid_loss_list, global_steps_list):
|
25 |
+
if save_path == None:
|
26 |
+
return
|
27 |
+
state_dict = {'train_loss_list': train_loss_list,
|
28 |
+
'valid_loss_list': valid_loss_list,
|
29 |
+
'global_steps_list': global_steps_list}
|
30 |
+
torch.save(state_dict, save_path)
|
31 |
+
print('[SAVE] Model with matrics has been saved successfully to \'{}\''.format(save_path))
|
32 |
+
|
33 |
+
def load_metrics(load_path, device):
|
34 |
+
if load_path == None:
|
35 |
+
return
|
36 |
+
state_dict = torch.load(load_path, map_location=device)
|
37 |
+
print('[LOAD] Model with matrics has been loaded successfully from \'{}\''.format(load_path))
|
38 |
+
return state_dict['train_loss_list'], state_dict['valid_loss_list'], state_dict['global_steps_list']
|