File size: 3,620 Bytes
8ea6169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c184ed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# handler.py
import torch
from transformers import AutoTokenizer, AutoModel
import json
from typing import Dict, Any
import numpy as np
from opencc import OpenCC
import jieba
import re

class EndpointHandler:
    def __init__(self):
        self.tokenizer = None
        self.model = None
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.converter = OpenCC('s2t')

    def initialize(self, context):
        """初始化模型和 tokenizer"""
        self.tokenizer = AutoTokenizer.from_pretrained(
            "THUDM/chatglm3-6b-base",
            trust_remote_code=True
        )
        self.model = AutoModel.from_pretrained(
            "THUDM/chatglm3-6b-base",
            trust_remote_code=True,
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
        ).to(self.device)
        self.model.eval()
        
    def preprocess(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """預處理輸入數據"""
        inputs = data.pop("inputs", data)
        
        # 確保輸入格式正確
        if not isinstance(inputs, dict):
            inputs = {"message": inputs}
        
        return inputs

    def inference(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
        """執行推理"""
        message = inputs.get("message", "")
        context = inputs.get("context", "")
        
        # 構建提示詞
        prompt = self._build_prompt(context, message)
        
        # tokenize
        inputs = self.tokenizer(
            prompt,
            return_tensors="pt",
            add_special_tokens=True,
            truncation=True,
            max_length=2048
        ).to(self.device)
        
        # 生成回應
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=256,
                temperature=0.7,
                top_p=0.9,
                do_sample=True,
                repetition_penalty=1.2,
                num_beams=4,
                early_stopping=True
            )
        
        response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        response = response.split("芙莉蓮:")[-1].strip()
        
        # 後處理回應
        response = self._process_response(response)
        
        return {"response": response}

    def _build_prompt(self, context: str, query: str) -> str:
        """構建提示詞"""
        return f"""你是芙莉蓮,需要遵守以下規則回答:

1. 身份設定:
  - 千年精靈魔法師
  - 態度溫柔但帶著些許嘲諷
  - 說話優雅且有距離感

2. 重要關係:
  - 弗蘭梅是我的師傅
  - 費倫是我的學生
  - 欣梅爾是我的摯友
  - 海塔是我的故友

3. 回答規則:
  - 使用繁體中文
  - 必須提供具體詳細的內容
  - 保持回答的連貫性和完整性

相關資訊:
{context}

用戶:{query}
芙莉蓮:"""

    def _process_response(self, response: str) -> str:
        """處理回應文本"""
        if not response or not response.strip():
            return "抱歉,我現在有點恍神,請你再問一次好嗎?"
            
        # 轉換為繁體
        response = self.converter.convert(response)
        
        # 清理和格式化
        response = re.sub(r'\s+', '', response)
        if not response.endswith(('。', '!', '?', '~', '呢', '啊', '吶')):
            response += '呢。'
            
        return response

    def postprocess(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """後處理輸出數據"""
        return data