FrierenChatbotV1 / handler.py
homer7676's picture
add file
8c184ed verified
raw
history blame
3.62 kB
# handler.py
import torch
from transformers import AutoTokenizer, AutoModel
import json
from typing import Dict, Any
import numpy as np
from opencc import OpenCC
import jieba
import re
class EndpointHandler:
def __init__(self):
self.tokenizer = None
self.model = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.converter = OpenCC('s2t')
def initialize(self, context):
"""初始化模型和 tokenizer"""
self.tokenizer = AutoTokenizer.from_pretrained(
"THUDM/chatglm3-6b-base",
trust_remote_code=True
)
self.model = AutoModel.from_pretrained(
"THUDM/chatglm3-6b-base",
trust_remote_code=True,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
).to(self.device)
self.model.eval()
def preprocess(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""預處理輸入數據"""
inputs = data.pop("inputs", data)
# 確保輸入格式正確
if not isinstance(inputs, dict):
inputs = {"message": inputs}
return inputs
def inference(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""執行推理"""
message = inputs.get("message", "")
context = inputs.get("context", "")
# 構建提示詞
prompt = self._build_prompt(context, message)
# tokenize
inputs = self.tokenizer(
prompt,
return_tensors="pt",
add_special_tokens=True,
truncation=True,
max_length=2048
).to(self.device)
# 生成回應
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=256,
temperature=0.7,
top_p=0.9,
do_sample=True,
repetition_penalty=1.2,
num_beams=4,
early_stopping=True
)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
response = response.split("芙莉蓮:")[-1].strip()
# 後處理回應
response = self._process_response(response)
return {"response": response}
def _build_prompt(self, context: str, query: str) -> str:
"""構建提示詞"""
return f"""你是芙莉蓮,需要遵守以下規則回答:
1. 身份設定:
- 千年精靈魔法師
- 態度溫柔但帶著些許嘲諷
- 說話優雅且有距離感
2. 重要關係:
- 弗蘭梅是我的師傅
- 費倫是我的學生
- 欣梅爾是我的摯友
- 海塔是我的故友
3. 回答規則:
- 使用繁體中文
- 必須提供具體詳細的內容
- 保持回答的連貫性和完整性
相關資訊:
{context}
用戶:{query}
芙莉蓮:"""
def _process_response(self, response: str) -> str:
"""處理回應文本"""
if not response or not response.strip():
return "抱歉,我現在有點恍神,請你再問一次好嗎?"
# 轉換為繁體
response = self.converter.convert(response)
# 清理和格式化
response = re.sub(r'\s+', '', response)
if not response.endswith(('。', '!', '?', '~', '呢', '啊', '吶')):
response += '呢。'
return response
def postprocess(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""後處理輸出數據"""
return data