# handler.py import torch from transformers import AutoTokenizer, AutoModel import json from typing import Dict, Any import numpy as np from opencc import OpenCC import jieba import re class EndpointHandler: def __init__(self): self.tokenizer = None self.model = None self.device = "cuda" if torch.cuda.is_available() else "cpu" self.converter = OpenCC('s2t') def initialize(self, context): """初始化模型和 tokenizer""" self.tokenizer = AutoTokenizer.from_pretrained( "THUDM/chatglm3-6b-base", trust_remote_code=True ) self.model = AutoModel.from_pretrained( "THUDM/chatglm3-6b-base", trust_remote_code=True, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 ).to(self.device) self.model.eval() def preprocess(self, data: Dict[str, Any]) -> Dict[str, Any]: """預處理輸入數據""" inputs = data.pop("inputs", data) # 確保輸入格式正確 if not isinstance(inputs, dict): inputs = {"message": inputs} return inputs def inference(self, inputs: Dict[str, Any]) -> Dict[str, Any]: """執行推理""" message = inputs.get("message", "") context = inputs.get("context", "") # 構建提示詞 prompt = self._build_prompt(context, message) # tokenize inputs = self.tokenizer( prompt, return_tensors="pt", add_special_tokens=True, truncation=True, max_length=2048 ).to(self.device) # 生成回應 with torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=256, temperature=0.7, top_p=0.9, do_sample=True, repetition_penalty=1.2, num_beams=4, early_stopping=True ) response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) response = response.split("芙莉蓮:")[-1].strip() # 後處理回應 response = self._process_response(response) return {"response": response} def _build_prompt(self, context: str, query: str) -> str: """構建提示詞""" return f"""你是芙莉蓮,需要遵守以下規則回答: 1. 身份設定: - 千年精靈魔法師 - 態度溫柔但帶著些許嘲諷 - 說話優雅且有距離感 2. 重要關係: - 弗蘭梅是我的師傅 - 費倫是我的學生 - 欣梅爾是我的摯友 - 海塔是我的故友 3. 回答規則: - 使用繁體中文 - 必須提供具體詳細的內容 - 保持回答的連貫性和完整性 相關資訊: {context} 用戶:{query} 芙莉蓮:""" def _process_response(self, response: str) -> str: """處理回應文本""" if not response or not response.strip(): return "抱歉,我現在有點恍神,請你再問一次好嗎?" # 轉換為繁體 response = self.converter.convert(response) # 清理和格式化 response = re.sub(r'\s+', '', response) if not response.endswith(('。', '!', '?', '~', '呢', '啊', '吶')): response += '呢。' return response def postprocess(self, data: Dict[str, Any]) -> Dict[str, Any]: """後處理輸出數據""" return data