Spaces:
Runtime error
Runtime error
lilingxi01
commited on
Commit
β’
c1f0c0e
1
Parent(s):
68f573d
[Model] Update ERC Model to v2-1216-03
Browse files- {model β ercbcm}/model.pt +2 -2
- model/ERCBCM.py +0 -14
- model/__init__.py +0 -38
- model/model_loader.py +0 -38
{model β ercbcm}/model.pt
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3c5e5b3e35d73c471fb1c5ddaddb8f5c81b8b6d99d61f1a94599a03b9689db8c
|
3 |
+
size 438023533
|
model/ERCBCM.py
DELETED
@@ -1,14 +0,0 @@
|
|
1 |
-
from torch import nn
|
2 |
-
from transformers import BertForSequenceClassification
|
3 |
-
|
4 |
-
class ERCBCM(nn.Module):
|
5 |
-
|
6 |
-
def __init__(self):
|
7 |
-
super(ERCBCM, self).__init__()
|
8 |
-
print('>>> ERCBCM Init!')
|
9 |
-
|
10 |
-
self.encoder = BertForSequenceClassification.from_pretrained('bert-base-uncased')
|
11 |
-
|
12 |
-
def forward(self, text, label):
|
13 |
-
loss, text_fea = self.encoder(text, labels=label)[:2]
|
14 |
-
return loss, text_fea
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/__init__.py
DELETED
@@ -1,38 +0,0 @@
|
|
1 |
-
import os, sys
|
2 |
-
|
3 |
-
myPath = os.path.dirname(os.path.abspath(__file__))
|
4 |
-
sys.path.insert(0, myPath + '/../../')
|
5 |
-
|
6 |
-
# ==========
|
7 |
-
|
8 |
-
import torch
|
9 |
-
|
10 |
-
from modules.prediction.model_loader import load_checkpoint
|
11 |
-
from modules.prediction.ERCBCM import ERCBCM
|
12 |
-
from modules.tokenizer import tokenizer, normalize_v2, PAD_TOKEN_ID
|
13 |
-
|
14 |
-
erc_root_folder = './model'
|
15 |
-
|
16 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
17 |
-
print('>>> GPU Available?', torch.cuda.is_available())
|
18 |
-
|
19 |
-
# ==========
|
20 |
-
|
21 |
-
model_for_predict = ERCBCM().to(device)
|
22 |
-
load_checkpoint(erc_root_folder + '/model.pt', model_for_predict, device)
|
23 |
-
|
24 |
-
def predict(sentence, name):
|
25 |
-
label = torch.tensor([0])
|
26 |
-
label = label.type(torch.LongTensor)
|
27 |
-
label = label.to(device)
|
28 |
-
text = tokenizer.encode(normalize_v2(sentence, name))
|
29 |
-
text += [PAD_TOKEN_ID] * (128 - len(text))
|
30 |
-
text = torch.tensor([text])
|
31 |
-
text = text.type(torch.LongTensor)
|
32 |
-
text = text.to(device)
|
33 |
-
_, output = model_for_predict(text, label)
|
34 |
-
pred = torch.argmax(output, 1).tolist()[0]
|
35 |
-
return 'CALLING' if pred == 1 else 'MENTIONING'
|
36 |
-
|
37 |
-
print(predict('are you okay, jimmy', 'jimmy'))
|
38 |
-
print(predict('jimmy is good', 'jimmy'))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/model_loader.py
DELETED
@@ -1,38 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
|
3 |
-
import torch.nn as nn
|
4 |
-
|
5 |
-
# Save and Load Functions.
|
6 |
-
|
7 |
-
def save_checkpoint(save_path, model, valid_loss):
|
8 |
-
if save_path == None:
|
9 |
-
return
|
10 |
-
state_dict = {'model_state_dict': model.state_dict(),
|
11 |
-
'valid_loss': valid_loss}
|
12 |
-
torch.save(state_dict, save_path)
|
13 |
-
print('[SAVE] Model has been saved successfully to \'{}\''.format(save_path))
|
14 |
-
|
15 |
-
def load_checkpoint(load_path, model, device):
|
16 |
-
if load_path == None:
|
17 |
-
return
|
18 |
-
state_dict = torch.load(load_path, map_location=device)
|
19 |
-
print('DICT:', state_dict)
|
20 |
-
print('[LOAD] Model has been loaded successfully from \'{}\''.format(load_path))
|
21 |
-
model.load_state_dict(state_dict['model_state_dict'])
|
22 |
-
return state_dict['valid_loss']
|
23 |
-
|
24 |
-
def save_metrics(save_path, train_loss_list, valid_loss_list, global_steps_list):
|
25 |
-
if save_path == None:
|
26 |
-
return
|
27 |
-
state_dict = {'train_loss_list': train_loss_list,
|
28 |
-
'valid_loss_list': valid_loss_list,
|
29 |
-
'global_steps_list': global_steps_list}
|
30 |
-
torch.save(state_dict, save_path)
|
31 |
-
print('[SAVE] Model with matrics has been saved successfully to \'{}\''.format(save_path))
|
32 |
-
|
33 |
-
def load_metrics(load_path, device):
|
34 |
-
if load_path == None:
|
35 |
-
return
|
36 |
-
state_dict = torch.load(load_path, map_location=device)
|
37 |
-
print('[LOAD] Model with matrics has been loaded successfully from \'{}\''.format(load_path))
|
38 |
-
return state_dict['train_loss_list'], state_dict['valid_loss_list'], state_dict['global_steps_list']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|