homer7676 commited on
Commit
4233874
1 Parent(s): 7de7db9

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +80 -68
handler.py CHANGED
@@ -9,54 +9,92 @@ logger = logging.getLogger(__name__)
9
  class EndpointHandler:
10
  def __init__(self, model_dir: str = None):
11
  self.model_dir = model_dir if model_dir else "homer7676/FrierenChatbotV1"
12
- self.tokenizer = None
13
- self.model = None
14
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
15
  logger.info(f"初始化 EndpointHandler,使用設備: {self.device}")
16
-
17
- def __call__(self, data: Dict[str, Any]) -> List[Dict[str, str]]:
18
- try:
19
- inputs = self.preprocess(data)
20
- outputs = self.inference(inputs)
21
- # 確保輸出不為空
22
- if not outputs or "generated_text" not in outputs:
23
- raise ValueError("No text was generated")
24
- return [outputs]
25
- except Exception as e:
26
- logger.error(f"處理過程錯誤: {str(e)}")
27
- return [{"error": str(e)}]
28
-
29
- def initialize(self, context):
30
- logger.info("開始初始化模型")
31
  try:
 
32
  self.tokenizer = AutoTokenizer.from_pretrained(
33
  self.model_dir,
34
  trust_remote_code=True
35
  )
36
-
37
  if self.tokenizer.pad_token is None:
38
  self.tokenizer.pad_token = self.tokenizer.eos_token
39
-
40
  self.model = AutoModelForCausalLM.from_pretrained(
41
  self.model_dir,
42
  trust_remote_code=True,
43
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
44
  ).to(self.device)
45
-
46
  self.model.eval()
47
- logger.info("模型初始化完成")
 
48
  except Exception as e:
49
- logger.error(f"模型載入錯誤: {str(e)}")
50
  raise
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  def inference(self, inputs: Dict[str, Any]) -> Dict[str, str]:
53
  logger.info("開始執行推理")
54
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  message = inputs.get("message", "")
56
  context = inputs.get("context", "")
57
- logger.info(f"處理訊息: {message}")
58
-
59
- # 構建提示詞
60
  prompt = f"""你是芙莉蓮,需要遵守以下規則回答:
61
  1. 身份設定:
62
  - 千年精靈魔法師
@@ -75,69 +113,43 @@ class EndpointHandler:
75
  用戶:{message}
76
  芙莉蓮:"""
77
 
78
- # 記錄提示詞長度
79
- logger.info(f"提示詞長度: {len(prompt)}")
80
-
81
- # Tokenize
82
- encoding = self.tokenizer.encode_plus(
83
  prompt,
84
- add_special_tokens=True,
85
  return_tensors="pt",
86
  padding=True,
87
  truncation=True,
88
  max_length=2048
89
- )
90
-
91
- # 移動到正確的設備
92
- input_ids = encoding["input_ids"].to(self.device)
93
- attention_mask = encoding["attention_mask"].to(self.device)
94
-
95
- logger.info(f"輸入 token 數量: {input_ids.shape[-1]}")
96
 
97
- # 生成回應
98
  with torch.no_grad():
99
  outputs = self.model.generate(
100
- input_ids=input_ids,
101
- attention_mask=attention_mask,
102
- max_new_tokens=256,
103
  temperature=0.7,
104
  top_p=0.9,
105
  top_k=50,
106
  do_sample=True,
107
  pad_token_id=self.tokenizer.pad_token_id,
108
- eos_token_id=self.tokenizer.eos_token_id,
109
- num_return_sequences=1,
110
- no_repeat_ngram_size=3
111
  )
112
-
113
- logger.info(f"生成的 token 數量: {outputs.shape[-1]}")
114
 
115
- # 解碼回應
116
- full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
117
 
118
- # 分離出模型的回應部分
119
- if "芙莉蓮:" in full_response:
120
- response = full_response.split("芙莉蓮:")[-1].strip()
121
- else:
122
- response = full_response.split("用戶:")[-1].strip()
123
-
124
- logger.info(f"生成回應長度: {len(response)}")
125
-
126
- # 確保回應不為空
127
  if not response:
128
- response = "抱歉,我似乎有點恍神了。能請你再說一次嗎?"
129
-
130
- return {
131
- "generated_text": response
132
- }
133
 
134
  except Exception as e:
135
  logger.error(f"推理過程錯誤: {str(e)}")
136
- raise
137
 
138
  def preprocess(self, data: Dict[str, Any]) -> Dict[str, Any]:
139
  logger.info(f"預處理輸入數據: {data}")
140
- inputs = data.pop("inputs", data)
141
- if not isinstance(inputs, dict):
142
- inputs = {"message": inputs}
143
- return inputs
 
9
  class EndpointHandler:
10
  def __init__(self, model_dir: str = None):
11
  self.model_dir = model_dir if model_dir else "homer7676/FrierenChatbotV1"
 
 
