Delete handler.py
Browse files- handler.py +0 -71
handler.py
DELETED
@@ -1,71 +0,0 @@
|
|
1 |
-
from typing import Dict, List, Any
|
2 |
-
import torch
|
3 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM
|
4 |
-
import re
|
5 |
-
|
6 |
-
class EndpointHandler():
|
7 |
-
def __init__(self, path="meyandrei/bankchat"):
|
8 |
-
# Load the model and tokenizer
|
9 |
-
self.tokenizer = AutoTokenizer.from_pretrained(path, padding_side='left', use_safetensors=True)
|
10 |
-
self.model = AutoModelForCausalLM.from_pretrained(path, use_safetensors=True)
|
11 |
-
self.context_token = self.tokenizer.encode('<|context|>', return_tensors='pt')
|
12 |
-
self.endofcontext_token = self.tokenizer.encode(' <|endofcontext|>', return_tensors='pt')
|
13 |
-
|
14 |
-
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
15 |
-
"""
|
16 |
-
data args:
|
17 |
-
inputs (:obj: `str`)
|
18 |
-
context (:obj: `list` of `str`)
|
19 |
-
Return:
|
20 |
-
A :obj:`list` | `dict`: will be serialized and returned
|
21 |
-
"""
|
22 |
-
user_input = data.get('inputs', '')
|
23 |
-
history = data.get('context', [])
|
24 |
-
|
25 |
-
if history == []:
|
26 |
-
context_tokenized = torch.LongTensor(history)
|
27 |
-
else:
|
28 |
-
history_str = tokenizer.decode(history[0])
|
29 |
-
turns = re.split('<\|system\|>|<\|user\|>', history_str)[1:]
|
30 |
-
|
31 |
-
for i in range(0, len(turns)-1, 2):
|
32 |
-
turns[i] = '<|user|>' + turns[i]
|
33 |
-
turns[i+1] = '<|system|>' + turns[i+1]
|
34 |
-
|
35 |
-
context_tokenized = self.tokenizer.encode(''.join(turns), return_tensors='pt')
|
36 |
-
|
37 |
-
user_input_tokenized = self.tokenizer.encode(' ' + user_input, return_tensors='pt')
|
38 |
-
model_input = torch.cat([self.context_token, context_tokenized, user_input_tokenized, self.endofcontext_token], dim=-1)
|
39 |
-
attention_mask = torch.ones_like(model_input)
|
40 |
-
|
41 |
-
out_tokenized = self.model.generate(model_input, max_length=1024, eos_token_id=50258, pad_token_id=50260, attention_mask=attention_mask).tolist()[0]
|
42 |
-
out_str = self.tokenizer.decode(out_tokenized)
|
43 |
-
out_str = out_str.split('\n')[0]
|
44 |
-
|
45 |
-
generated_substring = out_str.split('')[1] # belief, actions, system_response
|
46 |
-
|
47 |
-
beliefs_start_index = generated_substring.find('') + len('')
|
48 |
-
beliefs_end_index = generated_substring.find('', beliefs_start_index)
|
49 |
-
|
50 |
-
actions_start_index = generated_substring.find('') + len('')
|
51 |
-
actions_end_index = generated_substring.find('', actions_start_index)
|
52 |
-
|
53 |
-
response_start_index = generated_substring.find('') + len('')
|
54 |
-
response_end_index = generated_substring.find('', response_start_index)
|
55 |
-
|
56 |
-
beliefs_str = generated_substring[beliefs_start_index:beliefs_end_index]
|
57 |
-
actions_str = generated_substring[actions_start_index:actions_end_index]
|
58 |
-
system_response_str = generated_substring[response_start_index:response_end_index]
|
59 |
-
|
60 |
-
system_resp_tokenized = self.tokenizer.encode(' ' + system_response_str, return_tensors='pt')
|
61 |
-
history = torch.cat([torch.LongTensor(history), user_input_tokenized, system_resp_tokenized], dim=-1).tolist()
|
62 |
-
|
63 |
-
# Prepare the output
|
64 |
-
model_outputs = {
|
65 |
-
'response': system_response_str,
|
66 |
-
'context': history,
|
67 |
-
'beliefs': beliefs_str,
|
68 |
-
'actions': actions_str
|
69 |
-
}
|
70 |
-
|
71 |
-
return model_outputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|