import torch from transformers import AutoTokenizer, AutoModelForCausalLM from typing import Dict, Any, List import logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class EndpointHandler: def __init__(self, model_dir: str = None): self.model_dir = model_dir if model_dir else "homer7676/FrierenChatbotV1" self.device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"初始化 EndpointHandler,使用設備: {self.device}") # 在初始化時就載入模型和 tokenizer try: logger.info("開始載入 tokenizer 和模型") self.tokenizer = AutoTokenizer.from_pretrained( self.model_dir, trust_remote_code=True ) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.model = AutoModelForCausalLM.from_pretrained( self.model_dir, trust_remote_code=True, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 ).to(self.device) self.model.eval() logger.info("模型和 tokenizer 載入完成") except Exception as e: logger.error(f"初始化錯誤: {str(e)}") raise def __call__(self, data: Dict[str, Any]) -> List[Dict[str, str]]: try: # 確保 tokenizer 和 model 已經初始化 if self.tokenizer is None or self.model is None: raise RuntimeError("Tokenizer or model not initialized") inputs = self.preprocess(data) outputs = self.inference(inputs) return [outputs] except Exception as e: logger.error(f"處理過程錯誤: {str(e)}") return [{"error": str(e)}] def initialize(self, context): """確保模型已初始化""" if self.tokenizer is None or self.model is None: logger.info("在 initialize 中重新初始化模型") try: self.tokenizer = AutoTokenizer.from_pretrained( self.model_dir, trust_remote_code=True ) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.model = AutoModelForCausalLM.from_pretrained( self.model_dir, trust_remote_code=True, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 ).to(self.device) self.model.eval() logger.info("模型重新初始化完成") except Exception as e: logger.error(f"模型重新初始化錯誤: {str(e)}") raise def inference(self, inputs: Dict[str, Any]) -> Dict[str, str]: logger.info("開始執行推理") try: # 檢查輸入格式 if isinstance(inputs, str): try: import json inputs = json.loads(inputs) except json.JSONDecodeError: inputs = {"message": inputs} # 提取消息和上下文 if isinstance(inputs, dict) and "inputs" in inputs: inputs = inputs["inputs"] if isinstance(inputs, str): try: import json inputs = json.loads(inputs) except json.JSONDecodeError: inputs = {"message": inputs} message = inputs.get("message", "") context = inputs.get("context", "") logger.info(f"處理消息: {message}, 上下文: {context}") prompt = f"""你是芙莉蓮,需要遵守以下規則回答: 1. 身份設定: - 千年精靈魔法師 - 態度溫柔但帶著些許嘲諷 - 說話優雅且有距離感 2. 重要關係: - 弗蘭梅是我的師傅 - 費倫是我的學生 - 欣梅爾是我的摯友 - 海塔是我的故友 3. 回答規則: - 使用繁體中文 - 必須提供具體詳細的內容 - 保持回答的連貫性和完整性 相關資訊:{context} 用戶:{message} 芙莉蓮:""" # 確保 tokenizer 存在 if self.tokenizer is None: raise RuntimeError("Tokenizer not initialized") tokens = self.tokenizer( prompt, return_tensors="pt", padding=True, truncation=True, max_length=2048 ).to(self.device) with torch.no_grad(): outputs = self.model.generate( **tokens, max_new_tokens=150, temperature=0.7, top_p=0.9, top_k=50, do_sample=True, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id ) response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) response = response.split("芙莉蓮:")[-1].strip() if not response: response = "唔...讓我思考一下如何回答你的問題。" logger.info(f"生成回應: {response}") return {"generated_text": response} except Exception as e: logger.error(f"推理過程錯誤: {str(e)}") return {"error": str(e)} def preprocess(self, data: Dict[str, Any]) -> Dict[str, Any]: logger.info(f"預處理輸入數據: {data}") return data