Spaces:
Running
Running
File size: 8,436 Bytes
5e9cd1d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 |
from fastchat.conversation import Conversation
from configs import LOG_PATH, TEMPERATURE
import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.base_model_worker import BaseModelWorker
import uuid
import json
import sys
from pydantic import BaseModel, root_validator
import fastchat
import asyncio
from server.utils import get_model_worker_config
from typing import Dict, List, Optional
__all__ = ["ApiModelWorker", "ApiChatParams", "ApiCompletionParams", "ApiEmbeddingsParams"]
class ApiConfigParams(BaseModel):
'''
在线API配置参数,未提供的值会自动从model_config.ONLINE_LLM_MODEL中读取
'''
api_base_url: Optional[str] = None
api_proxy: Optional[str] = None
api_key: Optional[str] = None
secret_key: Optional[str] = None
group_id: Optional[str] = None # for minimax
is_pro: bool = False # for minimax
APPID: Optional[str] = None # for xinghuo
APISecret: Optional[str] = None # for xinghuo
is_v2: bool = False # for xinghuo
worker_name: Optional[str] = None
class Config:
extra = "allow"
@root_validator(pre=True)
def validate_config(cls, v: Dict) -> Dict:
if config := get_model_worker_config(v.get("worker_name")):
for n in cls.__fields__:
if n in config:
v[n] = config[n]
return v
def load_config(self, worker_name: str):
self.worker_name = worker_name
if config := get_model_worker_config(worker_name):
for n in self.__fields__:
if n in config:
setattr(self, n, config[n])
return self
class ApiModelParams(ApiConfigParams):
'''
模型配置参数
'''
version: Optional[str] = None
version_url: Optional[str] = None
api_version: Optional[str] = None # for azure
deployment_name: Optional[str] = None # for azure
resource_name: Optional[str] = None # for azure
temperature: float = TEMPERATURE
max_tokens: Optional[int] = None
top_p: Optional[float] = 1.0
class ApiChatParams(ApiModelParams):
'''
chat请求参数
'''
messages: List[Dict[str, str]]
system_message: Optional[str] = None # for minimax
role_meta: Dict = {} # for minimax
class ApiCompletionParams(ApiModelParams):
prompt: str
class ApiEmbeddingsParams(ApiConfigParams):
texts: List[str]
embed_model: Optional[str] = None
to_query: bool = False # for minimax
class ApiModelWorker(BaseModelWorker):
DEFAULT_EMBED_MODEL: str = None # None means not support embedding
def __init__(
self,
model_names: List[str],
controller_addr: str = None,
worker_addr: str = None,
context_len: int = 2048,
no_register: bool = False,
**kwargs,
):
kwargs.setdefault("worker_id", uuid.uuid4().hex[:8])
kwargs.setdefault("model_path", "")
kwargs.setdefault("limit_worker_concurrency", 5)
super().__init__(model_names=model_names,
controller_addr=controller_addr,
worker_addr=worker_addr,
**kwargs)
import fastchat.serve.base_model_worker
import sys
self.logger = fastchat.serve.base_model_worker.logger
# 恢复被fastchat覆盖的标准输出
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
self.context_len = context_len
self.semaphore = asyncio.Semaphore(self.limit_worker_concurrency)
self.version = None
if not no_register and self.controller_addr:
self.init_heart_beat()
def count_token(self, params):
prompt = params["prompt"]
return {"count": len(str(prompt)), "error_code": 0}
def generate_stream_gate(self, params: Dict):
self.call_ct += 1
try:
prompt = params["prompt"]
if self._is_chat(prompt):
messages = self.prompt_to_messages(prompt)
messages = self.validate_messages(messages)
else: # 使用chat模仿续写功能,不支持历史消息
messages = [{"role": self.user_role, "content": f"please continue writing from here: {prompt}"}]
p = ApiChatParams(
messages=messages,
temperature=params.get("temperature"),
top_p=params.get("top_p"),
max_tokens=params.get("max_new_tokens"),
version=self.version,
)
for resp in self.do_chat(p):
yield self._jsonify(resp)
except Exception as e:
yield self._jsonify({"error_code": 500, "text": f"{self.model_names[0]}请求API时发生错误:{e}"})
def generate_gate(self, params):
try:
for x in self.generate_stream_gate(params):
...
return json.loads(x[:-1].decode())
except Exception as e:
return {"error_code": 500, "text": str(e)}
# 需要用户自定义的方法
def do_chat(self, params: ApiChatParams) -> Dict:
'''
执行Chat的方法,默认使用模块里面的chat函数。
要求返回形式:{"error_code": int, "text": str}
'''
return {"error_code": 500, "text": f"{self.model_names[0]}未实现chat功能"}
# def do_completion(self, p: ApiCompletionParams) -> Dict:
# '''
# 执行Completion的方法,默认使用模块里面的completion函数。
# 要求返回形式:{"error_code": int, "text": str}
# '''
# return {"error_code": 500, "text": f"{self.model_names[0]}未实现completion功能"}
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
'''
执行Embeddings的方法,默认使用模块里面的embed_documents函数。
要求返回形式:{"code": int, "data": List[List[float]], "msg": str}
'''
return {"code": 500, "msg": f"{self.model_names[0]}未实现embeddings功能"}
def get_embeddings(self, params):
# fastchat对LLM做Embeddings限制很大,似乎只能使用openai的。
# 在前端通过OpenAIEmbeddings发起的请求直接出错,无法请求过来。
print("get_embedding")
print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
raise NotImplementedError
def validate_messages(self, messages: List[Dict]) -> List[Dict]:
'''
有些API对mesages有特殊格式,可以重写该函数替换默认的messages。
之所以跟prompt_to_messages分开,是因为他们应用场景不同、参数不同
'''
return messages
# help methods
@property
def user_role(self):
return self.conv.roles[0]
@property
def ai_role(self):
return self.conv.roles[1]
def _jsonify(self, data: Dict) -> str:
'''
将chat函数返回的结果按照fastchat openai-api-server的格式返回
'''
return json.dumps(data, ensure_ascii=False).encode() + b"\0"
def _is_chat(self, prompt: str) -> bool:
'''
检查prompt是否由chat messages拼接而来
TODO: 存在误判的可能,也许从fastchat直接传入原始messages是更好的做法
'''
key = f"{self.conv.sep}{self.user_role}:"
return key in prompt
def prompt_to_messages(self, prompt: str) -> List[Dict]:
'''
将prompt字符串拆分成messages.
'''
result = []
user_role = self.user_role
ai_role = self.ai_role
user_start = user_role + ":"
ai_start = ai_role + ":"
for msg in prompt.split(self.conv.sep)[1:-1]:
if msg.startswith(user_start):
if content := msg[len(user_start):].strip():
result.append({"role": user_role, "content": content})
elif msg.startswith(ai_start):
if content := msg[len(ai_start):].strip():
result.append({"role": ai_role, "content": content})
else:
raise RuntimeError(f"unknown role in msg: {msg}")
return result
@classmethod
def can_embedding(cls):
return cls.DEFAULT_EMBED_MODEL is not None
|