12
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
13
  logger.info(f"初始化 EndpointHandler,使用設備: {self.device}")
14
+
15
+ # 在初始化時就載入模型和 tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  try:
17
+ logger.info("開始載入 tokenizer 和模型")
18
  self.tokenizer = AutoTokenizer.from_pretrained(
19
  self.model_dir,
20
  trust_remote_code=True
21
  )
 
22
  if self.tokenizer.pad_token is None:
23
  self.tokenizer.pad_token = self.tokenizer.eos_token
24
+
25
  self.model = AutoModelForCausalLM.from_pretrained(
26
  self.model_dir,
27
  trust_remote_code=True,
28
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
29
  ).to(self.device)
 
30
  self.model.eval()
31
+ logger.info("模型和 tokenizer 載入完成")
32
+
33
  except Exception as e:
34
+ logger.error(f"初始化錯誤: {str(e)}")
35
  raise
36
 
37
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, str]]:
38
+ try:
39
+ # 確保 tokenizer 和 model 已經初始化
40
+ if self.tokenizer is None or self.model is None:
41
+ raise RuntimeError("Tokenizer or model not initialized")
42
+
43
+ inputs = self.preprocess(data)
44
+ outputs = self.inference(inputs)
45
+ return [outputs]
46
+ except Exception as e:
47
+ logger.error(f"處理過程錯誤: {str(e)}")
48
+ return [{"error": str(e)}]
49
+
50
+ def initialize(self, context):
51
+ """確保模型已初始化"""
52
+ if self.tokenizer is None or self.model is None:
53
+ logger.info("在 initialize 中重新初始化模型")
54
+ try:
55
+ self.tokenizer = AutoTokenizer.from_pretrained(
56
+ self.model_dir,
57
+ trust_remote_code=True
58
+ )
59
+ if self.tokenizer.pad_token is None:
60
+ self.tokenizer.pad_token = self.tokenizer.eos_token
61
+
62
+ self.model = AutoModelForCausalLM.from_pretrained(
63
+ self.model_dir,
64
+ trust_remote_code=True,
65
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
66
+ ).to(self.device)
67
+ self.model.eval()
68
+ logger.info("模型重新初始化完成")
69
+ except Exception as e:
70
+ logger.error(f"模型重新初始化錯誤: {str(e)}")
71
+ raise
72
+
73
  def inference(self, inputs: Dict[str, Any]) -> Dict[str, str]:
74
  logger.info("開始執行推理")
75
  try:
76
+ # 檢查輸入格式
77
+ if isinstance(inputs, str):
78
+ try:
79
+ import json
80
+ inputs = json.loads(inputs)
81
+ except json.JSONDecodeError:
82
+ inputs = {"message": inputs}
83
+
84
+ # 提取消息和上下文
85
+ if isinstance(inputs, dict) and "inputs" in inputs:
86
+ inputs = inputs["inputs"]
87
+ if isinstance(inputs, str):
88
+ try:
89
+ import json
90
+ inputs = json.loads(inputs)
91
+ except json.JSONDecodeError:
92
+ inputs = {"message": inputs}
93
+
94
  message = inputs.get("message", "")
95
  context = inputs.get("context", "")
96
+ logger.info(f"處理消息: {message}, 上下文: {context}")
97
+
 
98
  prompt = f"""你是芙莉蓮,需要遵守以下規則回答:
99
  1. 身份設定:
100
  - 千年精靈魔法師
 
113
  用戶:{message}
114
  芙莉蓮:"""
115
 
116
+ # 確保 tokenizer 存在
117
+ if self.tokenizer is None:
118
+ raise RuntimeError("Tokenizer not initialized")
119
+
120
+ tokens = self.tokenizer(
121
  prompt,
 
122
  return_tensors="pt",
123
  padding=True,
124
  truncation=True,
125
  max_length=2048
126
+ ).to(self.device)
 
 
 
 
 
 
127
 
 
128
  with torch.no_grad():
129
  outputs = self.model.generate(
130
+ **tokens,
131
+ max_new_tokens=150,
 
132
  temperature=0.7,
133
  top_p=0.9,
134
  top_k=50,
135
  do_sample=True,
136
  pad_token_id=self.tokenizer.pad_token_id,
137
+ eos_token_id=self.tokenizer.eos_token_id
 
 
138
  )
 
 
139
 
140
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
141
+ response = response.split("芙莉蓮:")[-1].strip()
142
 
 
 
 
 
 
 
 
 
 
143
  if not response:
144
+ response = "唔...讓我思考一下如何回答你的問題。"
145
+
146
+ logger.info(f"生成回應: {response}")
147
+ return {"generated_text": response}
 
148
 
149
  except Exception as e:
150
  logger.error(f"推理過程錯誤: {str(e)}")
151
+ return {"error": str(e)}
152
 
153
  def preprocess(self, data: Dict[str, Any]) -> Dict[str, Any]:
154
  logger.info(f"預處理輸入數據: {data}")
155
+ return data