lilingxi01 commited on
Commit
c1f0c0e
β€’
1 Parent(s): 68f573d

[Model] Update ERC Model to v2-1216-03

Browse files
{model β†’ ercbcm}/model.pt RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3911af416600534886f75bfee3f855a35ffa6ff85486685e8e8385e63051e7ed
3
- size 438022701
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3c5e5b3e35d73c471fb1c5ddaddb8f5c81b8b6d99d61f1a94599a03b9689db8c
3
+ size 438023533
model/ERCBCM.py DELETED
@@ -1,14 +0,0 @@
1
- from torch import nn
2
- from transformers import BertForSequenceClassification
3
-
4
- class ERCBCM(nn.Module):
5
-
6
- def __init__(self):
7
- super(ERCBCM, self).__init__()
8
- print('>>> ERCBCM Init!')
9
-
10
- self.encoder = BertForSequenceClassification.from_pretrained('bert-base-uncased')
11
-
12
- def forward(self, text, label):
13
- loss, text_fea = self.encoder(text, labels=label)[:2]
14
- return loss, text_fea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/__init__.py DELETED
@@ -1,38 +0,0 @@
1
- import os, sys
2
-
3
- myPath = os.path.dirname(os.path.abspath(__file__))
4
- sys.path.insert(0, myPath + '/../../')
5
-
6
- # ==========
7
-
8
- import torch
9
-
10
- from modules.prediction.model_loader import load_checkpoint
11
- from modules.prediction.ERCBCM import ERCBCM
12
- from modules.tokenizer import tokenizer, normalize_v2, PAD_TOKEN_ID
13
-
14
- erc_root_folder = './model'
15
-
16
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17
- print('>>> GPU Available?', torch.cuda.is_available())
18
-
19
- # ==========
20
-
21
- model_for_predict = ERCBCM().to(device)
22
- load_checkpoint(erc_root_folder + '/model.pt', model_for_predict, device)
23
-
24
- def predict(sentence, name):
25
- label = torch.tensor([0])
26
- label = label.type(torch.LongTensor)
27
- label = label.to(device)
28
- text = tokenizer.encode(normalize_v2(sentence, name))
29
- text += [PAD_TOKEN_ID] * (128 - len(text))
30
- text = torch.tensor([text])
31
- text = text.type(torch.LongTensor)
32
- text = text.to(device)
33
- _, output = model_for_predict(text, label)
34
- pred = torch.argmax(output, 1).tolist()[0]
35
- return 'CALLING' if pred == 1 else 'MENTIONING'
36
-
37
- print(predict('are you okay, jimmy', 'jimmy'))
38
- print(predict('jimmy is good', 'jimmy'))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/model_loader.py DELETED
@@ -1,38 +0,0 @@
1
- import torch
2
-
3
- import torch.nn as nn
4
-
5
- # Save and Load Functions.
6
-
7
- def save_checkpoint(save_path, model, valid_loss):
8
- if save_path == None:
9
- return
10
- state_dict = {'model_state_dict': model.state_dict(),
11
- 'valid_loss': valid_loss}
12
- torch.save(state_dict, save_path)
13
- print('[SAVE] Model has been saved successfully to \'{}\''.format(save_path))
14
-
15
- def load_checkpoint(load_path, model, device):
16
- if load_path == None:
17
- return
18
- state_dict = torch.load(load_path, map_location=device)
19
- print('DICT:', state_dict)
20
- print('[LOAD] Model has been loaded successfully from \'{}\''.format(load_path))
21
- model.load_state_dict(state_dict['model_state_dict'])
22
- return state_dict['valid_loss']
23
-
24
- def save_metrics(save_path, train_loss_list, valid_loss_list, global_steps_list):
25
- if save_path == None:
26
- return
27
- state_dict = {'train_loss_list': train_loss_list,
28
- 'valid_loss_list': valid_loss_list,
29
- 'global_steps_list': global_steps_list}
30
- torch.save(state_dict, save_path)
31
- print('[SAVE] Model with matrics has been saved successfully to \'{}\''.format(save_path))
32
-
33
- def load_metrics(load_path, device):
34
- if load_path == None:
35
- return
36
- state_dict = torch.load(load_path, map_location=device)
37
- print('[LOAD] Model with matrics has been loaded successfully from \'{}\''.format(load_path))
38
- return state_dict['train_loss_list'], state_dict['valid_loss_list'], state_dict['global_steps_list']