FrierenChatbotV1 / handler.py
homer7676's picture
Update handler.py
4233874 verified
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import Dict, Any, List
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class EndpointHandler:
def __init__(self, model_dir: str = None):
self.model_dir = model_dir if model_dir else "homer7676/FrierenChatbotV1"
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"初始化 EndpointHandler,使用設備: {self.device}")
# 在初始化時就載入模型和 tokenizer
try:
logger.info("開始載入 tokenizer 和模型")
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_dir,
trust_remote_code=True
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model = AutoModelForCausalLM.from_pretrained(
self.model_dir,
trust_remote_code=True,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
).to(self.device)
self.model.eval()
logger.info("模型和 tokenizer 載入完成")
except Exception as e:
logger.error(f"初始化錯誤: {str(e)}")
raise
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, str]]:
try:
# 確保 tokenizer 和 model 已經初始化
if self.tokenizer is None or self.model is None:
raise RuntimeError("Tokenizer or model not initialized")
inputs = self.preprocess(data)
outputs = self.inference(inputs)
return [outputs]
except Exception as e:
logger.error(f"處理過程錯誤: {str(e)}")
return [{"error": str(e)}]
def initialize(self, context):
"""確保模型已初始化"""
if self.tokenizer is None or self.model is None:
logger.info("在 initialize 中重新初始化模型")
try:
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_dir,
trust_remote_code=True
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model = AutoModelForCausalLM.from_pretrained(
self.model_dir,
trust_remote_code=True,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
).to(self.device)
self.model.eval()
logger.info("模型重新初始化完成")
except Exception as e:
logger.error(f"模型重新初始化錯誤: {str(e)}")
raise
def inference(self, inputs: Dict[str, Any]) -> Dict[str, str]:
logger.info("開始執行推理")
try:
# 檢查輸入格式
if isinstance(inputs, str):
try:
import json
inputs = json.loads(inputs)
except json.JSONDecodeError:
inputs = {"message": inputs}
# 提取消息和上下文
if isinstance(inputs, dict) and "inputs" in inputs:
inputs = inputs["inputs"]
if isinstance(inputs, str):
try:
import json
inputs = json.loads(inputs)
except json.JSONDecodeError:
inputs = {"message": inputs}
message = inputs.get("message", "")
context = inputs.get("context", "")
logger.info(f"處理消息: {message}, 上下文: {context}")
prompt = f"""你是芙莉蓮,需要遵守以下規則回答:
1. 身份設定:
- 千年精靈魔法師
- 態度溫柔但帶著些許嘲諷
- 說話優雅且有距離感
2. 重要關係:
- 弗蘭梅是我的師傅
- 費倫是我的學生
- 欣梅爾是我的摯友
- 海塔是我的故友
3. 回答規則:
- 使用繁體中文
- 必須提供具體詳細的內容
- 保持回答的連貫性和完整性
相關資訊:{context}
用戶:{message}
芙莉蓮:"""
# 確保 tokenizer 存在
if self.tokenizer is None:
raise RuntimeError("Tokenizer not initialized")
tokens = self.tokenizer(
prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=2048
).to(self.device)
with torch.no_grad():
outputs = self.model.generate(
**tokens,
max_new_tokens=150,
temperature=0.7,
top_p=0.9,
top_k=50,
do_sample=True,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
response = response.split("芙莉蓮:")[-1].strip()
if not response:
response = "唔...讓我思考一下如何回答你的問題。"
logger.info(f"生成回應: {response}")
return {"generated_text": response}
except Exception as e:
logger.error(f"推理過程錯誤: {str(e)}")
return {"error": str(e)}
def preprocess(self, data: Dict[str, Any]) -> Dict[str, Any]:
logger.info(f"預處理輸入數據: {data}")
return data