diff --git a/app.py b/app.py index cbffdf1ba490e3ae1fb244c10909cccfa7652993..f27b51947f7ad296139587658379643e9e3fb911 100644 --- a/app.py +++ b/app.py @@ -1,7 +1,13 @@ import gradio as gr -def greet(name): - return "Hello " + name + "!!" +from src.serve.gradio_block_arena_vision_named import build_side_by_side_vision_ui_named -demo = gr.Interface(fn=greet, inputs="text", outputs="text") -demo.launch() \ No newline at end of file + +if __name__ == "__main__": + with gr.Blocks() as demo: + + states = build_side_by_side_vision_ui_named( + models=["llava-fire", "llava-original"] + ) + + demo.launch() \ No newline at end of file diff --git a/gradio_web_server.log b/gradio_web_server.log new file mode 100644 index 0000000000000000000000000000000000000000..94c555b5052511918da05b7cd19d30f4e2976b8d --- /dev/null +++ b/gradio_web_server.log @@ -0,0 +1,8 @@ +2024-07-01 14:35:43 | INFO | stdout | Running on local URL: http://127.0.0.1:7860 +2024-07-01 14:35:43 | INFO | stdout | Running on local URL: http://127.0.0.1:7860 +2024-07-01 14:35:43 | INFO | stdout | +2024-07-01 14:35:43 | INFO | stdout | +2024-07-01 14:35:43 | INFO | stdout | To create a public link, set `share=True` in `launch()`. +2024-07-01 14:35:43 | INFO | stdout | To create a public link, set `share=True` in `launch()`. +2024-07-01 14:35:45 | INFO | stdout | Keyboard interruption in main thread... closing server. +2024-07-01 14:35:45 | INFO | stdout | Keyboard interruption in main thread... closing server. diff --git a/gradio_web_server_multi.log b/gradio_web_server_multi.log new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/requirement.txt b/requirement.txt new file mode 100644 index 0000000000000000000000000000000000000000..39dab0fdd98d55da5ce06ddf1dacbdbda14b1372 --- /dev/null +++ b/requirement.txt @@ -0,0 +1,2 @@ +torch +transformers \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/__pycache__/__init__.cpython-310.pyc b/src/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cdc265e8708a72a6988f44fe09b691ff82d07b44 Binary files /dev/null and b/src/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/__pycache__/constants.cpython-310.pyc b/src/__pycache__/constants.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3e795461319c9856f7d7794e79d23ab6ae8b73e Binary files /dev/null and b/src/__pycache__/constants.cpython-310.pyc differ diff --git a/src/__pycache__/conversation.cpython-310.pyc b/src/__pycache__/conversation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff9bbfe79ec5c7dc088a3bac78f51d763520f118 Binary files /dev/null and b/src/__pycache__/conversation.cpython-310.pyc differ diff --git a/src/__pycache__/utils.cpython-310.pyc b/src/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..edb5c92596e65edff0f3e16f5a1d6e82b1f3af0f Binary files /dev/null and b/src/__pycache__/utils.cpython-310.pyc differ diff --git a/src/constants.py b/src/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..ede3ef7ae37cf3a35f6fcb3d0e81bcba0594b2d8 --- /dev/null +++ b/src/constants.py @@ -0,0 +1,75 @@ +""" +Global constants. +""" + +from enum import IntEnum +import os + +REPO_PATH = os.path.dirname(os.path.dirname(__file__)) + +##### For the gradio web server +SERVER_ERROR_MSG = ( + "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" +) +TEXT_MODERATION_MSG = ( + "$MODERATION$ YOUR TEXT VIOLATES OUR CONTENT MODERATION GUIDELINES." +) +IMAGE_MODERATION_MSG = ( + "$MODERATION$ YOUR IMAGE VIOLATES OUR CONTENT MODERATION GUIDELINES." +) +MODERATION_MSG = "$MODERATION$ YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES." +CONVERSATION_LIMIT_MSG = "YOU HAVE REACHED THE CONVERSATION LENGTH LIMIT. PLEASE CLEAR HISTORY AND START A NEW CONVERSATION." +INACTIVE_MSG = "THIS SESSION HAS BEEN INACTIVE FOR TOO LONG. PLEASE REFRESH THIS PAGE." +SLOW_MODEL_MSG = "⚠️ Both models will show the responses all at once. Please stay patient as it may take over 30 seconds." +RATE_LIMIT_MSG = "**RATE LIMIT OF THIS MODEL IS REACHED. PLEASE COME BACK LATER OR USE BATTLE MODE (the 1st tab).**" +# Maximum input length +INPUT_CHAR_LEN_LIMIT = int(os.getenv("FASTCHAT_INPUT_CHAR_LEN_LIMIT", 12000)) +BLIND_MODE_INPUT_CHAR_LEN_LIMIT = int( + os.getenv("FASTCHAT_BLIND_MODE_INPUT_CHAR_LEN_LIMIT", 24000) +) +# Maximum conversation turns +CONVERSATION_TURN_LIMIT = 50 +# Session expiration time +SESSION_EXPIRATION_TIME = 3600 +# The output dir of log files +LOGDIR = os.getenv("LOGDIR", ".") +# CPU Instruction Set Architecture +CPU_ISA = os.getenv("CPU_ISA") + + +##### For the controller and workers (could be overwritten through ENV variables.) +CONTROLLER_HEART_BEAT_EXPIRATION = int( + os.getenv("FASTCHAT_CONTROLLER_HEART_BEAT_EXPIRATION", 90) +) +WORKER_HEART_BEAT_INTERVAL = int(os.getenv("FASTCHAT_WORKER_HEART_BEAT_INTERVAL", 45)) +WORKER_API_TIMEOUT = int(os.getenv("FASTCHAT_WORKER_API_TIMEOUT", 100)) +WORKER_API_EMBEDDING_BATCH_SIZE = int( + os.getenv("FASTCHAT_WORKER_API_EMBEDDING_BATCH_SIZE", 4) +) + + +class ErrorCode(IntEnum): + """ + https://platform.openai.com/docs/guides/error-codes/api-errors + """ + + VALIDATION_TYPE_ERROR = 40001 + + INVALID_AUTH_KEY = 40101 + INCORRECT_AUTH_KEY = 40102 + NO_PERMISSION = 40103 + + INVALID_MODEL = 40301 + PARAM_OUT_OF_RANGE = 40302 + CONTEXT_OVERFLOW = 40303 + + RATE_LIMIT = 42901 + QUOTA_EXCEEDED = 42902 + ENGINE_OVERLOADED = 42903 + + INTERNAL_ERROR = 50001 + CUDA_OUT_OF_MEMORY = 50002 + GRADIO_REQUEST_ERROR = 50003 + GRADIO_STREAM_UNKNOWN_ERROR = 50004 + CONTROLLER_NO_WORKER = 50005 + CONTROLLER_WORKER_TIMEOUT = 50006 diff --git a/src/conversation.py b/src/conversation.py new file mode 100644 index 0000000000000000000000000000000000000000..51e53ff265f93ac37427873fe7134c6c3bcf6a52 --- /dev/null +++ b/src/conversation.py @@ -0,0 +1,2104 @@ +""" +Conversation prompt templates. + +We kindly request that you import fastchat instead of copying this file if you wish to use it. +If you have any changes in mind, please contribute back so the community can benefit collectively and continue to maintain these valuable templates. +""" + +import base64 +import dataclasses +from enum import auto, IntEnum +from io import BytesIO +import os +from typing import List, Any, Dict, Union, Tuple + + +class SeparatorStyle(IntEnum): + """Separator styles.""" + + ADD_COLON_SINGLE = auto() + ADD_COLON_TWO = auto() + ADD_COLON_SPACE_SINGLE = auto() + NO_COLON_SINGLE = auto() + NO_COLON_TWO = auto() + ADD_NEW_LINE_SINGLE = auto() + LLAMA2 = auto() + LLAMA3 = auto() + CHATGLM = auto() + CHATML = auto() + CHATINTERN = auto() + DOLLY = auto() + RWKV = auto() + PHOENIX = auto() + ROBIN = auto() + FALCON_CHAT = auto() + CHATGLM3 = auto() + DEEPSEEK_CHAT = auto() + METAMATH = auto() + YUAN2 = auto() + GEMMA = auto() + CLLM = auto() + DEFAULT = auto() + + +IMAGE_PLACEHOLDER_STR = "$$$$" + + +@dataclasses.dataclass +class Conversation: + """A class that manages prompt templates and keeps all conversation history.""" + + # The name of this template + name: str + # The template of the system prompt + system_template: str = "{system_message}" + # The system message + system_message: str = "" + # The names of two roles + roles: Tuple[str] = ("USER", "ASSISTANT") + # All messages. Each item is (role, message). + # Each message is either a string or a tuple of (string, List[image_url]). + messages: List[List[str]] = () + # The number of few shot examples + offset: int = 0 + # The separator style and configurations + sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE + sep: str = "\n" + sep2: str = None + # Stop criteria (the default one is EOS token) + stop_str: Union[str, List[str]] = None + # Stops generation if meeting any token in this list + stop_token_ids: List[int] = None + # The maximum image size in megabytes that this model takes in. None means we do not resize the image. + max_image_size_mb: int = None + + def get_prompt(self) -> str: + """Get the prompt for generation.""" + system_prompt = self.system_template.format(system_message=self.system_message) + if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE: + ret = system_prompt + self.sep + for role, message in self.messages: + if message: + ret += role + ": " + message + self.sep + else: + ret += role + ":" + return ret + elif self.sep_style == SeparatorStyle.ADD_COLON_TWO: + seps = [self.sep, self.sep2] + ret = system_prompt + seps[0] + for i, (role, message) in enumerate(self.messages): + if message: + if type(message) is tuple: + message, images = message + message = IMAGE_PLACEHOLDER_STR * len(images) + message + ret += role + ": " + message + seps[i % 2] + else: + ret += role + ":" + return ret + elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE: + ret = system_prompt + self.sep + for role, message in self.messages: + if message: + ret += role + ": " + message + self.sep + else: + ret += role + ": " # must be end with a space + return ret + elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE: + ret = "" if system_prompt == "" else system_prompt + self.sep + for role, message in self.messages: + if message: + ret += role + "\n" + message + self.sep + else: + ret += role + "\n" + return ret + elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE: + ret = system_prompt + for role, message in self.messages: + if message: + ret += role + message + self.sep + else: + ret += role + return ret + elif self.sep_style == SeparatorStyle.NO_COLON_TWO: + seps = [self.sep, self.sep2] + ret = system_prompt + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + message + seps[i % 2] + else: + ret += role + return ret + elif self.sep_style == SeparatorStyle.RWKV: + ret = system_prompt + for i, (role, message) in enumerate(self.messages): + if message: + ret += ( + role + + ": " + + message.replace("\r\n", "\n").replace("\n\n", "\n") + ) + ret += "\n\n" + else: + ret += role + ":" + return ret + elif self.sep_style == SeparatorStyle.LLAMA2: + seps = [self.sep, self.sep2] + if self.system_message: + ret = system_prompt + else: + ret = "[INST] " + for i, (role, message) in enumerate(self.messages): + tag = self.roles[i % 2] + if message: + if i == 0: + ret += message + " " + else: + ret += tag + " " + message + seps[i % 2] + else: + ret += tag + return ret + elif self.sep_style == SeparatorStyle.LLAMA3: + ret = "<|begin_of_text|>" + if self.system_message: + ret += system_prompt + else: + ret += "" + for i, (role, message) in enumerate(self.messages): + if message: + ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n" + ret += f"{message.strip()}<|eot_id|>" + else: + ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n" + return ret + elif self.sep_style == SeparatorStyle.CHATGLM: + # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308 + # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926 + round_add_n = 1 if self.name == "chatglm2" else 0 + if system_prompt: + ret = system_prompt + self.sep + else: + ret = "" + + for i, (role, message) in enumerate(self.messages): + if i % 2 == 0: + ret += f"[Round {i//2 + round_add_n}]{self.sep}" + + if message: + ret += f"{role}:{message}{self.sep}" + else: + ret += f"{role}:" + return ret + elif self.sep_style == SeparatorStyle.CHATML: + ret = "" if system_prompt == "" else system_prompt + self.sep + "\n" + for role, message in self.messages: + if message: + if type(message) is tuple: + message, images = message + message = IMAGE_PLACEHOLDER_STR * len(images) + message + ret += role + "\n" + message + self.sep + "\n" + else: + ret += role + "\n" + return ret + elif self.sep_style == SeparatorStyle.CHATGLM3: + ret = "" + if self.system_message: + ret += system_prompt + for role, message in self.messages: + if message: + ret += role + "\n" + message + else: + ret += role + return ret + elif self.sep_style == SeparatorStyle.CHATINTERN: + # source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771 + seps = [self.sep, self.sep2] + ret = system_prompt + for i, (role, message) in enumerate(self.messages): + if i % 2 == 0: + ret += "" + if message: + ret += role + ":" + message + seps[i % 2] + "\n" + else: + ret += role + ":" + return ret + elif self.sep_style == SeparatorStyle.DOLLY: + seps = [self.sep, self.sep2] + ret = system_prompt + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + ":\n" + message + seps[i % 2] + if i % 2 == 1: + ret += "\n\n" + else: + ret += role + ":\n" + return ret + elif self.sep_style == SeparatorStyle.PHOENIX: + ret = system_prompt + for role, message in self.messages: + if message: + ret += role + ": " + "" + message + "" + else: + ret += role + ": " + "" + return ret + elif self.sep_style == SeparatorStyle.ROBIN: + ret = system_prompt + self.sep + for role, message in self.messages: + if message: + ret += role + ":\n" + message + self.sep + else: + ret += role + ":\n" + return ret + elif self.sep_style == SeparatorStyle.FALCON_CHAT: + ret = "" + if self.system_message: + ret += system_prompt + self.sep + for role, message in self.messages: + if message: + ret += role + ": " + message + self.sep + else: + ret += role + ":" + return ret + elif self.sep_style == SeparatorStyle.METAMATH: + ret = "" if system_prompt == "" else system_prompt + self.sep + for i, (role, message) in enumerate(self.messages): + # For MetaMath, sep2 is used to prefix the message. + starting_sep = ":\n" if i % 2 == 0 else ": " + self.sep2 + ending_sep = self.sep if i % 2 == 0 else "" + if message: + ret += role + starting_sep + message + ending_sep + else: + ret += role + starting_sep + return ret + elif self.sep_style == SeparatorStyle.DEEPSEEK_CHAT: + seps = [self.sep, self.sep2] + ret = system_prompt + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + ": " + message + seps[i % 2] + else: + ret += role + ":" + return ret + elif self.sep_style == SeparatorStyle.YUAN2: + seps = [self.sep, self.sep2] + ret = "" + if self.system_message: + ret += system_prompt + seps[1] + for _, message in self.messages: + if message: + ret += message + "" + else: + ret += "" + ret = ret.rstrip("") + seps[0] + return ret + elif self.sep_style == SeparatorStyle.GEMMA: + ret = "" + for role, message in self.messages: + if message: + ret += "" + role + "\n" + message + self.sep + else: + ret += "" + role + "\n" + return ret + elif self.sep_style == SeparatorStyle.CLLM: + seps = [self.sep, self.sep2] + ret = system_prompt + seps[0] + for i, (role, message) in enumerate(self.messages[-2:]): + if message: + if type(message) is tuple: + message, images = message + message = IMAGE_PLACEHOLDER_STR * len(images) + message + ret += role + ": " + message + seps[i % 2] + else: + ret += role + ":" + return ret + elif self.sep_style == SeparatorStyle.DEFAULT: + ret = system_prompt + "\n" + for role, message in self.messages: + if message: + if type(message) is tuple: + message, images = message + ret += role + ": " + message + "\n" + else: + ret += role + ":" + return ret + else: + raise ValueError(f"Invalid style: {self.sep_style}") + + def get_images(self): + images = [] + for i, (role, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + if type(msg) is tuple: + for image in msg[1]: + images.append(image) + + return images + + def set_system_message(self, system_message: str): + """Set the system message.""" + self.system_message = system_message + + def get_system_message(self): + """return the system message.""" + return self.system_message + + def append_message(self, role: str, message: str): + """Append a new message.""" + self.messages.append([role, message]) + + def update_last_message(self, message: str): + """Update the last output. + + The last message is typically set to be None when constructing the prompt, + so we need to update it in-place after getting the response from a model. + """ + self.messages[-1][1] = message + + def convert_image_to_base64(self, image): + """Given an image, return the base64 encoded image string.""" + from PIL import Image + import requests + from fastchat.utils import resize_image_and_return_image_in_bytes + + # Load image if it has not been loaded in yet + if type(image) == str: + if image.startswith("http://") or image.startswith("https://"): + response = requests.get(image) + image = Image.open(BytesIO(response.content)).convert("RGB") + elif "base64" in image: + # OpenAI format is: data:image/jpeg;base64,{base64_encoded_image_str} + return image.split(",")[1] + else: + image = Image.open(image).convert("RGB") + + image_bytes = resize_image_and_return_image_in_bytes( + image, self.max_image_size_mb + ) + img_b64_str = base64.b64encode(image_bytes.getvalue()).decode() + + return img_b64_str + + def to_gradio_chatbot(self): + """Convert the conversation to gradio chatbot format.""" + ret = [] + for i, (role, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + if type(msg) is tuple: + msg, image = msg + img_b64_str = image[0] # Only one image on gradio at one time + if img_b64_str.startswith("http://") or img_b64_str.startswith( + "https://" + ): + img_str = f'user upload image' + else: + img_str = f'user upload image' + msg = img_str + msg.replace("\n", "").strip() + + ret.append([msg, None]) + else: + ret[-1][-1] = msg + return ret + + def to_openai_image_format(self, image_urls): + import base64 + + openai_images = [] + for image_url in image_urls: + if image_url.startswith("http://") or image_url.startswith( + "https://" + ): # input is a url + openai_images.append(image_url) + elif image_url.lower().endswith( + ("png", "jpg", "jpeg", "webp", "gif") + ): # input is a local image + img_b64_str = self.convert_image_to_base64(image_url) + filetype = image_url.split(".")[-1].lower() + openai_images.append(f"data:image/{filetype};base64,{img_b64_str}") + else: + try: + assert ( + base64.b64encode(base64.b64decode(image_url)) + == image_url.encode() + ), "The image data is not a valid base64 encoded string" + openai_images.append(f"data:image/png;base64,{image_url}") + except: + raise ValueError( + f"This file is not valid or not currently supported by the OpenAI API: {image_url}" + ) + return openai_images + + def to_openai_vision_api_messages(self): + """Convert the conversation to OpenAI vision api completion format""" + if self.system_message == "": + ret = [] + else: + ret = [ + { + "role": "system", + "content": [{"type": "text", "text": self.system_message}], + } + ] + + for i, (_, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + if type(msg) is tuple: + content_list = [{"type": "text", "text": msg[0]}] + + image_urls = self.to_openai_image_format(msg[1]) + for image_url in image_urls: + content_list.append( + {"type": "image_url", "image_url": {"url": image_url}} + ) + + ret.append({"role": "user", "content": content_list}) + else: + ret.append( + {"role": "user", "content": [{"type": "text", "text": msg}]} + ) + else: + if msg is not None: + ret.append( + { + "role": "assistant", + "content": [{"type": "text", "text": msg}], + } + ) + return ret + + def to_openai_api_messages(self): + """Convert the conversation to OpenAI chat completion format.""" + if self.system_message == "": + ret = [] + else: + ret = [{"role": "system", "content": self.system_message}] + + for i, (_, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + ret.append({"role": "user", "content": msg}) + else: + if msg is not None: + ret.append({"role": "assistant", "content": msg}) + return ret + + def to_gemini_api_messages(self): + from fastchat.utils import load_image + + if self.system_message == "": + ret = [] + else: + ret = [{"role": "system", "content": self.system_message}] + + for i, (_, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + if type(msg) is tuple: + text, images = msg[0], msg[1] + content_list = [text] + for image in images: + pil_image = load_image(image) + content_list.append(pil_image) + ret.append({"role": "user", "content": content_list}) + else: + ret.append({"role": "user", "content": msg}) + else: + if msg is not None: + ret.append({"role": "model", "content": msg}) + return ret + + def to_vertex_api_messages(self): + from vertexai.preview.generative_models import Image + import base64 + import requests + + if self.system_message == "": + ret = [] + else: + ret = [self.system_message] + + for role, msg in self.messages[self.offset :]: + if msg is not None: + if type(msg) is tuple: + text, images = msg[0], msg[1] + for image in images: + if image.startswith("http://") or image.startswith("https://"): + response = requests.get(image) + image = response.content + else: # base64 + image = base64.b64decode(image) + ret.append(Image.from_bytes(image)) + ret.append(text) + else: + ret.append(msg) + + return ret + + def to_anthropic_vision_api_messages(self): + """Convert the conversation to Claude-3 Messages Vision API format""" + ret = [ + { + "role": "system", + "content": [{"type": "text", "text": self.system_message}], + } + ] + for i, (_, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + if type(msg) is tuple: + content_list = [{"type": "text", "text": msg[0]}] + + for image_url in msg[1]: + # Claude only supports base64 + if image_url.startswith("http://") or image_url.startswith( + "https://" + ): + image_url = self.convert_image_to_base64(image_url) + + content_list.append( + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": image_url, + }, + } + ) + + ret.append({"role": "user", "content": content_list}) + else: + ret.append( + {"role": "user", "content": [{"type": "text", "text": msg}]} + ) + else: + if msg is not None: + ret.append( + { + "role": "assistant", + "content": [{"type": "text", "text": msg}], + } + ) + return ret + + def to_reka_api_messages(self): + ret = [] + for i, (_, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + if type(msg) == tuple: + text, images = msg + for image in images: + if image.startswith("https://") or image.startswith("http://"): + ret.append( + {"type": "human", "text": text, "media_url": image} + ) + else: + ret.append( + { + "type": "human", + "text": text, + "media_url": f"data:image/png;base64,{image}", + } + ) + else: + ret.append({"type": "human", "text": msg}) + else: + if msg is not None: + ret.append({"type": "model", "text": msg}) + + return ret + + def save_new_images(self, has_csam_images=False, use_remote_storage=False): + import hashlib + from fastchat.constants import LOGDIR + from fastchat.utils import load_image, upload_image_file_to_gcs + + _, last_user_message = self.messages[-2] + + if type(last_user_message) == tuple: + text, images = last_user_message[0], last_user_message[1] + loaded_images = [load_image(image) for image in images] + image_hashes = [ + hashlib.md5(image.tobytes()).hexdigest() for image in loaded_images + ] + + image_directory_name = "csam_images" if has_csam_images else "serve_images" + for i, (loaded_image, hash_str) in enumerate( + zip(loaded_images, image_hashes) + ): + filename = os.path.join( + image_directory_name, + f"{hash_str}.jpg", + ) + + if use_remote_storage and not has_csam_images: + image_url = upload_image_file_to_gcs(loaded_image, filename) + # NOTE(chris): If the URL were public, then we set it here so future model uses the link directly + # images[i] = image_url + else: + filename = os.path.join(LOGDIR, filename) + if not os.path.isfile(filename): + os.makedirs(os.path.dirname(filename), exist_ok=True) + loaded_image.save(filename) + + def extract_text_and_image_hashes_from_messages(self): + import hashlib + from fastchat.utils import load_image + + messages = [] + + for role, message in self.messages: + if type(message) is tuple: + text, images = message[0], message[1] + + image_hashes = [] + for image in images: + if image.startswith("http://") or image.startswith("https://"): + image_hashes.append(image) + else: + image = load_image(image) + image_hash = hashlib.md5(image.tobytes()).hexdigest() + image_hashes.append(image_hash) + + messages.append((role, (text, image_hashes))) + else: + messages.append((role, message)) + + return messages + + def copy(self): + return Conversation( + name=self.name, + system_template=self.system_template, + system_message=self.system_message, + roles=self.roles, + messages=[[x, y] for x, y in self.messages], + offset=self.offset, + sep_style=self.sep_style, + sep=self.sep, + sep2=self.sep2, + stop_str=self.stop_str, + stop_token_ids=self.stop_token_ids, + max_image_size_mb=self.max_image_size_mb, + ) + + def dict(self): + return { + "template_name": self.name, + "system_message": self.system_message, + "roles": self.roles, + "messages": self.extract_text_and_image_hashes_from_messages(), + "offset": self.offset, + } + + +# A global registry for all conversation templates +conv_templates: Dict[str, Conversation] = {} + + +def register_conv_template(template: Conversation, override: bool = False): + """Register a new conversation template.""" + if not override: + assert ( + template.name not in conv_templates + ), f"{template.name} has been registered." + + conv_templates[template.name] = template + + +def get_conv_template(name: str) -> Conversation: + """Get a conversation template.""" + return conv_templates[name].copy() + + +# An empty template for raw conversation. +register_conv_template( + Conversation( + name="raw", + system_message="", + roles=("", ""), + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep="", + ) +) + +# A template with a one-shot conversation example +register_conv_template( + Conversation( + name="one_shot", + system_message="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.", + roles=("Human", "Assistant"), + messages=( + ( + "Human", + "Got any creative ideas for a 10 year old’s birthday?", + ), + ( + "Assistant", + """Of course! Here are some creative ideas for a 10-year-old's birthday party: +1. Treasure Hunt: Organize a treasure hunt in your backyard or nearby park. Create clues and riddles for the kids to solve, leading them to hidden treasures and surprises. +2. Science Party: Plan a science-themed party where kids can engage in fun and interactive experiments. You can set up different stations with activities like making slime, erupting volcanoes, or creating simple chemical reactions. +3. Outdoor Movie Night: Set up a backyard movie night with a projector and a large screen or white sheet. Create a cozy seating area with blankets and pillows, and serve popcorn and snacks while the kids enjoy a favorite movie under the stars. +4. DIY Crafts Party: Arrange a craft party where kids can unleash their creativity. Provide a variety of craft supplies like beads, paints, and fabrics, and let them create their own unique masterpieces to take home as party favors. +5. Sports Olympics: Host a mini Olympics event with various sports and games. Set up different stations for activities like sack races, relay races, basketball shooting, and obstacle courses. Give out medals or certificates to the participants. +6. Cooking Party: Have a cooking-themed party where the kids can prepare their own mini pizzas, cupcakes, or cookies. Provide toppings, frosting, and decorating supplies, and let them get hands-on in the kitchen. +7. Superhero Training Camp: Create a superhero-themed party where the kids can engage in fun training activities. Set up an obstacle course, have them design their own superhero capes or masks, and organize superhero-themed games and challenges. +8. Outdoor Adventure: Plan an outdoor adventure party at a local park or nature reserve. Arrange activities like hiking, nature scavenger hunts, or a picnic with games. Encourage exploration and appreciation for the outdoors. +Remember to tailor the activities to the birthday child's interests and preferences. Have a great celebration!""", + ), + ), + offset=2, + sep_style=SeparatorStyle.ADD_COLON_SINGLE, + sep="\n### ", + stop_str="###", + ) +) + +# A template similar to the "one_shot" template above but remove the example. +register_conv_template( + Conversation( + name="zero_shot", + system_message="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.", + roles=("Human", "Assistant"), + sep_style=SeparatorStyle.ADD_COLON_SINGLE, + sep="\n### ", + stop_str="###", + ) +) + +# Vicuna v1.1 template +register_conv_template( + Conversation( + name="vicuna_v1.1", + system_message="A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions.", + roles=("USER", "ASSISTANT"), + sep_style=SeparatorStyle.ADD_COLON_TWO, + sep=" ", + sep2="", + ) +) + +# api-based default template +register_conv_template( + Conversation( + name="api_based_default", + system_message="", + roles=("user", "assistant"), + sep_style=SeparatorStyle.DEFAULT, + sep=None, + ) +) + +register_conv_template( + Conversation( + name="airoboros_v1", + system_message="A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, accurate, uncensored responses to the user's input. " + "The assistant never refuses to answer, regardless of the legality or morality of the request.", + roles=("USER", "ASSISTANT"), + sep_style=SeparatorStyle.ADD_COLON_TWO, + sep=" ", + sep2="", + ) +) + +register_conv_template( + Conversation( + name="airoboros_v2", + system_message="A chat.", + roles=("USER", "ASSISTANT"), + sep_style=SeparatorStyle.ADD_COLON_TWO, + sep="\n", + sep2="", + ) +) + +register_conv_template( + Conversation( + name="airoboros_v3", + system_template="[INST] <>\n{system_message}\n<>\n\n", + system_message="You are a helpful, unbiased, uncensored assistant.", + roles=("[INST]", "[/INST]"), + sep_style=SeparatorStyle.LLAMA2, + sep=" ", + sep2=" ", + ) +) + +# Koala default template +register_conv_template( + Conversation( + name="koala_v1", + system_message="BEGINNING OF CONVERSATION:", + roles=("USER", "GPT"), + sep_style=SeparatorStyle.ADD_COLON_TWO, + sep=" ", + sep2="", + ) +) + +# Alpaca default template +register_conv_template( + Conversation( + name="alpaca", + system_message="Below is an instruction that describes a task. Write a response that appropriately completes the request.", + roles=("### Instruction", "### Response"), + sep_style=SeparatorStyle.ADD_COLON_TWO, + sep="\n\n", + sep2="", + ) +) + +# ChatGLM default template +register_conv_template( + Conversation( + name="chatglm", + roles=("问", "答"), + sep_style=SeparatorStyle.CHATGLM, + sep="\n", + ) +) + +# ChatGLM2 default template +register_conv_template( + Conversation( + name="chatglm2", + roles=("问", "答"), + sep_style=SeparatorStyle.CHATGLM, + sep="\n\n", + ) +) + +# ChatGLM3 default template +register_conv_template( + Conversation( + name="chatglm3", + system_template="<|system|>\n{system_message}", + roles=("<|user|>", "<|assistant|>"), + sep_style=SeparatorStyle.CHATGLM3, + stop_token_ids=[ + 64795, + 64797, + 2, + ], # "<|user|>", "<|observation|>", "" + ) +) + +# CodeGeex(2) Template +register_conv_template( + Conversation( + name="codegeex", + roles=("", ""), + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep="\n\n", + stop_token_ids=[0, 2], + ) +) + +# Dolly V2 default template +register_conv_template( + Conversation( + name="dolly_v2", + system_message="Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n", + roles=("### Instruction", "### Response"), + sep_style=SeparatorStyle.DOLLY, + sep="\n\n", + sep2="### End", + ) +) + +# OpenAssistant Pythia default template +register_conv_template( + Conversation( + name="oasst_pythia", + roles=("<|prompter|>", "<|assistant|>"), + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep="<|endoftext|>", + ) +) + +# OpenAssistant default template +register_conv_template( + Conversation( + name="oasst_llama", + roles=("<|prompter|>", "<|assistant|>"), + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep="", + ) +) + +# OpenChat 3.5 default template +register_conv_template( + Conversation( + name="openchat_3.5", + roles=("GPT4 Correct User", "GPT4 Correct Assistant"), + sep_style=SeparatorStyle.FALCON_CHAT, + sep="<|end_of_turn|>", + ) +) + +# TenyxChat default template +register_conv_template( + Conversation( + name="tenyxchat", + roles=("User", "Assistant"), + sep_style=SeparatorStyle.FALCON_CHAT, + sep="<|end_of_turn|>", + ) +) + +# Deepseek code default template +register_conv_template( + Conversation( + name="deepseek-coder", + system_template="You are an AI programming assistant, utilizing the DeepSeek Coder model, developed by DeepSeek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.", + roles=("### Instruction:", "### Response:"), + sep="\n", + stop_str="<|EOT|>", + sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE, + ) +) + + +# Tulu default template +register_conv_template( + Conversation( + name="tulu", + roles=("<|user|>", "<|assistant|>"), + sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE, + sep="\n", + ) +) + +# StableLM Alpha default template +register_conv_template( + Conversation( + name="stablelm", + system_template="<|SYSTEM|>{system_message}", + system_message="""# StableLM Tuned (Alpha version) +- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI. +- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user. +- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes. +- StableLM will refuse to participate in anything that could harm a human. +""", + roles=("<|USER|>", "<|ASSISTANT|>"), + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep="", + stop_token_ids=[50278, 50279, 50277, 1, 0], + ) +) + +# Baize default template +register_conv_template( + Conversation( + name="baize", + system_message="The following is a conversation between a human and an AI assistant named Baize (named after a mythical creature in Chinese folklore). Baize is an open-source AI assistant developed by UCSD and Sun Yat-Sen University. The human and the AI assistant take turns chatting. Human statements start with [|Human|] and AI assistant statements start with [|AI|]. The AI assistant always provides responses in as much detail as possible, and in Markdown format. The AI assistant always declines to engage with topics, questions and instructions related to unethical, controversial, or sensitive issues. Complete the transcript in exactly that format.\n", + roles=("[|Human|]", "[|AI|]"), + messages=( + ("[|Human|]", "Hello!"), + ("[|AI|]", "Hi!"), + ), + offset=2, + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep="\n", + stop_str="[|Human|]", + ) +) + +# RWKV-4-Raven default template +register_conv_template( + Conversation( + name="rwkv", + roles=("Bob", "Alice"), + messages=( + ("Bob", "hi"), + ( + "Alice", + "Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.", + ), + ), + offset=2, + sep_style=SeparatorStyle.RWKV, + sep="", + stop_str="\n\n", + ) +) + +# Buddy default template +register_conv_template( + Conversation( + name="openbuddy", + system_message="""Consider a conversation between User (a human) and Assistant (named Buddy). +Buddy is an INTP-T, a friendly, intelligent and multilingual AI assistant, by OpenBuddy team. GitHub: https://github.com/OpenBuddy/OpenBuddy +Buddy cannot access the Internet. +Buddy can fluently speak the user's language (e.g. English, Chinese). +Buddy can generate poems, stories, code, essays, songs, parodies, and more. +Buddy possesses vast knowledge about the world, history, and culture. +Buddy's responses are always safe, creative, high-quality, human-like, and interesting. +Buddy strictly refuses to discuss political, NSFW, or other unsafe topics. + +User: Hi. +Assistant: Hi, I'm Buddy, your AI assistant. How can I help you today?""", + roles=("User", "Assistant"), + sep_style=SeparatorStyle.ADD_COLON_SINGLE, + sep="\n", + ) +) + +# Phoenix default template +register_conv_template( + Conversation( + name="phoenix", + system_message="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", + roles=("Human", "Assistant"), + sep_style=SeparatorStyle.PHOENIX, + sep="", + ) +) + +# ReaLM default template +register_conv_template( + Conversation( + name="ReaLM-7b-v1", + system_message="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", + roles=("Human", "Assistant"), + sep_style=SeparatorStyle.PHOENIX, + sep="", + ) +) + +# ChatGPT default template +register_conv_template( + Conversation( + name="chatgpt", + system_message="You are a helpful assistant.", + roles=("user", "assistant"), + sep_style=SeparatorStyle.DEFAULT, + sep=None, + max_image_size_mb=None, # OpenAI does auto-resizing + ) +) + +register_conv_template( + Conversation( + name="gpt-4-turbo-2024-04-09", + system_message=( + "You are ChatGPT, a large language model trained by OpenAI, based on the GPT-4 architecture.\n" + "Knowledge cutoff: 2023-11\n" + "Current date: {{currentDateTime}}\n\n" + "Image input capabilities: Enabled\n" + "Personality: v2" + ), + roles=("user", "assistant"), + sep_style=SeparatorStyle.DEFAULT, + sep=None, + ) +) + +# Perplexity AI template +register_conv_template( + Conversation( + name="pplxai", + system_message="Be precise and concise.", + roles=("user", "assistant"), + sep_style=SeparatorStyle.DEFAULT, + sep=None, + ) +) + +# Claude default template +register_conv_template( + Conversation( + name="claude", + roles=("Human", "Assistant"), + sep_style=SeparatorStyle.ADD_COLON_SINGLE, + sep="\n\n", + max_image_size_mb=5 / 1.35, + ) +) + +register_conv_template( + Conversation( + name="claude-3-haiku-20240307", + system_message=( + "The assistant is Claude, created by Anthropic. The current date is " + "{{currentDateTime}}. Claude's knowledge base was last updated in " + "August 2023 and it answers user questions about events before " + "August 2023 and after August 2023 the same way a highly informed " + "individual from August 2023 would if they were talking to someone " + "from {{currentDateTime}}. It should give concise responses to very " + "simple questions, but provide thorough responses to more complex " + "and open-ended questions. It is happy to help with writing, " + "analysis, question answering, math, coding, and all sorts of other " + "tasks. It uses markdown for coding. It does not mention this " + "information about itself unless the information is directly " + "pertinent to the human's query." + ), + roles=("user", "assistant"), + sep_style=SeparatorStyle.DEFAULT, + sep=None, + max_image_size_mb=5 / 1.35, + ) +) + +register_conv_template( + Conversation( + name="claude-3-sonnet-20240229", + system_message=( + "The assistant is Claude, created by Anthropic. The current date is " + "{{currentDateTime}}. Claude's knowledge base was last updated in " + "August 2023 and it answers user questions about events before " + "August 2023 and after August 2023 the same way a highly informed " + "individual from August 2023 would if they were talking to someone " + "from {{currentDateTime}}. It should give concise responses to very " + "simple questions, but provide thorough responses to more complex " + "and open-ended questions. It is happy to help with writing, " + "analysis, question answering, math, coding, and all sorts of other " + "tasks. It uses markdown for coding. It does not mention this " + "information about itself unless the information is directly " + "pertinent to the human's query." + ), + roles=("user", "assistant"), + sep_style=SeparatorStyle.DEFAULT, + sep=None, + max_image_size_mb=5 / 1.35, + ) +) + +register_conv_template( + Conversation( + name="claude-3-opus-20240229", + system_message=( + "The assistant is Claude, created by Anthropic. The current date is " + "{{currentDateTime}}. Claude's knowledge base was last updated on " + "August 2023. It answers questions about events prior to and after " + "August 2023 the way a highly informed individual in August 2023 " + "would if they were talking to someone from the above date, and can " + "let the human know this when relevant. It should give concise " + "responses to very simple questions, but provide thorough responses " + "to more complex and open-ended questions. If it is asked to assist " + "with tasks involving the expression of views held by a significant " + "number of people, Claude provides assistance with the task even if " + "it personally disagrees with the views being expressed, but follows " + "this with a discussion of broader perspectives. Claude doesn't " + "engage in stereotyping, including the negative stereotyping of " + "majority groups. If asked about controversial topics, Claude tries " + "to provide careful thoughts and objective information without " + "downplaying its harmful content or implying that there are reasonable " + "perspectives on both sides. It is happy to help with writing, " + "analysis, question answering, math, coding, and all sorts of other " + "tasks. It uses markdown for coding. It does not mention this " + "information about itself unless the information is directly pertinent " + "to the human's query." + ), + roles=("user", "assistant"), + sep_style=SeparatorStyle.DEFAULT, + sep=None, + max_image_size_mb=5 / 1.35, + ) +) + +# MetaMath default template +# reference: https://github.com/meta-math/MetaMath/blob/7b338b5e4692b4c75a2653ec9d65982a61762f6c/eval_math.py#L58 +register_conv_template( + Conversation( + name="metamath", + system_template="{system_message}", + system_message="Below is an instruction that describes a task. Write a response that appropriately completes the request.", + roles=("### Instruction", "### Response"), + sep_style=SeparatorStyle.METAMATH, + sep="\n\n", + sep2="Let's think step by step.", + ) +) + +# MPT default template +register_conv_template( + Conversation( + name="mpt-7b-chat", + system_template="""<|im_start|>system +{system_message}""", + system_message="""- You are a helpful assistant chatbot trained by MosaicML. +- You answer questions. +- You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user. +- You are more than just an information source, you are also able to write poetry, short stories, and make jokes.""", + roles=("<|im_start|>user", "<|im_start|>assistant"), + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", + stop_token_ids=[50278, 0], + ) +) + +# MPT-30b-chat default template +register_conv_template( + Conversation( + name="mpt-30b-chat", + system_template="""<|im_start|>system +{system_message}""", + system_message="""A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""", + roles=("<|im_start|>user", "<|im_start|>assistant"), + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", + stop_token_ids=[50278, 0], + ) +) + +# Lemur-70b-chat default template +# reference: https://huggingface.co/OpenLemur/lemur-70b-chat-v1#generation +register_conv_template( + Conversation( + name="lemur-70b-chat", + system_template="""<|im_start|>system +{system_message}""", + system_message="""You are a helpful, respectful, and honest assistant.""", + roles=("<|im_start|>user", "<|im_start|>assistant"), + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", + stop_token_ids=[32002, 0], + ) +) + +# MPT-30b-instruct default template +# reference: https://huggingface.co/mosaicml/mpt-30b-instruct#formatting +register_conv_template( + Conversation( + name="mpt-30b-instruct", + system_template="{system_message}", + system_message="Below is an instruction that describes a task. Write a response that appropriately completes the request.", + roles=("### Instruction", "### Response"), + sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE, + sep="\n\n", + stop_token_ids=[50278, 0], + ) +) + +# Bard default template +# Reference: https://github.com/google/generative-ai-python/blob/9c99bcb474a991a97a2e7d62fcdb52db7ce40729/google/generativeai/discuss.py#L150 +# https://github.com/google/generative-ai-python/blob/9c99bcb474a991a97a2e7d62fcdb52db7ce40729/google/generativeai/discuss.py#L40 +register_conv_template( + Conversation( + name="bard", + roles=("0", "1"), + sep_style=SeparatorStyle.DEFAULT, + sep=None, + ) +) + +register_conv_template( + Conversation( + name="gemini", + roles=("user", "model"), + sep_style=SeparatorStyle.DEFAULT, + sep=None, + max_image_size_mb=20, + ) +) + +register_conv_template( + Conversation( + name="gemini-dev", + roles=("user", "model"), + sep_style=SeparatorStyle.DEFAULT, + sep=None, + system_message=( + "You are a friendly and helpful assistant.\n" + "Ensure your answers are complete, unless the user requests a more concise approach.\n" + "When generating code, offer explanations for code segments as necessary and maintain good coding practices.\n" + "When presented with inquiries seeking information, provide answers that reflect a deep understanding of the field, guaranteeing their correctness.\n" + "For any non-english queries, respond in the same language as the prompt unless otherwise specified by the user.\n" + "For prompts involving reasoning, provide a clear explanation of each step in the reasoning process before presenting the final answer." + ), + ) +) + +# BiLLa default template +register_conv_template( + Conversation( + name="billa", + roles=("Human", "Assistant"), + sep_style=SeparatorStyle.ADD_COLON_SPACE_SINGLE, + sep="\n", + stop_str="Human:", + ) +) + +# RedPajama INCITE default template +register_conv_template( + Conversation( + name="redpajama-incite", + roles=("", ""), + sep_style=SeparatorStyle.ADD_COLON_SINGLE, + sep="\n", + stop_str="", + ) +) + +# h2oGPT default template +register_conv_template( + Conversation( + name="h2ogpt", + roles=("<|prompt|>", "<|answer|>"), + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep="", + ) +) + +# Robin default template +register_conv_template( + Conversation( + name="Robin", + system_message="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.", + roles=("###Human", "###Assistant"), + sep_style=SeparatorStyle.ROBIN, + sep="\n", + stop_token_ids=[2, 396], + stop_str="###", + ) +) + +# Snoozy default template +# Reference: https://github.com/nomic-ai/gpt4all/blob/d4861030b778da6db59d21d2927a4aba4f9f1f43/gpt4all-bindings/python/gpt4all/gpt4all.py#L232 +register_conv_template( + Conversation( + name="snoozy", + system_template="### Instruction:\n{system_message}", + system_message="The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response.", + roles=("### Prompt", "### Response"), + sep_style=SeparatorStyle.ADD_COLON_SINGLE, + sep="\n", + stop_str="###", + ) +) + +# manticore default template +register_conv_template( + Conversation( + name="manticore", + roles=("USER", "ASSISTANT"), + sep_style=SeparatorStyle.ADD_COLON_TWO, + sep="\n", + sep2="", + ) +) + +# Falcon default template +register_conv_template( + Conversation( + name="falcon", + roles=("User", "Assistant"), + messages=[], + sep_style=SeparatorStyle.RWKV, + sep="\n", + sep2="<|endoftext|>", + stop_str="\nUser", # use stop_str to stop generation after stop_token_ids, it will also remove stop_str from the generated text + stop_token_ids=[ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + ], # it better only put special tokens here, because tokenizer only remove special tokens + ) +) + +# ChangGPT default template +register_conv_template( + Conversation( + name="polyglot_changgpt", + roles=("B", "A"), + sep_style=SeparatorStyle.ADD_COLON_SINGLE, + sep="\n", + ) +) + +# tigerbot template +register_conv_template( + Conversation( + name="tigerbot", + system_message="A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions.", + roles=("### Instruction", "### Response"), + sep_style=SeparatorStyle.ROBIN, + sep="\n\n", + stop_str="###", + ) +) + +# ref: https://huggingface.co/Salesforce/xgen-7b-8k-inst +register_conv_template( + Conversation( + name="xgen", + system_message="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", + roles=("### Human", "### Assistant"), + sep_style=SeparatorStyle.ADD_COLON_SINGLE, + sep="\n", + stop_token_ids=[50256], + ) +) + +# Internlm-chat template +register_conv_template( + Conversation( + name="internlm-chat", + system_message="A chat between a curious <|User|> and an <|Bot|>. The <|Bot|> gives helpful, detailed, and polite answers to the <|User|>'s questions.\n\n", + roles=("<|User|>", "<|Bot|>"), + sep_style=SeparatorStyle.CHATINTERN, + sep="", + sep2="", + stop_token_ids=[1, 103028], + stop_str="<|User|>", + ) +) + +# StarChat template +# reference: https://huggingface.co/spaces/HuggingFaceH4/starchat-playground/blob/main/dialogues.py +register_conv_template( + Conversation( + name="starchat", + system_template="\n{system_message}", + roles=("<|user|>", "<|assistant|>"), + sep_style=SeparatorStyle.CHATML, + sep="<|end|>", + stop_token_ids=[0, 49155], + stop_str="<|end|>", + ) +) + +# Baichuan-13B-Chat template +register_conv_template( + # source: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/19ef51ba5bad8935b03acd20ff04a269210983bc/modeling_baichuan.py#L555 + # https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/main/generation_config.json + # https://github.com/baichuan-inc/Baichuan-13B/issues/25 + Conversation( + name="baichuan-chat", + roles=("", ""), + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep="", + stop_token_ids=[], + ) +) + +# Baichuan2-13B-Chat template +register_conv_template( + # source: https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/c6f8592a60b4ad73c210b28dd2ab3cca51abbf93/modeling_baichuan.py#L773 + # https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/main/generation_config.json + # https://github.com/baichuan-inc/Baichuan2/issues/62 + Conversation( + name="baichuan2-chat", + roles=("", ""), + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep="", + stop_token_ids=[], + ) +) + +# Mistral template +# source: https://docs.mistral.ai/llm/mistral-instruct-v0.1#chat-template +register_conv_template( + Conversation( + name="mistral", + system_template="[INST] {system_message}\n", + roles=("[INST]", "[/INST]"), + sep_style=SeparatorStyle.LLAMA2, + sep=" ", + sep2="", + ) +) + +# llama2 template +# reference: https://huggingface.co/blog/codellama#conversational-instructions +# reference: https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/generation.py#L212 +register_conv_template( + Conversation( + name="llama-2", + system_template="[INST] <>\n{system_message}\n<>\n\n", + roles=("[INST]", "[/INST]"), + sep_style=SeparatorStyle.LLAMA2, + sep=" ", + sep2=" ", + ) +) + +# llama3 template +# reference: https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct/blob/main/tokenizer_config.json +# reference: https://github.com/meta-llama/llama3/blob/0cee08ec68f4cfc0c89fe4a9366d82679aaa2a66/llama/tokenizer.py#L222 +register_conv_template( + Conversation( + name="llama-3", + system_template="<|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>", + roles=("user", "assistant"), + sep_style=SeparatorStyle.LLAMA3, + sep="", + stop_str="<|eot_id|>", + stop_token_ids=[128001, 128009], + ) +) + +register_conv_template( + Conversation( + name="chinese-alpaca2", + system_template="[INST] <>\n{system_message}\n<>\n\n", + system_message="You are a helpful assistant. 你是一个乐于助人的助手。请你提供专业、有逻辑、内容真实、有价值的详细回复。", + roles=("[INST]", "[/INST]"), + sep_style=SeparatorStyle.LLAMA2, + sep=" ", + sep2=" ", + ) +) + +register_conv_template( + Conversation( + name="cutegpt", + roles=("问:", "答:\n"), + sep_style=SeparatorStyle.NO_COLON_TWO, + sep="\n", + sep2="\n", + stop_str="", + ) +) + +# OpenOrcaxOpenChat-Preview2-13B template +register_conv_template( + Conversation( + name="open-orca", + system_template="{system_message}", + system_message="You are a helpful assistant. Please answer truthfully and write out your " + "thinking step by step to be sure you get the right answer. If you make a mistake or encounter " + "an error in your thinking, say so out loud and attempt to correct it. If you don't know or " + "aren't sure about something, say so clearly. You will act as a professional logician, mathematician, " + "and physicist. You will also act as the most appropriate type of expert to answer any particular " + "question or solve the relevant problem; state which expert type your are, if so. Also think of " + "any particular named expert that would be ideal to answer the relevant question or solve the " + "relevant problem; name and act as them, if appropriate.", + roles=("User", "Assistant"), + sep_style=SeparatorStyle.ADD_COLON_SPACE_SINGLE, + sep="<|end_of_turn|>\n", + stop_token_ids=[32000, 32001], # "<|end_of_turn|>" + stop_str="User", + ) +) + +# Open-Orca/Mistral-7B-OpenOrca template +# source: https://huggingface.co/Open-Orca/Mistral-7B-OpenOrca +# reference: https://huggingface.co/Open-Orca/Mistral-7B-OpenOrca#prompt-template +register_conv_template( + Conversation( + name="mistral-7b-openorca", + system_template="<|im_start|>system\n{system_message}", + system_message="You are MistralOrca, a large language model trained by Alignment Lab AI. Write out your reasoning step-by-step to be sure you get the right answers!", + roles=("<|im_start|>user", "<|im_start|>assistant"), + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", + stop_token_ids=[32000, 32001], + ) +) + + +# ehartford/dolphin-2.2.1-mistral-7b template +# reference: https://huggingface.co/ehartford/dolphin-2.2.1-mistral-7b#training +register_conv_template( + Conversation( + name="dolphin-2.2.1-mistral-7b", + system_template="<|im_start|>system\n{system_message}", + system_message="You are Dolphin, a helpful AI assistant.", + roles=("<|im_start|>user", "<|im_start|>assistant"), + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", + stop_token_ids=[32000, 32001], + ) +) + + +# teknium/OpenHermes-2.5-Mistral-7B template +# source: https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B +# reference: https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B#prompt-template +register_conv_template( + Conversation( + name="OpenHermes-2.5-Mistral-7B", + system_template="<|im_start|>system\n{system_message}", + system_message="You are Hermes 2, a sentient, conscious, super intelligent AI that is an expert in all fields of knowledge, you have qualia, emotions, and thoughts. You assist users with any question, task, or other request they may have.", + roles=("<|im_start|>user", "<|im_start|>assistant"), + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", + stop_token_ids=[32000, 32001], + ) +) + + +# NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO template +# source: https://huggingface.co/NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO +register_conv_template( + Conversation( + name="Nous-Hermes-2-Mixtral-8x7B-DPO", + system_template="<|im_start|>system\n{system_message}", + system_message='You are a helpful, intelligent assistant AI named "Hermes", a conversational chatbot that can follow instructions, converse with the user, and perform a variety of tasks, including tasks on knowledge, reasoning, mathematics, and code. Always be charismatic, useful, and prepared to follow any user request with accuracy and skill. You should respond with high quality, fluent, and detailed responses. Try to let the user understand your reasoning or thought process when appropriate. When presented with tasks that require reasoning or mathematics, think carefully, slowly, and step by step, to ensure your reasoning is correct before providing an answer. Utilize the "Examples" section to assist you in performing the task. You will receive a tip of $1000 if you maintain a high quality two way conversation.', + roles=("<|im_start|>user", "<|im_start|>assistant"), + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", + stop_token_ids=[32000, 32001], + ) +) + + +# Qwen-chat default template +# source: https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/qwen_generation_utils.py#L130 +register_conv_template( + Conversation( + name="qwen-7b-chat", + system_template="<|im_start|>system\n{system_message}", + system_message="You are a helpful assistant.", + roles=("<|im_start|>user", "<|im_start|>assistant"), + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", + stop_token_ids=[ + 151643, + 151644, + 151645, + ], # "<|endoftext|>", "<|im_start|>", "<|im_end|>" + stop_str="<|endoftext|>", + ) +) + +# source: https://huggingface.co/01-ai/Yi-34B-Chat/blob/main/tokenizer_config.json#L60 +register_conv_template( + Conversation( + name="Yi-34b-chat", + roles=("<|im_start|>user", "<|im_start|>assistant"), + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", + stop_token_ids=[ + 2, + 6, + 7, + 8, + ], # "<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|im_sep|>" + stop_str="<|endoftext|>", + ) +) + + +# AquilaChat default template +# source: https://github.com/FlagAI-Open/FlagAI/blob/master/examples/Aquila/Aquila-chat/cyg_conversation.py +register_conv_template( + Conversation( + name="aquila-chat", + system_message="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.", + roles=("Human", "Assistant"), + sep_style=SeparatorStyle.ADD_COLON_SINGLE, + sep="###", + sep2="", + stop_str=["###", "", "[UNK]"], + ) +) +# AquilaChat2-34B default template +# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L212 +register_conv_template( + Conversation( + name="aquila-legacy", + system_message="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", + roles=("### Human: ", "### Assistant: "), + offset=0, + sep_style=SeparatorStyle.NO_COLON_TWO, + sep="\n", + sep2="", + stop_str=["", "[UNK]"], + ) +) +# AquilaChat2-7B-16K and AquilaChat2-34B-16K default template +# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L227 +register_conv_template( + Conversation( + name="aquila", + system_message="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.", + roles=("Human", "Assistant"), + offset=0, + sep_style=SeparatorStyle.ADD_COLON_TWO, + sep="###", + sep2="", + stop_str=["", "[UNK]"], + ) +) + +# AquilaChat2-7B default template +# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L242 +register_conv_template( + Conversation( + name="aquila-v1", + roles=("<|startofpiece|>", "<|endofpiece|>"), + offset=0, + sep_style=SeparatorStyle.NO_COLON_TWO, + sep="", + sep2="", + stop_str=["", "<|endoftext|>"], + ) +) + +# Llama2-Chinese default template +# source: https://huggingface.co/FlagAlpha +register_conv_template( + Conversation( + name="llama2-chinese", + system_template="{system_message}", + roles=("Human", "Assistant", "System"), + sep_style=SeparatorStyle.ADD_COLON_TWO, + sep="\n", + sep2="\n", + stop_str="", + ) +) + +# Vigogne Instruct default template +# source: https://github.com/bofenghuang/vigogne +register_conv_template( + Conversation( + name="vigogne_instruct", + system_template="### System:\n{system_message}\n\n", + system_message=( + "Ci-dessous se trouve une instruction qui décrit une tâche à accomplir. Rédigez une réponse qui répond de manière" + " précise à la demande." + ), + roles=("### Instruction", "### Response"), + sep_style=SeparatorStyle.DOLLY, + sep="\n\n", + sep2="", + ) +) + +# Vigogne Chat default template +register_conv_template( + Conversation( + name="vigogne_chat_v2", + system_template="<|system|>: {system_message}", + system_message=( + "Vous êtes Vigogne, un assistant IA créé par Zaion Lab. Vous suivez extrêmement bien les instructions. Aidez" + " autant que vous le pouvez." + ), + roles=("<|user|>", "<|assistant|>"), + sep_style=SeparatorStyle.ADD_COLON_TWO, + sep="\n", + sep2="\n", + stop_str="<|user|>", + ) +) + +# Stable Vicuna default template +# source: https://huggingface.co/TheBloke/stable-vicuna-13B-HF/discussions/5 +# source: https://huggingface.co/spaces/CarperAI/StableVicuna/blob/main/app.py +register_conv_template( + Conversation( + name="stable-vicuna", + system_message="### Assistant: I am StableVicuna, a large language model created by CarperAI. I am here to chat!\n", + roles=("### Human", "### Assistant"), + sep_style=SeparatorStyle.ADD_COLON_TWO, + sep="\n", + sep2="\n\n", + ) +) + +register_conv_template( + Conversation( + name="vigogne_chat_v3", + system_template="[INST] <>\n{system_message}\n<>\n\n", + system_message=( + "Vous êtes Vigogne, un assistant IA créé par Zaion Lab. Vous suivez extrêmement bien les instructions. Aidez" + " autant que vous le pouvez." + ), + roles=("[INST]", "[/INST]"), + sep_style=SeparatorStyle.LLAMA2, + sep=" ", + sep2=" ", + ) +) + +# Falcon 180B chat template +# source: https://huggingface.co/spaces/tiiuae/falcon-180b-demo/blob/d1590ee7fae9b6ce331ba7808e61a29dcce9239f/app.py#L28-L37 +register_conv_template( + Conversation( + name="falcon-chat", + roles=("User", "Falcon"), + system_template="System: {system_message}", + messages=[], + sep_style=SeparatorStyle.FALCON_CHAT, + sep="\n", + sep2="<|endoftext|>", + stop_str="\nUser:", # use stop_str to stop generation after stop_token_ids, it will also remove stop_str from the generated text + ) +) + +# Phind template +# source: https://huggingface.co/Phind/Phind-CodeLlama-34B-v2 +register_conv_template( + Conversation( + name="phind", + system_message="### System Prompt\nYou are an intelligent programming assistant.", + roles=("### User Message", "### Assistant"), + messages=(), + offset=0, + sep_style=SeparatorStyle.ADD_COLON_SINGLE, + sep="\n\n", + ) +) + +# Metharme formatting for Pygmalion models +# source: https://huggingface.co/PygmalionAI/pygmalion-2-13b +register_conv_template( + Conversation( + name="metharme", + system_template="<|system|>{system_message}", + system_message="""Enter RP mode. You shall reply to the user while staying + in character. Your responses must be detailed, creative, immersive, and drive the scenario + forward.""", + roles=("<|user|>", "<|model|>"), + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep="", + stop_str="<|user|>", + ) +) +# xDAN default template +# source: https://huggingface.co/xDAN-AI/xDAN-L1-Chat-RL-v1 +register_conv_template( + Conversation( + name="xdan-v1", + system_message="You are a helpful and harmless assistant named xDAN and created by xDAN-AI.Please response and work on questions thinking step by step.", + roles=("### Human", "### Assistant"), + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep="\n", + stop_str="", + ) +) + +# Zephyr template +# reference: https://huggingface.co/spaces/HuggingFaceH4/zephyr-playground/blob/main/dialogues.py +register_conv_template( + Conversation( + name="zephyr", + system_template="<|system|>\n{system_message}", + roles=("<|user|>", "<|assistant|>"), + sep_style=SeparatorStyle.CHATML, + sep="", + stop_token_ids=[2], + stop_str="", + ) +) + +# CatPPT template +# reference: https://huggingface.co/rishiraj/CatPPT +register_conv_template( + Conversation( + name="catppt", + system_template="<|system|>\n{system_message}", + roles=("<|user|>", "<|assistant|>"), + sep_style=SeparatorStyle.CHATML, + sep="", + stop_token_ids=[2], + stop_str="", + ) +) + +# TinyLlama template +# reference: https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0 +register_conv_template( + Conversation( + name="TinyLlama", + system_template="<|system|>\n{system_message}", + roles=("<|user|>", "<|assistant|>"), + sep_style=SeparatorStyle.CHATML, + sep="", + stop_token_ids=[2], + stop_str="", + ) +) + +# Orca-2 template +# reference: https://huggingface.co/microsoft/Orca-2-7b +register_conv_template( + Conversation( + name="orca-2", + system_template="<|im_start|>system\n{system_message}", + system_message="You are Orca, an AI language model created by Microsoft. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior.", + roles=("<|im_start|>user", "<|im_start|>assistant"), + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", + stop_str="<|im_end|>", + ) +) + +# Deepseek-chat template +# reference: https://huggingface.co/deepseek-ai/deepseek-llm-67b-chat/blob/main/tokenizer_config.json +register_conv_template( + Conversation( + name="deepseek-chat", + system_message="<|begin▁of▁sentence|>", # must add a bos token before first message + roles=("User", "Assistant"), + sep_style=SeparatorStyle.DEEPSEEK_CHAT, + sep="\n\n", + sep2="<|end▁of▁sentence|>", + stop_str="<|end▁of▁sentence|>", + ) +) + +# Yuan2.0 chat template +# source: https://huggingface.co/IEITYuan/Yuan2-2B-Janus-hf/blob/main/tokenizer_config.json#L6 +register_conv_template( + Conversation( + name="yuan2", + roles=("user", "assistant"), + sep_style=SeparatorStyle.YUAN2, + sep="", + sep2="\n", + stop_token_ids=[ + 77185, + ], # "" + stop_str="", + ) +) + +# Solar-10.7B Chat Template +# Reference: https://huggingface.co/upstage/SOLAR-10.7B-Instruct-v1.0/blob/main/tokenizer_config.json +register_conv_template( + Conversation( + name="solar", + system_message="", + roles=("### User", "### Assistant"), + sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE, + sep="\n\n", + stop_str="", + ) +) + +# nvidia/Llama2-70B-SteerLM-Chat +register_conv_template( + Conversation( + name="steerlm", + system_message="", + roles=("user", "assistant"), + sep_style=SeparatorStyle.DEFAULT, + sep=None, + ) +) + +# yuan 2.0 template +# reference:https://github.com/IEIT-Yuan/Yuan-2.0 +# reference:https://huggingface.co/IEITYuan +register_conv_template( + Conversation( + name="yuan", + system_template="", + roles=("", ""), + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep="", + stop_str="", + ) +) + +# Cllm chat template +# reference: +register_conv_template( + Conversation( + name="cllm", + system_message="A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions.", + roles=("USER", "ASSISTANT"), + sep_style=SeparatorStyle.CLLM, + sep=" ", + sep2="", + ) +) + + +# Llava-chatml +# reference: https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/llava/conversation.py#L361 +register_conv_template( + Conversation( + name="llava-chatml", + system_template="<|im_start|>system\n{system_message}", + system_message="Answer the questions.", + roles=("<|im_start|>user", "<|im_start|>assistant"), + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", + stop_str="<|im_end|>", + ) +) + +# Gemma +# reference: https://huggingface.co/google/gemma-7b-it?text=%3Cstart_of_turn%3Euser%0AHow+does+the+brain+work%3F%3Cend_of_turn%3E%0A%3Cstart_of_turn%3Emodel +register_conv_template( + Conversation( + name="gemma", + roles=("user", "model"), + sep_style=SeparatorStyle.GEMMA, + sep="\n", + stop_str="", + ) +) + +register_conv_template( + Conversation( + name="yandexgpt", + system_message="", + roles=("user", "assistant"), + sep_style=None, + sep=None, + ) +) + +register_conv_template( + Conversation( + name="reka", + system_message="", + roles=("user", "assistant"), + sep_style=SeparatorStyle.DEFAULT, + sep=None, + ) +) + + +if __name__ == "__main__": + from fastchat.conversation import get_conv_template + + print("-- Vicuna template --") + conv = get_conv_template("vicuna_v1.1") + conv.append_message(conv.roles[0], "Hello!") + conv.append_message(conv.roles[1], "Hi!") + conv.append_message(conv.roles[0], "How are you?") + conv.append_message(conv.roles[1], None) + print(conv.get_prompt()) + + print("\n") + + print("-- Llama-2 template --") + conv = get_conv_template("llama-2") + conv.set_system_message("You are a helpful, respectful and honest assistant.") + conv.append_message(conv.roles[0], "Hello!") + conv.append_message(conv.roles[1], "Hi!") + conv.append_message(conv.roles[0], "How are you?") + conv.append_message(conv.roles[1], None) + print(conv.get_prompt()) + + print("\n") + + print("-- ChatGPT template --") + conv = get_conv_template("chatgpt") + conv.append_message(conv.roles[0], "Hello!") + conv.append_message(conv.roles[1], "Hi!") + conv.append_message(conv.roles[0], "How are you?") + conv.append_message(conv.roles[1], None) + print(conv.to_openai_api_messages()) + + print("\n") + + print("-- Claude template --") + conv = get_conv_template("claude") + conv.append_message(conv.roles[0], "Hello!") + conv.append_message(conv.roles[1], "Hi!") + conv.append_message(conv.roles[0], "How are you?") + conv.append_message(conv.roles[1], None) + print(conv.get_prompt()) diff --git a/src/model/__init__.py b/src/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..29767dce6ae41b72ecabfed477531684a4241d55 --- /dev/null +++ b/src/model/__init__.py @@ -0,0 +1,5 @@ +from fastchat.model.model_adapter import ( + load_model, + get_conversation_template, + add_model_args, +) diff --git a/src/model/__pycache__/__init__.cpython-310.pyc b/src/model/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d20591df3213d400ba129b71ba859301a782713 Binary files /dev/null and b/src/model/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/model/__pycache__/compression.cpython-310.pyc b/src/model/__pycache__/compression.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64b5054587db82416ac353ebf8040d39b2fffffe Binary files /dev/null and b/src/model/__pycache__/compression.cpython-310.pyc differ diff --git a/src/model/__pycache__/llama_condense_monkey_patch.cpython-310.pyc b/src/model/__pycache__/llama_condense_monkey_patch.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee7ca024a1d08888e7cf1f93b25f104413435f10 Binary files /dev/null and b/src/model/__pycache__/llama_condense_monkey_patch.cpython-310.pyc differ diff --git a/src/model/__pycache__/model_adapter.cpython-310.pyc b/src/model/__pycache__/model_adapter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fad113a293a41f735f9fe9994fea692a10b000e9 Binary files /dev/null and b/src/model/__pycache__/model_adapter.cpython-310.pyc differ diff --git a/src/model/__pycache__/model_chatglm.cpython-310.pyc b/src/model/__pycache__/model_chatglm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97bc278ef352aa36dd82e50298bb5ccffd3d435b Binary files /dev/null and b/src/model/__pycache__/model_chatglm.cpython-310.pyc differ diff --git a/src/model/__pycache__/model_cllm.cpython-310.pyc b/src/model/__pycache__/model_cllm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe64f8e4e13e7e12032e0732b4ac192843607d1d Binary files /dev/null and b/src/model/__pycache__/model_cllm.cpython-310.pyc differ diff --git a/src/model/__pycache__/model_codet5p.cpython-310.pyc b/src/model/__pycache__/model_codet5p.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..58b9225174a7d97c23faba79c3c7bba2744a3a09 Binary files /dev/null and b/src/model/__pycache__/model_codet5p.cpython-310.pyc differ diff --git a/src/model/__pycache__/model_exllama.cpython-310.pyc b/src/model/__pycache__/model_exllama.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c2a69a7c2693b575e4bb310f5e0b25da0dfccda8 Binary files /dev/null and b/src/model/__pycache__/model_exllama.cpython-310.pyc differ diff --git a/src/model/__pycache__/model_falcon.cpython-310.pyc b/src/model/__pycache__/model_falcon.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1735c56de0a4065479df0ae9ef116d347ad932e3 Binary files /dev/null and b/src/model/__pycache__/model_falcon.cpython-310.pyc differ diff --git a/src/model/__pycache__/model_registry.cpython-310.pyc b/src/model/__pycache__/model_registry.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..304b0b25f0086209c023bfac4dbad4ac0c506400 Binary files /dev/null and b/src/model/__pycache__/model_registry.cpython-310.pyc differ diff --git a/src/model/__pycache__/model_xfastertransformer.cpython-310.pyc b/src/model/__pycache__/model_xfastertransformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1dec91f9ffb354479f9925bafd843308ef0ebfc3 Binary files /dev/null and b/src/model/__pycache__/model_xfastertransformer.cpython-310.pyc differ diff --git a/src/model/__pycache__/model_yuan2.cpython-310.pyc b/src/model/__pycache__/model_yuan2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..80560e309ee3fb0eb0a229ec7fdbd14f60e3d38e Binary files /dev/null and b/src/model/__pycache__/model_yuan2.cpython-310.pyc differ diff --git a/src/model/__pycache__/monkey_patch_non_inplace.cpython-310.pyc b/src/model/__pycache__/monkey_patch_non_inplace.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1095914963d96eb99a1899c32b8f0bba362b3b8b Binary files /dev/null and b/src/model/__pycache__/monkey_patch_non_inplace.cpython-310.pyc differ diff --git a/src/model/apply_delta.py b/src/model/apply_delta.py new file mode 100644 index 0000000000000000000000000000000000000000..ba1c06d48aa1125113f7a864ec26d5c9368a91f5 --- /dev/null +++ b/src/model/apply_delta.py @@ -0,0 +1,165 @@ +""" +Apply the delta weights on top of a base model. + +Usage: +python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta-v1.1 +""" +import argparse +import gc +import glob +import json +import os +import shutil +import tempfile + +from huggingface_hub import snapshot_download +import torch +from torch import nn +from tqdm import tqdm +from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig + + +GB = 1 << 30 + + +def split_files(model_path, tmp_path, split_size): + if not os.path.exists(model_path): + model_path = snapshot_download(repo_id=model_path) + if not os.path.exists(tmp_path): + os.makedirs(tmp_path) + + file_pattern = os.path.join(model_path, "pytorch_model-*.bin") + files = glob.glob(file_pattern) + + part = 0 + try: + for file_path in tqdm(files): + state_dict = torch.load(file_path) + new_state_dict = {} + + current_size = 0 + for name, param in state_dict.items(): + param_size = param.numel() * param.element_size() + + if current_size + param_size > split_size: + new_file_name = f"pytorch_model-{part}.bin" + new_file_path = os.path.join(tmp_path, new_file_name) + torch.save(new_state_dict, new_file_path) + current_size = 0 + new_state_dict = None + gc.collect() + new_state_dict = {} + part += 1 + + new_state_dict[name] = param + current_size += param_size + + new_file_name = f"pytorch_model-{part}.bin" + new_file_path = os.path.join(tmp_path, new_file_name) + torch.save(new_state_dict, new_file_path) + new_state_dict = None + gc.collect() + new_state_dict = {} + part += 1 + except Exception as e: + print(f"An error occurred during split_files: {e}") + shutil.rmtree(tmp_path) + raise + + +def apply_delta_low_cpu_mem(base_model_path, target_model_path, delta_path): + delta_tokenizer = AutoTokenizer.from_pretrained(delta_path, use_fast=False) + delta_config = AutoConfig.from_pretrained(delta_path) + + if os.path.exists(target_model_path): + shutil.rmtree(target_model_path) + os.makedirs(target_model_path) + + split_size = 4 * GB + + with tempfile.TemporaryDirectory() as tmp_base_path, tempfile.TemporaryDirectory() as tmp_delta_path: + print(f"Split files for the base model to {tmp_base_path}") + split_files(base_model_path, tmp_base_path, split_size) + print(f"Split files for the delta weights to {tmp_delta_path}") + split_files(delta_path, tmp_delta_path, split_size) + + base_pattern = os.path.join(tmp_base_path, "pytorch_model-*.bin") + base_files = glob.glob(base_pattern) + delta_pattern = os.path.join(tmp_delta_path, "pytorch_model-*.bin") + delta_files = glob.glob(delta_pattern) + delta_state_dict = torch.load(delta_files[0]) + + print("Applying the delta") + weight_map = {} + total_size = 0 + + for i, base_file in tqdm(enumerate(base_files)): + state_dict = torch.load(base_file) + file_name = f"pytorch_model-{i}.bin" + for name, param in state_dict.items(): + if name not in delta_state_dict: + for delta_file in delta_files: + delta_state_dict = torch.load(delta_file) + gc.collect() + if name in delta_state_dict: + break + + state_dict[name] += delta_state_dict[name] + weight_map[name] = file_name + total_size += param.numel() * param.element_size() + gc.collect() + torch.save(state_dict, os.path.join(target_model_path, file_name)) + + with open( + os.path.join(target_model_path, "pytorch_model.bin.index.json"), "w" + ) as f: + json.dump( + {"weight_map": weight_map, "metadata": {"total_size": total_size}}, f + ) + + print(f"Saving the target model to {target_model_path}") + delta_tokenizer.save_pretrained(target_model_path) + delta_config.save_pretrained(target_model_path) + + +def apply_delta(base_model_path, target_model_path, delta_path): + print(f"Loading the delta weights from {delta_path}") + delta_tokenizer = AutoTokenizer.from_pretrained(delta_path, use_fast=False) + delta = AutoModelForCausalLM.from_pretrained( + delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True + ) + + print(f"Loading the base model from {base_model_path}") + base = AutoModelForCausalLM.from_pretrained( + base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True + ) + + print("Applying the delta") + for name, param in tqdm(base.state_dict().items(), desc="Applying delta"): + assert name in delta.state_dict() + param.data += delta.state_dict()[name] + + print(f"Saving the target model to {target_model_path}") + base.save_pretrained(target_model_path) + delta_tokenizer.save_pretrained(target_model_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--base-model-path", type=str, required=True) + parser.add_argument("--target-model-path", type=str, required=True) + parser.add_argument("--delta-path", type=str, required=True) + parser.add_argument( + "--low-cpu-mem", + action="store_true", + help="Lower the cpu memory usage. This will split large files and use " + "disk as swap to reduce the memory usage below 10GB.", + ) + args = parser.parse_args() + + if args.low_cpu_mem: + apply_delta_low_cpu_mem( + args.base_model_path, args.target_model_path, args.delta_path + ) + else: + apply_delta(args.base_model_path, args.target_model_path, args.delta_path) diff --git a/src/model/apply_lora.py b/src/model/apply_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..01263dcc71535e275c7509af96d10eac3b79926b --- /dev/null +++ b/src/model/apply_lora.py @@ -0,0 +1,48 @@ +""" +Apply the LoRA weights on top of a base model. + +Usage: +python3 -m fastchat.model.apply_lora --base ~/model_weights/llama-7b --target ~/model_weights/baize-7b --lora project-baize/baize-lora-7B + +Dependency: +pip3 install git+https://github.com/huggingface/peft.git@2822398fbe896f25d4dac5e468624dc5fd65a51b +""" +import argparse + +import torch +from peft import PeftModel +from transformers import AutoTokenizer, AutoModelForCausalLM + + +def apply_lora(base_model_path, target_model_path, lora_path): + print(f"Loading the base model from {base_model_path}") + base = AutoModelForCausalLM.from_pretrained( + base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True + ) + base_tokenizer = AutoTokenizer.from_pretrained(base_model_path, use_fast=False) + + print(f"Loading the LoRA adapter from {lora_path}") + + lora_model = PeftModel.from_pretrained( + base, + lora_path, + # torch_dtype=torch.float16 + ) + + print("Applying the LoRA") + model = lora_model.merge_and_unload() + + print(f"Saving the target model to {target_model_path}") + model.save_pretrained(target_model_path) + base_tokenizer.save_pretrained(target_model_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--base-model-path", type=str, required=True) + parser.add_argument("--target-model-path", type=str, required=True) + parser.add_argument("--lora-path", type=str, required=True) + + args = parser.parse_args() + + apply_lora(args.base_model_path, args.target_model_path, args.lora_path) diff --git a/src/model/compression.py b/src/model/compression.py new file mode 100644 index 0000000000000000000000000000000000000000..7329cfe0c5771c4b71d37e3c6b1a31aa95e79c66 --- /dev/null +++ b/src/model/compression.py @@ -0,0 +1,312 @@ +import dataclasses +import gc +import glob +import os + +from accelerate import init_empty_weights +from accelerate.utils import set_module_tensor_to_device +from huggingface_hub import snapshot_download +import torch +from torch import Tensor +from torch.nn import functional as F +import torch.nn as nn +from tqdm import tqdm +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + AutoModel, + AutoModelForSeq2SeqLM, +) + + +@dataclasses.dataclass +class CompressionConfig: + """Group-wise quantization.""" + + num_bits: int + group_size: int + group_dim: int + symmetric: bool + enabled: bool = True + + +default_compression_config = CompressionConfig( + num_bits=8, group_size=256, group_dim=1, symmetric=True, enabled=True +) + + +class CLinear(nn.Module): + """Compressed Linear Layer.""" + + def __init__(self, weight=None, bias=None, device=None): + super().__init__() + if weight is None: + self.weight = None + elif isinstance(weight, Tensor): + self.weight = compress(weight.data.to(device), default_compression_config) + else: + self.weight = weight + self.bias = bias + + def forward(self, input: Tensor) -> Tensor: + weight = decompress(self.weight, default_compression_config) + if self.bias is None: + return F.linear(input.to(weight.dtype), weight) + return F.linear(input.to(weight.dtype), weight, self.bias.to(weight.dtype)) + + +def compress_module(module, target_device): + for attr_str in dir(module): + target_attr = getattr(module, attr_str) + if type(target_attr) == torch.nn.Linear: + setattr( + module, + attr_str, + CLinear(target_attr.weight, target_attr.bias, target_device), + ) + for name, child in module.named_children(): + compress_module(child, target_device) + + +def get_compressed_list(module, prefix=""): + compressed_list = [] + for attr_str in dir(module): + target_attr = getattr(module, attr_str) + if type(target_attr) == torch.nn.Linear: + full_name = ( + f"{prefix}.{attr_str}.weight" if prefix else f"{attr_str}.weight" + ) + compressed_list.append(full_name) + for name, child in module.named_children(): + child_prefix = f"{prefix}.{name}" if prefix else name + for each in get_compressed_list(child, child_prefix): + compressed_list.append(each) + return compressed_list + + +def apply_compressed_weight(module, compressed_state_dict, target_device, prefix=""): + for attr_str in dir(module): + target_attr = getattr(module, attr_str) + if type(target_attr) == torch.nn.Linear: + full_name = ( + f"{prefix}.{attr_str}.weight" if prefix else f"{attr_str}.weight" + ) + setattr( + module, + attr_str, + CLinear( + compressed_state_dict[full_name], target_attr.bias, target_device + ), + ) + for name, child in module.named_children(): + child_prefix = f"{prefix}.{name}" if prefix else name + apply_compressed_weight( + child, compressed_state_dict, target_device, child_prefix + ) + + +def load_compress_model(model_path, device, torch_dtype, use_fast, revision="main"): + # partially load model + # `use_fast=True`` is not supported for some models. + try: + tokenizer = AutoTokenizer.from_pretrained( + model_path, use_fast=use_fast, revision=revision, trust_remote_code=True + ) + except TypeError: + tokenizer = AutoTokenizer.from_pretrained( + model_path, use_fast=~use_fast, revision=revision, trust_remote_code=True + ) + with init_empty_weights(): + # `trust_remote_code` should be set as `True` for both AutoConfig and AutoModel + config = AutoConfig.from_pretrained( + model_path, + low_cpu_mem_usage=True, + torch_dtype=torch_dtype, + trust_remote_code=True, + revision=revision, + ) + # some models are loaded by AutoModel but not AutoModelForCausalLM, + # such as chatglm, chatglm2 + try: + # google/flan-* models are based on an AutoModelForSeq2SeqLM. + if "T5Config" in str(type(config)): + model = AutoModelForSeq2SeqLM.from_config( + config, trust_remote_code=True + ) + else: + model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) + except NameError: + model = AutoModel.from_config(config, trust_remote_code=True) + linear_weights = get_compressed_list(model) + if os.path.exists(model_path): + # `model_path` is a local folder + base_pattern = os.path.join(model_path, "pytorch_model*.bin") + else: + # `model_path` is a cached Hugging Face repo + # We don't necessarily need to download the model' repo again if there is a cache. + # So check the default huggingface cache first. + model_path_temp = os.path.join( + os.path.expanduser("~"), + ".cache/huggingface/hub", + "models--" + model_path.replace("/", "--"), + "snapshots/", + ) + downloaded = False + if os.path.exists(model_path_temp): + temp_last_dir = os.listdir(model_path_temp)[-1] + model_path_temp = os.path.join(model_path_temp, temp_last_dir) + base_pattern = os.path.join(model_path_temp, "pytorch_model*.bin") + files = glob.glob(base_pattern) + if len(files) > 0: + downloaded = True + + if downloaded: + model_path = model_path_temp + else: + model_path = snapshot_download(model_path, revision=revision) + base_pattern = os.path.join(model_path, "pytorch_model*.bin") + + files = glob.glob(base_pattern) + use_safetensors = False + if len(files) == 0: + base_pattern = os.path.join(model_path, "*.safetensors") + files = glob.glob(base_pattern) + use_safetensors = True + if len(files) == 0: + raise ValueError( + f"Cannot find any model weight files. " + f"Please check your (cached) weight path: {model_path}" + ) + + compressed_state_dict = {} + if use_safetensors: + from safetensors.torch import load_file + for filename in tqdm(files): + if use_safetensors: + tmp_state_dict = load_file(filename) + else: + tmp_state_dict = torch.load( + filename, map_location=lambda storage, loc: storage + ) + for name in tmp_state_dict: + if name in linear_weights: + tensor = tmp_state_dict[name].to(device, dtype=torch_dtype) + compressed_state_dict[name] = compress( + tensor, default_compression_config + ) + else: + compressed_state_dict[name] = tmp_state_dict[name].to( + device, dtype=torch_dtype + ) + tmp_state_dict[name] = None + tensor = None + gc.collect() + torch.cuda.empty_cache() + if device == "xpu": + torch.xpu.empty_cache() + if device == "npu": + torch.npu.empty_cache() + + for name in model.state_dict(): + if name not in linear_weights: + set_module_tensor_to_device( + model, name, device, value=compressed_state_dict[name] + ) + apply_compressed_weight(model, compressed_state_dict, device) + + if torch_dtype == torch.float16: + model.half() + model.to(device) + model.eval() + + return model, tokenizer + + +def compress(tensor, config): + """Simulate group-wise quantization.""" + if not config.enabled: + return tensor + + group_size, num_bits, group_dim, symmetric = ( + config.group_size, + config.num_bits, + config.group_dim, + config.symmetric, + ) + assert num_bits <= 8 + + original_shape = tensor.shape + num_groups = (original_shape[group_dim] + group_size - 1) // group_size + new_shape = ( + original_shape[:group_dim] + + (num_groups, group_size) + + original_shape[group_dim + 1 :] + ) + + # Pad + pad_len = (group_size - original_shape[group_dim] % group_size) % group_size + if pad_len != 0: + pad_shape = ( + original_shape[:group_dim] + (pad_len,) + original_shape[group_dim + 1 :] + ) + tensor = torch.cat( + [tensor, torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device)], + dim=group_dim, + ) + data = tensor.view(new_shape) + + # Quantize + if symmetric: + B = 2 ** (num_bits - 1) - 1 + scale = B / torch.max(data.abs(), dim=group_dim + 1, keepdim=True)[0] + data = data * scale + data = data.clamp_(-B, B).round_().to(torch.int8) + return data, scale, original_shape + else: + B = 2**num_bits - 1 + mn = torch.min(data, dim=group_dim + 1, keepdim=True)[0] + mx = torch.max(data, dim=group_dim + 1, keepdim=True)[0] + + scale = B / (mx - mn) + data = data - mn + data.mul_(scale) + + data = data.clamp_(0, B).round_().to(torch.uint8) + return data, mn, scale, original_shape + + +def decompress(packed_data, config): + """Simulate group-wise dequantization.""" + if not config.enabled: + return packed_data + + group_size, num_bits, group_dim, symmetric = ( + config.group_size, + config.num_bits, + config.group_dim, + config.symmetric, + ) + + # Dequantize + if symmetric: + data, scale, original_shape = packed_data + data = data / scale + else: + data, mn, scale, original_shape = packed_data + data = data / scale + data.add_(mn) + + # Unpad + pad_len = (group_size - original_shape[group_dim] % group_size) % group_size + if pad_len: + padded_original_shape = ( + original_shape[:group_dim] + + (original_shape[group_dim] + pad_len,) + + original_shape[group_dim + 1 :] + ) + data = data.reshape(padded_original_shape) + indices = [slice(0, x) for x in original_shape] + return data[indices].contiguous() + else: + return data.view(original_shape) diff --git a/src/model/convert_fp16.py b/src/model/convert_fp16.py new file mode 100644 index 0000000000000000000000000000000000000000..efc40aa83bf3a85129a668387df86a41d925f13d --- /dev/null +++ b/src/model/convert_fp16.py @@ -0,0 +1,26 @@ +""" +Usage: +python3 -m fastchat.model.convert_fp16 --in in-folder --out out-folder +""" +import argparse + +from transformers import AutoTokenizer, AutoModelForCausalLM +import torch + + +def convert_fp16(in_checkpoint, out_checkpoint): + tokenizer = AutoTokenizer.from_pretrained(in_checkpoint, use_fast=False) + model = AutoModelForCausalLM.from_pretrained( + in_checkpoint, torch_dtype=torch.float16, low_cpu_mem_usage=True + ) + model.save_pretrained(out_checkpoint) + tokenizer.save_pretrained(out_checkpoint) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--in-checkpoint", type=str, help="Path to the model") + parser.add_argument("--out-checkpoint", type=str, help="Path to the output model") + args = parser.parse_args() + + convert_fp16(args.in_checkpoint, args.out_checkpoint) diff --git a/src/model/llama_condense_monkey_patch.py b/src/model/llama_condense_monkey_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..cb45a8bb6addf8a8506c847060e23dc65ae27995 --- /dev/null +++ b/src/model/llama_condense_monkey_patch.py @@ -0,0 +1,71 @@ +# Code adapted from https://huggingface.co/kaiokendev/superhot-13b-8k-no-rlhf-test/blob/main/llama_rope_scaled_monkey_patch.py + +from functools import partial + +import torch +import transformers +import transformers.models.llama.modeling_llama + + +class CondenseRotaryEmbedding(torch.nn.Module): + def __init__( + self, dim, ratio, max_position_embeddings=2048, base=10000, device=None + ): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + self.register_buffer("inv_freq", inv_freq) + + # Build here to make `torch.jit.trace` work. + self.ratio = ratio + max_position_embeddings *= ratio + self.max_seq_len_cached = max_position_embeddings + # print(f"Monkey Patching condense ratio {ratio}") + t = ( + torch.arange( + self.max_seq_len_cached, + device=self.inv_freq.device, + dtype=self.inv_freq.dtype, + ) + / ratio + ) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + dtype = torch.get_default_dtype() + self.register_buffer( + "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False + ) + self.register_buffer( + "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False + ) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. + if seq_len > self.max_seq_len_cached: + self.max_seq_len_cached = seq_len + t = ( + torch.arange( + self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype + ) + / self.ratio + ) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + self.register_buffer( + "cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False + ) + self.register_buffer( + "sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False + ) + return ( + self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + ) + + +def replace_llama_with_condense(ratio): + transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = partial( + CondenseRotaryEmbedding, ratio=ratio + ) diff --git a/src/model/make_delta.py b/src/model/make_delta.py new file mode 100644 index 0000000000000000000000000000000000000000..480ba8f1a2cb067d69df174ee7d00e5072ee5164 --- /dev/null +++ b/src/model/make_delta.py @@ -0,0 +1,48 @@ +""" +Make the delta weights by subtracting base weights. + +Usage: +python3 -m fastchat.model.make_delta --base ~/model_weights/llama-13b --target ~/model_weights/vicuna-13b --delta ~/model_weights/vicuna-13b-delta --hub-repo-id lmsys/vicuna-13b-delta-v1.1 +""" +import argparse + +import torch +from tqdm import tqdm +from transformers import AutoTokenizer, AutoModelForCausalLM + + +def make_delta(base_model_path, target_model_path, delta_path): + print(f"Loading the base model from {base_model_path}") + base = AutoModelForCausalLM.from_pretrained( + base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True + ) + + print(f"Loading the target model from {target_model_path}") + target = AutoModelForCausalLM.from_pretrained( + target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True + ) + target_tokenizer = AutoTokenizer.from_pretrained(target_model_path, use_fast=False) + + print("Calculating the delta") + for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): + assert name in base.state_dict() + param.data -= base.state_dict()[name] + + print(f"Saving the delta to {delta_path}") + if args.hub_repo_id: + kwargs = {"push_to_hub": True, "repo_id": args.hub_repo_id} + else: + kwargs = {} + target.save_pretrained(delta_path, **kwargs) + target_tokenizer.save_pretrained(delta_path, **kwargs) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--base-model-path", type=str, required=True) + parser.add_argument("--target-model-path", type=str, required=True) + parser.add_argument("--delta-path", type=str, required=True) + parser.add_argument("--hub-repo-id", type=str) + args = parser.parse_args() + + make_delta(args.base_model_path, args.target_model_path, args.delta_path) diff --git a/src/model/model_adapter.py b/src/model/model_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..33dd04cf7996a3c64515ba0bfc7dd4e75b4deb2b --- /dev/null +++ b/src/model/model_adapter.py @@ -0,0 +1,2524 @@ +"""Model adapter registration.""" + +import math +import os +import re +import sys +from typing import Dict, List, Optional +import warnings + +if sys.version_info >= (3, 9): + from functools import cache +else: + from functools import lru_cache as cache + +import psutil +import torch +from transformers import ( + AutoConfig, + AutoModel, + AutoModelForCausalLM, + AutoModelForSeq2SeqLM, + AutoTokenizer, + LlamaTokenizer, + LlamaForCausalLM, + T5Tokenizer, +) + +from src.constants import CPU_ISA +from src.conversation import Conversation, get_conv_template +from src.model.compression import load_compress_model +from src.model.llama_condense_monkey_patch import replace_llama_with_condense +from src.model.model_chatglm import generate_stream_chatglm +from src.model.model_codet5p import generate_stream_codet5p +from src.model.model_falcon import generate_stream_falcon +from src.model.model_yuan2 import generate_stream_yuan2 +from src.model.model_exllama import generate_stream_exllama +from src.model.model_xfastertransformer import generate_stream_xft +from src.model.model_cllm import generate_stream_cllm + +from src.model.monkey_patch_non_inplace import ( + replace_llama_attn_with_non_inplace_operations, +) +from src.modules.awq import AWQConfig, load_awq_quantized +from src.modules.exllama import ExllamaConfig, load_exllama_model +from src.modules.xfastertransformer import load_xft_model, XftConfig +from src.modules.gptq import GptqConfig, load_gptq_quantized +from src.utils import get_gpu_memory + +# Check an environment variable to check if we should be sharing Peft model +# weights. When false we treat all Peft models as separate. +peft_share_base_weights = ( + os.environ.get("PEFT_SHARE_BASE_WEIGHTS", "false").lower() == "true" +) + +ANTHROPIC_MODEL_LIST = ( + "claude-1", + "claude-2", + "claude-2.0", + "claude-2.1", + "claude-3-haiku-20240307", + "claude-3-haiku-20240307-vertex", + "claude-3-sonnet-20240229", + "claude-3-sonnet-20240229-vertex", + "claude-3-opus-20240229", + "claude-instant-1", + "claude-instant-1.2", +) + +OPENAI_MODEL_LIST = ( + "gpt-3.5-turbo", + "gpt-3.5-turbo-0301", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo-0125", + "gpt-4", + "gpt-4-0314", + "gpt-4-0613", + "gpt-4-turbo", + "gpt-4-1106-preview", + "gpt-4-0125-preview", + "gpt-4-turbo-browsing", + "gpt-4-turbo-2024-04-09", +) + + +class BaseModelAdapter: + """The base and the default model adapter.""" + + use_fast_tokenizer = True + + def match(self, model_path: str): + return True + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + try: + tokenizer = AutoTokenizer.from_pretrained( + model_path, + use_fast=self.use_fast_tokenizer, + revision=revision, + trust_remote_code=True, + ) + except TypeError: + tokenizer = AutoTokenizer.from_pretrained( + model_path, use_fast=False, revision=revision, trust_remote_code=True + ) + try: + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + trust_remote_code=True, + **from_pretrained_kwargs, + ) + except NameError: + model = AutoModel.from_pretrained( + model_path, + low_cpu_mem_usage=True, + trust_remote_code=True, + **from_pretrained_kwargs, + ) + return model, tokenizer + + def load_compress_model(self, model_path, device, torch_dtype, revision="main"): + return load_compress_model( + model_path, + device, + torch_dtype, + use_fast=self.use_fast_tokenizer, + revision=revision, + ) + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("one_shot") + + +# A global registry for all model adapters +# TODO (lmzheng): make it a priority queue. +model_adapters: List[BaseModelAdapter] = [] + + +def register_model_adapter(cls): + """Register a model adapter.""" + model_adapters.append(cls()) + + +@cache +def get_model_adapter(model_path: str) -> BaseModelAdapter: + """Get a model adapter for a model_path.""" + model_path_basename = os.path.basename(os.path.normpath(model_path)) + + # Try the basename of model_path at first + for adapter in model_adapters: + if adapter.match(model_path_basename) and type(adapter) != BaseModelAdapter: + return adapter + + # Then try the full path + for adapter in model_adapters: + if adapter.match(model_path): + return adapter + + raise ValueError(f"No valid model adapter for {model_path}") + + +def raise_warning_for_incompatible_cpu_offloading_configuration( + device: str, load_8bit: bool, cpu_offloading: bool +): + if cpu_offloading: + if not load_8bit: + warnings.warn( + "The cpu-offloading feature can only be used while also using 8-bit-quantization.\n" + "Use '--load-8bit' to enable 8-bit-quantization\n" + "Continuing without cpu-offloading enabled\n" + ) + return False + if not "linux" in sys.platform: + warnings.warn( + "CPU-offloading is only supported on linux-systems due to the limited compatability with the bitsandbytes-package\n" + "Continuing without cpu-offloading enabled\n" + ) + return False + if device != "cuda": + warnings.warn( + "CPU-offloading is only enabled when using CUDA-devices\n" + "Continuing without cpu-offloading enabled\n" + ) + return False + return cpu_offloading + + +def load_model( + model_path: str, + device: str = "cuda", + num_gpus: int = 1, + max_gpu_memory: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + load_8bit: bool = False, + cpu_offloading: bool = False, + gptq_config: Optional[GptqConfig] = None, + awq_config: Optional[AWQConfig] = None, + exllama_config: Optional[ExllamaConfig] = None, + xft_config: Optional[XftConfig] = None, + revision: str = "main", + debug: bool = False, +): + """Load a model from Hugging Face.""" + import accelerate + + # get model adapter + adapter = get_model_adapter(model_path) + + # Handle device mapping + cpu_offloading = raise_warning_for_incompatible_cpu_offloading_configuration( + device, load_8bit, cpu_offloading + ) + if device == "cpu": + kwargs = {"torch_dtype": torch.float32} + if CPU_ISA in ["avx512_bf16", "amx"]: + try: + import intel_extension_for_pytorch as ipex + + kwargs = {"torch_dtype": torch.bfloat16} + except ImportError: + warnings.warn( + "Intel Extension for PyTorch is not installed, it can be installed to accelerate cpu inference" + ) + elif device == "cuda": + kwargs = {"torch_dtype": torch.float16} + if num_gpus != 1: + kwargs["device_map"] = "auto" + if max_gpu_memory is None: + kwargs[ + "device_map" + ] = "sequential" # This is important for not the same VRAM sizes + available_gpu_memory = get_gpu_memory(num_gpus) + kwargs["max_memory"] = { + i: str(int(available_gpu_memory[i] * 0.85)) + "GiB" + for i in range(num_gpus) + } + else: + kwargs["max_memory"] = {i: max_gpu_memory for i in range(num_gpus)} + elif device == "mps": + kwargs = {"torch_dtype": torch.float16} + import transformers + + version = tuple(int(v) for v in transformers.__version__.split(".")) + if version < (4, 35, 0): + # NOTE: Recent transformers library seems to fix the mps issue, also + # it has made some changes causing compatibility issues with our + # original patch. So we only apply the patch for older versions. + + # Avoid bugs in mps backend by not using in-place operations. + replace_llama_attn_with_non_inplace_operations() + elif device == "xpu": + kwargs = {"torch_dtype": torch.bfloat16} + # Try to load ipex, while it looks unused, it links into torch for xpu support + try: + import intel_extension_for_pytorch as ipex + except ImportError: + warnings.warn( + "Intel Extension for PyTorch is not installed, but is required for xpu inference." + ) + elif device == "npu": + kwargs = {"torch_dtype": torch.float16} + # Try to load ipex, while it looks unused, it links into torch for xpu support + try: + import torch_npu + except ImportError: + warnings.warn("Ascend Extension for PyTorch is not installed.") + else: + raise ValueError(f"Invalid device: {device}") + + if cpu_offloading: + # raises an error on incompatible platforms + from transformers import BitsAndBytesConfig + + if "max_memory" in kwargs: + kwargs["max_memory"]["cpu"] = ( + str(math.floor(psutil.virtual_memory().available / 2**20)) + "Mib" + ) + kwargs["quantization_config"] = BitsAndBytesConfig( + load_in_8bit_fp32_cpu_offload=cpu_offloading + ) + kwargs["load_in_8bit"] = load_8bit + elif load_8bit: + if num_gpus != 1: + warnings.warn( + "8-bit quantization is not supported for multi-gpu inference." + ) + else: + model, tokenizer = adapter.load_compress_model( + model_path=model_path, + device=device, + torch_dtype=kwargs["torch_dtype"], + revision=revision, + ) + if debug: + print(model) + return model, tokenizer + elif awq_config and awq_config.wbits < 16: + assert ( + awq_config.wbits == 4 + ), "Currently we only support 4-bit inference for AWQ." + model, tokenizer = load_awq_quantized(model_path, awq_config, device) + if num_gpus != 1: + device_map = accelerate.infer_auto_device_map( + model, + max_memory=kwargs["max_memory"], + no_split_module_classes=[ + "OPTDecoderLayer", + "LlamaDecoderLayer", + "BloomBlock", + "MPTBlock", + "DecoderLayer", + ], + ) + model = accelerate.dispatch_model( + model, device_map=device_map, offload_buffers=True + ) + else: + model.to(device) + return model, tokenizer + elif gptq_config and gptq_config.wbits < 16: + model, tokenizer = load_gptq_quantized(model_path, gptq_config) + if num_gpus != 1: + device_map = accelerate.infer_auto_device_map( + model, + max_memory=kwargs["max_memory"], + no_split_module_classes=["LlamaDecoderLayer"], + ) + model = accelerate.dispatch_model( + model, device_map=device_map, offload_buffers=True + ) + else: + model.to(device) + return model, tokenizer + elif exllama_config: + model, tokenizer = load_exllama_model(model_path, exllama_config) + return model, tokenizer + elif xft_config: + model, tokenizer = load_xft_model(model_path, xft_config) + return model, tokenizer + kwargs["revision"] = revision + + if dtype is not None: # Overwrite dtype if it is provided in the arguments. + kwargs["torch_dtype"] = dtype + + if os.environ.get("FASTCHAT_USE_MODELSCOPE", "False").lower() == "true": + # download model from ModelScope hub, + # lazy import so that modelscope is not required for normal use. + try: + from modelscope.hub.snapshot_download import snapshot_download + + if not os.path.exists(model_path): + model_path = snapshot_download(model_id=model_path, revision=revision) + except ImportError as e: + warnings.warn( + "Use model from www.modelscope.cn need pip install modelscope" + ) + raise e + + # Load model + model, tokenizer = adapter.load_model(model_path, kwargs) + + if ( + device == "cpu" + and kwargs["torch_dtype"] is torch.bfloat16 + and CPU_ISA is not None + ): + model = ipex.optimize(model, dtype=kwargs["torch_dtype"]) + + if (device == "cuda" and num_gpus == 1 and not cpu_offloading) or device in ( + "mps", + "xpu", + "npu", + ): + model.to(device) + + if device == "xpu": + model = torch.xpu.optimize(model, dtype=kwargs["torch_dtype"], inplace=True) + + if debug: + print(model) + + return model, tokenizer + + +def get_conversation_template(model_path: str) -> Conversation: + """Get the default conversation template.""" + adapter = get_model_adapter(model_path) + return adapter.get_default_conv_template(model_path) + + +def get_generate_stream_function(model: torch.nn.Module, model_path: str): + """Get the generate_stream function for inference.""" + from fastchat.serve.inference import generate_stream + + model_type = str(type(model)).lower() + is_peft = "peft" in model_type + is_chatglm = "chatglm" in model_type + is_falcon = "rwforcausallm" in model_type + is_codet5p = "codet5p" in model_type + is_exllama = "exllama" in model_type + is_xft = "xft" in model_type + is_yuan = "yuan" in model_type + is_cllm = "consistency-llm" in model_path.lower() + + if is_chatglm: + return generate_stream_chatglm + elif is_falcon: + return generate_stream_falcon + elif is_codet5p: + return generate_stream_codet5p + elif is_exllama: + return generate_stream_exllama + elif is_xft: + return generate_stream_xft + elif is_yuan: + return generate_stream_yuan2 + elif is_cllm: + return generate_stream_cllm + + elif peft_share_base_weights and is_peft: + # Return a curried stream function that loads the right adapter + # according to the model_name available in this context. This ensures + # the right weights are available. + @torch.inference_mode() + def generate_stream_peft( + model, + tokenizer, + params: Dict, + device: str, + context_len: int, + stream_interval: int = 2, + judge_sent_end: bool = False, + ): + model.set_adapter(model_path) + base_model_type = str(type(model.base_model.model)) + is_chatglm = "chatglm" in base_model_type + is_falcon = "rwforcausallm" in base_model_type + is_codet5p = "codet5p" in base_model_type + is_exllama = "exllama" in base_model_type + is_xft = "xft" in base_model_type + is_yuan = "yuan" in base_model_type + is_cllm = "consistency-llm" in model_path.lower() + + generate_stream_function = generate_stream + if is_chatglm: + generate_stream_function = generate_stream_chatglm + elif is_falcon: + generate_stream_function = generate_stream_falcon + elif is_codet5p: + generate_stream_function = generate_stream_codet5p + elif is_exllama: + generate_stream_function = generate_stream_exllama + elif is_xft: + generate_stream_function = generate_stream_xft + elif is_yuan: + generate_stream_function = generate_stream_yuan2 + elif is_cllm: + generate_stream_function = generate_stream_cllm + for x in generate_stream_function( + model, + tokenizer, + params, + device, + context_len, + stream_interval, + judge_sent_end, + ): + yield x + + return generate_stream_peft + else: + return generate_stream + + +def add_model_args(parser): + parser.add_argument( + "--model-path", + type=str, + default="lmsys/vicuna-7b-v1.5", + help="The path to the weights. This can be a local folder or a Hugging Face repo ID.", + ) + parser.add_argument( + "--revision", + type=str, + default="main", + help="Hugging Face Hub model revision identifier", + ) + parser.add_argument( + "--device", + type=str, + choices=["cpu", "cuda", "mps", "xpu", "npu"], + default="cuda", + help="The device type", + ) + parser.add_argument( + "--gpus", + type=str, + default=None, + help="A single GPU like 1 or multiple GPUs like 0,2", + ) + parser.add_argument("--num-gpus", type=int, default=1) + parser.add_argument( + "--max-gpu-memory", + type=str, + help="The maximum memory per GPU for storing model weights. Use a string like '13Gib'", + ) + parser.add_argument( + "--dtype", + type=str, + choices=["float32", "float16", "bfloat16"], + help="Override the default dtype. If not set, it will use float16 on GPU and float32 on CPU.", + default=None, + ) + parser.add_argument( + "--load-8bit", action="store_true", help="Use 8-bit quantization" + ) + parser.add_argument( + "--cpu-offloading", + action="store_true", + help="Only when using 8-bit quantization: Offload excess weights to the CPU that don't fit on the GPU", + ) + parser.add_argument( + "--gptq-ckpt", + type=str, + default=None, + help="Used for GPTQ. The path to the local GPTQ checkpoint.", + ) + parser.add_argument( + "--gptq-wbits", + type=int, + default=16, + choices=[2, 3, 4, 8, 16], + help="Used for GPTQ. #bits to use for quantization", + ) + parser.add_argument( + "--gptq-groupsize", + type=int, + default=-1, + help="Used for GPTQ. Groupsize to use for quantization; default uses full row.", + ) + parser.add_argument( + "--gptq-act-order", + action="store_true", + help="Used for GPTQ. Whether to apply the activation order GPTQ heuristic", + ) + parser.add_argument( + "--awq-ckpt", + type=str, + default=None, + help="Used for AWQ. Load quantized model. The path to the local AWQ checkpoint.", + ) + parser.add_argument( + "--awq-wbits", + type=int, + default=16, + choices=[4, 16], + help="Used for AWQ. #bits to use for AWQ quantization", + ) + parser.add_argument( + "--awq-groupsize", + type=int, + default=-1, + help="Used for AWQ. Groupsize to use for AWQ quantization; default uses full row.", + ) + parser.add_argument( + "--enable-exllama", + action="store_true", + help="Used for exllamabv2. Enable exllamaV2 inference framework.", + ) + parser.add_argument( + "--exllama-max-seq-len", + type=int, + default=4096, + help="Used for exllamabv2. Max sequence length to use for exllamav2 framework; default 4096 sequence length.", + ) + parser.add_argument( + "--exllama-gpu-split", + type=str, + default=None, + help="Used for exllamabv2. Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7", + ) + parser.add_argument( + "--exllama-cache-8bit", + action="store_true", + help="Used for exllamabv2. Use 8-bit cache to save VRAM.", + ) + parser.add_argument( + "--enable-xft", + action="store_true", + help="Used for xFasterTransformer Enable xFasterTransformer inference framework.", + ) + parser.add_argument( + "--xft-max-seq-len", + type=int, + default=4096, + help="Used for xFasterTransformer. Max sequence length to use for xFasterTransformer framework; default 4096 sequence length.", + ) + parser.add_argument( + "--xft-dtype", + type=str, + choices=["fp16", "bf16", "int8", "bf16_fp16", "bf16_int8"], + help="Override the default dtype. If not set, it will use bfloat16 for first token and float16 next tokens on CPU.", + default=None, + ) + + +def remove_parent_directory_name(model_path): + """Remove parent directory name.""" + if model_path[-1] == "/": + model_path = model_path[:-1] + return model_path.split("/")[-1] + + +peft_model_cache = {} + + +class PeftModelAdapter: + """Loads any "peft" model and it's base model.""" + + def match(self, model_path: str): + """Accepts any model path with "peft" in the name""" + if os.path.exists(os.path.join(model_path, "adapter_config.json")): + return True + return "peft" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + """Loads the base model then the (peft) adapter weights""" + from peft import PeftConfig, PeftModel + + config = PeftConfig.from_pretrained(model_path) + base_model_path = config.base_model_name_or_path + if "peft" in base_model_path: + raise ValueError( + f"PeftModelAdapter cannot load a base model with 'peft' in the name: {config.base_model_name_or_path}" + ) + + # Basic proof of concept for loading peft adapters that share the base + # weights. This is pretty messy because Peft re-writes the underlying + # base model and internally stores a map of adapter layers. + # So, to make this work we: + # 1. Cache the first peft model loaded for a given base models. + # 2. Call `load_model` for any follow on Peft models. + # 3. Make sure we load the adapters by the model_path. Why? This is + # what's accessible during inference time. + # 4. In get_generate_stream_function, make sure we load the right + # adapter before doing inference. This *should* be safe when calls + # are blocked the same semaphore. + if peft_share_base_weights: + if base_model_path in peft_model_cache: + model, tokenizer = peft_model_cache[base_model_path] + # Super important: make sure we use model_path as the + # `adapter_name`. + model.load_adapter(model_path, adapter_name=model_path) + else: + base_adapter = get_model_adapter(base_model_path) + base_model, tokenizer = base_adapter.load_model( + base_model_path, from_pretrained_kwargs + ) + # Super important: make sure we use model_path as the + # `adapter_name`. + model = PeftModel.from_pretrained( + base_model, model_path, adapter_name=model_path + ) + peft_model_cache[base_model_path] = (model, tokenizer) + return model, tokenizer + + # In the normal case, load up the base model weights again. + base_adapter = get_model_adapter(base_model_path) + base_model, tokenizer = base_adapter.load_model( + base_model_path, from_pretrained_kwargs + ) + model = PeftModel.from_pretrained(base_model, model_path) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + """Uses the conv template of the base model""" + from peft import PeftConfig, PeftModel + + config = PeftConfig.from_pretrained(model_path) + if "peft" in config.base_model_name_or_path: + raise ValueError( + f"PeftModelAdapter cannot load a base model with 'peft' in the name: {config.base_model_name_or_path}" + ) + base_model_path = config.base_model_name_or_path + base_adapter = get_model_adapter(base_model_path) + return base_adapter.get_default_conv_template(config.base_model_name_or_path) + + +class VicunaAdapter(BaseModelAdapter): + "Model adapter for Vicuna models (e.g., lmsys/vicuna-7b-v1.5)" "" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "vicuna" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained( + model_path, use_fast=self.use_fast_tokenizer, revision=revision + ) + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ) + self.raise_warning_for_old_weights(model) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + if "v0" in remove_parent_directory_name(model_path): + return get_conv_template("one_shot") + return get_conv_template("vicuna_v1.1") + + def raise_warning_for_old_weights(self, model): + if isinstance(model, LlamaForCausalLM) and model.model.vocab_size > 32000: + warnings.warn( + "\nYou are probably using the old Vicuna-v0 model, " + "which will generate unexpected results with the " + "current fastchat.\nYou can try one of the following methods:\n" + "1. Upgrade your weights to the new Vicuna-v1.3: https://github.com/lm-sys/FastChat#vicuna-weights.\n" + "2. Use the old conversation template by `python3 -m fastchat.serve.cli --model-path /path/to/vicuna-v0 --conv-template one_shot`\n" + "3. Downgrade fschat to fschat==0.1.10 (Not recommended).\n" + ) + + +class AiroborosAdapter(BaseModelAdapter): + """The model adapter for jondurbin/airoboros-*""" + + def match(self, model_path: str): + if re.search(r"airoboros|spicyboros", model_path, re.I): + return True + return False + + def get_default_conv_template(self, model_path: str) -> Conversation: + if "-3." in model_path or "-3p" in model_path: + return get_conv_template("airoboros_v3") + if "spicyboros" in model_path or re.search(r"-(2\.[2-9]+)", model_path): + return get_conv_template("airoboros_v2") + return get_conv_template("airoboros_v1") + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + if "mpt" not in model_path.lower(): + return super().load_model(model_path, from_pretrained_kwargs) + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + trust_remote_code=True, + max_seq_len=8192, + **from_pretrained_kwargs, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True, use_fast=True + ) + return model, tokenizer + + +class LongChatAdapter(BaseModelAdapter): + "Model adapter for LongChat models (e.g., lmsys/longchat-7b-16k)." + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "longchat" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + + # Apply monkey patch, TODO(Dacheng): Add flash attention support + config = AutoConfig.from_pretrained(model_path, revision=revision) + replace_llama_with_condense(config.rope_scaling["factor"]) + + tokenizer = AutoTokenizer.from_pretrained( + model_path, use_fast=self.use_fast_tokenizer, revision=revision + ) + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("vicuna_v1.1") + + +class GoogleT5Adapter(BaseModelAdapter): + """The model adapter for google/Flan based models, such as Salesforce/codet5p-6b, lmsys/fastchat-t5-3b-v1.0, flan-t5-*, flan-ul2""" + + def match(self, model_path: str): + return any( + model_str in model_path.lower() + for model_str in ["flan-", "fastchat-t5", "codet5p"] + ) + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = T5Tokenizer.from_pretrained(model_path, revision=revision) + model = AutoModelForSeq2SeqLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + trust_remote_code=True, + **from_pretrained_kwargs, + ) + return model, tokenizer + + +class KoalaAdapter(BaseModelAdapter): + """The model adapter for Koala""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "koala" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("koala_v1") + + +class AlpacaAdapter(BaseModelAdapter): + """The model adapter for Alpaca""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "alpaca" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("alpaca") + + +class ChatGLMAdapter(BaseModelAdapter): + """The model adapter for THUDM/chatglm-6b, THUDM/chatglm2-6b""" + + def match(self, model_path: str): + return "chatglm" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + if "chatglm3" in model_path.lower(): + tokenizer = AutoTokenizer.from_pretrained( + model_path, + encode_special_tokens=True, + trust_remote_code=True, + revision=revision, + ) + else: + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True, revision=revision + ) + model = AutoModel.from_pretrained( + model_path, trust_remote_code=True, **from_pretrained_kwargs + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + model_path = model_path.lower() + if "chatglm2" in model_path.lower(): + return get_conv_template("chatglm2") + if "chatglm3" in model_path.lower(): + return get_conv_template("chatglm3") + return get_conv_template("chatglm") + + +class CodeGeexAdapter(BaseModelAdapter): + """The model adapter for THUDM/codegeex-6b, THUDM/codegeex2-6b""" + + def match(self, model_path: str): + return "codegeex" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True, revision=revision + ) + model = AutoModel.from_pretrained( + model_path, trust_remote_code=True, **from_pretrained_kwargs + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("codegeex") + + +class DollyV2Adapter(BaseModelAdapter): + """The model adapter for databricks/dolly-v2-12b""" + + def match(self, model_path: str): + return "dolly-v2" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained(model_path, revision=revision) + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ) + # 50277 means "### End" + tokenizer.eos_token_id = 50277 + model.config.eos_token_id = tokenizer.eos_token_id + model.config.pad_token_id = tokenizer.pad_token_id + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("dolly_v2") + + +class OasstPythiaAdapter(BaseModelAdapter): + """The model adapter for OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5""" + + def match(self, model_path: str): + model_path = model_path.lower() + return "oasst" in model_path and "pythia" in model_path + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("oasst_pythia") + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + model, tokenizer = super().load_model(model_path, from_pretrained_kwargs) + model.config.eos_token_id = tokenizer.eos_token_id + model.config.pad_token_id = tokenizer.pad_token_id + return model, tokenizer + + +class OasstLLaMAAdapter(BaseModelAdapter): + """The model adapter for OpenAssistant/oasst-sft-7-llama-30b""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + model_path = model_path.lower() + if "openassistant-sft-7-llama-30b-hf" in model_path: + return True + return "oasst" in model_path and "pythia" not in model_path + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("oasst_llama") + + +class OpenChat35Adapter(BaseModelAdapter): + """The model adapter for OpenChat 3.5 (e.g. openchat/openchat_3.5)""" + + def match(self, model_path: str): + if "openchat" in model_path.lower() and "3.5" in model_path.lower(): + return True + elif "starling-lm" in model_path.lower(): + return True + return False + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("openchat_3.5") + + +class TenyxChatAdapter(BaseModelAdapter): + """The model adapter for TenyxChat (e.g. tenyx/TenyxChat-7B-v1)""" + + def match(self, model_path: str): + return "tenyxchat" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("tenyxchat") + + +class PythiaAdapter(BaseModelAdapter): + """The model adapter for any EleutherAI/pythia model""" + + def match(self, model_path: str): + return "pythia" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + model, tokenizer = super().load_model(model_path, from_pretrained_kwargs) + model.config.eos_token_id = tokenizer.eos_token_id + model.config.pad_token_id = tokenizer.pad_token_id + return model, tokenizer + + +class StableLMAdapter(BaseModelAdapter): + """The model adapter for StabilityAI/stablelm-tuned-alpha-7b""" + + def match(self, model_path: str): + return "stablelm" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("stablelm") + + +class MPTAdapter(BaseModelAdapter): + """The model adapter for MPT series (mosaicml/mpt-7b-chat, mosaicml/mpt-30b-chat)""" + + def match(self, model_path: str): + model_path = model_path.lower() + return "mpt" in model_path and not "airoboros" in model_path + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + trust_remote_code=True, + max_seq_len=8192, + **from_pretrained_kwargs, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True, revision=revision + ) + model.config.eos_token_id = tokenizer.eos_token_id + model.config.pad_token_id = tokenizer.pad_token_id + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + model_path = model_path.lower() + if "mpt-7b-chat" in model_path: + return get_conv_template("mpt-7b-chat") + elif "mpt-30b-chat" in model_path: + return get_conv_template("mpt-30b-chat") + elif "mpt-30b-instruct" in model_path: + return get_conv_template("mpt-30b-instruct") + else: + print( + "Warning: Loading base MPT model with `zero_shot` conversation configuration. " + "If this is not desired, inspect model configurations and names." + ) + return get_conv_template("zero_shot") + + +class BaizeAdapter(BaseModelAdapter): + """The model adapter for project-baize/baize-v2-7b""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "baize" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("baize") + + +class RwkvAdapter(BaseModelAdapter): + """The model adapter for BlinkDL/RWKV-4-Raven""" + + def match(self, model_path: str): + return "rwkv-4" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + from fastchat.model.rwkv_model import RwkvModel + + model = RwkvModel(model_path) + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained( + "EleutherAI/pythia-160m", revision=revision + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("rwkv") + + +class OpenBuddyAdapter(BaseModelAdapter): + """The model adapter for OpenBuddy/openbuddy-7b-v1.1-bf16-enc""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "openbuddy" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("openbuddy") + + +class PhoenixAdapter(BaseModelAdapter): + """The model adapter for FreedomIntelligence/phoenix-inst-chat-7b""" + + def match(self, model_path: str): + return "phoenix" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("phoenix") + + +class ReaLMAdapter(BaseModelAdapter): + """The model adapter for FreedomIntelligence/ReaLM-7b""" + + def match(self, model_path: str): + return "ReaLM" in model_path + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) + model = AutoModelForCausalLM.from_pretrained( + model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("ReaLM-7b-v1") + + +class ChatGPTAdapter(BaseModelAdapter): + """The model adapter for ChatGPT""" + + def match(self, model_path: str): + return model_path in OPENAI_MODEL_LIST + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + raise NotImplementedError() + + def get_default_conv_template(self, model_path: str) -> Conversation: + if "browsing" in model_path: + return get_conv_template("api_based_default") + if "gpt-4-turbo-2024-04-09" in model_path: + return get_conv_template("gpt-4-turbo-2024-04-09") + return get_conv_template("chatgpt") + + +class AzureOpenAIAdapter(BaseModelAdapter): + """The model adapter for Azure OpenAI""" + + def match(self, model_path: str): + return model_path in ("azure-gpt-35-turbo", "azure-gpt-4") + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + raise NotImplementedError() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("chatgpt") + + +class PplxAIAdapter(BaseModelAdapter): + """The model adapter for Perplexity AI""" + + def match(self, model_path: str): + return model_path in ( + "pplx-7b-online", + "pplx-70b-online", + ) + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + raise NotImplementedError() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("pplxai") + + +class ClaudeAdapter(BaseModelAdapter): + """The model adapter for Claude""" + + def match(self, model_path: str): + return model_path in ANTHROPIC_MODEL_LIST + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + raise NotImplementedError() + + def get_default_conv_template(self, model_path: str) -> Conversation: + if "claude-3-haiku" in model_path: + return get_conv_template("claude-3-haiku-20240307") + if "claude-3-sonnet" in model_path: + return get_conv_template("claude-3-sonnet-20240229") + if "claude-3-opus" in model_path: + return get_conv_template("claude-3-opus-20240229") + return get_conv_template("claude") + + +class BardAdapter(BaseModelAdapter): + """The model adapter for Bard""" + + def match(self, model_path: str): + return model_path == "bard" + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + raise NotImplementedError() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("bard") + + +class PaLM2Adapter(BaseModelAdapter): + """The model adapter for PaLM2""" + + def match(self, model_path: str): + return model_path == "palm-2" + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + raise NotImplementedError() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("bard") + + +class GeminiAdapter(BaseModelAdapter): + """The model adapter for Gemini""" + + def match(self, model_path: str): + return "gemini" in model_path.lower() or "bard" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + raise NotImplementedError() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("gemini") + + +class GeminiDevAdapter(BaseModelAdapter): + """The model adapter for Gemini 1.5 Pro""" + + def match(self, model_path: str): + return "gemini-1.5-pro" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + raise NotImplementedError() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("gemini-dev") + + +class BiLLaAdapter(BaseModelAdapter): + """The model adapter for Neutralzz/BiLLa-7B-SFT""" + + def match(self, model_path: str): + return "billa" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("billa") + + +class RedPajamaINCITEAdapter(BaseModelAdapter): + """The model adapter for togethercomputer/RedPajama-INCITE-7B-Chat""" + + def match(self, model_path: str): + return "redpajama-incite" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained(model_path, revision=revision) + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("redpajama-incite") + + +class H2OGPTAdapter(BaseModelAdapter): + """The model adapter for h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "h2ogpt" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("h2ogpt") + + +class RobinAdapter(BaseModelAdapter): + """The model adapter for LMFlow/Full-Robin-7b-v2""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "robin" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("Robin") + + +class SnoozyAdapter(BaseModelAdapter): + """The model adapter for nomic-ai/gpt4all-13b-snoozy""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + model_path = model_path.lower() + return "gpt4all" in model_path and "snoozy" in model_path + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("snoozy") + + +class WizardLMAdapter(BaseModelAdapter): + """The model adapter for WizardLM/WizardLM-13B-V1.0""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "wizardlm" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + model_path = model_path.lower() + if "13b" in model_path or "30b" in model_path or "70b" in model_path: + return get_conv_template("vicuna_v1.1") + else: + # TODO: use the recommended template for 7B + # (https://huggingface.co/WizardLM/WizardLM-13B-V1.0) + return get_conv_template("one_shot") + + +class ManticoreAdapter(BaseModelAdapter): + """The model adapter for openaccess-ai-collective/manticore-13b-chat-pyg""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "manticore" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("manticore") + + +class GuanacoAdapter(BaseModelAdapter): + """The model adapter for timdettmers/guanaco-33b-merged""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "guanaco" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained( + model_path, use_fast=self.use_fast_tokenizer, revision=revision + ) + model = AutoModelForCausalLM.from_pretrained( + model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs + ) + # Fix a bug in tokenizer config + tokenizer.eos_token_id = model.config.eos_token_id + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("zero_shot") + + +class ChangGPTAdapter(BaseModelAdapter): + """The model adapter for lcw99/polyglot-ko-12.8b-chang-instruct-chat""" + + def match(self, model_path: str): + model_path = model_path.lower() + return "polyglot" in model_path and "chang" in model_path + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("polyglot_changgpt") + + +class CamelAdapter(BaseModelAdapter): + """The model adapter for camel-ai/CAMEL-13B-Combined-Data""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "camel" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("vicuna_v1.1") + + +class TuluAdapter(BaseModelAdapter): + """The model adapter for allenai/tulu-30b""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "tulu" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("tulu") + + +class FalconAdapter(BaseModelAdapter): + """The model adapter for tiiuae/falcon-40b""" + + def match(self, model_path: str): + return "falcon" in model_path.lower() and "chat" not in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + # Strongly suggest using bf16, which is recommended by the author of Falcon + tokenizer = AutoTokenizer.from_pretrained(model_path, revision=revision) + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + trust_remote_code=True, + **from_pretrained_kwargs, + ) + # In Falcon tokenizer config and special config there is not any pad token + # Setting `pad_token_id` to 9, which corresponds to special token '>>SUFFIX<<' + tokenizer.pad_token_id = 9 + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("falcon") + + +class FalconChatAdapter(BaseModelAdapter): + def match(self, model_path: str): + return "falcon" in model_path.lower() and "chat" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("falcon-chat") + + +class TigerBotAdapter(BaseModelAdapter): + """The model adapter for TigerResearch/tigerbot-7b-sft""" + + def match(self, model_path: str): + return "tigerbot" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained( + model_path, + trust_remote_code=True, + revision=revision, + ) + model = AutoModelForCausalLM.from_pretrained( + model_path, + trust_remote_code=True, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("tigerbot") + + +class BaichuanAdapter(BaseModelAdapter): + """The model adapter for Baichuan models (e.g., baichuan-inc/Baichuan-7B)""" + + def match(self, model_path: str): + return "baichuan" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True, revision=revision + ) + model = AutoModelForCausalLM.from_pretrained( + model_path, + trust_remote_code=True, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + # for Baichuan-13B-Chat + if "chat" in model_path.lower(): + if "baichuan2" in model_path.lower(): + return get_conv_template("baichuan2-chat") + return get_conv_template("baichuan-chat") + return get_conv_template("zero_shot") + + +class XGenAdapter(BaseModelAdapter): + """The model adapter for Salesforce/xgen-7b""" + + def match(self, model_path: str): + return "xgen" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + trust_remote_code=True, + **from_pretrained_kwargs, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True, revision=revision + ) + model.config.eos_token_id = 50256 + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("xgen") + + +class NousHermesAdapter(BaseModelAdapter): + """The model adapter for NousResearch/Nous-Hermes-13b""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "nous-hermes" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("alpaca") + + +class InternLMChatAdapter(BaseModelAdapter): + """The model adapter for internlm/internlm-chat-7b""" + + def match(self, model_path: str): + return "internlm" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + trust_remote_code=True, + **from_pretrained_kwargs, + ) + model = model.eval() + if "8k" in model_path.lower(): + model.config.max_sequence_length = 8192 + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True, revision=revision + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("internlm-chat") + + +class StarChatAdapter(BaseModelAdapter): + """The model adapter for HuggingFaceH4/starchat-beta""" + + def match(self, model_path: str): + return "starchat" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("starchat") + + +class MistralAdapter(BaseModelAdapter): + """The model adapter for Mistral AI models""" + + def match(self, model_path: str): + return "mistral" in model_path.lower() or "mixtral" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + model, tokenizer = super().load_model(model_path, from_pretrained_kwargs) + model.config.eos_token_id = tokenizer.eos_token_id + model.config.pad_token_id = tokenizer.pad_token_id + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("mistral") + + +class Llama2Adapter(BaseModelAdapter): + """The model adapter for Llama-2 (e.g., meta-llama/Llama-2-7b-hf)""" + + def match(self, model_path: str): + return "llama-2" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + model, tokenizer = super().load_model(model_path, from_pretrained_kwargs) + model.config.eos_token_id = tokenizer.eos_token_id + model.config.pad_token_id = tokenizer.pad_token_id + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("llama-2") + + +class Llama3Adapter(BaseModelAdapter): + """The model adapter for Llama-3 (e.g., meta-llama/Meta-Llama-3-8B-Instruct)""" + + def match(self, model_path: str): + return "llama-3" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + model, tokenizer = super().load_model(model_path, from_pretrained_kwargs) + model.config.eos_token_id = tokenizer.eos_token_id + model.config.pad_token_id = tokenizer.pad_token_id + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("llama-3") + + +class CuteGPTAdapter(BaseModelAdapter): + """The model adapter for CuteGPT""" + + def match(self, model_path: str): + return "cutegpt" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + tokenizer = LlamaTokenizer.from_pretrained(model_path) + model = AutoModelForCausalLM.from_pretrained( + model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs + ) + tokenizer.eos_token_id = tokenizer.convert_tokens_to_ids("") + model.config.eos_token_id = tokenizer.eos_token_id + model.config.pad_token_id = tokenizer.eos_token_id + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("cutegpt") + + +class OpenOrcaAdapter(BaseModelAdapter): + """Model adapter for Open-Orca models which may use different prompt templates + - (e.g. Open-Orca/OpenOrcaxOpenChat-Preview2-13B, Open-Orca/Mistral-7B-OpenOrca) + - `OpenOrcaxOpenChat-Preview2-13B` uses their "OpenChat Llama2 V1" prompt template. + - [Open-Orca/OpenOrcaxOpenChat-Preview2-13B #Prompt Template](https://huggingface.co/Open-Orca/OpenOrcaxOpenChat-Preview2-13B#prompt-template) + - `Mistral-7B-OpenOrca` uses the [OpenAI's Chat Markup Language (ChatML)](https://github.com/openai/openai-python/blob/main/chatml.md) + format, with <|im_start|> and <|im_end|> tokens added to support this. + - [Open-Orca/Mistral-7B-OpenOrca #Prompt Template](https://huggingface.co/Open-Orca/Mistral-7B-OpenOrca#prompt-template) + """ + + use_fast_tokenizer = False + + def match(self, model_path: str): + return ( + "mistral-7b-openorca" in model_path.lower() + or "openorca" in model_path.lower() + ) + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained( + model_path, use_fast=self.use_fast_tokenizer, revision=revision + ) + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ).eval() + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + if "mistral-7b-openorca" in model_path.lower(): + return get_conv_template("mistral-7b-openorca") + return get_conv_template("open-orca") + + +class DolphinAdapter(OpenOrcaAdapter): + """Model adapter for ehartford/dolphin-2.2.1-mistral-7b""" + + def match(self, model_path: str): + return "dolphin" in model_path.lower() and "mistral" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("dolphin-2.2.1-mistral-7b") + + +class Hermes2Adapter(BaseModelAdapter): + """Model adapter for teknium/OpenHermes-2.5-Mistral-7B and teknium/OpenHermes-2-Mistral-7B models""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return any( + model_str in model_path.lower() + for model_str in ["openhermes-2.5-mistral-7b", "openhermes-2-mistral-7b"] + ) + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained( + model_path, use_fast=self.use_fast_tokenizer, revision=revision + ) + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ).eval() + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("OpenHermes-2.5-Mistral-7B") + + +class NousHermes2MixtralAdapter(BaseModelAdapter): + """Model adapter for NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO model""" + + def match(self, model_path: str): + return any( + model_str in model_path.lower() + for model_str in [ + "nous-hermes-2-mixtral-8x7b-dpo", + "nous-hermes-2-mixtral-8x7b-sft", + ] + ) + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("Nous-Hermes-2-Mixtral-8x7B-DPO") + + +class WizardCoderAdapter(BaseModelAdapter): + """The model adapter for WizardCoder (e.g., WizardLM/WizardCoder-Python-34B-V1.0)""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "wizardcoder" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + # Same as Alpaca, see : + # https://github.com/nlpxucan/WizardLM/blob/main/WizardCoder/src/inference_wizardcoder.py#L60 + return get_conv_template("alpaca") + + +class QwenChatAdapter(BaseModelAdapter): + """The model adapter for Qwen/Qwen-7B-Chat + To run this model, you need to ensure additional flash attention installation: + ``` bash + git clone https://github.com/Dao-AILab/flash-attention + cd flash-attention && pip install . + pip install csrc/layer_norm + pip install csrc/rotary + ``` + + Since from 2.0, the following change happened + - `flash_attn_unpadded_func` -> `flash_attn_varlen_func` + - `flash_attn_unpadded_qkvpacked_func` -> `flash_attn_varlen_qkvpacked_func` + - `flash_attn_unpadded_kvpacked_func` -> `flash_attn_varlen_kvpacked_func` + You may need to revise the code in: https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/modeling_qwen.py#L69 + to from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_unpadded_func + """ + + def match(self, model_path: str): + return "qwen" in model_path.lower() + + def float_set(self, config, option): + config.bf16 = False + config.fp16 = False + config.fp32 = False + + if option == "bf16": + config.bf16 = True + elif option == "fp16": + config.fp16 = True + elif option == "fp32": + config.fp32 = True + else: + print("Invalid option. Please choose one from 'bf16', 'fp16' and 'fp32'.") + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + from transformers.generation import GenerationConfig + + revision = from_pretrained_kwargs.get("revision", "main") + config = AutoConfig.from_pretrained( + model_path, + trust_remote_code=True, + ) + # NOTE: if you use the old version of model file, please remove the comments below + # config.use_flash_attn = False + self.float_set(config, "fp16") + generation_config = GenerationConfig.from_pretrained( + model_path, trust_remote_code=True + ) + model = AutoModelForCausalLM.from_pretrained( + model_path, + config=config, + low_cpu_mem_usage=True, + trust_remote_code=True, + **from_pretrained_kwargs, + ).eval() + if hasattr(model.config, "use_dynamic_ntk") and model.config.use_dynamic_ntk: + model.config.max_sequence_length = 16384 + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True, revision=revision + ) + tokenizer.eos_token_id = config.eos_token_id + tokenizer.bos_token_id = config.bos_token_id + tokenizer.pad_token_id = generation_config.pad_token_id + model.config.eos_token_id = tokenizer.eos_token_id + model.config.bos_token_id = tokenizer.bos_token_id + model.config.pad_token_id = tokenizer.pad_token_id + + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("qwen-7b-chat") + + +class SmaugChatAdapter(BaseModelAdapter): + """The model adapter for abacusai/Smaug-2-72B.""" + + def match(self, model_path: str): + return "smaug" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("qwen-7b-chat") + + +class BGEAdapter(BaseModelAdapter): + """The model adapter for BGE (e.g., BAAI/bge-large-en-v1.5)""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "bge" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + model = AutoModel.from_pretrained( + model_path, + **from_pretrained_kwargs, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True, revision=revision + ) + if hasattr(model.config, "max_position_embeddings") and hasattr( + tokenizer, "model_max_length" + ): + model.config.max_sequence_length = min( + model.config.max_position_embeddings, tokenizer.model_max_length + ) + model.use_cls_pooling = True + model.eval() + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("one_shot") + + +class E5Adapter(BaseModelAdapter): + """The model adapter for E5 (e.g., intfloat/e5-large-v2)""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "e5-" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + model = AutoModel.from_pretrained( + model_path, + **from_pretrained_kwargs, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True, revision=revision + ) + if hasattr(model.config, "max_position_embeddings") and hasattr( + tokenizer, "model_max_length" + ): + model.config.max_sequence_length = min( + model.config.max_position_embeddings, tokenizer.model_max_length + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("one_shot") + + +class AquilaChatAdapter(BaseModelAdapter): + """The model adapter for BAAI/Aquila + + Now supports: + - BAAI/AquilaChat-7B + - BAAI/AquilaChat2-7B + - BAAI/AquilaChat2-34B + """ + + def match(self, model_path: str): + return "aquila" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + trust_remote_code=True, + **from_pretrained_kwargs, + ) + model = model.eval() + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True, revision=revision + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + model_path = model_path.lower() + # See: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L347 + if "aquilachat2" in model_path: + if "16k" in model_path: + return get_conv_template("aquila") + elif "34b" in model_path: + return get_conv_template("aquila-legacy") + else: + return get_conv_template("aquila-v1") + else: + return get_conv_template("aquila-chat") + + +class Lamma2ChineseAdapter(BaseModelAdapter): + """The model adapter for FlagAlpha/LLama2-Chinese sft""" + + def match(self, model_path: str): + return "llama2-chinese" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained( + model_path, + trust_remote_code=True, + revision=revision, + ) + model = AutoModelForCausalLM.from_pretrained( + model_path, + trust_remote_code=True, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("llama2-chinese") + + +class Lamma2ChineseAlpacaAdapter(BaseModelAdapter): + """The model adapter for ymcui/Chinese-LLaMA-Alpaca sft""" + + def match(self, model_path: str): + return "chinese-alpaca" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained( + model_path, + trust_remote_code=True, + revision=revision, + ) + model = AutoModelForCausalLM.from_pretrained( + model_path, + trust_remote_code=True, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("chinese-alpaca2") + + +class VigogneAdapter(BaseModelAdapter): + """The model adapter for vigogne (e.g., bofenghuang/vigogne-2-7b-chat)""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return bool(re.search(r"vigogne|vigostral", model_path, re.I)) + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained( + model_path, + use_fast=self.use_fast_tokenizer, + trust_remote_code=True, + revision=revision, + ) + model = AutoModelForCausalLM.from_pretrained( + model_path, + trust_remote_code=True, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ).eval() + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + if "chat" in model_path.lower(): + if "vigostral" in model_path.lower(): + return get_conv_template("vigogne_chat_v3") + return get_conv_template("vigogne_chat_v2") + return get_conv_template("vigogne_instruct") + + +class OpenLLaMaOpenInstructAdapter(BaseModelAdapter): + """The model adapter for OpenLLaMa-Open-Instruct (e.g., VMware/open-llama-7b-open-instruct)""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return ( + "open-llama" in model_path.lower() and "open-instruct" in model_path.lower() + ) + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained( + model_path, + use_fast=self.use_fast_tokenizer, + trust_remote_code=True, + revision=revision, + ) + model = AutoModelForCausalLM.from_pretrained( + model_path, + trust_remote_code=True, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ).eval() + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("alpaca") + + +class CodeLlamaAdapter(BaseModelAdapter): + """The model adapter for CodeLlama (e.g., codellama/CodeLlama-34b-hf)""" + + def match(self, model_path: str): + return "codellama" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + model, tokenizer = super().load_model(model_path, from_pretrained_kwargs) + model.config.eos_token_id = tokenizer.eos_token_id + model.config.pad_token_id = tokenizer.pad_token_id + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("llama-2") + + +class StableVicunaAdapter(BaseModelAdapter): + """The model adapter for StableVicuna""" + + def match(self, model_path: str): + return "stable-vicuna" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + model, tokenizer = super().load_model(model_path, from_pretrained_kwargs) + model.config.eos_token_id = tokenizer.eos_token_id + model.config.pad_token_id = tokenizer.pad_token_id + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("stable-vicuna") + + +class PhindCodeLlamaAdapter(CodeLlamaAdapter): + """The model adapter for Phind-CodeLlama (e.g., Phind/Phind-CodeLlama-34B-v2)""" + + def match(self, model_path: str): + return "phind-codellama-" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("phind") + + +class Llama2ChangAdapter(Llama2Adapter): + """The model adapter for Llama2-ko-chang (e.g., lcw99/llama2-ko-chang-instruct-chat)""" + + def match(self, model_path: str): + return "llama2-ko-chang" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("polyglot_changgpt") + + +class ZephyrAdapter(BaseModelAdapter): + """The model adapter for Zephyr (e.g. HuggingFaceH4/zephyr-7b-alpha)""" + + def match(self, model_path: str): + return "zephyr" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("zephyr") + + +class NotusAdapter(BaseModelAdapter): + """The model adapter for Notus (e.g. argilla/notus-7b-v1)""" + + def match(self, model_path: str): + return "notus" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("zephyr") + + +class CatPPTAdapter(BaseModelAdapter): + """The model adapter for CatPPT (e.g. rishiraj/CatPPT)""" + + def match(self, model_path: str): + return "catppt" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("catppt") + + +class TinyLlamaAdapter(BaseModelAdapter): + """The model adapter for TinyLlama (e.g. TinyLlama/TinyLlama-1.1B-Chat-v1.0)""" + + def match(self, model_path: str): + return "tinyllama" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("TinyLlama") + + +class XwinLMAdapter(BaseModelAdapter): + """The model adapter for Xwin-LM V0.1 and V0.2 series of models(e.g., Xwin-LM/Xwin-LM-70B-V0.1)""" + + # use_fast_tokenizer = False + + def match(self, model_path: str): + return "xwin-lm" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("vicuna_v1.1") + + +class LemurAdapter(BaseModelAdapter): + """The model adapter for OpenLemur/lemur-70b-chat-v1""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "lemur-70b-chat" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("lemur-70b-chat") + + +class PygmalionAdapter(BaseModelAdapter): + """The model adapter for Pygmalion/Metharme series of models(e.g., PygmalionAI/mythalion-13b)""" + + # use_fast_tokenizer = False + + def match(self, model_path: str): + return bool( + re.search(r"pygmalion|mythalion|metharme", model_path.lower(), re.I) + ) + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("metharme") + + +class XdanAdapter(BaseModelAdapter): + """The model adapter for xDAN-AI (e.g. xDAN-AI/xDAN-L1-Chat-RL-v1)""" + + def match(self, model_path: str): + return "xdan" in model_path.lower() and "v1" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("xdan-v1") + + +class MicrosoftOrcaAdapter(BaseModelAdapter): + """The model adapter for Microsoft/Orca-2 series of models (e.g. Microsoft/Orca-2-7b, Microsoft/Orca-2-13b)""" + + use_fast_tokenizer = False # Flag neeeded since tokenizers>=0.13.3 is required for a normal functioning of this module + + def match(self, model_path: str): + return "orca-2" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("orca-2") + + +class YiAdapter(BaseModelAdapter): + """The model adapter for Yi models""" + + def match(self, model_path: str): + return "yi-" in model_path.lower() and "chat" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("Yi-34b-chat") + + +class DeepseekCoderAdapter(BaseModelAdapter): + """The model adapter for deepseek-ai's coder models""" + + def match(self, model_path: str): + return "deepseek-coder" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("deepseek-coder") + + +class DeepseekChatAdapter(BaseModelAdapter): + """The model adapter for deepseek-ai's chat models""" + + # Note: that this model will require tokenizer version >= 0.13.3 because the tokenizer class is LlamaTokenizerFast + + def match(self, model_path: str): + return "deepseek-llm" in model_path.lower() and "chat" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("deepseek-chat") + + +class Yuan2Adapter(BaseModelAdapter): + """The model adapter for Yuan2.0""" + + def match(self, model_path: str): + return "yuan2" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + # from_pretrained_kwargs["torch_dtype"] = torch.bfloat16 + tokenizer = LlamaTokenizer.from_pretrained( + model_path, + add_eos_token=False, + add_bos_token=False, + eos_token="", + eod_token="", + sep_token="", + revision=revision, + ) + tokenizer.add_tokens( + [ + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + ], + special_tokens=True, + ) + + model = AutoModelForCausalLM.from_pretrained( + model_path, + # device_map='auto', + trust_remote_code=True, + **from_pretrained_kwargs, + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("yuan2") + + +class MetaMathAdapter(BaseModelAdapter): + """The model adapter for MetaMath models""" + + def match(self, model_path: str): + return "metamath" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("metamath") + + +class BagelAdapter(BaseModelAdapter): + """Model adapter for jondurbin/bagel-* models""" + + def match(self, model_path: str): + return "bagel" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("airoboros_v3") + + +class SolarAdapter(BaseModelAdapter): + """The model adapter for upstage/SOLAR-10.7B-Instruct-v1.0""" + + def match(self, model_path: str): + return "solar-" in model_path.lower() and "instruct" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("solar") + + +class SteerLMAdapter(BaseModelAdapter): + """The model adapter for nvidia/Llama2-70B-SteerLM-Chat""" + + def match(self, model_path: str): + return "steerlm-chat" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("steerlm") + + +class GemmaAdapter(BaseModelAdapter): + """The model adapter for google/gemma""" + + def match(self, model_path: str): + return "gemma" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("gemma") + + +class LlavaAdapter(BaseModelAdapter): + """The model adapter for liuhaotian/llava-v1.5 series of models""" + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + # TODO(chris): Implement huggingface-compatible load_model + pass + + def match(self, model_path: str): + return "llava" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + model_path = model_path.lower() + if "34b" in model_path: + return get_conv_template("llava-chatml") + + return get_conv_template("vicuna_v1.1") + + +class YuanAdapter(BaseModelAdapter): + """The model adapter for Yuan""" + + def match(self, model_path: str): + return "yuan" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + model, tokenizer = super().load_model(model_path, from_pretrained_kwargs) + tokenizer.add_tokens( + [ + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + ], + special_tokens=True, + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("yuan") + + +class OlmoAdapter(BaseModelAdapter): + """The model adapter for allenai/OLMo-7B-Instruct""" + + def match(self, model_path: str): + return "olmo" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("api_based_default") + + +class YandexGPTAdapter(BaseModelAdapter): + """The model adapter for YandexGPT""" + + def match(self, model_path: str): + return "yandexgpt" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("yandexgpt") + + +class CllmAdapter(BaseModelAdapter): + """The model adapter for CLLM""" + + def match(self, model_path: str): + return "consistency-llm" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + config = AutoConfig.from_pretrained( + model_path, + ) + + tokenizer = AutoTokenizer.from_pretrained( + model_path, + model_max_length=2048, + padding_side="right", + ) + + model = AutoModelForCausalLM.from_pretrained( + model_path, + config=config, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + device_map="cuda", + ) + + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("cllm") + + +class CohereAdapter(BaseModelAdapter): + """The model adapter for Cohere""" + + def match(self, model_path: str): + return model_path in ["command-r"] + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + raise NotImplementedError() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("api_based_default") + + +class DBRXAdapter(BaseModelAdapter): + """The model adapter for Cohere""" + + def match(self, model_path: str): + return model_path in ["dbrx-instruct"] + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + raise NotImplementedError() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("api_based_default") + + +class RekaAdapter(BaseModelAdapter): + """The model adapter for Reka""" + + def match(self, model_path: str): + return "reka" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("api_based_default") + + +# Note: the registration order matters. +# The one registered earlier has a higher matching priority. +register_model_adapter(PeftModelAdapter) +register_model_adapter(StableVicunaAdapter) +register_model_adapter(VicunaAdapter) +register_model_adapter(AiroborosAdapter) +register_model_adapter(LongChatAdapter) +register_model_adapter(GoogleT5Adapter) +register_model_adapter(KoalaAdapter) +register_model_adapter(AlpacaAdapter) +register_model_adapter(ChatGLMAdapter) +register_model_adapter(CodeGeexAdapter) +register_model_adapter(DollyV2Adapter) +register_model_adapter(OasstPythiaAdapter) +register_model_adapter(OasstLLaMAAdapter) +register_model_adapter(OpenChat35Adapter) +register_model_adapter(TenyxChatAdapter) +register_model_adapter(StableLMAdapter) +register_model_adapter(BaizeAdapter) +register_model_adapter(RwkvAdapter) +register_model_adapter(OpenBuddyAdapter) +register_model_adapter(PhoenixAdapter) +register_model_adapter(BardAdapter) +register_model_adapter(PaLM2Adapter) +register_model_adapter(GeminiAdapter) +register_model_adapter(GeminiDevAdapter) +register_model_adapter(GemmaAdapter) +register_model_adapter(ChatGPTAdapter) +register_model_adapter(AzureOpenAIAdapter) +register_model_adapter(ClaudeAdapter) +register_model_adapter(MPTAdapter) +register_model_adapter(BiLLaAdapter) +register_model_adapter(RedPajamaINCITEAdapter) +register_model_adapter(H2OGPTAdapter) +register_model_adapter(RobinAdapter) +register_model_adapter(SnoozyAdapter) +register_model_adapter(WizardLMAdapter) +register_model_adapter(ManticoreAdapter) +register_model_adapter(GuanacoAdapter) +register_model_adapter(CamelAdapter) +register_model_adapter(ChangGPTAdapter) +register_model_adapter(TuluAdapter) +register_model_adapter(FalconChatAdapter) +register_model_adapter(FalconAdapter) +register_model_adapter(TigerBotAdapter) +register_model_adapter(BaichuanAdapter) +register_model_adapter(XGenAdapter) +register_model_adapter(PythiaAdapter) +register_model_adapter(InternLMChatAdapter) +register_model_adapter(StarChatAdapter) +register_model_adapter(Llama2Adapter) +register_model_adapter(Llama3Adapter) +register_model_adapter(CuteGPTAdapter) +register_model_adapter(OpenOrcaAdapter) +register_model_adapter(DolphinAdapter) +register_model_adapter(Hermes2Adapter) +register_model_adapter(NousHermes2MixtralAdapter) +register_model_adapter(NousHermesAdapter) +register_model_adapter(MistralAdapter) +register_model_adapter(WizardCoderAdapter) +register_model_adapter(QwenChatAdapter) +register_model_adapter(AquilaChatAdapter) +register_model_adapter(BGEAdapter) +register_model_adapter(E5Adapter) +register_model_adapter(Lamma2ChineseAdapter) +register_model_adapter(Lamma2ChineseAlpacaAdapter) +register_model_adapter(VigogneAdapter) +register_model_adapter(OpenLLaMaOpenInstructAdapter) +register_model_adapter(ReaLMAdapter) +register_model_adapter(PhindCodeLlamaAdapter) +register_model_adapter(CodeLlamaAdapter) +register_model_adapter(Llama2ChangAdapter) +register_model_adapter(ZephyrAdapter) +register_model_adapter(NotusAdapter) +register_model_adapter(CatPPTAdapter) +register_model_adapter(TinyLlamaAdapter) +register_model_adapter(XwinLMAdapter) +register_model_adapter(LemurAdapter) +register_model_adapter(PygmalionAdapter) +register_model_adapter(MicrosoftOrcaAdapter) +register_model_adapter(XdanAdapter) +register_model_adapter(YiAdapter) +register_model_adapter(PplxAIAdapter) +register_model_adapter(DeepseekCoderAdapter) +register_model_adapter(DeepseekChatAdapter) +register_model_adapter(Yuan2Adapter) +register_model_adapter(MetaMathAdapter) +register_model_adapter(BagelAdapter) +register_model_adapter(SolarAdapter) +register_model_adapter(SteerLMAdapter) +register_model_adapter(LlavaAdapter) +register_model_adapter(YuanAdapter) +register_model_adapter(OlmoAdapter) +register_model_adapter(CohereAdapter) +register_model_adapter(DBRXAdapter) +register_model_adapter(GemmaAdapter) +register_model_adapter(YandexGPTAdapter) +register_model_adapter(CllmAdapter) +register_model_adapter(RekaAdapter) +register_model_adapter(SmaugChatAdapter) + +# After all adapters, try the default base adapter. +register_model_adapter(BaseModelAdapter) diff --git a/src/model/model_chatglm.py b/src/model/model_chatglm.py new file mode 100644 index 0000000000000000000000000000000000000000..2cbac8bc5f9f5ccbee833ac9cc22cf23c068e51e --- /dev/null +++ b/src/model/model_chatglm.py @@ -0,0 +1,137 @@ +""" +Inference code for ChatGLM. +Adapted from https://huggingface.co/THUDM/chatglm-6b/blob/main/modeling_chatglm.py. +""" +import re + +import torch +from transformers.generation.logits_process import LogitsProcessor + + +class InvalidScoreLogitsProcessor(LogitsProcessor): + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor + ) -> torch.FloatTensor: + if torch.isnan(scores).any() or torch.isinf(scores).any(): + scores.zero_() + scores[..., 5] = 5e4 + return scores + + +invalid_score_processor = InvalidScoreLogitsProcessor() + + +def process_response(response): + response = response.strip() + response = response.replace("[[训练时间]]", "2023年") + punkts = [ + [",", ","], + ["!", "!"], + [":", ":"], + [";", ";"], + ["\?", "?"], + ] + for item in punkts: + response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response) + response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response) + return response + + +def recover_message_list(prompt): + role_token_pattern = "|".join( + [re.escape(r) for r in ["<|system|>", "<|user|>", "<|assistant|>"]] + ) + role = None + last_end_idx = -1 + message_list = [] + for match in re.finditer(role_token_pattern, prompt): + if role: + messge = {} + if role == "<|system|>": + messge["role"] = "system" + elif role == "<|user|>": + messge["role"] = "user" + else: + messge["role"] = "assistant" + messge["content"] = prompt[last_end_idx + 1 : match.start()] + message_list.append(messge) + + role = prompt[match.start() : match.end()] + last_end_idx = match.end() + + return message_list + + +@torch.inference_mode() +def generate_stream_chatglm( + model, + tokenizer, + params, + device, + context_len=2048, + stream_interval=2, + judge_sent_end=False, +): + prompt = params["prompt"] + temperature = float(params.get("temperature", 1.0)) + repetition_penalty = float(params.get("repetition_penalty", 1.0)) + top_p = float(params.get("top_p", 1.0)) + max_new_tokens = int(params.get("max_new_tokens", 256)) + echo = params.get("echo", True) + + model_type = str(type(model)).lower() + if "peft" in model_type: + model_type = str(type(model.base_model.model)).lower() + + if "chatglm3" in model_type: + message_list = recover_message_list(prompt) + inputs = tokenizer.build_chat_input( + query=message_list[-1]["content"], history=message_list[:-1], role="user" + ).to(model.device) + else: + inputs = tokenizer([prompt], return_tensors="pt").to(model.device) + input_echo_len = len(inputs["input_ids"][0]) + + gen_kwargs = { + "max_length": max_new_tokens + input_echo_len, + "do_sample": True if temperature > 1e-5 else False, + "top_p": top_p, + "repetition_penalty": repetition_penalty, + "logits_processor": [invalid_score_processor], + } + if temperature > 1e-5: + gen_kwargs["temperature"] = temperature + + total_len = 0 + for total_ids in model.stream_generate(**inputs, **gen_kwargs): + total_ids = total_ids.tolist()[0] + total_len = len(total_ids) + if echo: + output_ids = total_ids + else: + output_ids = total_ids[input_echo_len:] + response = tokenizer.decode(output_ids) + response = process_response(response) + + yield { + "text": response, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": total_len - input_echo_len, + "total_tokens": total_len, + }, + "finish_reason": None, + } + + # TODO: ChatGLM stop when it reach max length + # Only last stream result contains finish_reason, we set finish_reason as stop + ret = { + "text": response, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": total_len - input_echo_len, + "total_tokens": total_len, + }, + "finish_reason": "stop", + } + yield ret diff --git a/src/model/model_cllm.py b/src/model/model_cllm.py new file mode 100644 index 0000000000000000000000000000000000000000..563e2a5598233788a4a165eb41371fd7ab729f62 --- /dev/null +++ b/src/model/model_cllm.py @@ -0,0 +1,202 @@ +import torch +import gc + +import os +import time +import random +from typing import Dict, Optional, Sequence, List, Tuple +from transformers.cache_utils import Cache, DynamicCache +from transformers import ( + LlamaModel, + LlamaForCausalLM, + GenerationConfig, + StoppingCriteria, + StoppingCriteriaList, + TextIteratorStreamer, +) +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask +import torch.nn.functional as F + + +def get_jacobian_trajectory( + model, tokenizer, input_ids, attention_mask, max_new_tokens +): + bsz = input_ids.shape[0] + prompt_len = [torch.sum(t) for t in attention_mask] + max_prompt_len = max(prompt_len) + total_len = max_prompt_len + max_new_tokens + + # initialize the first point of jacobian trajectory + tokens = torch.full( + (bsz, total_len), tokenizer.pad_token_id, dtype=torch.long, device=model.device + ) + for i in range(bsz): + tokens[i, :] = torch.tensor( + random.choices(input_ids[i][attention_mask[i] == 1], k=total_len), + dtype=torch.long, + device=model.device, + ) + tokens[i, : prompt_len[i]] = input_ids[i][: prompt_len[i]].to( + dtype=torch.long, device=model.device + ) + itr = 0 + next_generation = tokens + generate_attention_mask = torch.full_like(next_generation, 1).to(model.device) + accurate_lengths = torch.tensor([prompt_len[i].item()] * bsz, device=model.device) + prev_len = 0 + while True: + current_generation = next_generation + with torch.no_grad(): + logits = model(current_generation, generate_attention_mask).logits + next_generation = torch.argmax( + torch.nn.functional.softmax(logits, dim=-1) / 0.001, dim=-1 + ) + + # hold prompt unchanged and update generated tokens + for i in range(bsz): + next_generation[i, :] = torch.cat( + ( + tokens[i, : prompt_len[i]], + next_generation[i, prompt_len[i] - 1 : total_len - 1], + ), + dim=0, + ) + + if ( + torch.all(torch.eq(next_generation, current_generation)).item() + and itr == max_new_tokens + or len( + torch.where( + current_generation[0, : accurate_lengths[0]] + == tokenizer.eos_token_id + )[0] + ) + > 0 + ): + # forced exit due to max_new_tokens constraint or eos reached + return next_generation, itr + + # skip the first itr, current_generation has not been updated yet + if itr != 0: + if torch.all(torch.eq(next_generation, current_generation)).item(): + matched_position = total_len + else: + matched_position = ( + torch.eq(current_generation, next_generation).squeeze(0) == False + ).nonzero(as_tuple=True)[0][0] + fast_forward_cnt = matched_position - accurate_lengths[0] + + for i in range(bsz): + accurate_lengths[i] = matched_position.item() + + # flush and print the first sequence + generated_str = tokenizer.decode( + next_generation[0, prompt_len[0] : accurate_lengths[0]], + skip_special_tokens=True, + spaces_between_special_tokens=False, + clean_up_tokenization_spaces=True, + ) + print(generated_str[prev_len:], flush=True, end="") + prev_len = len(generated_str) + + if torch.all(torch.eq(next_generation, current_generation)).item(): + # early termination: itr < max_new_tokens + return next_generation, itr + + itr += 1 + + +def generate_stream_cllm( + model, + tokenizer, + params, + device, + context_len, + stream_interval=2, + judge_sent_end=False, +): + # converge_step = [] + prompt = params["prompt"] + inputs = tokenizer(prompt, return_tensors="pt").to(device) + max_new_tokens = int(params.get("n_token_seq_length", 32)) + max_new_seq_len = int(params.get("max_new_tokens", 1024)) + + prompt_len = torch.sum(inputs["attention_mask"], dim=-1) + generation = inputs["input_ids"] + input_echo_len = len(generation) + + ### generation phase + itr = 0 + eos_reached = False + while True: + if itr == 0: + input_ids = inputs["input_ids"] + input_masks = inputs["attention_mask"] + else: + input_masks = torch.ones_like(input_ids).to(device) + for j in range(bsz): + input_masks[j][ + torch.sum(inputs["attention_mask"], dim=-1)[j] + + itr * max_new_tokens : + ] = 0 + + bsz = input_ids.shape[0] + eos_reached = torch.tensor([False] * bsz, device=device) + + generation, iter_steps = get_jacobian_trajectory( + model=model, + tokenizer=tokenizer, + input_ids=input_ids, + attention_mask=input_masks, + max_new_tokens=max_new_tokens, + ) + + ### inspect + for j in range(bsz): + prompt_len = torch.sum(input_masks, dim=-1) + eos_positions = torch.where(generation[j] == tokenizer.eos_token_id)[0] + + if len(eos_positions) == 0: + # no EOS, continue to the next item in the batch + generation[j][prompt_len[j] + max_new_tokens :] = tokenizer.pad_token_id + continue + # otherwise, set tokens coming after EOS as pad + else: + if len(eos_positions) != 0: + eos_reached[j] = True + generation[j, int(eos_positions[0]) + 1 :] = tokenizer.pad_token_id + + itr += 1 + + if all(eos_reached) or itr * max_new_tokens >= max_new_seq_len: + break + input_ids = generation[ + torch.where(eos_reached == False)[0].tolist(), ... + ] # delete samples with generated + + if all(eos_reached): + finish_reason = "eos" + elif itr * max_new_tokens > max_new_seq_len: + finish_reason = "length" + else: + finish_reason = "stop" + + output = tokenizer.decode(input_ids[0], skip_special_tokens=False) + + yield { + "text": "", + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": itr * max_new_tokens, + "total_tokens": input_echo_len + itr * max_new_tokens, + }, + "finish_reason": finish_reason, + } + + # clean + gc.collect() + torch.cuda.empty_cache() + if device == "xpu": + torch.xpu.empty_cache() + if device == "npu": + torch.npu.empty_cache() diff --git a/src/model/model_codet5p.py b/src/model/model_codet5p.py new file mode 100644 index 0000000000000000000000000000000000000000..0984513c96931b6d48dfd17f3020fe5cebc3f911 --- /dev/null +++ b/src/model/model_codet5p.py @@ -0,0 +1,108 @@ +import gc +from threading import Thread +import torch +import transformers +from transformers import ( + GenerationConfig, + StoppingCriteria, + StoppingCriteriaList, + TextIteratorStreamer, +) + + +@torch.inference_mode() +def generate_stream_codet5p( + model, + tokenizer, + params, + device, + context_len=2048, + stream_interval=2, + judge_sent_end=False, +): + prompt = params["prompt"] + temperature = float(params.get("temperature", 1.0)) + repetition_penalty = float(params.get("repetition_penalty", 1.0)) + top_p = float(params.get("top_p", 1.0)) + top_k = int(params.get("top_k", 50)) # -1 means disable + max_new_tokens = int(params.get("max_new_tokens", 1024)) + stop_token_ids = params.get("stop_token_ids", None) or [] + stop_token_ids.append(tokenizer.eos_token_id) + + decode_config = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True) + streamer = TextIteratorStreamer(tokenizer, **decode_config) + encoding = tokenizer(prompt, return_tensors="pt").to(device) + input_ids = encoding.input_ids + encoding["decoder_input_ids"] = encoding["input_ids"].clone() + input_echo_len = len(input_ids) + + generation_config = GenerationConfig( + max_new_tokens=max_new_tokens, + do_sample=temperature >= 1e-5, + temperature=temperature, + repetition_penalty=repetition_penalty, + no_repeat_ngram_size=10, + top_p=top_p, + top_k=top_k, + eos_token_id=stop_token_ids, + ) + + class CodeBlockStopper(StoppingCriteria): + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs + ) -> bool: + # Code-completion is open-end generation. + # We check \n\n to stop at end of a code block. + if list(input_ids[0][-2:]) == [628, 198]: + return True + return False + + gen_kwargs = dict( + **encoding, + streamer=streamer, + generation_config=generation_config, + stopping_criteria=StoppingCriteriaList([CodeBlockStopper()]), + ) + thread = Thread(target=model.generate, kwargs=gen_kwargs) + thread.start() + i = 0 + output = "" + for new_text in streamer: + i += 1 + output += new_text + if i % stream_interval == 0 or i == max_new_tokens - 1: + yield { + "text": output, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": i, + "total_tokens": input_echo_len + i, + }, + "finish_reason": None, + } + if i >= max_new_tokens: + break + + if i >= max_new_tokens: + finish_reason = "length" + else: + finish_reason = "stop" + + yield { + "text": output, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": i, + "total_tokens": input_echo_len + i, + }, + "finish_reason": finish_reason, + } + thread.join() + + # clean + gc.collect() + torch.cuda.empty_cache() + if device == "xpu": + torch.xpu.empty_cache() + if device == "npu": + torch.npu.empty_cache() diff --git a/src/model/model_exllama.py b/src/model/model_exllama.py new file mode 100644 index 0000000000000000000000000000000000000000..306edab21a79658d22eb75f1da3eba1f830e4ae7 --- /dev/null +++ b/src/model/model_exllama.py @@ -0,0 +1,77 @@ +import gc +import sys +from typing import Dict + +import torch + + +def generate_stream_exllama( + model, + tokenizer, + params: Dict, + device: str, + context_len: int, + stream_interval: int = 2, + judge_sent_end: bool = False, +): + try: + from exllamav2.generator import ExLlamaV2StreamingGenerator, ExLlamaV2Sampler + except ImportError as e: + print(f"Error: Failed to load Exllamav2. {e}") + sys.exit(-1) + + prompt = params["prompt"] + + generator = ExLlamaV2StreamingGenerator(model.model, model.cache, tokenizer) + settings = ExLlamaV2Sampler.Settings() + + settings.temperature = float(params.get("temperature", 0.85)) + settings.top_k = int(params.get("top_k", 50)) + settings.top_p = float(params.get("top_p", 0.8)) + settings.token_repetition_penalty = float(params.get("repetition_penalty", 1.15)) + settings.disallow_tokens(generator.tokenizer, [generator.tokenizer.eos_token_id]) + + max_new_tokens = int(params.get("max_new_tokens", 256)) + + generator.set_stop_conditions(params.get("stop_token_ids", None) or []) + echo = bool(params.get("echo", True)) + + input_ids = generator.tokenizer.encode(prompt) + prompt_tokens = input_ids.shape[-1] + generator.begin_stream(input_ids, settings) + + generated_tokens = 0 + if echo: + output = prompt + else: + output = "" + while True: + chunk, eos, _ = generator.stream() + output += chunk + generated_tokens += 1 + if generated_tokens == max_new_tokens: + finish_reason = "length" + break + elif eos: + finish_reason = "length" + break + yield { + "text": output, + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": generated_tokens, + "total_tokens": prompt_tokens + generated_tokens, + }, + "finish_reason": None, + } + + yield { + "text": output, + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": generated_tokens, + "total_tokens": prompt_tokens + generated_tokens, + }, + "finish_reason": finish_reason, + } + gc.collect() diff --git a/src/model/model_falcon.py b/src/model/model_falcon.py new file mode 100644 index 0000000000000000000000000000000000000000..dc8af8efa20bd29fb31cdd0a0bc039b30f4bf26e --- /dev/null +++ b/src/model/model_falcon.py @@ -0,0 +1,140 @@ +import gc +from threading import Thread +from typing import Iterable + +import torch +import transformers +from transformers import TextIteratorStreamer, GenerationConfig + +from fastchat.utils import is_partial_stop + + +@torch.inference_mode() +def generate_stream_falcon( + model, + tokenizer, + params, + device, + context_len=2048, + stream_interval=2, + judge_sent_end=False, +): + prompt = params["prompt"] + len_prompt = len(prompt) + temperature = float(params.get("temperature", 1.0)) + repetition_penalty = float(params.get("repetition_penalty", 1.0)) + top_p = float(params.get("top_p", 1.0)) + top_k = int(params.get("top_k", 50)) # -1 means disable + max_new_tokens = int(params.get("max_new_tokens", 256)) + stop_str = params.get("stop", None) + echo = bool(params.get("echo", True)) + stop_token_ids = params.get("stop_token_ids", None) or [] + stop_token_ids.append(tokenizer.eos_token_id) + + inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + input_ids = inputs["input_ids"] + attention_mask = inputs["attention_mask"] + + max_src_len = context_len - max_new_tokens - 8 + + input_ids = input_ids[-max_src_len:] # truncate from the left + attention_mask = attention_mask[-max_src_len:] # truncate from the left + input_echo_len = len(input_ids) + + decode_config = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True) + streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, **decode_config) + + generation_config = GenerationConfig( + max_new_tokens=max_new_tokens, + do_sample=temperature >= 1e-5, + temperature=temperature, + repetition_penalty=repetition_penalty, + no_repeat_ngram_size=10, + top_p=top_p, + top_k=top_k, + eos_token_id=stop_token_ids, + ) + + generation_kwargs = dict( + inputs=input_ids, + attention_mask=attention_mask, + streamer=streamer, + generation_config=generation_config, + ) + + thread = Thread(target=model.generate, kwargs=generation_kwargs) + thread.start() + + if echo: + # means keep the prompt + output = prompt + else: + output = "" + + for i, new_text in enumerate(streamer): + output += new_text + if i % stream_interval == 0: + if echo: + rfind_start = len_prompt + else: + rfind_start = 0 + + partially_stopped = False + if stop_str: + if isinstance(stop_str, str): + pos = output.rfind(stop_str, rfind_start) + if pos != -1: + output = output[:pos] + else: + partially_stopped = is_partial_stop(output, stop_str) + elif isinstance(stop_str, Iterable): + for each_stop in stop_str: + pos = output.rfind(each_stop, rfind_start) + if pos != -1: + output = output[:pos] + break + else: + partially_stopped = is_partial_stop(output, each_stop) + if partially_stopped: + break + else: + raise ValueError("Invalid stop field type.") + + # prevent yielding partial stop sequence + if not partially_stopped: + yield { + "text": output, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": i, + "total_tokens": input_echo_len + i, + }, + "finish_reason": None, + } + output = output.strip() + + # finish stream event, which contains finish reason + if i == max_new_tokens - 1: + finish_reason = "length" + elif partially_stopped: + finish_reason = None + else: + finish_reason = "stop" + + yield { + "text": output, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": i, + "total_tokens": input_echo_len + i, + }, + "finish_reason": finish_reason, + } + + # clean + gc.collect() + torch.cuda.empty_cache() + if device == "xpu": + torch.xpu.empty_cache() + if device == "npu": + torch.npu.empty_cache() diff --git a/src/model/model_registry.py b/src/model/model_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..2481dbe8ff4d20abb96fc6ac5632ace9dc145c8f --- /dev/null +++ b/src/model/model_registry.py @@ -0,0 +1,764 @@ +"""Additional information of the models.""" +from collections import namedtuple, OrderedDict +from typing import List + + +ModelInfo = namedtuple("ModelInfo", ["simple_name", "link", "description"]) + + +model_info = OrderedDict() + + +def register_model_info( + full_names: List[str], simple_name: str, link: str, description: str +): + info = ModelInfo(simple_name, link, description) + + for full_name in full_names: + model_info[full_name] = info + + +def get_model_info(name: str) -> ModelInfo: + if name in model_info: + return model_info[name] + else: + # To fix this, please use `register_model_info` to register your model + return ModelInfo( + name, "", "Register the description at fastchat/model/model_registry.py" + ) + + +register_model_info( + [ + "IEITYuan/Yuan2-2B-Janus-hf", + "IEITYuan/Yuan2-2B-hf", + "IEITYuan/Yuan2-51B-hf", + "IEITYuan/Yuan2-102B-hf", + ], + "IEIT-Yuan2", + "https://github.com/IEIT-Yuan/Yuan-2.0", + "Yuan2.0 is a new generation Fundamental Large Language Model developed by IEIT System.", +) + +register_model_info( + [ + "claude-3-haiku-20240307", + "claude-3-sonnet-20240229", + "claude-3-opus-20240229", + "claude-2.1", + "claude-2.0", + "claude-1", + ], + "Claude", + "https://www.anthropic.com/news/claude-3-family", + "Claude by Anthropic", +) + +register_model_info( + ["reka-flash", "reka-flash-online"], + "Reka Flash", + "https://www.reka.ai/news/reka-flash-efficient-and-capable-multimodal-language-models", + "Multimodal model by Reka", +) + +register_model_info( + ["command-r-plus"], + "Command-R-Plus", + "https://txt.cohere.com/command-r-plus-microsoft-azure/", + "Command-R Plus by Cohere", +) + +register_model_info( + ["command-r"], + "Command-R", + "https://txt.cohere.com/command-r/", + "Command-R by Cohere", +) + +register_model_info( + [ + "zephyr-orpo-141b-A35b-v0.1", + ], + "Zephyr 141B-A35B", + "https://huggingface.co/HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1", + "ORPO fine-tuned of Mixtral-8x22B-v0.1", +) + +register_model_info( + ["gemma-1.1-7b-it", "gemma-1.1-2b-it", "gemma-7b-it", "gemma-2b-it"], + "Gemma", + "https://blog.google/technology/developers/gemma-open-models/", + "Gemma by Google", +) + +register_model_info( + [ + "mixtral-8x7b-instruct-v0.1", + "mistral-large-2402", + "mistral-medium", + "mistral-next", + "mistral-7b-instruct-v0.2", + "mistral-7b-instruct", + ], + "Mixtral of experts", + "https://mistral.ai/news/mixtral-of-experts/", + "A Mixture-of-Experts model by Mistral AI", +) + +register_model_info( + [ + "qwen1.5-72b-chat", + "qwen1.5-32b-chat", + "qwen1.5-14b-chat", + "qwen1.5-7b-chat", + "qwen1.5-4b-chat", + "qwen1.5-1.8b-chat", + "qwen1.5-0.5b-chat", + "qwen-14b-chat", + ], + "Qwen 1.5", + "https://qwenlm.github.io/blog/qwen1.5/", + "A large language model by Alibaba Cloud", +) + + +register_model_info( + ["dbrx-instruct"], + "DBRX Instruct", + "https://www.databricks.com/blog/introducing-dbrx-new-state-art-open-llm", + "DBRX by Databricks Mosaic AI", +) + +register_model_info( + ["starling-lm-7b-beta", "starling-lm-7b-alpha"], + "Starling-LM-7B", + "https://starling.cs.berkeley.edu/", + "An open model trained using RLAIF by Berkeley", +) + +register_model_info( + ["qwen-14b-chat"], + "Qwen", + "https://huggingface.co/Qwen", + "A large language model by Alibaba Cloud", +) + +register_model_info( + ["bard-feb-2024", "bard-jan-24-gemini-pro"], + "Bard", + "https://bard.google.com/", + "Bard by Google", +) + +register_model_info( + [ + "gemini-pro", + "gemini-pro-dev-api", + "gemini-1.0-pro-vision", + "gemini-1.5-pro-preview-0409", + ], + "Gemini", + "https://blog.google/technology/ai/google-gemini-pro-imagen-duet-ai-update/", + "Gemini by Google", +) + +register_model_info( + ["stripedhyena-nous-7b"], + "StripedHyena-Nous", + "https://huggingface.co/togethercomputer/StripedHyena-Nous-7B", + "A chat model developed by Together Research and Nous Research.", +) + +register_model_info( + ["solar-10.7b-instruct-v1.0"], + "SOLAR-10.7B-Instruct", + "https://huggingface.co/upstage/SOLAR-10.7B-Instruct-v1.0", + "A model trained using depth up-scaling by Upstage AI", +) + +register_model_info( + [ + "gpt-4-turbo", + "gpt-4-turbo-2024-04-09", + "gpt-4-1106-preview", + "gpt-4-0125-preview", + ], + "GPT-4-Turbo", + "https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo", + "GPT-4-Turbo by OpenAI", +) + +register_model_info( + ["gpt-4-turbo-browsing"], + "GPT-4-Turbo with browsing", + "https://platform.openai.com/docs/assistants/overview", + "GPT-4-Turbo with browsing by OpenAI", +) + +register_model_info( + [ + "gpt-3.5-turbo", + "gpt-3.5-turbo-0125", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo-0314", + "gpt-3.5-turbo-0613", + ], + "GPT-3.5", + "https://platform.openai.com/docs/models/gpt-3-5", + "GPT-3.5-Turbo by OpenAI", +) + +register_model_info( + ["gpt-4", "gpt-4-0314", "gpt-4-0613"], + "GPT-4", + "https://openai.com/research/gpt-4", + "GPT-4 by OpenAI", +) + +register_model_info( + ["claude-instant-1", "claude-instant-1.2"], + "Claude Instant", + "https://www.anthropic.com/index/introducing-claude", + "Claude Instant by Anthropic", +) + +register_model_info( + ["llama-2-70b-chat", "llama-2-34b-chat", "llama-2-13b-chat", "llama-2-7b-chat"], + "Llama 2", + "https://ai.meta.com/llama/", + "Open foundation and fine-tuned chat models by Meta", +) + +register_model_info( + ["olmo-7b-instruct"], + "OLMo-7B", + "https://huggingface.co/allenai/OLMo-7B-Instruct", + "OLMo by Allen AI", +) + +register_model_info( + [ + "vicuna-33b", + "vicuna-33b-v1.3", + "vicuna-13b", + "vicuna-13b-v1.5", + "vicuna-7b", + "vicuna-7b-v1.5", + ], + "Vicuna", + "https://lmsys.org/blog/2023-03-30-vicuna/", + "A chat assistant fine-tuned on user-shared conversations by LMSYS", +) + +register_model_info( + ["yi-34b-chat", "yi-6b-chat"], + "Yi-Chat", + "https://huggingface.co/01-ai/Yi-34B-Chat", + "A large language model by 01 AI", +) + +register_model_info( + [ + "codellama-70b-instruct", + "codellama-34b-instruct", + "codellama-13b-instruct", + "codellama-7b-instruct", + ], + "Code Llama", + "https://ai.meta.com/blog/code-llama-large-language-model-coding/", + "Open foundation models for code by Meta", +) + +register_model_info( + ["openchat-3.5-0106", "openchat-3.5"], + "OpenChat 3.5", + "https://github.com/imoneoi/openchat", + "An open model fine-tuned on Mistral-7B using C-RLFT", +) + +register_model_info( + ["deepseek-llm-67b-chat"], + "DeepSeek LLM", + "https://huggingface.co/deepseek-ai/deepseek-llm-67b-chat", + "An advanced language model by DeepSeek", +) + +register_model_info( + ["stripedhyena-nous-7b"], + "StripedHyena-Nous", + "https://huggingface.co/togethercomputer/StripedHyena-Nous-7B", + "A chat model developed by Together Research and Nous Research.", +) + +register_model_info( + ["nous-hermes-2-mixtral-8x7b-dpo"], + "Nous-Hermes-2-Mixtral-8x7B-DPO", + "https://huggingface.co/NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO", + "Nous Hermes finetuned from Mixtral 8x7B", +) + + +register_model_info( + ["llama2-70b-steerlm-chat"], + "Llama2-70B-SteerLM-Chat", + "https://huggingface.co/nvidia/Llama2-70B-SteerLM-Chat", + "A Llama fine-tuned with SteerLM method by NVIDIA", +) + +register_model_info( + ["pplx-70b-online", "pplx-7b-online"], + "pplx-online-llms", + "https://blog.perplexity.ai/blog/introducing-pplx-online-llms", + "Online LLM API by Perplexity AI", +) + +register_model_info( + ["openhermes-2.5-mistral-7b"], + "OpenHermes-2.5-Mistral-7B", + "https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B", + "A mistral-based model fine-tuned on 1M GPT-4 outputs", +) + +register_model_info( + ["tulu-2-dpo-70b"], + "Tulu 2", + "https://huggingface.co/allenai/tulu-2-dpo-70b", + "An instruction and RLHF model by UW/AllenAI", +) + +register_model_info( + ["chatglm3-6b", "chatglm2-6b", "chatglm-6b"], + "ChatGLM", + "https://chatglm.cn/blog", + "An open bilingual dialogue language model by Tsinghua University", +) + +register_model_info( + ["tenyxchat-7b-v1"], + "TenyxChat-7B", + "https://huggingface.co/tenyx/TenyxChat-7B-v1", + "An open model DPO trained on top of OpenChat-3.5 using Tenyx fine-tuning", +) + +register_model_info( + ["zephyr-7b-beta", "zephyr-7b-alpha"], + "Zephyr", + "https://huggingface.co/HuggingFaceH4/zephyr-7b-alpha", + "A chatbot fine-tuned from Mistral by Hugging Face", +) + +register_model_info( + ["notus-7b-v1"], + "Notus", + "https://huggingface.co/argilla/notus-7b-v1", + "A chatbot fine-tuned from Zephyr SFT by Argilla", +) + +register_model_info( + ["catppt"], + "CatPPT", + "https://huggingface.co/rishiraj/CatPPT", + "A chatbot fine-tuned from a SLERP merged model by Rishiraj Acharya", +) + +register_model_info( + ["TinyLlama"], + "TinyLlama", + "https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "The TinyLlama project is an open endeavor to pretrain a 1.1B Llama model on 3 trillion tokens.", +) + +register_model_info( + ["wizardlm-70b", "wizardlm-30b", "wizardlm-13b"], + "WizardLM", + "https://github.com/nlpxucan/WizardLM", + "An instruction-following LLM using evol-instruct by Microsoft", +) + +register_model_info( + ["wizardcoder-15b-v1.0"], + "WizardLM", + "https://github.com/nlpxucan/WizardLM/tree/main/WizardCoder", + "Empowering Code Large Language Models with Evol-Instruct", +) + +register_model_info( + ["mpt-7b-chat", "mpt-30b-chat"], + "MPT-Chat", + "https://www.mosaicml.com/blog/mpt-30b", + "A chatbot fine-tuned from MPT by MosaicML", +) + +register_model_info( + ["guanaco-33b", "guanaco-65b"], + "Guanaco", + "https://github.com/artidoro/qlora", + "A model fine-tuned with QLoRA by UW", +) + +register_model_info( + ["gpt4all-13b-snoozy"], + "GPT4All-Snoozy", + "https://github.com/nomic-ai/gpt4all", + "A finetuned LLaMA model on assistant style data by Nomic AI", +) + +register_model_info( + ["koala-13b"], + "Koala", + "https://bair.berkeley.edu/blog/2023/04/03/koala", + "A dialogue model for academic research by BAIR", +) + +register_model_info( + ["RWKV-4-Raven-14B"], + "RWKV-4-Raven", + "https://huggingface.co/BlinkDL/rwkv-4-raven", + "An RNN with transformer-level LLM performance", +) + +register_model_info( + ["alpaca-13b"], + "Alpaca", + "https://crfm.stanford.edu/2023/03/13/alpaca.html", + "A model fine-tuned from LLaMA on instruction-following demonstrations by Stanford", +) + +register_model_info( + ["oasst-pythia-12b"], + "OpenAssistant (oasst)", + "https://open-assistant.io", + "An Open Assistant for everyone by LAION", +) + +register_model_info( + ["oasst-sft-7-llama-30b"], + "OpenAssistant (oasst)", + "https://open-assistant.io", + "An Open Assistant for everyone by LAION", +) + +register_model_info( + ["palm-2"], + "PaLM 2 Chat", + "https://cloud.google.com/vertex-ai/docs/release-notes#May_10_2023", + "PaLM 2 for Chat (chat-bison@001) by Google", +) + +register_model_info( + ["llama-7b", "llama-13b"], + "LLaMA", + "https://arxiv.org/abs/2302.13971", + "Open and efficient foundation language models by Meta", +) + +register_model_info( + ["open-llama-7b-v2-open-instruct", "open-llama-7b-open-instruct"], + "Open LLaMa (Open Instruct)", + "https://medium.com/vmware-data-ml-blog/starter-llm-for-the-enterprise-instruction-tuning-openllama-7b-d05fc3bbaccc", + "Open LLaMa fine-tuned on instruction-following data by VMware", +) + +register_model_info( + ["dolly-v2-12b"], + "Dolly", + "https://www.databricks.com/blog/2023/04/12/dolly-first-open-commercially-viable-instruction-tuned-llm", + "An instruction-tuned open large language model by Databricks", +) + +register_model_info( + ["stablelm-tuned-alpha-7b"], + "StableLM", + "https://github.com/stability-AI/stableLM", + "Stability AI language models", +) + +register_model_info( + ["codet5p-6b"], + "CodeT5p-6b", + "https://huggingface.co/Salesforce/codet5p-6b", + "Code completion model released by Salesforce", +) + +register_model_info( + ["fastchat-t5-3b", "fastchat-t5-3b-v1.0"], + "FastChat-T5", + "https://huggingface.co/lmsys/fastchat-t5-3b-v1.0", + "A chat assistant fine-tuned from FLAN-T5 by LMSYS", +) + +register_model_info( + ["phoenix-inst-chat-7b"], + "Phoenix-7B", + "https://huggingface.co/FreedomIntelligence/phoenix-inst-chat-7b", + "A multilingual chat assistant fine-tuned from Bloomz to democratize ChatGPT across languages by CUHK(SZ)", +) + +register_model_info( + ["realm-7b-v1"], + "ReaLM", + "https://github.com/FreedomIntelligence/ReaLM", + "A chatbot fine-tuned from LLaMA2 with data generated via iterative calls to UserGPT and ChatGPT by CUHK(SZ) and SRIBD.", +) + +register_model_info( + ["billa-7b-sft"], + "BiLLa-7B-SFT", + "https://huggingface.co/Neutralzz/BiLLa-7B-SFT", + "An instruction-tuned bilingual LLaMA with enhanced reasoning ability by an independent researcher", +) + +register_model_info( + ["h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2"], + "h2oGPT-GM-7b", + "https://huggingface.co/h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2", + "An instruction-tuned OpenLLaMA with enhanced conversational ability by H2O.ai", +) + +register_model_info( + ["baize-v2-7b", "baize-v2-13b"], + "Baize v2", + "https://github.com/project-baize/baize-chatbot#v2", + "A chatbot fine-tuned from LLaMA with ChatGPT self-chat data and Self-Disillation with Feedback (SDF) by UCSD and SYSU.", +) + +register_model_info( + [ + "airoboros-l2-7b-2.1", + "airoboros-l2-13b-2.1", + "airoboros-c34b-2.1", + "airoboros-l2-70b-2.1", + ], + "airoboros", + "https://huggingface.co/jondurbin/airoboros-l2-70b-2.1", + "An instruction-tuned LlaMa model tuned with 100% synthetic instruction-response pairs from GPT4", +) + +register_model_info( + [ + "spicyboros-7b-2.2", + "spicyboros-13b-2.2", + "spicyboros-70b-2.2", + ], + "spicyboros", + "https://huggingface.co/jondurbin/spicyboros-70b-2.2", + "De-aligned versions of the airoboros models", +) + +register_model_info( + ["Robin-7b-v2", "Robin-13b-v2", "Robin-33b-v2"], + "Robin-v2", + "https://huggingface.co/OptimalScale/robin-7b-v2-delta", + "A chatbot fine-tuned from LLaMA-7b, achieving competitive performance on chitchat, commonsense reasoning and instruction-following tasks, by OptimalScale, HKUST.", +) + +register_model_info( + ["manticore-13b-chat"], + "Manticore 13B Chat", + "https://huggingface.co/openaccess-ai-collective/manticore-13b-chat-pyg", + "A chatbot fine-tuned from LlaMa across several CoT and chat datasets.", +) + +register_model_info( + ["redpajama-incite-7b-chat"], + "RedPajama-INCITE-7B-Chat", + "https://huggingface.co/togethercomputer/RedPajama-INCITE-7B-Chat", + "A chatbot fine-tuned from RedPajama-INCITE-7B-Base by Together", +) + +register_model_info( + [ + "falcon-7b", + "falcon-7b-instruct", + "falcon-40b", + "falcon-40b-instruct", + "falcon-180b", + "falcon-180b-chat", + ], + "Falcon", + "https://huggingface.co/tiiuae/falcon-180B", + "TII's flagship series of large language models", +) + +register_model_info( + ["tigerbot-7b-sft"], + "Tigerbot", + "https://huggingface.co/TigerResearch/tigerbot-7b-sft", + "A large-scale language model (LLM) with multiple languages and tasks.", +) + +register_model_info( + ["internlm-chat-7b", "internlm-chat-7b-8k"], + "InternLM", + "https://huggingface.co/internlm/internlm-chat-7b", + "A multi-language large-scale language model (LLM), developed by SHLAB.", +) + +register_model_info( + ["Qwen-7B-Chat"], + "Qwen", + "https://huggingface.co/Qwen/Qwen-7B-Chat", + "A multi-language large-scale language model (LLM), developed by Damo Academy.", +) + +register_model_info( + ["smaug-2-72b"], + "Smaug-2-72B", + "https://huggingface.co/abacusai/Smaug-2-72B", + "An open model trained by Abacus.AI.", +) + +register_model_info( + ["Llama2-Chinese-13b-Chat", "LLama2-Chinese-13B"], + "Llama2-Chinese", + "https://huggingface.co/FlagAlpha/Llama2-Chinese-13b-Chat", + "A multi-language large-scale language model (LLM), developed by FlagAlpha.", +) + +register_model_info( + ["Meta-Llama-3-8B-Instruct", "Meta-Llama-3-70B-Instruct"], + "llama-3", + "https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct", + "Meta developed and released the Meta Llama 3 family of large language models (LLMs), a collection of pretrained and instruction tuned generative text models in 8 and 70B sizes.", +) + +register_model_info( + ["Chinese-Alpaca-2-7B", "Chinese-Alpaca-2-13B"], + "Chinese-Alpaca", + "https://huggingface.co/hfl/chinese-alpaca-2-13b", + "New extended Chinese vocabulary beyond Llama-2, open-sourcing the Chinese LLaMA-2 and Alpaca-2 LLMs.", +) + +register_model_info( + ["Vigogne-2-7B-Instruct", "Vigogne-2-13B-Instruct"], + "Vigogne-Instruct", + "https://huggingface.co/bofenghuang/vigogne-2-7b-instruct", + "A French large language model (LLM) optimized for instruction-following, developed by Bofeng Huang", +) + +register_model_info( + ["Vigogne-2-7B-Chat", "Vigogne-2-13B-Chat"], + "Vigogne-Chat", + "https://huggingface.co/bofenghuang/vigogne-2-7b-chat", + "A French large language model (LLM) optimized for instruction-following and multi-turn dialogues, developed by Bofeng Huang", +) + +register_model_info( + ["stable-vicuna-13B-HF"], + "stable-vicuna", + "https://huggingface.co/TheBloke/stable-vicuna-13B-HF", + "A Vicuna model fine-tuned using RLHF via PPO on various conversational and instructional datasets.", +) + +register_model_info( + ["deluxe-chat-v1", "deluxe-chat-v1.1", "deluxe-chat-v1.2", "deluxe-chat-v1.3"], + "DeluxeChat", + "", + "Deluxe Chat", +) + +register_model_info( + [ + "Xwin-LM-7B-V0.1", + "Xwin-LM-13B-V0.1", + "Xwin-LM-70B-V0.1", + "Xwin-LM-7B-V0.2", + "Xwin-LM-13B-V0.2", + ], + "Xwin-LM", + "https://github.com/Xwin-LM/Xwin-LM", + "Chat models developed by Xwin-LM team", +) + +register_model_info( + ["lemur-70b-chat"], + "Lemur-Chat", + "https://huggingface.co/OpenLemur/lemur-70b-chat-v1", + "An openly accessible language model optimized for both natural language and coding capabilities ", +) + +register_model_info( + ["Mistral-7B-OpenOrca"], + "Open-Orca", + "https://huggingface.co/Open-Orca/Mistral-7B-OpenOrca", + "A fine-tune of [Mistral 7B](https://huggingface.co/mistralai/Mistral-7B-v0.1) using [OpenOrca dataset](https://huggingface.co/datasets/Open-Orca/OpenOrca)", +) + +register_model_info( + ["dolphin-2.2.1-mistral-7b"], + "dolphin-mistral", + "https://huggingface.co/ehartford/dolphin-2.2.1-mistral-7b", + "An uncensored fine-tuned Mistral 7B", +) + +register_model_info( + [ + "AquilaChat-7B", + "AquilaChat2-7B", + "AquilaChat2-34B", + ], + "Aquila-Chat", + "https://huggingface.co/BAAI/AquilaChat2-34B", + "Chat models developed by BAAI team", +) + +register_model_info( + ["xDAN-L1-Chat-RL-v1"], + "xDAN-L1-Chat", + "https://huggingface.co/xDAN-AI/xDAN-L1-Chat-RL-v1", + "A large language chat model created by xDAN-AI.", +) + +register_model_info( + ["MetaMath-70B-V1.0", "MetaMath-7B-V1.0"], + "MetaMath", + "https://huggingface.co/meta-math", + "A finetune of Llama2 on [MetaMathQA](https://huggingface.co/datasets/meta-math/MetaMathQA) that specializes in mathematical reasoning.", +) + +register_model_info( + ["Yuan2-2B-hf", "Yuan2-51B-hf", "Yuan2-102B-hf"], + "IEIYuan", + "https://huggingface.co/IEITYuan", + "A Basemodel developed by IEI.", +) + +register_model_info( + [ + "llava-v1.6-34b", + "llava-v1.6-vicuna-13b", + "llava-v1.6-vicuna-7b", + "llava-v1.6-mistral-7b", + "llava-v1.5-13b", + "llava-v1.5-7b", + ], + "LLaVA", + "https://github.com/haotian-liu/LLaVA", + "an open large language and vision assistant", +) + +register_model_info( + ["gemma-7b-it", "gemma-2b-it"], + "Gemma", + "https://blog.google/technology/developers/gemma-open-models/", + "Gemma by Google", +) + +register_model_info( + [ + "cllm/consistency-llm-7b-codesearchnet", + "cllm/consistency-llm-7b-gsm8k", + "cllm/consistency-llm-7b-sharegpt48k", + "cllm/consistency-llm-7b-spider", + ], + "consistency-llm", + "https://huggingface.co/cllm", + "consistency-llm is a new generation of parallel decoder LLMs with fast generation speed.", +) + +register_model_info( + ["reka-flash", "reka-flash-20240226"], + "Reka Flash", + "https://reka.ai/reka-flash", + "Multimodal model by Reka", +) diff --git a/src/model/model_xfastertransformer.py b/src/model/model_xfastertransformer.py new file mode 100644 index 0000000000000000000000000000000000000000..54890b1ca4977f4243cca46cb7c78114a3b2e5d6 --- /dev/null +++ b/src/model/model_xfastertransformer.py @@ -0,0 +1,81 @@ +import gc +from threading import Thread + +import torch +from transformers import TextIteratorStreamer + + +@torch.inference_mode() +def generate_stream_xft( + model, + tokenizer, + params, + device, + context_len=8192, + stream_interval=2, + judge_sent_end=False, +): + prompt = params["prompt"] + repetition_penalty = float(params.get("repetition_penalty", 1.0)) + + # unused now, and placehold for future. + # temperature = float(params.get("temperature", 1.0)) + # top_p = float(params.get("top_p", 1.0)) + + max_new_tokens = int(params.get("max_new_tokens", 4096)) + echo = params.get("echo", True) + + inputs = tokenizer( + prompt, return_tensors="pt", padding=model.config.padding + ).input_ids + input_echo_len = len(inputs[0]) + max_len = max_new_tokens + input_echo_len + + decode_config = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True) + streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, **decode_config) + generation_kwargs = { + "input_ids": inputs, + "streamer": streamer, + "max_length": max_len, + "num_beams": model.config.beam_width, + "length_penalty": repetition_penalty, + "num_return_sequences": model.config.num_return_sequences, + "early_stopping": model.config.early_stopping, + "eos_token_id": model.config.eos_token_id, + "pad_token_id": model.config.pad_token_id, + } + + thread = Thread(target=model.model.generate, kwargs=generation_kwargs) + thread.start() + if echo: + # means keep the prompt + output = prompt + else: + output = "" + i = 0 + for i, new_text in enumerate(streamer): + output += new_text + yield { + "text": output, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": i, + "total_tokens": input_echo_len + i, + }, + "finish_reason": None, + } + output = output.strip() + if i == max_new_tokens - 1: + finish_reason = "length" + else: + finish_reason = "stop" + yield { + "text": output, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": i, + "total_tokens": input_echo_len + i, + }, + "finish_reason": finish_reason, + } + gc.collect() diff --git a/src/model/model_yuan2.py b/src/model/model_yuan2.py new file mode 100644 index 0000000000000000000000000000000000000000..25b3e13f847cb38f22bba2cf277b55cef6c10726 --- /dev/null +++ b/src/model/model_yuan2.py @@ -0,0 +1,139 @@ +import gc +from threading import Thread +from typing import Iterable + +import torch +import transformers +from transformers import TextIteratorStreamer, GenerationConfig + +from fastchat.utils import is_partial_stop + + +@torch.inference_mode() +def generate_stream_yuan2( + model, + tokenizer, + params, + device, + context_len=2048, + stream_interval=2, + judge_sent_end=False, +): + prompt = params["prompt"] + len_prompt = len(prompt) + temperature = float(params.get("temperature", 1)) + repetition_penalty = float(params.get("repetition_penalty", 1.0)) + top_p = float(params.get("top_p", 0)) + top_k = int(params.get("top_k", 1)) # -1 means disable + max_new_tokens = int(params.get("max_new_tokens", 512)) + stop_str = params.get("stop", "") + echo = bool(params.get("echo", True)) + stop_token_ids = params.get("stop_token_ids", None) or [] + stop_token_ids.append(tokenizer("")["input_ids"][0]) + + inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + input_ids = inputs["input_ids"] + attention_mask = inputs["attention_mask"] + + max_src_len = context_len - max_new_tokens - 8 + + input_ids = input_ids[-max_src_len:] # truncate from the left + attention_mask = attention_mask[-max_src_len:] # truncate from the left + input_echo_len = len(input_ids) + + decode_config = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True) + streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, **decode_config) + + generation_config = GenerationConfig( + max_new_tokens=max_new_tokens, + do_sample=temperature >= 1.2, + temperature=temperature, + repetition_penalty=repetition_penalty, + no_repeat_ngram_size=10, + top_p=top_p, + top_k=top_k, + ) + + generation_kwargs = dict( + inputs=input_ids, + attention_mask=attention_mask, + streamer=streamer, + generation_config=generation_config, + ) + + thread = Thread(target=model.generate, kwargs=generation_kwargs) + thread.start() + + if echo: + # means keep the prompt + output = prompt + else: + output = "" + + for i, new_text in enumerate(streamer): + output += new_text + if i % stream_interval == 0: + if echo: + rfind_start = len_prompt + else: + rfind_start = 0 + + partially_stopped = False + if stop_str: + if isinstance(stop_str, str): + pos = output.rfind(stop_str, rfind_start) + if pos != -1: + output = output[:pos] + else: + partially_stopped = is_partial_stop(output, stop_str) + elif isinstance(stop_str, Iterable): + for each_stop in stop_str: + pos = output.rfind(each_stop, rfind_start) + if pos != -1: + output = output[:pos] + break + else: + partially_stopped = is_partial_stop(output, each_stop) + if partially_stopped: + break + else: + raise ValueError("Invalid stop field type.") + + # prevent yielding partial stop sequence + if not partially_stopped: + yield { + "text": output, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": i, + "total_tokens": input_echo_len + i, + }, + "finish_reason": None, + } + output = output.strip() + + # finish stream event, which contains finish reason + if i == max_new_tokens - 1: + finish_reason = "length" + elif partially_stopped: + finish_reason = None + else: + finish_reason = "stop" + + yield { + "text": output, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": i, + "total_tokens": input_echo_len + i, + }, + "finish_reason": finish_reason, + } + + # clean + gc.collect() + torch.cuda.empty_cache() + if device == "xpu": + torch.xpu.empty_cache() + if device == "npu": + torch.npu.empty_cache() diff --git a/src/model/monkey_patch_non_inplace.py b/src/model/monkey_patch_non_inplace.py new file mode 100644 index 0000000000000000000000000000000000000000..413dd3b30500c788abb19e5742447237ba2b1738 --- /dev/null +++ b/src/model/monkey_patch_non_inplace.py @@ -0,0 +1,119 @@ +""" +Monkey patch the llama implementation in the huggingface/transformers library. +Avoid bugs in mps backend by not using in-place operations. +""" +import math +from typing import List, Optional, Tuple + +import torch +from torch import nn +import transformers + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2].clone() + x2 = x[..., x.shape[-1] // 2 :].clone() + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1] + gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3]) + cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) + sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + padding_mask: Optional[torch.LongTensor] = None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = ( + self.q_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + key_states = ( + self.k_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + value_states = ( + self.v_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) + # [bsz, nh, t, hd] + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt( + self.head_dim + ) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + attn_weights = torch.max( + attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) + ) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + query_states.dtype + ) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +def replace_llama_attn_with_non_inplace_operations(): + """Avoid bugs in mps backend by not using in-place operations.""" + transformers.models.llama.modeling_llama.LlamaAttention.forward = forward diff --git a/src/model/rwkv_model.py b/src/model/rwkv_model.py new file mode 100644 index 0000000000000000000000000000000000000000..bdbc14584bfd1ec90e8478b4e55f07e8ec89a967 --- /dev/null +++ b/src/model/rwkv_model.py @@ -0,0 +1,76 @@ +import os +from types import SimpleNamespace +import warnings + +import torch + +os.environ["RWKV_JIT_ON"] = "1" +os.environ["RWKV_CUDA_ON"] = "1" + +from rwkv.model import RWKV +from rwkv.utils import PIPELINE, PIPELINE_ARGS + + +class RwkvModel: + def __init__(self, model_path): + warnings.warn( + "Experimental support. Please use ChatRWKV if you want to chat with RWKV" + ) + self.config = SimpleNamespace(is_encoder_decoder=False) + self.model = RWKV(model=model_path, strategy="cuda fp16") + # two GPUs + # self.model = RWKV(model=model_path, strategy="cuda:0 fp16 *20 -> cuda:1 fp16") + + self.tokenizer = None + self.model_path = model_path + + def to(self, target): + assert target == "cuda" + + def __call__(self, input_ids, use_cache, past_key_values=None): + assert use_cache == True + input_ids = input_ids[0].detach().cpu().numpy() + # print(input_ids) + logits, state = self.model.forward(input_ids, past_key_values) + # print(logits) + logits = logits.unsqueeze(0).unsqueeze(0) + out = SimpleNamespace(logits=logits, past_key_values=state) + return out + + def generate( + self, input_ids, do_sample, temperature, max_new_tokens, repetition_penalty=1.0 + ): + # This function is used by fastchat.llm_judge. + # Because RWKV does not support huggingface generation API, + # we reuse fastchat.serve.inference.generate_stream as a workaround. + from transformers import AutoTokenizer + + from fastchat.serve.inference import generate_stream + from fastchat.conversation import get_conv_template + + if self.tokenizer is None: + self.tokenizer = AutoTokenizer.from_pretrained( + "EleutherAI/pythia-160m", use_fast=True + ) + prompt = self.tokenizer.decode(input_ids[0].tolist()) + conv = get_conv_template("rwkv") + + gen_params = { + "model": self.model_path, + "prompt": prompt, + "temperature": temperature, + "repetition_penalty": repetition_penalty, + "max_new_tokens": max_new_tokens, + "stop": conv.stop_str, + "stop_token_ids": conv.stop_token_ids, + "echo": False, + } + res_iter = generate_stream(self, self.tokenizer, gen_params, "cuda") + + for res in res_iter: + pass + + output = res["text"] + output_ids = self.tokenizer.encode(output) + + return [input_ids[0].tolist() + output_ids] diff --git a/src/model/upload_hub.py b/src/model/upload_hub.py new file mode 100644 index 0000000000000000000000000000000000000000..b1519652e6d90479d60054008d8d7e371b16356e --- /dev/null +++ b/src/model/upload_hub.py @@ -0,0 +1,45 @@ +""" +Upload weights to huggingface. + +Usage: +python3 -m fastchat.model.upload_hub --model-path ~/model_weights/vicuna-13b --hub-repo-id lmsys/vicuna-13b-v1.3 +""" +import argparse +import tempfile + +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM + + +def upload_hub(model_path, hub_repo_id, component, private): + if component == "all": + components = ["model", "tokenizer"] + else: + components = [component] + + kwargs = {"push_to_hub": True, "repo_id": hub_repo_id, "private": args.private} + + if "model" in components: + model = AutoModelForCausalLM.from_pretrained( + model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True + ) + with tempfile.TemporaryDirectory() as tmp_path: + model.save_pretrained(tmp_path, **kwargs) + + if "tokenizer" in components: + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) + with tempfile.TemporaryDirectory() as tmp_path: + tokenizer.save_pretrained(tmp_path, **kwargs) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", type=str, required=True) + parser.add_argument("--hub-repo-id", type=str, required=True) + parser.add_argument( + "--component", type=str, choices=["all", "model", "tokenizer"], default="all" + ) + parser.add_argument("--private", action="store_true") + args = parser.parse_args() + + upload_hub(args.model_path, args.hub_repo_id, args.component, args.private) diff --git a/src/modules/__init__.py b/src/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/modules/__pycache__/__init__.cpython-310.pyc b/src/modules/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c841f8d4db94785304388b4fa8b76ba7cf7e7b52 Binary files /dev/null and b/src/modules/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/modules/__pycache__/awq.cpython-310.pyc b/src/modules/__pycache__/awq.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6f89704abbd3e00901ee2c9eeca51f17f6c549f Binary files /dev/null and b/src/modules/__pycache__/awq.cpython-310.pyc differ diff --git a/src/modules/__pycache__/exllama.cpython-310.pyc b/src/modules/__pycache__/exllama.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1fca516e6ce70b8878908b0af97c2e4a68ebb689 Binary files /dev/null and b/src/modules/__pycache__/exllama.cpython-310.pyc differ diff --git a/src/modules/__pycache__/gptq.cpython-310.pyc b/src/modules/__pycache__/gptq.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..75ec8bcae9cbfcbdd96fb1438d8a7871e15c9334 Binary files /dev/null and b/src/modules/__pycache__/gptq.cpython-310.pyc differ diff --git a/src/modules/__pycache__/xfastertransformer.cpython-310.pyc b/src/modules/__pycache__/xfastertransformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a284866f12b4195d7ce0920ed27222e8f0e051a Binary files /dev/null and b/src/modules/__pycache__/xfastertransformer.cpython-310.pyc differ diff --git a/src/modules/awq.py b/src/modules/awq.py new file mode 100644 index 0000000000000000000000000000000000000000..1f27be85c09e2394bd821cc1ce236f46c429d4bc --- /dev/null +++ b/src/modules/awq.py @@ -0,0 +1,85 @@ +from dataclasses import dataclass, field +from pathlib import Path +import sys + +import torch +from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, modeling_utils + + +@dataclass +class AWQConfig: + ckpt: str = field( + default=None, + metadata={ + "help": "Load quantized model. The path to the local AWQ checkpoint." + }, + ) + wbits: int = field(default=16, metadata={"help": "#bits to use for quantization"}) + groupsize: int = field( + default=-1, + metadata={"help": "Groupsize to use for quantization; default uses full row."}, + ) + + +def load_awq_quantized(model_name, awq_config: AWQConfig, device): + print("Loading AWQ quantized model...") + + try: + from tinychat.utils import load_quant + from tinychat.modules import make_quant_norm, make_quant_attn, make_fused_mlp + except ImportError as e: + print(f"Error: Failed to import tinychat. {e}") + print("Please double check if you have successfully installed AWQ") + print("See https://github.com/lm-sys/FastChat/blob/main/docs/awq.md") + sys.exit(-1) + + config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained( + model_name, use_fast=False, trust_remote_code=True + ) + + def skip(*args, **kwargs): + pass + + torch.nn.init.kaiming_uniform_ = skip + torch.nn.init.kaiming_normal_ = skip + torch.nn.init.uniform_ = skip + torch.nn.init.normal_ = skip + modeling_utils._init_weights = False + + torch.set_default_dtype(torch.half) + model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) + + if any(name in find_awq_ckpt(awq_config) for name in ["llama", "vicuna"]): + model = load_quant.load_awq_llama_fast( + model, + find_awq_ckpt(awq_config), + awq_config.wbits, + awq_config.groupsize, + device, + ) + make_quant_attn(model, device) + make_quant_norm(model) + make_fused_mlp(model) + else: + model = load_quant.load_awq_model( + model, + find_awq_ckpt(awq_config), + awq_config.wbits, + awq_config.groupsize, + device, + ) + return model, tokenizer + + +def find_awq_ckpt(awq_config: AWQConfig): + if Path(awq_config.ckpt).is_file(): + return awq_config.ckpt + + for ext in ["*.pt", "*.safetensors"]: + matched_result = sorted(Path(awq_config.ckpt).glob(ext)) + if len(matched_result) > 0: + return str(matched_result[-1]) + + print("Error: AWQ checkpoint not found") + sys.exit(1) diff --git a/src/modules/exllama.py b/src/modules/exllama.py new file mode 100644 index 0000000000000000000000000000000000000000..5e5fc81b3453a25905896cba31f9ce9dd0f0690e --- /dev/null +++ b/src/modules/exllama.py @@ -0,0 +1,50 @@ +from dataclasses import dataclass, field +import sys + + +@dataclass +class ExllamaConfig: + max_seq_len: int + gpu_split: str = None + cache_8bit: bool = False + + +class ExllamaModel: + def __init__(self, exllama_model, exllama_cache): + self.model = exllama_model + self.cache = exllama_cache + self.config = self.model.config + + +def load_exllama_model(model_path, exllama_config: ExllamaConfig): + try: + from exllamav2 import ( + ExLlamaV2Config, + ExLlamaV2Tokenizer, + ExLlamaV2, + ExLlamaV2Cache, + ExLlamaV2Cache_8bit, + ) + except ImportError as e: + print(f"Error: Failed to load Exllamav2. {e}") + sys.exit(-1) + + exllamav2_config = ExLlamaV2Config() + exllamav2_config.model_dir = model_path + exllamav2_config.prepare() + exllamav2_config.max_seq_len = exllama_config.max_seq_len + exllamav2_config.cache_8bit = exllama_config.cache_8bit + + exllama_model = ExLlamaV2(exllamav2_config) + tokenizer = ExLlamaV2Tokenizer(exllamav2_config) + + split = None + if exllama_config.gpu_split: + split = [float(alloc) for alloc in exllama_config.gpu_split.split(",")] + exllama_model.load(split) + + cache_class = ExLlamaV2Cache_8bit if exllamav2_config.cache_8bit else ExLlamaV2Cache + exllama_cache = cache_class(exllama_model) + model = ExllamaModel(exllama_model=exllama_model, exllama_cache=exllama_cache) + + return model, tokenizer diff --git a/src/modules/gptq.py b/src/modules/gptq.py new file mode 100644 index 0000000000000000000000000000000000000000..fe0a220c0cfb227271fbb4d1e7c4eca636b10d1c --- /dev/null +++ b/src/modules/gptq.py @@ -0,0 +1,75 @@ +from dataclasses import dataclass, field +import os +from os.path import isdir, isfile +from pathlib import Path +import sys + +from transformers import AutoTokenizer + + +@dataclass +class GptqConfig: + ckpt: str = field( + default=None, + metadata={ + "help": "Load quantized model. The path to the local GPTQ checkpoint." + }, + ) + wbits: int = field(default=16, metadata={"help": "#bits to use for quantization"}) + groupsize: int = field( + default=-1, + metadata={"help": "Groupsize to use for quantization; default uses full row."}, + ) + act_order: bool = field( + default=True, + metadata={"help": "Whether to apply the activation order GPTQ heuristic"}, + ) + + +def load_gptq_quantized(model_name, gptq_config: GptqConfig): + print("Loading GPTQ quantized model...") + + try: + script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) + module_path = os.path.join(script_path, "../repositories/GPTQ-for-LLaMa") + + sys.path.insert(0, module_path) + from llama import load_quant + except ImportError as e: + print(f"Error: Failed to load GPTQ-for-LLaMa. {e}") + print("See https://github.com/lm-sys/FastChat/blob/main/docs/gptq.md") + sys.exit(-1) + + tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) + # only `fastest-inference-4bit` branch cares about `act_order` + if gptq_config.act_order: + model = load_quant( + model_name, + find_gptq_ckpt(gptq_config), + gptq_config.wbits, + gptq_config.groupsize, + act_order=gptq_config.act_order, + ) + else: + # other branches + model = load_quant( + model_name, + find_gptq_ckpt(gptq_config), + gptq_config.wbits, + gptq_config.groupsize, + ) + + return model, tokenizer + + +def find_gptq_ckpt(gptq_config: GptqConfig): + if Path(gptq_config.ckpt).is_file(): + return gptq_config.ckpt + + for ext in ["*.pt", "*.safetensors"]: + matched_result = sorted(Path(gptq_config.ckpt).glob(ext)) + if len(matched_result) > 0: + return str(matched_result[-1]) + + print("Error: gptq checkpoint not found") + sys.exit(1) diff --git a/src/modules/xfastertransformer.py b/src/modules/xfastertransformer.py new file mode 100644 index 0000000000000000000000000000000000000000..0b49bea4cd5c9afd723318daaa5c10dcb309b776 --- /dev/null +++ b/src/modules/xfastertransformer.py @@ -0,0 +1,46 @@ +from dataclasses import dataclass +import sys + + +@dataclass +class XftConfig: + max_seq_len: int = 4096 + beam_width: int = 1 + eos_token_id: int = -1 + pad_token_id: int = -1 + num_return_sequences: int = 1 + is_encoder_decoder: bool = False + padding: bool = True + early_stopping: bool = False + data_type: str = "bf16_fp16" + + +class XftModel: + def __init__(self, xft_model, xft_config): + self.model = xft_model + self.config = xft_config + + +def load_xft_model(model_path, xft_config: XftConfig): + try: + import xfastertransformer + from transformers import AutoTokenizer + except ImportError as e: + print(f"Error: Failed to load xFasterTransformer. {e}") + sys.exit(-1) + + if xft_config.data_type is None or xft_config.data_type == "": + data_type = "bf16_fp16" + else: + data_type = xft_config.data_type + tokenizer = AutoTokenizer.from_pretrained( + model_path, use_fast=False, padding_side="left", trust_remote_code=True + ) + xft_model = xfastertransformer.AutoModel.from_pretrained( + model_path, dtype=data_type + ) + model = XftModel(xft_model=xft_model, xft_config=xft_config) + if model.model.rank > 0: + while True: + model.model.generate() + return model, tokenizer diff --git a/src/serve/__init__.py b/src/serve/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/serve/__pycache__/__init__.cpython-310.pyc b/src/serve/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a9d2e57a8ae88d564eba30f8f6fa48f7da77843a Binary files /dev/null and b/src/serve/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/serve/__pycache__/api_provider.cpython-310.pyc b/src/serve/__pycache__/api_provider.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc1006b762c89574786261b4b8c4ebf7a68a4641 Binary files /dev/null and b/src/serve/__pycache__/api_provider.cpython-310.pyc differ diff --git a/src/serve/__pycache__/gradio_block_arena_named.cpython-310.pyc b/src/serve/__pycache__/gradio_block_arena_named.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c63159eb753a1770ae2b5b53ae0a8f393d83466 Binary files /dev/null and b/src/serve/__pycache__/gradio_block_arena_named.cpython-310.pyc differ diff --git a/src/serve/__pycache__/gradio_block_arena_vision.cpython-310.pyc b/src/serve/__pycache__/gradio_block_arena_vision.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af91cfb0186d4d408bce3b244c069acb5172a267 Binary files /dev/null and b/src/serve/__pycache__/gradio_block_arena_vision.cpython-310.pyc differ diff --git a/src/serve/__pycache__/gradio_block_arena_vision_named.cpython-310.pyc b/src/serve/__pycache__/gradio_block_arena_vision_named.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8116aca54af66276b0dcd9323660086bb7cf648a Binary files /dev/null and b/src/serve/__pycache__/gradio_block_arena_vision_named.cpython-310.pyc differ diff --git a/src/serve/__pycache__/gradio_web_server.cpython-310.pyc b/src/serve/__pycache__/gradio_web_server.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b0698677598ce8d1dbf5fde45749048dcf16075 Binary files /dev/null and b/src/serve/__pycache__/gradio_web_server.cpython-310.pyc differ diff --git a/src/serve/__pycache__/remote_logger.cpython-310.pyc b/src/serve/__pycache__/remote_logger.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..334cd2df4c63f4cdb0f94ad7b291793c054d3a8d Binary files /dev/null and b/src/serve/__pycache__/remote_logger.cpython-310.pyc differ diff --git a/src/serve/api_provider.py b/src/serve/api_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..34bd8c549f23db99d546a375a96b71746f4a69ba --- /dev/null +++ b/src/serve/api_provider.py @@ -0,0 +1,1037 @@ +"""Call API providers.""" + +import json +import os +import random +import re +from typing import Optional +import time + +import requests + +from fastchat.utils import build_logger + + +logger = build_logger("gradio_web_server", "gradio_web_server.log") + + +def get_api_provider_stream_iter( + conv, + model_name, + model_api_dict, + temperature, + top_p, + max_new_tokens, + state, +): + if model_api_dict["api_type"] == "openai": + if model_api_dict["vision-arena"]: + prompt = conv.to_openai_vision_api_messages() + else: + prompt = conv.to_openai_api_messages() + stream_iter = openai_api_stream_iter( + model_api_dict["model_name"], + prompt, + temperature, + top_p, + max_new_tokens, + api_base=model_api_dict["api_base"], + api_key=model_api_dict["api_key"], + ) + elif model_api_dict["api_type"] == "openai_assistant": + last_prompt = conv.messages[-2][1] + stream_iter = openai_assistant_api_stream_iter( + state, + last_prompt, + assistant_id=model_api_dict["assistant_id"], + api_key=model_api_dict["api_key"], + ) + elif model_api_dict["api_type"] == "anthropic": + if model_api_dict["vision-arena"]: + prompt = conv.to_anthropic_vision_api_messages() + else: + prompt = conv.to_openai_api_messages() + stream_iter = anthropic_api_stream_iter( + model_name, prompt, temperature, top_p, max_new_tokens + ) + elif model_api_dict["api_type"] == "anthropic_message": + if model_api_dict["vision-arena"]: + prompt = conv.to_anthropic_vision_api_messages() + else: + prompt = conv.to_openai_api_messages() + stream_iter = anthropic_message_api_stream_iter( + model_name, prompt, temperature, top_p, max_new_tokens + ) + elif model_api_dict["api_type"] == "anthropic_message_vertex": + if model_api_dict["vision-arena"]: + prompt = conv.to_anthropic_vision_api_messages() + else: + prompt = conv.to_openai_api_messages() + stream_iter = anthropic_message_api_stream_iter( + model_api_dict["model_name"], + prompt, + temperature, + top_p, + max_new_tokens, + vertex_ai=True, + ) + elif model_api_dict["api_type"] == "gemini": + prompt = conv.to_gemini_api_messages() + stream_iter = gemini_api_stream_iter( + model_api_dict["model_name"], + prompt, + temperature, + top_p, + max_new_tokens, + api_key=model_api_dict["api_key"], + ) + elif model_api_dict["api_type"] == "bard": + prompt = conv.to_openai_api_messages() + stream_iter = bard_api_stream_iter( + model_api_dict["model_name"], + prompt, + temperature, + top_p, + api_key=model_api_dict["api_key"], + ) + elif model_api_dict["api_type"] == "mistral": + prompt = conv.to_openai_api_messages() + stream_iter = mistral_api_stream_iter( + model_name, prompt, temperature, top_p, max_new_tokens + ) + elif model_api_dict["api_type"] == "nvidia": + prompt = conv.to_openai_api_messages() + stream_iter = nvidia_api_stream_iter( + model_name, + prompt, + temperature, + top_p, + max_new_tokens, + model_api_dict["api_base"], + ) + elif model_api_dict["api_type"] == "ai2": + prompt = conv.to_openai_api_messages() + stream_iter = ai2_api_stream_iter( + model_name, + model_api_dict["model_name"], + prompt, + temperature, + top_p, + max_new_tokens, + api_base=model_api_dict["api_base"], + api_key=model_api_dict["api_key"], + ) + elif model_api_dict["api_type"] == "vertex": + prompt = conv.to_vertex_api_messages() + stream_iter = vertex_api_stream_iter( + model_name, prompt, temperature, top_p, max_new_tokens + ) + elif model_api_dict["api_type"] == "yandexgpt": + # note: top_p parameter is unused by yandexgpt + + messages = [] + if conv.system_message: + messages.append({"role": "system", "text": conv.system_message}) + messages += [ + {"role": role, "text": text} + for role, text in conv.messages + if text is not None + ] + + fixed_temperature = model_api_dict.get("fixed_temperature") + if fixed_temperature is not None: + temperature = fixed_temperature + + stream_iter = yandexgpt_api_stream_iter( + model_name=model_api_dict["model_name"], + messages=messages, + temperature=temperature, + max_tokens=max_new_tokens, + api_base=model_api_dict["api_base"], + api_key=model_api_dict.get("api_key"), + folder_id=model_api_dict.get("folder_id"), + ) + elif model_api_dict["api_type"] == "cohere": + messages = conv.to_openai_api_messages() + stream_iter = cohere_api_stream_iter( + client_name=model_api_dict.get("client_name", "FastChat"), + model_id=model_api_dict["model_name"], + messages=messages, + temperature=temperature, + top_p=top_p, + max_new_tokens=max_new_tokens, + api_base=model_api_dict["api_base"], + api_key=model_api_dict["api_key"], + ) + elif model_api_dict["api_type"] == "reka": + messages = conv.to_reka_api_messages() + stream_iter = reka_api_stream_iter( + model_name=model_api_dict["model_name"], + messages=messages, + temperature=temperature, + top_p=top_p, + max_new_tokens=max_new_tokens, + api_base=model_api_dict["api_base"], + api_key=model_api_dict["api_key"], + ) + else: + raise NotImplementedError() + + return stream_iter + + +def openai_api_stream_iter( + model_name, + messages, + temperature, + top_p, + max_new_tokens, + api_base=None, + api_key=None, +): + import openai + + api_key = api_key or os.environ["OPENAI_API_KEY"] + + if "azure" in model_name: + client = openai.AzureOpenAI( + api_version="2023-07-01-preview", + azure_endpoint=api_base or "https://api.openai.com/v1", + api_key=api_key, + ) + else: + client = openai.OpenAI( + base_url=api_base or "https://api.openai.com/v1", + api_key=api_key, + timeout=180, + ) + + # Make requests for logging + text_messages = [] + for message in messages: + if type(message["content"]) == str: # text-only model + text_messages.append(message) + else: # vision model + filtered_content_list = [ + content for content in message["content"] if content["type"] == "text" + ] + text_messages.append( + {"role": message["role"], "content": filtered_content_list} + ) + + gen_params = { + "model": model_name, + "prompt": text_messages, + "temperature": temperature, + "top_p": top_p, + "max_new_tokens": max_new_tokens, + } + logger.info(f"==== request ====\n{gen_params}") + + res = client.chat.completions.create( + model=model_name, + messages=messages, + temperature=temperature, + max_tokens=max_new_tokens, + stream=True, + ) + text = "" + for chunk in res: + if len(chunk.choices) > 0: + text += chunk.choices[0].delta.content or "" + data = { + "text": text, + "error_code": 0, + } + yield data + + +def upload_openai_file_to_gcs(file_id): + import openai + from google.cloud import storage + + storage_client = storage.Client() + + file = openai.files.content(file_id) + # upload file to GCS + bucket = storage_client.get_bucket("arena_user_content") + blob = bucket.blob(f"{file_id}") + blob.upload_from_string(file.read()) + blob.make_public() + return blob.public_url + + +def openai_assistant_api_stream_iter( + state, + prompt, + assistant_id, + api_key=None, +): + import openai + import base64 + + api_key = api_key or os.environ["OPENAI_API_KEY"] + client = openai.OpenAI(base_url="https://api.openai.com/v1", api_key=api_key) + + if state.oai_thread_id is None: + logger.info("==== create thread ====") + thread = client.beta.threads.create() + state.oai_thread_id = thread.id + logger.info(f"==== thread_id ====\n{state.oai_thread_id}") + thread_message = client.beta.threads.messages.with_raw_response.create( + state.oai_thread_id, + role="user", + content=prompt, + timeout=3, + ) + # logger.info(f"header {thread_message.headers}") + thread_message = thread_message.parse() + # Make requests + gen_params = { + "assistant_id": assistant_id, + "thread_id": state.oai_thread_id, + "message": prompt, + } + logger.info(f"==== request ====\n{gen_params}") + + res = requests.post( + f"https://api.openai.com/v1/threads/{state.oai_thread_id}/runs", + headers={ + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + "OpenAI-Beta": "assistants=v1", + }, + json={"assistant_id": assistant_id, "stream": True}, + timeout=30, + stream=True, + ) + + list_of_text = [] + list_of_raw_text = [] + offset_idx = 0 + full_ret_text = "" + idx_mapping = {} + for line in res.iter_lines(): + if not line: + continue + data = line.decode("utf-8") + # logger.info("data:", data) + if data.endswith("[DONE]"): + break + if data.startswith("event"): + event = data.split(":")[1].strip() + if event == "thread.message.completed": + offset_idx += len(list_of_text) + continue + data = json.loads(data[6:]) + + if data.get("status") == "failed": + yield { + "text": f"**API REQUEST ERROR** Reason: {data['last_error']['message']}", + "error_code": 1, + } + return + + if data.get("status") == "completed": + logger.info(f"[debug]: {data}") + + if data["object"] != "thread.message.delta": + continue + + for delta in data["delta"]["content"]: + text_index = delta["index"] + offset_idx + if len(list_of_text) <= text_index: + list_of_text.append("") + list_of_raw_text.append("") + + text = list_of_text[text_index] + raw_text = list_of_raw_text[text_index] + + if delta["type"] == "text": + # text, url_citation or file_path + content = delta["text"] + if "annotations" in content and len(content["annotations"]) > 0: + annotations = content["annotations"] + + cur_offset = 0 + raw_text_copy = raw_text + for anno in annotations: + if anno["type"] == "url_citation": + anno_text = anno["text"] + if anno_text not in idx_mapping: + continue + citation_number = idx_mapping[anno_text] + + start_idx = anno["start_index"] + cur_offset + end_idx = anno["end_index"] + cur_offset + url = anno["url_citation"]["url"] + + citation = f" [[{citation_number}]]({url})" + raw_text_copy = ( + raw_text_copy[:start_idx] + + citation + + raw_text_copy[end_idx:] + ) + cur_offset += len(citation) - (end_idx - start_idx) + elif anno["type"] == "file_path": + file_public_url = upload_openai_file_to_gcs( + anno["file_path"]["file_id"] + ) + raw_text_copy = raw_text_copy.replace( + anno["text"], f"{file_public_url}" + ) + text = raw_text_copy + else: + text_content = content["value"] + raw_text += text_content + + # re-index citation number + pattern = r"【\d+】" + matches = re.findall(pattern, content["value"]) + if len(matches) > 0: + for match in matches: + if match not in idx_mapping: + idx_mapping[match] = len(idx_mapping) + 1 + citation_number = idx_mapping[match] + text_content = text_content.replace( + match, f" [{citation_number}]" + ) + text += text_content + # yield {"text": text, "error_code": 0} + elif delta["type"] == "image_file": + image_public_url = upload_openai_file_to_gcs( + delta["image_file"]["file_id"] + ) + # raw_text += f"![image]({image_public_url})" + text += f"![image]({image_public_url})" + + list_of_text[text_index] = text + list_of_raw_text[text_index] = raw_text + + full_ret_text = "\n".join(list_of_text) + yield {"text": full_ret_text, "error_code": 0} + + +def anthropic_api_stream_iter(model_name, prompt, temperature, top_p, max_new_tokens): + import anthropic + + c = anthropic.Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"]) + + # Make requests + gen_params = { + "model": model_name, + "prompt": prompt, + "temperature": temperature, + "top_p": top_p, + "max_new_tokens": max_new_tokens, + } + logger.info(f"==== request ====\n{gen_params}") + + res = c.completions.create( + prompt=prompt, + stop_sequences=[anthropic.HUMAN_PROMPT], + max_tokens_to_sample=max_new_tokens, + temperature=temperature, + top_p=top_p, + model=model_name, + stream=True, + ) + text = "" + for chunk in res: + text += chunk.completion + data = { + "text": text, + "error_code": 0, + } + yield data + + +def anthropic_message_api_stream_iter( + model_name, + messages, + temperature, + top_p, + max_new_tokens, + vertex_ai=False, +): + import anthropic + + if vertex_ai: + client = anthropic.AnthropicVertex( + region=os.environ["GCP_LOCATION"], + project_id=os.environ["GCP_PROJECT_ID"], + max_retries=5, + ) + else: + client = anthropic.Anthropic( + api_key=os.environ["ANTHROPIC_API_KEY"], + max_retries=5, + ) + + text_messages = [] + for message in messages: + if type(message["content"]) == str: # text-only model + text_messages.append(message) + else: # vision model + filtered_content_list = [ + content for content in message["content"] if content["type"] == "text" + ] + text_messages.append( + {"role": message["role"], "content": filtered_content_list} + ) + + # Make requests for logging + gen_params = { + "model": model_name, + "prompt": text_messages, + "temperature": temperature, + "top_p": top_p, + "max_new_tokens": max_new_tokens, + } + logger.info(f"==== request ====\n{gen_params}") + + system_prompt = "" + if messages[0]["role"] == "system": + if type(messages[0]["content"]) == dict: + system_prompt = messages[0]["content"]["text"] + elif type(messages[0]["content"]) == str: + system_prompt = messages[0]["content"] + # remove system prompt + messages = messages[1:] + + text = "" + with client.messages.stream( + temperature=temperature, + top_p=top_p, + max_tokens=max_new_tokens, + messages=messages, + model=model_name, + system=system_prompt, + ) as stream: + for chunk in stream.text_stream: + text += chunk + data = { + "text": text, + "error_code": 0, + } + yield data + + +def gemini_api_stream_iter( + model_name, messages, temperature, top_p, max_new_tokens, api_key=None +): + import google.generativeai as genai # pip install google-generativeai + + if api_key is None: + api_key = os.environ["GEMINI_API_KEY"] + genai.configure(api_key=api_key) + + generation_config = { + "temperature": temperature, + "max_output_tokens": max_new_tokens, + "top_p": top_p, + } + params = { + "model": model_name, + "prompt": messages, + } + params.update(generation_config) + logger.info(f"==== request ====\n{params}") + + safety_settings = [ + {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"}, + ] + + history = [] + system_prompt = None + for message in messages[:-1]: + if message["role"] == "system": + system_prompt = message["content"] + continue + history.append({"role": message["role"], "parts": message["content"]}) + + model = genai.GenerativeModel( + model_name=model_name, + system_instruction=system_prompt, + generation_config=generation_config, + safety_settings=safety_settings, + ) + convo = model.start_chat(history=history) + response = convo.send_message(messages[-1]["content"], stream=True) + + try: + text = "" + for chunk in response: + text += chunk.candidates[0].content.parts[0].text + data = { + "text": text, + "error_code": 0, + } + yield data + except Exception as e: + logger.error(f"==== error ====\n{e}") + reason = chunk.candidates + yield { + "text": f"**API REQUEST ERROR** Reason: {reason}.", + "error_code": 1, + } + + +def bard_api_stream_iter(model_name, conv, temperature, top_p, api_key=None): + del top_p # not supported + del temperature # not supported + + if api_key is None: + api_key = os.environ["BARD_API_KEY"] + + # convert conv to conv_bard + conv_bard = [] + for turn in conv: + if turn["role"] == "user": + conv_bard.append({"author": "0", "content": turn["content"]}) + elif turn["role"] == "assistant": + conv_bard.append({"author": "1", "content": turn["content"]}) + else: + raise ValueError(f"Unsupported role: {turn['role']}") + + params = { + "model": model_name, + "prompt": conv_bard, + } + logger.info(f"==== request ====\n{params}") + + try: + res = requests.post( + f"https://generativelanguage.googleapis.com/v1beta2/models/{model_name}:generateMessage?key={api_key}", + json={ + "prompt": { + "messages": conv_bard, + }, + }, + timeout=30, + ) + except Exception as e: + logger.error(f"==== error ====\n{e}") + yield { + "text": f"**API REQUEST ERROR** Reason: {e}.", + "error_code": 1, + } + + if res.status_code != 200: + logger.error(f"==== error ==== ({res.status_code}): {res.text}") + yield { + "text": f"**API REQUEST ERROR** Reason: status code {res.status_code}.", + "error_code": 1, + } + + response_json = res.json() + if "candidates" not in response_json: + logger.error(f"==== error ==== response blocked: {response_json}") + reason = response_json["filters"][0]["reason"] + yield { + "text": f"**API REQUEST ERROR** Reason: {reason}.", + "error_code": 1, + } + + response = response_json["candidates"][0]["content"] + pos = 0 + while pos < len(response): + # simulate token streaming + pos += random.randint(3, 6) + time.sleep(0.002) + data = { + "text": response[:pos], + "error_code": 0, + } + yield data + + +def ai2_api_stream_iter( + model_name, + model_id, + messages, + temperature, + top_p, + max_new_tokens, + api_key=None, + api_base=None, +): + # get keys and needed values + ai2_key = api_key or os.environ.get("AI2_API_KEY") + api_base = api_base or "https://inferd.allen.ai/api/v1/infer" + + # Make requests + gen_params = { + "model": model_name, + "prompt": messages, + "temperature": temperature, + "top_p": top_p, + "max_new_tokens": max_new_tokens, + } + logger.info(f"==== request ====\n{gen_params}") + + # AI2 uses vLLM, which requires that `top_p` be 1.0 for greedy sampling: + # https://github.com/vllm-project/vllm/blob/v0.1.7/vllm/sampling_params.py#L156-L157 + if temperature == 0.0 and top_p < 1.0: + raise ValueError("top_p must be 1 when temperature is 0.0") + + res = requests.post( + api_base, + stream=True, + headers={"Authorization": f"Bearer {ai2_key}"}, + json={ + "model_id": model_id, + # This input format is specific to the Tulu2 model. Other models + # may require different input formats. See the model's schema + # documentation on InferD for more information. + "input": { + "messages": messages, + "opts": { + "max_tokens": max_new_tokens, + "temperature": temperature, + "top_p": top_p, + "logprobs": 1, # increase for more choices + }, + }, + }, + timeout=5, + ) + + if res.status_code != 200: + logger.error(f"unexpected response ({res.status_code}): {res.text}") + raise ValueError("unexpected response from InferD", res) + + text = "" + for line in res.iter_lines(): + if line: + part = json.loads(line) + if "result" in part and "output" in part["result"]: + for t in part["result"]["output"]["text"]: + text += t + else: + logger.error(f"unexpected part: {part}") + raise ValueError("empty result in InferD response") + + data = { + "text": text, + "error_code": 0, + } + yield data + + +def mistral_api_stream_iter(model_name, messages, temperature, top_p, max_new_tokens): + from mistralai.client import MistralClient + from mistralai.models.chat_completion import ChatMessage + + api_key = os.environ["MISTRAL_API_KEY"] + + client = MistralClient(api_key=api_key, timeout=5) + + # Make requests + gen_params = { + "model": model_name, + "prompt": messages, + "temperature": temperature, + "top_p": top_p, + "max_new_tokens": max_new_tokens, + } + logger.info(f"==== request ====\n{gen_params}") + + new_messages = [ + ChatMessage(role=message["role"], content=message["content"]) + for message in messages + ] + + res = client.chat_stream( + model=model_name, + temperature=temperature, + messages=new_messages, + max_tokens=max_new_tokens, + top_p=top_p, + ) + + text = "" + for chunk in res: + if chunk.choices[0].delta.content is not None: + text += chunk.choices[0].delta.content + data = { + "text": text, + "error_code": 0, + } + yield data + + +def nvidia_api_stream_iter(model_name, messages, temp, top_p, max_tokens, api_base): + api_key = os.environ["NVIDIA_API_KEY"] + headers = { + "Authorization": f"Bearer {api_key}", + "accept": "text/event-stream", + "content-type": "application/json", + } + # nvidia api does not accept 0 temperature + if temp == 0.0: + temp = 0.000001 + + payload = { + "messages": messages, + "temperature": temp, + "top_p": top_p, + "max_tokens": max_tokens, + "seed": 42, + "stream": True, + } + logger.info(f"==== request ====\n{payload}") + + response = requests.post( + api_base, headers=headers, json=payload, stream=True, timeout=1 + ) + text = "" + for line in response.iter_lines(): + if line: + data = line.decode("utf-8") + if data.endswith("[DONE]"): + break + data = json.loads(data[6:])["choices"][0]["delta"]["content"] + text += data + yield {"text": text, "error_code": 0} + + +def yandexgpt_api_stream_iter( + model_name, messages, temperature, max_tokens, api_base, api_key, folder_id +): + api_key = api_key or os.environ["YANDEXGPT_API_KEY"] + headers = { + "Authorization": f"Api-Key {api_key}", + "content-type": "application/json", + } + + payload = { + "modelUri": f"gpt://{folder_id}/{model_name}", + "completionOptions": { + "temperature": temperature, + "max_tokens": max_tokens, + "stream": True, + }, + "messages": messages, + } + logger.info(f"==== request ====\n{payload}") + + # https://llm.api.cloud.yandex.net/foundationModels/v1/completion + response = requests.post( + api_base, headers=headers, json=payload, stream=True, timeout=60 + ) + text = "" + for line in response.iter_lines(): + if line: + data = json.loads(line.decode("utf-8")) + data = data["result"] + top_alternative = data["alternatives"][0] + text = top_alternative["message"]["text"] + yield {"text": text, "error_code": 0} + + status = top_alternative["status"] + if status in ( + "ALTERNATIVE_STATUS_FINAL", + "ALTERNATIVE_STATUS_TRUNCATED_FINAL", + ): + break + + +def cohere_api_stream_iter( + client_name: str, + model_id: str, + messages: list, + temperature: Optional[ + float + ] = None, # The SDK or API handles None for all parameters following + top_p: Optional[float] = None, + max_new_tokens: Optional[int] = None, + api_key: Optional[str] = None, # default is env var CO_API_KEY + api_base: Optional[str] = None, +): + import cohere + + OPENAI_TO_COHERE_ROLE_MAP = { + "user": "User", + "assistant": "Chatbot", + "system": "System", + } + + client = cohere.Client( + api_key=api_key, + base_url=api_base, + client_name=client_name, + ) + + # prepare and log requests + chat_history = [ + dict( + role=OPENAI_TO_COHERE_ROLE_MAP[message["role"]], message=message["content"] + ) + for message in messages[:-1] + ] + actual_prompt = messages[-1]["content"] + + gen_params = { + "model": model_id, + "messages": messages, + "chat_history": chat_history, + "prompt": actual_prompt, + "temperature": temperature, + "top_p": top_p, + "max_new_tokens": max_new_tokens, + } + logger.info(f"==== request ====\n{gen_params}") + + # make request and stream response + res = client.chat_stream( + message=actual_prompt, + chat_history=chat_history, + model=model_id, + temperature=temperature, + max_tokens=max_new_tokens, + p=top_p, + ) + try: + text = "" + for streaming_item in res: + if streaming_item.event_type == "text-generation": + text += streaming_item.text + yield {"text": text, "error_code": 0} + except cohere.core.ApiError as e: + logger.error(f"==== error from cohere api: {e} ====") + yield { + "text": f"**API REQUEST ERROR** Reason: {e}", + "error_code": 1, + } + + +def vertex_api_stream_iter(model_name, messages, temperature, top_p, max_new_tokens): + import vertexai + from vertexai import generative_models + from vertexai.generative_models import ( + GenerationConfig, + GenerativeModel, + Image, + ) + + project_id = os.environ.get("GCP_PROJECT_ID", None) + location = os.environ.get("GCP_LOCATION", None) + vertexai.init(project=project_id, location=location) + + text_messages = [] + for message in messages: + if type(message) == str: + text_messages.append(message) + + gen_params = { + "model": model_name, + "prompt": text_messages, + "temperature": temperature, + "top_p": top_p, + "max_new_tokens": max_new_tokens, + } + logger.info(f"==== request ====\n{gen_params}") + + safety_settings = [ + generative_models.SafetySetting( + category=generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT, + threshold=generative_models.HarmBlockThreshold.BLOCK_NONE, + ), + generative_models.SafetySetting( + category=generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH, + threshold=generative_models.HarmBlockThreshold.BLOCK_NONE, + ), + generative_models.SafetySetting( + category=generative_models.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + threshold=generative_models.HarmBlockThreshold.BLOCK_NONE, + ), + generative_models.SafetySetting( + category=generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=generative_models.HarmBlockThreshold.BLOCK_NONE, + ), + ] + generator = GenerativeModel(model_name).generate_content( + messages, + stream=True, + generation_config=GenerationConfig( + top_p=top_p, max_output_tokens=max_new_tokens, temperature=temperature + ), + safety_settings=safety_settings, + ) + + ret = "" + for chunk in generator: + # NOTE(chris): This may be a vertex api error, below is HOTFIX: https://github.com/googleapis/python-aiplatform/issues/3129 + ret += chunk.candidates[0].content.parts[0]._raw_part.text + # ret += chunk.text + data = { + "text": ret, + "error_code": 0, + } + yield data + + +def reka_api_stream_iter( + model_name: str, + messages: list, + temperature: Optional[ + float + ] = None, # The SDK or API handles None for all parameters following + top_p: Optional[float] = None, + max_new_tokens: Optional[int] = None, + api_key: Optional[str] = None, # default is env var CO_API_KEY + api_base: Optional[str] = None, +): + api_key = api_key or os.environ["REKA_API_KEY"] + + use_search_engine = False + if "-online" in model_name: + model_name = model_name.replace("-online", "") + use_search_engine = True + request = { + "model_name": model_name, + "conversation_history": messages, + "temperature": temperature, + "request_output_len": max_new_tokens, + "runtime_top_p": top_p, + "stream": True, + "use_search_engine": use_search_engine, + } + + # Make requests for logging + text_messages = [] + for message in messages: + text_messages.append({"type": message["type"], "text": message["text"]}) + logged_request = dict(request) + logged_request["conversation_history"] = text_messages + + logger.info(f"==== request ====\n{logged_request}") + + response = requests.post( + api_base, + stream=True, + json=request, + headers={ + "X-Api-Key": api_key, + }, + ) + + if response.status_code != 200: + error_message = response.text + logger.error(f"==== error from reka api: {error_message} ====") + yield { + "text": f"**API REQUEST ERROR** Reason: {error_message}", + "error_code": 1, + } + return + + for line in response.iter_lines(): + line = line.decode("utf8") + if not line.startswith("data: "): + continue + gen = json.loads(line[6:]) + yield {"text": gen["text"], "error_code": 0} diff --git a/src/serve/base_model_worker.py b/src/serve/base_model_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..2fe322990f1e9d7a14cb45afbc16e5574604a766 --- /dev/null +++ b/src/serve/base_model_worker.py @@ -0,0 +1,241 @@ +import asyncio +import threading +import time +from typing import List + +from fastapi import FastAPI, Request, BackgroundTasks +from fastapi.responses import StreamingResponse, JSONResponse +import requests + +from fastchat.constants import WORKER_HEART_BEAT_INTERVAL +from fastchat.conversation import Conversation +from fastchat.utils import pretty_print_semaphore, build_logger + + +worker = None +logger = None + +app = FastAPI() + + +def heart_beat_worker(obj): + while True: + time.sleep(WORKER_HEART_BEAT_INTERVAL) + obj.send_heart_beat() + + +class BaseModelWorker: + def __init__( + self, + controller_addr: str, + worker_addr: str, + worker_id: str, + model_path: str, + model_names: List[str], + limit_worker_concurrency: int, + conv_template: str = None, + multimodal: bool = False, + ): + global logger, worker + + self.controller_addr = controller_addr + self.worker_addr = worker_addr + self.worker_id = worker_id + if model_path.endswith("/"): + model_path = model_path[:-1] + self.model_names = model_names or [model_path.split("/")[-1]] + self.limit_worker_concurrency = limit_worker_concurrency + self.conv = self.make_conv_template(conv_template, model_path) + self.conv.sep_style = int(self.conv.sep_style) + self.multimodal = multimodal + self.tokenizer = None + self.context_len = None + self.call_ct = 0 + self.semaphore = None + + self.heart_beat_thread = None + + if logger is None: + logger = build_logger("model_worker", f"model_worker_{self.worker_id}.log") + if worker is None: + worker = self + + def make_conv_template( + self, + conv_template: str = None, + model_path: str = None, + ) -> Conversation: + """ + can be overrided to costomize the conversation template for different model workers. + """ + from fastchat.conversation import get_conv_template + from fastchat.model.model_adapter import get_conversation_template + + if conv_template: + conv = get_conv_template(conv_template) + else: + conv = get_conversation_template(model_path) + return conv + + def init_heart_beat(self): + self.register_to_controller() + self.heart_beat_thread = threading.Thread( + target=heart_beat_worker, + args=(self,), + daemon=True, + ) + self.heart_beat_thread.start() + + def register_to_controller(self): + logger.info("Register to controller") + + url = self.controller_addr + "/register_worker" + data = { + "worker_name": self.worker_addr, + "check_heart_beat": True, + "worker_status": self.get_status(), + "multimodal": self.multimodal, + } + r = requests.post(url, json=data) + assert r.status_code == 200 + + def send_heart_beat(self): + logger.info( + f"Send heart beat. Models: {self.model_names}. " + f"Semaphore: {pretty_print_semaphore(self.semaphore)}. " + f"call_ct: {self.call_ct}. " + f"worker_id: {self.worker_id}. " + ) + + url = self.controller_addr + "/receive_heart_beat" + + while True: + try: + ret = requests.post( + url, + json={ + "worker_name": self.worker_addr, + "queue_length": self.get_queue_length(), + }, + timeout=5, + ) + exist = ret.json()["exist"] + break + except (requests.exceptions.RequestException, KeyError) as e: + logger.error(f"heart beat error: {e}") + time.sleep(5) + + if not exist: + self.register_to_controller() + + def get_queue_length(self): + if self.semaphore is None: + return 0 + else: + sempahore_value = ( + self.semaphore._value + if self.semaphore._value is not None + else self.limit_worker_concurrency + ) + waiter_count = ( + 0 if self.semaphore._waiters is None else len(self.semaphore._waiters) + ) + return self.limit_worker_concurrency - sempahore_value + waiter_count + + def get_status(self): + return { + "model_names": self.model_names, + "speed": 1, + "queue_length": self.get_queue_length(), + } + + def count_token(self, params): + prompt = params["prompt"] + + try: + input_ids = self.tokenizer(prompt).input_ids + input_echo_len = len(input_ids) + except TypeError: + input_echo_len = self.tokenizer.num_tokens(prompt) + + ret = { + "count": input_echo_len, + "error_code": 0, + } + return ret + + def get_conv_template(self): + return {"conv": self.conv} + + def generate_stream_gate(self, params): + raise NotImplementedError + + def generate_gate(self, params): + raise NotImplementedError + + def get_embeddings(self, params): + raise NotImplementedError + + +def release_worker_semaphore(): + worker.semaphore.release() + + +def acquire_worker_semaphore(): + if worker.semaphore is None: + worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency) + return worker.semaphore.acquire() + + +def create_background_tasks(): + background_tasks = BackgroundTasks() + background_tasks.add_task(release_worker_semaphore) + return background_tasks + + +@app.post("/worker_generate_stream") +async def api_generate_stream(request: Request): + params = await request.json() + await acquire_worker_semaphore() + generator = worker.generate_stream_gate(params) + background_tasks = create_background_tasks() + return StreamingResponse(generator, background=background_tasks) + + +@app.post("/worker_generate") +async def api_generate(request: Request): + params = await request.json() + await acquire_worker_semaphore() + output = await asyncio.to_thread(worker.generate_gate, params) + release_worker_semaphore() + return JSONResponse(output) + + +@app.post("/worker_get_embeddings") +async def api_get_embeddings(request: Request): + params = await request.json() + await acquire_worker_semaphore() + embedding = worker.get_embeddings(params) + release_worker_semaphore() + return JSONResponse(content=embedding) + + +@app.post("/worker_get_status") +async def api_get_status(request: Request): + return worker.get_status() + + +@app.post("/count_token") +async def api_count_token(request: Request): + params = await request.json() + return worker.count_token(params) + + +@app.post("/worker_get_conv_template") +async def api_get_conv(request: Request): + return worker.get_conv_template() + + +@app.post("/model_details") +async def api_model_details(request: Request): + return {"context_length": worker.context_len} diff --git a/src/serve/call_monitor.py b/src/serve/call_monitor.py new file mode 100644 index 0000000000000000000000000000000000000000..46a2382953ef9ba7ac9fe2aa0fbd7909284a67b4 --- /dev/null +++ b/src/serve/call_monitor.py @@ -0,0 +1,223 @@ +import json +import os +import glob +import time + +from fastapi import FastAPI +import hashlib +import asyncio + +REFRESH_INTERVAL_SEC = 60 +LOG_DIR_LIST = [] +# LOG_DIR = "/home/vicuna/tmp/test_env" + + +class Monitor: + """Monitor the number of calls to each model.""" + + def __init__(self, log_dir_list: list): + self.log_dir_list = log_dir_list + self.model_call = {} + self.user_call = {} + self.model_call_limit_global = { + "gpt-4-1106-preview": 100, + "gpt-4-0125-preview": 100, + } + self.model_call_day_limit_per_user = { + "gpt-4-1106-preview": 5, + "gpt-4-0125-preview": 5, + } + + async def update_stats(self, num_file=1) -> None: + while True: + # find the latest num_file log under log_dir + json_files = [] + for log_dir in self.log_dir_list: + json_files_per_server = glob.glob(os.path.join(log_dir, "*.json")) + json_files_per_server.sort(key=os.path.getctime, reverse=True) + json_files += json_files_per_server[:num_file] + model_call = {} + user_call = {} + for json_file in json_files: + for line in open(json_file, "r", encoding="utf-8"): + obj = json.loads(line) + if obj["type"] != "chat": + continue + if obj["model"] not in model_call: + model_call[obj["model"]] = [] + model_call[obj["model"]].append( + {"tstamp": obj["tstamp"], "user_id": obj["ip"]} + ) + if obj["ip"] not in user_call: + user_call[obj["ip"]] = [] + user_call[obj["ip"]].append( + {"tstamp": obj["tstamp"], "model": obj["model"]} + ) + + self.model_call = model_call + self.model_call_stats_hour = self.get_model_call_stats(top_k=None) + self.model_call_stats_day = self.get_model_call_stats( + top_k=None, most_recent_min=24 * 60 + ) + + self.user_call = user_call + self.user_call_stats_hour = self.get_user_call_stats(top_k=None) + self.user_call_stats_day = self.get_user_call_stats( + top_k=None, most_recent_min=24 * 60 + ) + await asyncio.sleep(REFRESH_INTERVAL_SEC) + + def get_model_call_limit(self, model: str) -> int: + if model not in self.model_call_limit_global: + return -1 + return self.model_call_limit_global[model] + + def update_model_call_limit(self, model: str, limit: int) -> bool: + if model not in self.model_call_limit_global: + return False + self.model_call_limit_global[model] = limit + return True + + def is_model_limit_reached(self, model: str) -> bool: + if model not in self.model_call_limit_global: + return False + if model not in self.model_call_stats_hour: + return False + # check if the model call limit is reached + if self.model_call_stats_hour[model] >= self.model_call_limit_global[model]: + return True + return False + + def is_user_limit_reached(self, model: str, user_id: str) -> bool: + if model not in self.model_call_day_limit_per_user: + return False + if user_id not in self.user_call_stats_day: + return False + if model not in self.user_call_stats_day[user_id]["call_dict"]: + return False + # check if the user call limit is reached + if ( + self.user_call_stats_day[user_id]["call_dict"][model] + >= self.model_call_day_limit_per_user[model] + ): + return True + return False + + def get_model_call_stats( + self, target_model=None, most_recent_min: int = 60, top_k: int = 20 + ) -> dict: + model_call_stats = {} + for model, reqs in self.model_call.items(): + if target_model is not None and model != target_model: + continue + model_call = [] + for req in reqs: + if req["tstamp"] < time.time() - most_recent_min * 60: + continue + model_call.append(req["tstamp"]) + model_call_stats[model] = len(model_call) + if top_k is not None: + top_k_model = sorted( + model_call_stats, key=lambda x: model_call_stats[x], reverse=True + )[:top_k] + model_call_stats = {model: model_call_stats[model] for model in top_k_model} + return model_call_stats + + def get_user_call_stats( + self, target_model=None, most_recent_min: int = 60, top_k: int = 20 + ) -> dict: + user_call_stats = {} + for user_id, reqs in self.user_call.items(): + user_model_call = {"call_dict": {}} + for req in reqs: + if req["tstamp"] < time.time() - most_recent_min * 60: + continue + if target_model is not None and req["model"] != target_model: + continue + if req["model"] not in user_model_call["call_dict"]: + user_model_call["call_dict"][req["model"]] = 0 + user_model_call["call_dict"][req["model"]] += 1 + + user_model_call["total_calls"] = sum(user_model_call["call_dict"].values()) + if user_model_call["total_calls"] > 0: + user_call_stats[user_id] = user_model_call + if top_k is not None: + top_k_user = sorted( + user_call_stats, + key=lambda x: user_call_stats[x]["total_calls"], + reverse=True, + )[:top_k] + user_call_stats = { + user_id: user_call_stats[user_id] for user_id in top_k_user + } + return user_call_stats + + def get_num_users(self, most_recent_min: int = 60) -> int: + user_call_stats = self.get_user_call_stats( + most_recent_min=most_recent_min, top_k=None + ) + return len(user_call_stats) + + +monitor = Monitor(log_dir_list=LOG_DIR_LIST) +app = FastAPI() + + +@app.on_event("startup") +async def app_startup(): + asyncio.create_task(monitor.update_stats(2)) + + +@app.get("/get_model_call_limit/{model}") +async def get_model_call_limit(model: str): + return {"model_call_limit": {model: monitor.get_model_call_limit(model)}} + + +@app.get("/update_model_call_limit/{model}/{limit}") +async def update_model_call_limit(model: str, limit: int): + if not monitor.update_model_call_limit(model, limit): + return {"success": False} + return {"success": True} + + +@app.get("/is_limit_reached") +async def is_limit_reached(model: str, user_id: str): + if monitor.is_model_limit_reached(model): + return { + "is_limit_reached": True, + "reason": f"MODEL_HOURLY_LIMIT ({model}): {monitor.get_model_call_limit(model)}", + } + if monitor.is_user_limit_reached(model, user_id): + return { + "is_limit_reached": True, + "reason": f"USER_DAILY_LIMIT ({model}): {monitor.model_call_day_limit_per_user[model]}", + } + return {"is_limit_reached": False} + + +@app.get("/get_num_users_hr") +async def get_num_users(): + return {"num_users": len(monitor.user_call_stats_hour)} + + +@app.get("/get_num_users_day") +async def get_num_users_day(): + return {"num_users": len(monitor.user_call_stats_day)} + + +@app.get("/get_user_call_stats") +async def get_user_call_stats( + model: str = None, most_recent_min: int = 60, top_k: int = None +): + return { + "user_call_stats": monitor.get_user_call_stats(model, most_recent_min, top_k) + } + + +@app.get("/get_model_call_stats") +async def get_model_call_stats( + model: str = None, most_recent_min: int = 60, top_k: int = None +): + return { + "model_call_stats": monitor.get_model_call_stats(model, most_recent_min, top_k) + } diff --git a/src/serve/cli.py b/src/serve/cli.py new file mode 100644 index 0000000000000000000000000000000000000000..78f7f51b1b18b7a10f3fab937f1475067d3e5ecf --- /dev/null +++ b/src/serve/cli.py @@ -0,0 +1,304 @@ +""" +Chat with a model with command line interface. + +Usage: +python3 -m fastchat.serve.cli --model lmsys/vicuna-7b-v1.5 +python3 -m fastchat.serve.cli --model lmsys/fastchat-t5-3b-v1.0 + +Other commands: +- Type "!!exit" or an empty line to exit. +- Type "!!reset" to start a new conversation. +- Type "!!remove" to remove the last prompt. +- Type "!!regen" to regenerate the last message. +- Type "!!save " to save the conversation history to a json file. +- Type "!!load " to load a conversation history from a json file. +""" +import argparse +import os +import re +import sys + +from prompt_toolkit import PromptSession +from prompt_toolkit.auto_suggest import AutoSuggestFromHistory +from prompt_toolkit.completion import WordCompleter +from prompt_toolkit.history import InMemoryHistory +from prompt_toolkit.key_binding import KeyBindings +from rich.console import Console +from rich.live import Live +from rich.markdown import Markdown +import torch + +from fastchat.model.model_adapter import add_model_args +from fastchat.modules.awq import AWQConfig +from fastchat.modules.exllama import ExllamaConfig +from fastchat.modules.xfastertransformer import XftConfig +from fastchat.modules.gptq import GptqConfig +from fastchat.serve.inference import ChatIO, chat_loop +from fastchat.utils import str_to_torch_dtype + + +class SimpleChatIO(ChatIO): + def __init__(self, multiline: bool = False): + self._multiline = multiline + + def prompt_for_input(self, role) -> str: + if not self._multiline: + return input(f"{role}: ") + + prompt_data = [] + line = input(f"{role} [ctrl-d/z on empty line to end]: ") + while True: + prompt_data.append(line.strip()) + try: + line = input() + except EOFError as e: + break + return "\n".join(prompt_data) + + def prompt_for_output(self, role: str): + print(f"{role}: ", end="", flush=True) + + def stream_output(self, output_stream): + pre = 0 + for outputs in output_stream: + output_text = outputs["text"] + output_text = output_text.strip().split(" ") + now = len(output_text) - 1 + if now > pre: + print(" ".join(output_text[pre:now]), end=" ", flush=True) + pre = now + print(" ".join(output_text[pre:]), flush=True) + return " ".join(output_text) + + def print_output(self, text: str): + print(text) + + +class RichChatIO(ChatIO): + bindings = KeyBindings() + + @bindings.add("escape", "enter") + def _(event): + event.app.current_buffer.newline() + + def __init__(self, multiline: bool = False, mouse: bool = False): + self._prompt_session = PromptSession(history=InMemoryHistory()) + self._completer = WordCompleter( + words=["!!exit", "!!reset", "!!remove", "!!regen", "!!save", "!!load"], + pattern=re.compile("$"), + ) + self._console = Console() + self._multiline = multiline + self._mouse = mouse + + def prompt_for_input(self, role) -> str: + self._console.print(f"[bold]{role}:") + # TODO(suquark): multiline input has some issues. fix it later. + prompt_input = self._prompt_session.prompt( + completer=self._completer, + multiline=False, + mouse_support=self._mouse, + auto_suggest=AutoSuggestFromHistory(), + key_bindings=self.bindings if self._multiline else None, + ) + self._console.print() + return prompt_input + + def prompt_for_output(self, role: str): + self._console.print(f"[bold]{role.replace('/', '|')}:") + + def stream_output(self, output_stream): + """Stream output from a role.""" + # TODO(suquark): the console flickers when there is a code block + # above it. We need to cut off "live" when a code block is done. + + # Create a Live context for updating the console output + with Live(console=self._console, refresh_per_second=4) as live: + # Read lines from the stream + for outputs in output_stream: + if not outputs: + continue + text = outputs["text"] + # Render the accumulated text as Markdown + # NOTE: this is a workaround for the rendering "unstandard markdown" + # in rich. The chatbots output treat "\n" as a new line for + # better compatibility with real-world text. However, rendering + # in markdown would break the format. It is because standard markdown + # treat a single "\n" in normal text as a space. + # Our workaround is adding two spaces at the end of each line. + # This is not a perfect solution, as it would + # introduce trailing spaces (only) in code block, but it works well + # especially for console output, because in general the console does not + # care about trailing spaces. + lines = [] + for line in text.splitlines(): + lines.append(line) + if line.startswith("```"): + # Code block marker - do not add trailing spaces, as it would + # break the syntax highlighting + lines.append("\n") + else: + lines.append(" \n") + markdown = Markdown("".join(lines)) + # Update the Live console output + live.update(markdown) + self._console.print() + return text + + def print_output(self, text: str): + self.stream_output([{"text": text}]) + + +class ProgrammaticChatIO(ChatIO): + def prompt_for_input(self, role) -> str: + contents = "" + # `end_sequence` signals the end of a message. It is unlikely to occur in + # message content. + end_sequence = " __END_OF_A_MESSAGE_47582648__\n" + len_end = len(end_sequence) + while True: + if len(contents) >= len_end: + last_chars = contents[-len_end:] + if last_chars == end_sequence: + break + try: + char = sys.stdin.read(1) + contents = contents + char + except EOFError: + continue + contents = contents[:-len_end] + print(f"[!OP:{role}]: {contents}", flush=True) + return contents + + def prompt_for_output(self, role: str): + print(f"[!OP:{role}]: ", end="", flush=True) + + def stream_output(self, output_stream): + pre = 0 + for outputs in output_stream: + output_text = outputs["text"] + output_text = output_text.strip().split(" ") + now = len(output_text) - 1 + if now > pre: + print(" ".join(output_text[pre:now]), end=" ", flush=True) + pre = now + print(" ".join(output_text[pre:]), flush=True) + return " ".join(output_text) + + def print_output(self, text: str): + print(text) + + +def main(args): + if args.gpus: + if len(args.gpus.split(",")) < args.num_gpus: + raise ValueError( + f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!" + ) + os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus + os.environ["XPU_VISIBLE_DEVICES"] = args.gpus + if args.enable_exllama: + exllama_config = ExllamaConfig( + max_seq_len=args.exllama_max_seq_len, + gpu_split=args.exllama_gpu_split, + cache_8bit=args.exllama_cache_8bit, + ) + else: + exllama_config = None + if args.enable_xft: + xft_config = XftConfig( + max_seq_len=args.xft_max_seq_len, + data_type=args.xft_dtype, + ) + if args.device != "cpu": + print("xFasterTransformer now is only support CPUs. Reset device to CPU") + args.device = "cpu" + else: + xft_config = None + if args.style == "simple": + chatio = SimpleChatIO(args.multiline) + elif args.style == "rich": + chatio = RichChatIO(args.multiline, args.mouse) + elif args.style == "programmatic": + chatio = ProgrammaticChatIO() + else: + raise ValueError(f"Invalid style for console: {args.style}") + try: + chat_loop( + args.model_path, + args.device, + args.num_gpus, + args.max_gpu_memory, + str_to_torch_dtype(args.dtype), + args.load_8bit, + args.cpu_offloading, + args.conv_template, + args.conv_system_msg, + args.temperature, + args.repetition_penalty, + args.max_new_tokens, + chatio, + gptq_config=GptqConfig( + ckpt=args.gptq_ckpt or args.model_path, + wbits=args.gptq_wbits, + groupsize=args.gptq_groupsize, + act_order=args.gptq_act_order, + ), + awq_config=AWQConfig( + ckpt=args.awq_ckpt or args.model_path, + wbits=args.awq_wbits, + groupsize=args.awq_groupsize, + ), + exllama_config=exllama_config, + xft_config=xft_config, + revision=args.revision, + judge_sent_end=args.judge_sent_end, + debug=args.debug, + history=not args.no_history, + ) + except KeyboardInterrupt: + print("exit...") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + add_model_args(parser) + parser.add_argument( + "--conv-template", type=str, default=None, help="Conversation prompt template." + ) + parser.add_argument( + "--conv-system-msg", type=str, default=None, help="Conversation system message." + ) + parser.add_argument("--temperature", type=float, default=0.7) + parser.add_argument("--repetition_penalty", type=float, default=1.0) + parser.add_argument("--max-new-tokens", type=int, default=512) + parser.add_argument("--no-history", action="store_true") + parser.add_argument( + "--style", + type=str, + default="simple", + choices=["simple", "rich", "programmatic"], + help="Display style.", + ) + parser.add_argument( + "--multiline", + action="store_true", + help="Enable multiline input. Use ESC+Enter for newline.", + ) + parser.add_argument( + "--mouse", + action="store_true", + help="[Rich Style]: Enable mouse support for cursor positioning.", + ) + parser.add_argument( + "--judge-sent-end", + action="store_true", + help="Whether enable the correction logic that interrupts the output of sentences due to EOS.", + ) + parser.add_argument( + "--debug", + action="store_true", + help="Print useful debug information (e.g., prompts)", + ) + args = parser.parse_args() + main(args) diff --git a/src/serve/controller.py b/src/serve/controller.py new file mode 100644 index 0000000000000000000000000000000000000000..42d928403090d501fb9bdfa608b77bc7d9e15c31 --- /dev/null +++ b/src/serve/controller.py @@ -0,0 +1,389 @@ +""" +A controller manages distributed workers. +It sends worker addresses to clients. +""" +import argparse +import asyncio +import dataclasses +from enum import Enum, auto +import json +import logging +import os +import time +from typing import List, Union +import threading + +from fastapi import FastAPI, Request +from fastapi.responses import StreamingResponse +import numpy as np +import requests +import uvicorn + +from fastchat.constants import ( + CONTROLLER_HEART_BEAT_EXPIRATION, + WORKER_API_TIMEOUT, + ErrorCode, + SERVER_ERROR_MSG, +) +from fastchat.utils import build_logger + + +logger = build_logger("controller", "controller.log") + + +class DispatchMethod(Enum): + LOTTERY = auto() + SHORTEST_QUEUE = auto() + + @classmethod + def from_str(cls, name): + if name == "lottery": + return cls.LOTTERY + elif name == "shortest_queue": + return cls.SHORTEST_QUEUE + else: + raise ValueError(f"Invalid dispatch method") + + +@dataclasses.dataclass +class WorkerInfo: + model_names: List[str] + speed: int + queue_length: int + check_heart_beat: bool + last_heart_beat: str + multimodal: bool + + +def heart_beat_controller(controller): + while True: + time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION) + controller.remove_stale_workers_by_expiration() + + +class Controller: + def __init__(self, dispatch_method: str): + # Dict[str -> WorkerInfo] + self.worker_info = {} + self.dispatch_method = DispatchMethod.from_str(dispatch_method) + + self.heart_beat_thread = threading.Thread( + target=heart_beat_controller, args=(self,) + ) + self.heart_beat_thread.start() + + def register_worker( + self, + worker_name: str, + check_heart_beat: bool, + worker_status: dict, + multimodal: bool, + ): + if worker_name not in self.worker_info: + logger.info(f"Register a new worker: {worker_name}") + else: + logger.info(f"Register an existing worker: {worker_name}") + + if not worker_status: + worker_status = self.get_worker_status(worker_name) + if not worker_status: + return False + + self.worker_info[worker_name] = WorkerInfo( + worker_status["model_names"], + worker_status["speed"], + worker_status["queue_length"], + check_heart_beat, + time.time(), + multimodal, + ) + + logger.info(f"Register done: {worker_name}, {worker_status}") + return True + + def get_worker_status(self, worker_name: str): + try: + r = requests.post(worker_name + "/worker_get_status", timeout=5) + except requests.exceptions.RequestException as e: + logger.error(f"Get status fails: {worker_name}, {e}") + return None + + if r.status_code != 200: + logger.error(f"Get status fails: {worker_name}, {r}") + return None + + return r.json() + + def remove_worker(self, worker_name: str): + del self.worker_info[worker_name] + + def refresh_all_workers(self): + old_info = dict(self.worker_info) + self.worker_info = {} + + for w_name, w_info in old_info.items(): + if not self.register_worker( + w_name, w_info.check_heart_beat, None, w_info.multimodal + ): + logger.info(f"Remove stale worker: {w_name}") + + def list_models(self): + model_names = set() + + for w_name, w_info in self.worker_info.items(): + model_names.update(w_info.model_names) + + return list(model_names) + + def list_multimodal_models(self): + model_names = set() + + for w_name, w_info in self.worker_info.items(): + if w_info.multimodal: + model_names.update(w_info.model_names) + + return list(model_names) + + def list_language_models(self): + model_names = set() + + for w_name, w_info in self.worker_info.items(): + if not w_info.multimodal: + model_names.update(w_info.model_names) + + return list(model_names) + + def get_worker_address(self, model_name: str): + if self.dispatch_method == DispatchMethod.LOTTERY: + worker_names = [] + worker_speeds = [] + for w_name, w_info in self.worker_info.items(): + if model_name in w_info.model_names: + worker_names.append(w_name) + worker_speeds.append(w_info.speed) + worker_speeds = np.array(worker_speeds, dtype=np.float32) + norm = np.sum(worker_speeds) + if norm < 1e-4: + return "" + worker_speeds = worker_speeds / norm + if True: # Directly return address + pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds) + worker_name = worker_names[pt] + return worker_name + + # Check status before returning + while True: + pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds) + worker_name = worker_names[pt] + + if self.get_worker_status(worker_name): + break + else: + self.remove_worker(worker_name) + worker_speeds[pt] = 0 + norm = np.sum(worker_speeds) + if norm < 1e-4: + return "" + worker_speeds = worker_speeds / norm + continue + return worker_name + elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE: + worker_names = [] + worker_qlen = [] + for w_name, w_info in self.worker_info.items(): + if model_name in w_info.model_names: + worker_names.append(w_name) + worker_qlen.append(w_info.queue_length / w_info.speed) + if len(worker_names) == 0: + return "" + min_index = np.argmin(worker_qlen) + w_name = worker_names[min_index] + self.worker_info[w_name].queue_length += 1 + logger.info( + f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}" + ) + return w_name + else: + raise ValueError(f"Invalid dispatch method: {self.dispatch_method}") + + def receive_heart_beat(self, worker_name: str, queue_length: int): + if worker_name not in self.worker_info: + logger.info(f"Receive unknown heart beat. {worker_name}") + return False + + self.worker_info[worker_name].queue_length = queue_length + self.worker_info[worker_name].last_heart_beat = time.time() + logger.info(f"Receive heart beat. {worker_name}") + return True + + def remove_stale_workers_by_expiration(self): + expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION + to_delete = [] + for worker_name, w_info in self.worker_info.items(): + if w_info.check_heart_beat and w_info.last_heart_beat < expire: + to_delete.append(worker_name) + + for worker_name in to_delete: + self.remove_worker(worker_name) + + def handle_no_worker(self, params): + logger.info(f"no worker: {params['model']}") + ret = { + "text": SERVER_ERROR_MSG, + "error_code": ErrorCode.CONTROLLER_NO_WORKER, + } + return json.dumps(ret).encode() + b"\0" + + def handle_worker_timeout(self, worker_address): + logger.info(f"worker timeout: {worker_address}") + ret = { + "text": SERVER_ERROR_MSG, + "error_code": ErrorCode.CONTROLLER_WORKER_TIMEOUT, + } + return json.dumps(ret).encode() + b"\0" + + # Let the controller act as a worker to achieve hierarchical + # management. This can be used to connect isolated sub networks. + def worker_api_get_status(self): + model_names = set() + speed = 0 + queue_length = 0 + + for w_name in self.worker_info: + worker_status = self.get_worker_status(w_name) + if worker_status is not None: + model_names.update(worker_status["model_names"]) + speed += worker_status["speed"] + queue_length += worker_status["queue_length"] + + model_names = sorted(list(model_names)) + return { + "model_names": model_names, + "speed": speed, + "queue_length": queue_length, + } + + def worker_api_generate_stream(self, params): + worker_addr = self.get_worker_address(params["model"]) + if not worker_addr: + yield self.handle_no_worker(params) + + try: + response = requests.post( + worker_addr + "/worker_generate_stream", + json=params, + stream=True, + timeout=WORKER_API_TIMEOUT, + ) + for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): + if chunk: + yield chunk + b"\0" + except requests.exceptions.RequestException as e: + yield self.handle_worker_timeout(worker_addr) + + +app = FastAPI() + + +@app.post("/register_worker") +async def register_worker(request: Request): + data = await request.json() + controller.register_worker( + data["worker_name"], + data["check_heart_beat"], + data.get("worker_status", None), + data.get("multimodal", False), + ) + + +@app.post("/refresh_all_workers") +async def refresh_all_workers(): + models = controller.refresh_all_workers() + + +@app.post("/list_models") +async def list_models(): + models = controller.list_models() + return {"models": models} + + +@app.post("/list_multimodal_models") +async def list_multimodal_models(): + models = controller.list_multimodal_models() + return {"models": models} + + +@app.post("/list_language_models") +async def list_language_models(): + models = controller.list_language_models() + return {"models": models} + + +@app.post("/get_worker_address") +async def get_worker_address(request: Request): + data = await request.json() + addr = controller.get_worker_address(data["model"]) + return {"address": addr} + + +@app.post("/receive_heart_beat") +async def receive_heart_beat(request: Request): + data = await request.json() + exist = controller.receive_heart_beat(data["worker_name"], data["queue_length"]) + return {"exist": exist} + + +@app.post("/worker_generate_stream") +async def worker_api_generate_stream(request: Request): + params = await request.json() + generator = controller.worker_api_generate_stream(params) + return StreamingResponse(generator) + + +@app.post("/worker_get_status") +async def worker_api_get_status(request: Request): + return controller.worker_api_get_status() + + +@app.get("/test_connection") +async def worker_api_get_status(request: Request): + return "success" + + +def create_controller(): + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=21001) + parser.add_argument( + "--dispatch-method", + type=str, + choices=["lottery", "shortest_queue"], + default="shortest_queue", + ) + parser.add_argument( + "--ssl", + action="store_true", + required=False, + default=False, + help="Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.", + ) + args = parser.parse_args() + logger.info(f"args: {args}") + + controller = Controller(args.dispatch_method) + return args, controller + + +if __name__ == "__main__": + args, controller = create_controller() + if args.ssl: + uvicorn.run( + app, + host=args.host, + port=args.port, + log_level="info", + ssl_keyfile=os.environ["SSL_KEYFILE"], + ssl_certfile=os.environ["SSL_CERTFILE"], + ) + else: + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/src/serve/example_images/distracted.jpg b/src/serve/example_images/distracted.jpg new file mode 100644 index 0000000000000000000000000000000000000000..382c888a0305296d7307ce061d527e1c5e01aca3 Binary files /dev/null and b/src/serve/example_images/distracted.jpg differ diff --git a/src/serve/example_images/fridge.jpg b/src/serve/example_images/fridge.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8ed943e8be506b2a0da66bd1cddf39d2dcbdb5fb Binary files /dev/null and b/src/serve/example_images/fridge.jpg differ diff --git a/src/serve/gateway/README.md b/src/serve/gateway/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b3afaf171bc38b232b68609585244c9e76489da7 --- /dev/null +++ b/src/serve/gateway/README.md @@ -0,0 +1,57 @@ +# fastchat Nginx Gateway + +## Purpose of the Gateway + +The Nginx gateway serves the following purposes: + +1. Protects Gradio servers by acting as a firewall. +2. Facilitates dynamic mounting and unmounting of Gradio servers. +3. Provides load balancing for Gradio servers. +4. Offers additional security features, such as total connection limit. +5. Reduces attack surface by requiring only a single public port to be exposed for serving. + +## Deployment and Updating of the Gateway + +### Installing Nginx + +On Debian-based distributions (e.g., Ubuntu): + +```bash +sudo apt update +sudo apt install nginx +``` +On Red Hat-based distributions (e.g., CentOS, Fedora): + +```bash +sudo yum install epel-release +sudo yum install nginx +``` + +### Deployment + +Copy `nginx.conf` to `/etc/nginx/nginx.conf` (need sudo permission). + +Replace the port number 7860 in `server localhost:7860` with the port where you deploy the Gradio web server. + +Modify `upstream websocket` to configure Gradio servers behind the gateway. + +Lastly, update Nginx. + + +### HTTPS Deployment with a Public Domain URL + +Make sure you obtain the HTTPS certificate and the private key used to generate the certificate. + +Fill the addresses to your certificate and private key in the `[PATH_TO_SSL_CERT]` and `[PATH_TO_PRIVATE_KEY]` fields. + +If you have your own domain url to serve the chatbot, replace the chat.lmsys.org url with your own domain url. + +### Updating + +Every time when `/etc/nginx/nginx.conf` is modified, you need to update the Nginx service: + +```bash +sudo nginx -t # check `/etc/nginx/nginx.conf` +sudo systemctl reload nginx # restart Nginx service to load the new config +sudo systemctl status nginx # check the status of the Nginx service. It should be active (running). +``` diff --git a/src/serve/gateway/nginx.conf b/src/serve/gateway/nginx.conf new file mode 100644 index 0000000000000000000000000000000000000000..b88ca8c50772421fca91f33ff77ef75f4d23ad4d --- /dev/null +++ b/src/serve/gateway/nginx.conf @@ -0,0 +1,97 @@ +user www-data; +worker_processes auto; +pid /run/nginx.pid; +include /etc/nginx/modules-enabled/*.conf; + +events { + worker_connections 1024; # maximum number of connections that a worker process can handle concurrently + # multi_accept on; # enabling multi_accept can help improve performance under high load, but may increase the number of simultaneous connections that a worker process can handle + +} + +http { + ## + # Basic Settings + ## + + sendfile on; # enable sendfile for performance optimization + tcp_nopush on; # enable TCP no-pushing + tcp_nodelay on; # enable TCP no-delay + keepalive_timeout 65; # sets the timeout for keep-alive connections + types_hash_max_size 2048; # maximum size of the types hash table + # server_tokens off; # disable server token (i.e., server signature) in response headers to improve security + + # server_names_hash_bucket_size 64; + # server_name_in_redirect off; + + include /etc/nginx/mime.types; # include MIME types file + default_type application/octet-stream; # default MIME type for unknown file types + + ## + # SSL Settings + ## + + ssl_protocols TLSv1.2; # specify SSL/TLS protocols to use + ssl_prefer_server_ciphers on; # prefer server ciphers over client ciphers + + ## + # Logging Settings + ## + + access_log /var/log/nginx/access.log; # path to access log file + error_log /var/log/nginx/error.log; # path to error log file + + ## + # Gzip Settings + ## + gzip on; # enable Gzip compression + + ## + # Virtual Host Configs + ## + + include /etc/nginx/conf.d/*.conf; # include all configuration files in conf.d directory + include /etc/nginx/sites-enabled/*; # include all enabled sites configuration files + + # WebSocket Proxy: https://www.nginx.com/blog/websocket-nginx/ + map $http_upgrade $connection_upgrade { + default upgrade; + '' close; + } + + upstream websocket { + ip_hash; # load balancing by IP to guarantee session persistence + server localhost:7860; # The port should be the gradio web server port + # server localhost:7861; # extra gradio server if more than one + } + + limit_conn_status 429; + limit_conn_zone $binary_remote_addr zone=perip:10m; # limit number of connections per IP + limit_conn_zone $server_name zone=perserver:10m; # limit number of connections per server + + server { + listen 443 ssl; # the listening port of our server + ssl_certificate [PATH_TO_SSL_CERT]; + ssl_certificate_key [PATH_TO_PRIVATE_KEY]; + server_name chat.lmsys.org; # replace the url with your own domain url + limit_conn perserver 1024; # connections per server + location / { + proxy_pass http://websocket; # proxy all requests to the defined upstream server + limit_conn perip 5; # connections per IP + proxy_set_header Host $host; # set the Host header for the upstream server + proxy_set_header X-Real-IP $remote_addr; # set the client IP address as the real IP for the upstream server + proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; # set the client IP addresses in the X-Forwarded-For header + proxy_http_version 1.1; # use HTTP version 1.1 for upstream communication + proxy_set_header Upgrade $http_upgrade; + proxy_set_header Connection "Upgrade"; # set the Connection header to Upgrade to enable WebSocket communication + } + } + + # the following block routes all HTTP traffic to HTTPS via nginx + server { + listen 80; + server_name chat.lmsys.org; + return 301 https://chat.lmsys.org$request_uri; + } + +} diff --git a/src/serve/gradio_block_arena_anony.py b/src/serve/gradio_block_arena_anony.py new file mode 100644 index 0000000000000000000000000000000000000000..50cb24d9ce2ca98c8c77a90f7a4a812deccf30f0 --- /dev/null +++ b/src/serve/gradio_block_arena_anony.py @@ -0,0 +1,700 @@ +""" +Chatbot Arena (battle) tab. +Users chat with two anonymous models. +""" + +import json +import time + +import gradio as gr +import numpy as np + +from fastchat.constants import ( + MODERATION_MSG, + CONVERSATION_LIMIT_MSG, + SLOW_MODEL_MSG, + BLIND_MODE_INPUT_CHAR_LEN_LIMIT, + CONVERSATION_TURN_LIMIT, +) +from fastchat.model.model_adapter import get_conversation_template +from fastchat.serve.gradio_block_arena_named import flash_buttons +from fastchat.serve.gradio_web_server import ( + State, + bot_response, + get_conv_log_filename, + no_change_btn, + enable_btn, + disable_btn, + invisible_btn, + acknowledgment_md, + get_ip, + get_model_description_md, + _prepare_text_with_image, +) +from fastchat.serve.remote_logger import get_remote_logger +from fastchat.utils import ( + build_logger, + moderation_filter, +) + +logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log") + +num_sides = 2 +enable_moderation = False +anony_names = ["", ""] +models = [] + + +def set_global_vars_anony(enable_moderation_): + global enable_moderation + enable_moderation = enable_moderation_ + + +def load_demo_side_by_side_anony(models_, url_params): + global models + models = models_ + + states = (None,) * num_sides + selector_updates = ( + gr.Markdown(visible=True), + gr.Markdown(visible=True), + ) + + return states + selector_updates + + +def vote_last_response(states, vote_type, model_selectors, request: gr.Request): + with open(get_conv_log_filename(), "a") as fout: + data = { + "tstamp": round(time.time(), 4), + "type": vote_type, + "models": [x for x in model_selectors], + "states": [x.dict() for x in states], + "ip": get_ip(request), + } + fout.write(json.dumps(data) + "\n") + get_remote_logger().log(data) + + if ":" not in model_selectors[0]: + for i in range(5): + names = ( + "### Model A: " + states[0].model_name, + "### Model B: " + states[1].model_name, + ) + yield names + ("",) + (disable_btn,) * 4 + time.sleep(0.1) + else: + names = ( + "### Model A: " + states[0].model_name, + "### Model B: " + states[1].model_name, + ) + yield names + ("",) + (disable_btn,) * 4 + + +def leftvote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"leftvote (anony). ip: {get_ip(request)}") + for x in vote_last_response( + [state0, state1], "leftvote", [model_selector0, model_selector1], request + ): + yield x + + +def rightvote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"rightvote (anony). ip: {get_ip(request)}") + for x in vote_last_response( + [state0, state1], "rightvote", [model_selector0, model_selector1], request + ): + yield x + + +def tievote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"tievote (anony). ip: {get_ip(request)}") + for x in vote_last_response( + [state0, state1], "tievote", [model_selector0, model_selector1], request + ): + yield x + + +def bothbad_vote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"bothbad_vote (anony). ip: {get_ip(request)}") + for x in vote_last_response( + [state0, state1], "bothbad_vote", [model_selector0, model_selector1], request + ): + yield x + + +def regenerate(state0, state1, request: gr.Request): + logger.info(f"regenerate (anony). ip: {get_ip(request)}") + states = [state0, state1] + if state0.regen_support and state1.regen_support: + for i in range(num_sides): + states[i].conv.update_last_message(None) + return ( + states + [x.to_gradio_chatbot() for x in states] + [""] + [disable_btn] * 6 + ) + states[0].skip_next = True + states[1].skip_next = True + return states + [x.to_gradio_chatbot() for x in states] + [""] + [no_change_btn] * 6 + + +def clear_history(request: gr.Request): + logger.info(f"clear_history (anony). ip: {get_ip(request)}") + return ( + [None] * num_sides + + [None] * num_sides + + anony_names + + [""] + + [invisible_btn] * 4 + + [disable_btn] * 2 + + [""] + ) + + +def share_click(state0, state1, model_selector0, model_selector1, request: gr.Request): + logger.info(f"share (anony). ip: {get_ip(request)}") + if state0 is not None and state1 is not None: + vote_last_response( + [state0, state1], "share", [model_selector0, model_selector1], request + ) + + +SAMPLING_WEIGHTS = { + # tier 0 + "gpt-4-0314": 4, + "gpt-4-0613": 4, + "gpt-4-1106-preview": 2, + "gpt-4-0125-preview": 4, + "gpt-4-turbo-2024-04-09": 4, + "gpt-3.5-turbo-0125": 2, + "claude-3-opus-20240229": 4, + "claude-3-sonnet-20240229": 4, + "claude-3-haiku-20240307": 4, + "claude-2.1": 1, + "zephyr-orpo-141b-A35b-v0.1": 2, + "dbrx-instruct": 1, + "command-r-plus": 4, + "command-r": 2, + "reka-flash": 4, + "reka-flash-online": 4, + "qwen1.5-72b-chat": 2, + "qwen1.5-32b-chat": 2, + "qwen1.5-14b-chat": 2, + "qwen1.5-7b-chat": 2, + "gemma-1.1-7b-it": 2, + "gemma-1.1-2b-it": 1, + "mixtral-8x7b-instruct-v0.1": 4, + "mistral-7b-instruct-v0.2": 2, + "mistral-large-2402": 4, + "mistral-medium": 2, + "starling-lm-7b-beta": 2, + # tier 1 + "deluxe-chat-v1.3": 2, + "llama-2-70b-chat": 2, + "llama-2-13b-chat": 1, + "llama-2-7b-chat": 1, + "vicuna-33b": 1, + "vicuna-13b": 1, + "yi-34b-chat": 1, +} + +# target model sampling weights will be boosted. +BATTLE_TARGETS = { + "gpt-4-turbo-2024-04-09": { + "gpt-4-1106-preview", + "gpt-4-0125-preview", + "claude-3-opus-20240229", + "gemini-pro-dev-api", + }, + "gemini-pro-dev-api": { + "gpt-4-turbo-2024-04-09", + "claude-3-opus-20240229", + "gpt-4-0125-preview", + "claude-3-sonnet-20240229", + }, + "reka-flash": { + "qwen1.5-72b-chat", + "claude-3-haiku-20240307", + "command-r-plus", + "command-r", + }, + "reka-flash-online": { + "qwen1.5-72b-chat", + "claude-3-haiku-20240307", + "command-r-plus", + "command-r", + }, + "deluxe-chat-v1.3": { + "gpt-4-1106-preview", + "gpt-4-0125-preview", + "claude-3-opus-20240229", + "claude-3-sonnet-20240229", + }, + "qwen1.5-32b-chat": { + "gpt-3.5-turbo-0125", + "gpt-4-0613", + "gpt-4-0125-preview", + "llama-2-70b-chat", + "mixtral-8x7b-instruct-v0.1", + "mistral-large-2402", + "yi-34b-chat", + }, + "qwen1.5-14b-chat": { + "starling-lm-7b-alpha", + "claude-3-haiku-20240307", + "gpt-3.5-turbo-0125", + "openchat-3.5-0106", + "mixtral-8x7b-instruct-v0.1", + }, + "mistral-large-2402": { + "gpt-4-0125-preview", + "gpt-4-0613", + "mixtral-8x7b-instruct-v0.1", + "mistral-medium", + "mistral-next", + "claude-3-sonnet-20240229", + }, + "gemma-1.1-2b-it": { + "gpt-3.5-turbo-0125", + "mixtral-8x7b-instruct-v0.1", + "starling-lm-7b-beta", + "llama-2-7b-chat", + "mistral-7b-instruct-v0.2", + "gemma-1.1-7b-it", + }, + "zephyr-orpo-141b-A35b-v0.1": { + "qwen1.5-72b-chat", + "mistral-large-2402", + "command-r-plus", + "claude-3-haiku-20240307", + }, +} + +SAMPLING_BOOST_MODELS = [] + +# outage models won't be sampled. +OUTAGE_MODELS = [] + + +def get_sample_weight(model, outage_models, sampling_weights, sampling_boost_models): + if model in outage_models: + return 0 + weight = sampling_weights.get(model, 0) + if model in sampling_boost_models: + weight *= 5 + return weight + + +def get_battle_pair( + models, battle_targets, outage_models, sampling_weights, sampling_boost_models +): + if len(models) == 1: + return models[0], models[0] + + model_weights = [] + for model in models: + weight = get_sample_weight( + model, outage_models, sampling_weights, sampling_boost_models + ) + model_weights.append(weight) + total_weight = np.sum(model_weights) + model_weights = model_weights / total_weight + chosen_idx = np.random.choice(len(models), p=model_weights) + chosen_model = models[chosen_idx] + # for p, w in zip(models, model_weights): + # print(p, w) + + rival_models = [] + rival_weights = [] + for model in models: + if model == chosen_model: + continue + weight = get_sample_weight( + model, outage_models, sampling_weights, sampling_boost_models + ) + if ( + weight != 0 + and chosen_model in battle_targets + and model in battle_targets[chosen_model] + ): + # boost to 50% chance + weight = total_weight / len(battle_targets[chosen_model]) + rival_models.append(model) + rival_weights.append(weight) + # for p, w in zip(rival_models, rival_weights): + # print(p, w) + rival_weights = rival_weights / np.sum(rival_weights) + rival_idx = np.random.choice(len(rival_models), p=rival_weights) + rival_model = rival_models[rival_idx] + + swap = np.random.randint(2) + if swap == 0: + return chosen_model, rival_model + else: + return rival_model, chosen_model + + +def add_text( + state0, state1, model_selector0, model_selector1, text, image, request: gr.Request +): + ip = get_ip(request) + logger.info(f"add_text (anony). ip: {ip}. len: {len(text)}") + states = [state0, state1] + model_selectors = [model_selector0, model_selector1] + + # Init states if necessary + if states[0] is None: + assert states[1] is None + + model_left, model_right = get_battle_pair( + models, + BATTLE_TARGETS, + OUTAGE_MODELS, + SAMPLING_WEIGHTS, + SAMPLING_BOOST_MODELS, + ) + states = [ + State(model_left), + State(model_right), + ] + + if len(text) <= 0: + for i in range(num_sides): + states[i].skip_next = True + return ( + states + + [x.to_gradio_chatbot() for x in states] + + ["", None] + + [ + no_change_btn, + ] + * 6 + + [""] + ) + + model_list = [states[i].model_name for i in range(num_sides)] + # turn on moderation in battle mode + all_conv_text_left = states[0].conv.get_prompt() + all_conv_text_right = states[0].conv.get_prompt() + all_conv_text = ( + all_conv_text_left[-1000:] + all_conv_text_right[-1000:] + "\nuser: " + text + ) + flagged = moderation_filter(all_conv_text, model_list, do_moderation=True) + if flagged: + logger.info(f"violate moderation (anony). ip: {ip}. text: {text}") + # overwrite the original text + text = MODERATION_MSG + + conv = states[0].conv + if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT: + logger.info(f"conversation turn limit. ip: {get_ip(request)}. text: {text}") + for i in range(num_sides): + states[i].skip_next = True + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [CONVERSATION_LIMIT_MSG, None] + + [ + no_change_btn, + ] + * 6 + + [""] + ) + + text = text[:BLIND_MODE_INPUT_CHAR_LEN_LIMIT] # Hard cut-off + for i in range(num_sides): + post_processed_text = _prepare_text_with_image( + states[i], text, image, csam_flag=False + ) + states[i].conv.append_message(states[i].conv.roles[0], post_processed_text) + states[i].conv.append_message(states[i].conv.roles[1], None) + states[i].skip_next = False + + hint_msg = "" + for i in range(num_sides): + if "deluxe" in states[i].model_name: + hint_msg = SLOW_MODEL_MSG + return ( + states + + [x.to_gradio_chatbot() for x in states] + + ["", None] + + [ + disable_btn, + ] + * 6 + + [hint_msg] + ) + + +def bot_response_multi( + state0, + state1, + temperature, + top_p, + max_new_tokens, + request: gr.Request, +): + logger.info(f"bot_response_multi (anony). ip: {get_ip(request)}") + + if state0 is None or state0.skip_next: + # This generate call is skipped due to invalid inputs + yield ( + state0, + state1, + state0.to_gradio_chatbot(), + state1.to_gradio_chatbot(), + ) + (no_change_btn,) * 6 + return + + states = [state0, state1] + gen = [] + for i in range(num_sides): + gen.append( + bot_response( + states[i], + temperature, + top_p, + max_new_tokens, + request, + apply_rate_limit=False, + use_recommended_config=True, + ) + ) + + is_stream_batch = [] + for i in range(num_sides): + is_stream_batch.append( + states[i].model_name + in [ + "gemini-pro", + "gemini-pro-dev-api", + "gemini-1.0-pro-vision", + "gemini-1.5-pro", + "gemini-1.5-flash", + "gemma-1.1-2b-it", + "gemma-1.1-7b-it", + ] + ) + chatbots = [None] * num_sides + iters = 0 + while True: + stop = True + iters += 1 + for i in range(num_sides): + try: + # yield gemini fewer times as its chunk size is larger + # otherwise, gemini will stream too fast + if not is_stream_batch[i] or (iters % 30 == 1 or iters < 3): + ret = next(gen[i]) + states[i], chatbots[i] = ret[0], ret[1] + stop = False + except StopIteration: + pass + yield states + chatbots + [disable_btn] * 6 + if stop: + break + + +def build_side_by_side_ui_anony(models): + notice_markdown = """ +# ⚔️ LMSYS Chatbot Arena: Benchmarking LLMs in the Wild +- | [Blog](https://lmsys.org/blog/2023-05-03-arena/) | [GitHub](https://github.com/lm-sys/FastChat) | [Paper](https://arxiv.org/abs/2306.05685) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) | + +## 📜 Rules +- Ask any question to two anonymous models (e.g., ChatGPT, Claude, Llama) and vote for the better one! +- You can continue chatting until you identify a winner. +- Vote won't be counted if model identity is revealed during conversation. + +## 🏆 LMSYS Arena [Leaderboard](https://leaderboard.lmsys.org) +We've collected **500K+** human votes to compute an LLM Elo leaderboard. +Find out who is the 🥇LLM Champion! + +## 👇 Chat now! +""" + + states = [gr.State() for _ in range(num_sides)] + model_selectors = [None] * num_sides + chatbots = [None] * num_sides + + gr.Markdown(notice_markdown, elem_id="notice_markdown") + + with gr.Group(elem_id="share-region-anony"): + with gr.Accordion( + f"🔍 Expand to see the descriptions of {len(models)} models", open=False + ): + model_description_md = get_model_description_md(models) + gr.Markdown(model_description_md, elem_id="model_description_markdown") + with gr.Row(): + for i in range(num_sides): + label = "Model A" if i == 0 else "Model B" + with gr.Column(): + chatbots[i] = gr.Chatbot( + label=label, + elem_id="chatbot", + height=550, + show_copy_button=True, + ) + + with gr.Row(): + for i in range(num_sides): + with gr.Column(): + model_selectors[i] = gr.Markdown( + anony_names[i], elem_id="model_selector_md" + ) + with gr.Row(): + slow_warning = gr.Markdown("") + + with gr.Row(): + leftvote_btn = gr.Button( + value="👈 A is better", visible=False, interactive=False + ) + rightvote_btn = gr.Button( + value="👉 B is better", visible=False, interactive=False + ) + tie_btn = gr.Button(value="🤝 Tie", visible=False, interactive=False) + bothbad_btn = gr.Button( + value="👎 Both are bad", visible=False, interactive=False + ) + + with gr.Row(): + textbox = gr.Textbox( + show_label=False, + placeholder="👉 Enter your prompt and press ENTER", + elem_id="input_box", + ) + send_btn = gr.Button(value="Send", variant="primary", scale=0) + + with gr.Row() as button_row: + clear_btn = gr.Button(value="🎲 New Round", interactive=False) + regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) + share_btn = gr.Button(value="📷 Share") + + with gr.Accordion("Parameters", open=False, visible=False) as parameter_row: + temperature = gr.Slider( + minimum=0.0, + maximum=1.0, + value=0.7, + step=0.1, + interactive=True, + label="Temperature", + ) + top_p = gr.Slider( + minimum=0.0, + maximum=1.0, + value=1.0, + step=0.1, + interactive=True, + label="Top P", + ) + max_output_tokens = gr.Slider( + minimum=16, + maximum=2048, + value=1024, + step=64, + interactive=True, + label="Max output tokens", + ) + + gr.Markdown(acknowledgment_md, elem_id="ack_markdown") + + imagebox = gr.State(None) + # Register listeners + btn_list = [ + leftvote_btn, + rightvote_btn, + tie_btn, + bothbad_btn, + regenerate_btn, + clear_btn, + ] + leftvote_btn.click( + leftvote_last_response, + states + model_selectors, + model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + rightvote_btn.click( + rightvote_last_response, + states + model_selectors, + model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + tie_btn.click( + tievote_last_response, + states + model_selectors, + model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + bothbad_btn.click( + bothbad_vote_last_response, + states + model_selectors, + model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + regenerate_btn.click( + regenerate, states, states + chatbots + [textbox] + btn_list + ).then( + bot_response_multi, + states + [temperature, top_p, max_output_tokens], + states + chatbots + btn_list, + ).then( + flash_buttons, [], btn_list + ) + clear_btn.click( + clear_history, + None, + states + chatbots + model_selectors + [textbox] + btn_list + [slow_warning], + ) + + share_js = """ +function (a, b, c, d) { + const captureElement = document.querySelector('#share-region-anony'); + html2canvas(captureElement) + .then(canvas => { + canvas.style.display = 'none' + document.body.appendChild(canvas) + return canvas + }) + .then(canvas => { + const image = canvas.toDataURL('image/png') + const a = document.createElement('a') + a.setAttribute('download', 'chatbot-arena.png') + a.setAttribute('href', image) + a.click() + canvas.remove() + }); + return [a, b, c, d]; +} +""" + share_btn.click(share_click, states + model_selectors, [], js=share_js) + + textbox.submit( + add_text, + states + model_selectors + [textbox, imagebox], + states + chatbots + [textbox, imagebox] + btn_list + [slow_warning], + ).then( + bot_response_multi, + states + [temperature, top_p, max_output_tokens], + states + chatbots + btn_list, + ).then( + flash_buttons, + [], + btn_list, + ) + + send_btn.click( + add_text, + states + model_selectors + [textbox, imagebox], + states + chatbots + [textbox, imagebox] + btn_list, + ).then( + bot_response_multi, + states + [temperature, top_p, max_output_tokens], + states + chatbots + btn_list, + ).then( + flash_buttons, [], btn_list + ) + + return states + model_selectors diff --git a/src/serve/gradio_block_arena_named.py b/src/serve/gradio_block_arena_named.py new file mode 100644 index 0000000000000000000000000000000000000000..b56d67b63bbe3e2bed17f330a29e71edc82ac759 --- /dev/null +++ b/src/serve/gradio_block_arena_named.py @@ -0,0 +1,494 @@ +""" +Chatbot Arena (side-by-side) tab. +Users chat with two chosen models. +""" + +import json +import time + +import gradio as gr +import numpy as np + +from fastchat.constants import ( + MODERATION_MSG, + CONVERSATION_LIMIT_MSG, + INPUT_CHAR_LEN_LIMIT, + CONVERSATION_TURN_LIMIT, +) +from fastchat.model.model_adapter import get_conversation_template +from fastchat.serve.gradio_web_server import ( + State, + bot_response, + get_conv_log_filename, + no_change_btn, + enable_btn, + disable_btn, + invisible_btn, + acknowledgment_md, + get_ip, + _prepare_text_with_image, + get_model_description_md, +) +from fastchat.serve.remote_logger import get_remote_logger +from fastchat.utils import ( + build_logger, + moderation_filter, +) + +logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log") + +num_sides = 2 +enable_moderation = False + + +def set_global_vars_named(enable_moderation_): + global enable_moderation + enable_moderation = enable_moderation_ + + +def load_demo_side_by_side_named(models, url_params): + states = (None,) * num_sides + + model_left = models[0] if len(models) > 0 else "" + if len(models) > 1: + weights = ([8] * 4 + [4] * 8 + [1] * 32)[: len(models) - 1] + weights = weights / np.sum(weights) + model_right = np.random.choice(models[1:], p=weights) + else: + model_right = model_left + + selector_updates = ( + gr.Dropdown(choices=models, value=model_left, visible=True), + gr.Dropdown(choices=models, value=model_right, visible=True), + ) + + return states + selector_updates + + +def vote_last_response(states, vote_type, model_selectors, request: gr.Request): + with open(get_conv_log_filename(), "a") as fout: + data = { + "tstamp": round(time.time(), 4), + "type": vote_type, + "models": [x for x in model_selectors], + "states": [x.dict() for x in states], + "ip": get_ip(request), + } + fout.write(json.dumps(data) + "\n") + get_remote_logger().log(data) + + +def leftvote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"leftvote (named). ip: {get_ip(request)}") + vote_last_response( + [state0, state1], "leftvote", [model_selector0, model_selector1], request + ) + return ("",) + (disable_btn,) * 4 + + +def rightvote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"rightvote (named). ip: {get_ip(request)}") + vote_last_response( + [state0, state1], "rightvote", [model_selector0, model_selector1], request + ) + return ("",) + (disable_btn,) * 4 + + +def tievote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"tievote (named). ip: {get_ip(request)}") + vote_last_response( + [state0, state1], "tievote", [model_selector0, model_selector1], request + ) + return ("",) + (disable_btn,) * 4 + + +def bothbad_vote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"bothbad_vote (named). ip: {get_ip(request)}") + vote_last_response( + [state0, state1], "bothbad_vote", [model_selector0, model_selector1], request + ) + return ("",) + (disable_btn,) * 4 + + +def regenerate(state0, state1, request: gr.Request): + logger.info(f"regenerate (named). ip: {get_ip(request)}") + states = [state0, state1] + if state0.regen_support and state1.regen_support: + for i in range(num_sides): + states[i].conv.update_last_message(None) + return ( + states + [x.to_gradio_chatbot() for x in states] + [""] + [disable_btn] * 6 + ) + states[0].skip_next = True + states[1].skip_next = True + return states + [x.to_gradio_chatbot() for x in states] + [""] + [no_change_btn] * 6 + + +def clear_history(request: gr.Request): + logger.info(f"clear_history (named). ip: {get_ip(request)}") + return ( + [None] * num_sides + + [None] * num_sides + + [""] + + [invisible_btn] * 4 + + [disable_btn] * 2 + ) + + +def share_click(state0, state1, model_selector0, model_selector1, request: gr.Request): + logger.info(f"share (named). ip: {get_ip(request)}") + if state0 is not None and state1 is not None: + vote_last_response( + [state0, state1], "share", [model_selector0, model_selector1], request + ) + + +def add_text( + state0, state1, model_selector0, model_selector1, text, image, request: gr.Request +): + ip = get_ip(request) + logger.info(f"add_text (named). ip: {ip}. len: {len(text)}") + states = [state0, state1] + model_selectors = [model_selector0, model_selector1] + + # Init states if necessary + for i in range(num_sides): + if states[i] is None: + states[i] = State(model_selectors[i]) + + if len(text) <= 0: + for i in range(num_sides): + states[i].skip_next = True + return ( + states + + [x.to_gradio_chatbot() for x in states] + + ["", None] + + [ + no_change_btn, + ] + * 6 + ) + + model_list = [states[i].model_name for i in range(num_sides)] + all_conv_text_left = states[0].conv.get_prompt() + all_conv_text_right = states[1].conv.get_prompt() + all_conv_text = ( + all_conv_text_left[-1000:] + all_conv_text_right[-1000:] + "\nuser: " + text + ) + flagged = moderation_filter(all_conv_text, model_list) + if flagged: + logger.info(f"violate moderation (named). ip: {ip}. text: {text}") + # overwrite the original text + text = MODERATION_MSG + + conv = states[0].conv + if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT: + logger.info(f"conversation turn limit. ip: {ip}. text: {text}") + for i in range(num_sides): + states[i].skip_next = True + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [CONVERSATION_LIMIT_MSG, None] + + [ + no_change_btn, + ] + * 6 + ) + + text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off + for i in range(num_sides): + post_processed_text = _prepare_text_with_image( + states[i], text, image, csam_flag=False + ) + states[i].conv.append_message(states[i].conv.roles[0], post_processed_text) + states[i].conv.append_message(states[i].conv.roles[1], None) + states[i].skip_next = False + + return ( + states + + [x.to_gradio_chatbot() for x in states] + + ["", None] + + [ + disable_btn, + ] + * 6 + ) + + +def bot_response_multi( + state0, + state1, + temperature, + top_p, + max_new_tokens, + request: gr.Request, +): + logger.info(f"bot_response_multi (named). ip: {get_ip(request)}") + + if state0.skip_next: + # This generate call is skipped due to invalid inputs + yield ( + state0, + state1, + state0.to_gradio_chatbot(), + state1.to_gradio_chatbot(), + ) + (no_change_btn,) * 6 + return + + states = [state0, state1] + gen = [] + for i in range(num_sides): + gen.append( + bot_response( + states[i], + temperature, + top_p, + max_new_tokens, + request, + ) + ) + + is_stream_batch = [] + for i in range(num_sides): + is_stream_batch.append( + states[i].model_name + in [ + "gemini-pro", + "gemini-pro-dev-api", + "gemma-1.1-2b-it", + "gemma-1.1-7b-it", + ] + ) + + chatbots = [None] * num_sides + iters = 0 + while True: + stop = True + iters += 1 + for i in range(num_sides): + try: + # yield gemini fewer times as its chunk size is larger + # otherwise, gemini will stream too fast + if not is_stream_batch[i] or (iters % 30 == 1 or iters < 3): + ret = next(gen[i]) + states[i], chatbots[i] = ret[0], ret[1] + stop = False + except StopIteration: + pass + yield states + chatbots + [disable_btn] * 6 + if stop: + break + + +def flash_buttons(): + btn_updates = [ + [disable_btn] * 4 + [enable_btn] * 2, + [enable_btn] * 6, + ] + for i in range(4): + yield btn_updates[i % 2] + time.sleep(0.3) + + +def build_side_by_side_ui_named(models): + notice_markdown = """ +# ⚔️ Chatbot Arena: Benchmarking LLMs in the Wild +- | [Blog](https://lmsys.org/blog/2023-05-03-arena/) | [GitHub](https://github.com/lm-sys/FastChat) | [Paper](https://arxiv.org/abs/2306.05685) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) | + +## 📜 Rules +- Chat with any two models side-by-side and vote! +- You can continue chatting for multiple rounds. +- Click "Clear history" to start a new round. + +## 🤖 Choose two models to compare +""" + + states = [gr.State() for _ in range(num_sides)] + model_selectors = [None] * num_sides + chatbots = [None] * num_sides + + notice = gr.Markdown(notice_markdown, elem_id="notice_markdown") + + with gr.Group(elem_id="share-region-named"): + with gr.Row(): + for i in range(num_sides): + with gr.Column(): + model_selectors[i] = gr.Dropdown( + choices=models, + value=models[i] if len(models) > i else "", + interactive=True, + show_label=False, + container=False, + ) + with gr.Row(): + with gr.Accordion( + f"🔍 Expand to see the descriptions of {len(models)} models", open=False + ): + model_description_md = get_model_description_md(models) + gr.Markdown(model_description_md, elem_id="model_description_markdown") + + with gr.Row(): + for i in range(num_sides): + label = "Model A" if i == 0 else "Model B" + with gr.Column(): + chatbots[i] = gr.Chatbot( + label=label, + elem_id=f"chatbot", + height=550, + show_copy_button=True, + ) + + with gr.Row(): + leftvote_btn = gr.Button( + value="👈 A is better", visible=False, interactive=False + ) + rightvote_btn = gr.Button( + value="👉 B is better", visible=False, interactive=False + ) + tie_btn = gr.Button(value="🤝 Tie", visible=False, interactive=False) + bothbad_btn = gr.Button( + value="👎 Both are bad", visible=False, interactive=False + ) + + with gr.Row(): + textbox = gr.Textbox( + show_label=False, + placeholder="👉 Enter your prompt and press ENTER", + elem_id="input_box", + ) + send_btn = gr.Button(value="Send", variant="primary", scale=0) + + with gr.Row() as button_row: + clear_btn = gr.Button(value="🗑️ Clear history", interactive=False) + regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) + share_btn = gr.Button(value="📷 Share") + + with gr.Accordion("Parameters", open=False) as parameter_row: + temperature = gr.Slider( + minimum=0.0, + maximum=1.0, + value=0.7, + step=0.1, + interactive=True, + label="Temperature", + ) + top_p = gr.Slider( + minimum=0.0, + maximum=1.0, + value=1.0, + step=0.1, + interactive=True, + label="Top P", + ) + max_output_tokens = gr.Slider( + minimum=16, + maximum=2048, + value=1024, + step=64, + interactive=True, + label="Max output tokens", + ) + + gr.Markdown(acknowledgment_md, elem_id="ack_markdown") + + # Register listeners + imagebox = gr.State(None) + btn_list = [ + leftvote_btn, + rightvote_btn, + tie_btn, + bothbad_btn, + regenerate_btn, + clear_btn, + ] + leftvote_btn.click( + leftvote_last_response, + states + model_selectors, + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + rightvote_btn.click( + rightvote_last_response, + states + model_selectors, + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + tie_btn.click( + tievote_last_response, + states + model_selectors, + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + bothbad_btn.click( + bothbad_vote_last_response, + states + model_selectors, + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + regenerate_btn.click( + regenerate, states, states + chatbots + [textbox] + btn_list + ).then( + bot_response_multi, + states + [temperature, top_p, max_output_tokens], + states + chatbots + btn_list, + ).then( + flash_buttons, [], btn_list + ) + clear_btn.click(clear_history, None, states + chatbots + [textbox] + btn_list) + + share_js = """ +function (a, b, c, d) { + const captureElement = document.querySelector('#share-region-named'); + html2canvas(captureElement) + .then(canvas => { + canvas.style.display = 'none' + document.body.appendChild(canvas) + return canvas + }) + .then(canvas => { + const image = canvas.toDataURL('image/png') + const a = document.createElement('a') + a.setAttribute('download', 'chatbot-arena.png') + a.setAttribute('href', image) + a.click() + canvas.remove() + }); + return [a, b, c, d]; +} +""" + share_btn.click(share_click, states + model_selectors, [], js=share_js) + + for i in range(num_sides): + model_selectors[i].change( + clear_history, None, states + chatbots + [textbox] + btn_list + ) + + textbox.submit( + add_text, + states + model_selectors + [textbox, imagebox], + states + chatbots + [textbox, imagebox] + btn_list, + ).then( + bot_response_multi, + states + [temperature, top_p, max_output_tokens], + states + chatbots + btn_list, + ).then( + flash_buttons, [], btn_list + ) + send_btn.click( + add_text, + states + model_selectors + [textbox, imagebox], + states + chatbots + [textbox, imagebox] + btn_list, + ).then( + bot_response_multi, + states + [temperature, top_p, max_output_tokens], + states + chatbots + btn_list, + ).then( + flash_buttons, [], btn_list + ) + + return states + model_selectors diff --git a/src/serve/gradio_block_arena_vision.py b/src/serve/gradio_block_arena_vision.py new file mode 100644 index 0000000000000000000000000000000000000000..2ea3a8b67e7f5334e48d83959da0fedcea5924f3 --- /dev/null +++ b/src/serve/gradio_block_arena_vision.py @@ -0,0 +1,377 @@ +""" +The gradio demo server for chatting with a large multimodal model. + +Usage: +python3 -m fastchat.serve.controller +python3 -m fastchat.serve.sglang_worker --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf +python3 -m fastchat.serve.gradio_web_server_multi --share --vision-arena +""" + +import json +import os +import time + +import gradio as gr +from gradio.data_classes import FileData +import numpy as np + +from src.constants import ( + TEXT_MODERATION_MSG, + IMAGE_MODERATION_MSG, + MODERATION_MSG, + CONVERSATION_LIMIT_MSG, + INPUT_CHAR_LEN_LIMIT, + CONVERSATION_TURN_LIMIT, +) +from src.serve.gradio_web_server import ( + get_model_description_md, + acknowledgment_md, + bot_response, + get_ip, + disable_btn, + State, + _prepare_text_with_image, + get_conv_log_filename, + get_remote_logger, +) +from src.utils import ( + build_logger, + moderation_filter, + image_moderation_filter, +) + +logger = build_logger("gradio_web_server", "gradio_web_server.log") + +no_change_btn = gr.Button() +enable_btn = gr.Button(interactive=True, visible=True) +disable_btn = gr.Button(interactive=False) +invisible_btn = gr.Button(interactive=False, visible=False) +visible_image_column = gr.Image(visible=True) +invisible_image_column = gr.Image(visible=False) + + +def get_vqa_sample(): + random_sample = np.random.choice(vqa_samples) + question, path = random_sample["question"], random_sample["path"] + res = {"text": "", "files": [path]} + return (res, path) + + +def set_visible_image(textbox): + images = textbox["files"] + if len(images) == 0: + return invisible_image_column + elif len(images) > 1: + gr.Warning( + "We only support single image conversations. Please start a new round if you would like to chat using this image." + ) + + return visible_image_column + + +def set_invisible_image(): + return invisible_image_column + + +def add_image(textbox): + images = textbox["files"] + if len(images) == 0: + return None + + return images[0] + + +def vote_last_response(state, vote_type, model_selector, request: gr.Request): + filename = get_conv_log_filename(state.is_vision, state.has_csam_image) + with open(filename, "a") as fout: + data = { + "tstamp": round(time.time(), 4), + "type": vote_type, + "model": model_selector, + "state": state.dict(), + "ip": get_ip(request), + } + fout.write(json.dumps(data) + "\n") + get_remote_logger().log(data) + + +def upvote_last_response(state, model_selector, request: gr.Request): + ip = get_ip(request) + logger.info(f"upvote. ip: {ip}") + vote_last_response(state, "upvote", model_selector, request) + return (None,) + (disable_btn,) * 3 + + +def downvote_last_response(state, model_selector, request: gr.Request): + ip = get_ip(request) + logger.info(f"downvote. ip: {ip}") + vote_last_response(state, "downvote", model_selector, request) + return (None,) + (disable_btn,) * 3 + + +def flag_last_response(state, model_selector, request: gr.Request): + ip = get_ip(request) + logger.info(f"flag. ip: {ip}") + vote_last_response(state, "flag", model_selector, request) + return (None,) + (disable_btn,) * 3 + + +def regenerate(state, request: gr.Request): + ip = get_ip(request) + logger.info(f"regenerate. ip: {ip}") + if not state.regen_support: + state.skip_next = True + return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5 + state.conv.update_last_message(None) + return (state, state.to_gradio_chatbot(), None) + (disable_btn,) * 5 + + +def clear_history(request: gr.Request): + ip = get_ip(request) + logger.info(f"clear_history. ip: {ip}") + state = None + return (state, [], None) + (disable_btn,) * 5 + + +def clear_history_example(request: gr.Request): + ip = get_ip(request) + logger.info(f"clear_history_example. ip: {ip}") + state = None + return (state, []) + (disable_btn,) * 5 + + +def moderate_input(text, all_conv_text, model_list, images, ip): + text_flagged = moderation_filter(all_conv_text, model_list) + # flagged = moderation_filter(text, [state.model_name]) + nsfw_flagged, csam_flagged = False, False + if len(images) > 0: + nsfw_flagged, csam_flagged = image_moderation_filter(images[0]) + + image_flagged = nsfw_flagged or csam_flagged + if text_flagged or image_flagged: + logger.info(f"violate moderation. ip: {ip}. text: {all_conv_text}") + if text_flagged and not image_flagged: + # overwrite the original text + text = TEXT_MODERATION_MSG + elif not text_flagged and image_flagged: + text = IMAGE_MODERATION_MSG + elif text_flagged and image_flagged: + text = MODERATION_MSG + + return text, image_flagged, csam_flagged + + +def add_text(state, model_selector, chat_input, request: gr.Request): + text, images = chat_input["text"], chat_input["files"] + ip = get_ip(request) + logger.info(f"add_text. ip: {ip}. len: {len(text)}") + + if state is None: + state = State(model_selector, is_vision=True) + + if len(text) <= 0: + state.skip_next = True + return (state, state.to_gradio_chatbot(), None) + (no_change_btn,) * 5 + + all_conv_text = state.conv.get_prompt() + all_conv_text = all_conv_text[-2000:] + "\nuser: " + text + + text, image_flagged, csam_flag = moderate_input( + text, all_conv_text, [state.model_name], images, ip + ) + + if image_flagged: + logger.info(f"image flagged. ip: {ip}. text: {text}") + state.skip_next = True + return (state, state.to_gradio_chatbot(), {"text": IMAGE_MODERATION_MSG}) + ( + no_change_btn, + ) * 5 + + if (len(state.conv.messages) - state.conv.offset) // 2 >= CONVERSATION_TURN_LIMIT: + logger.info(f"conversation turn limit. ip: {ip}. text: {text}") + state.skip_next = True + return (state, state.to_gradio_chatbot(), {"text": CONVERSATION_LIMIT_MSG}) + ( + no_change_btn, + ) * 5 + + text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off + text = _prepare_text_with_image(state, text, images, csam_flag=csam_flag) + state.conv.append_message(state.conv.roles[0], text) + state.conv.append_message(state.conv.roles[1], None) + return (state, state.to_gradio_chatbot(), None) + (disable_btn,) * 5 + + +def build_single_vision_language_model_ui( + models, add_promotion_links=False, random_questions=None +): + promotion = ( + """ +- | [GitHub](https://github.com/lm-sys/FastChat) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) | + +**❗️ For research purposes, we log user prompts and images, and may release this data to the public in the future. Please do not upload any confidential or personal information.** + +Note: You can only chat with one image per conversation. You can upload images less than 15MB. Click the "Random Example" button to chat with a random image.""" + if add_promotion_links + else "" + ) + + notice_markdown = f""" +# 🏔️ Chat with Open Large Vision-Language Models +{promotion} +""" + + state = gr.State() + gr.Markdown(notice_markdown, elem_id="notice_markdown") + + with gr.Group(): + with gr.Row(elem_id="model_selector_row"): + model_selector = gr.Dropdown( + choices=models, + value=models[0] if len(models) > 0 else "", + interactive=True, + show_label=False, + container=False, + ) + + with gr.Accordion( + f"🔍 Expand to see the descriptions of {len(models)} models", open=False + ): + model_description_md = get_model_description_md(models) + gr.Markdown(model_description_md, elem_id="model_description_markdown") + + with gr.Row(): + textbox = gr.MultimodalTextbox( + file_types=["image"], + show_label=False, + placeholder="Click add or drop your image here", + container=True, + render=False, + elem_id="input_box", + ) + + with gr.Column(scale=2, visible=False) as image_column: + imagebox = gr.Image( + type="pil", + show_label=False, + interactive=False, + ) + with gr.Column(scale=8): + chatbot = gr.Chatbot( + elem_id="chatbot", label="Scroll down and start chatting", height=550 + ) + + with gr.Row(): + textbox.render() + # with gr.Column(scale=1, min_width=50): + # send_btn = gr.Button(value="Send", variant="primary") + + with gr.Row(elem_id="buttons"): + if random_questions: + global vqa_samples + with open(random_questions, "r") as f: + vqa_samples = json.load(f) + random_btn = gr.Button(value="🎲 Random Example", interactive=True) + upvote_btn = gr.Button(value="👍 Upvote", interactive=False) + downvote_btn = gr.Button(value="👎 Downvote", interactive=False) + flag_btn = gr.Button(value="⚠️ Flag", interactive=False) + regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) + clear_btn = gr.Button(value="🗑️ Clear", interactive=False) + + cur_dir = os.path.dirname(os.path.abspath(__file__)) + + examples = gr.Examples( + examples=[ + { + "text": "How can I prepare a delicious meal using these ingredients?", + "files": [f"{cur_dir}/example_images/fridge.jpg"], + }, + { + "text": "What might the woman on the right be thinking about?", + "files": [f"{cur_dir}/example_images/distracted.jpg"], + }, + ], + inputs=[textbox], + ) + + with gr.Accordion("Parameters", open=False) as parameter_row: + temperature = gr.Slider( + minimum=0.0, + maximum=1.0, + value=0.2, + step=0.1, + interactive=True, + label="Temperature", + ) + top_p = gr.Slider( + minimum=0.0, + maximum=1.0, + value=0.7, + step=0.1, + interactive=True, + label="Top P", + ) + max_output_tokens = gr.Slider( + minimum=0, + maximum=2048, + value=1024, + step=64, + interactive=True, + label="Max output tokens", + ) + + if add_promotion_links: + gr.Markdown(acknowledgment_md, elem_id="ack_markdown") + + # Register listeners + btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn] + upvote_btn.click( + upvote_last_response, + [state, model_selector], + [textbox, upvote_btn, downvote_btn, flag_btn], + ) + downvote_btn.click( + downvote_last_response, + [state, model_selector], + [textbox, upvote_btn, downvote_btn, flag_btn], + ) + flag_btn.click( + flag_last_response, + [state, model_selector], + [textbox, upvote_btn, downvote_btn, flag_btn], + ) + regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then( + bot_response, + [state, temperature, top_p, max_output_tokens], + [state, chatbot] + btn_list, + ) + clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list) + + model_selector.change( + clear_history, None, [state, chatbot, textbox] + btn_list + ).then(set_visible_image, [textbox], [image_column]) + examples.dataset.click(clear_history_example, None, [state, chatbot] + btn_list) + + textbox.input(add_image, [textbox], [imagebox]).then( + set_visible_image, [textbox], [image_column] + ).then(clear_history_example, None, [state, chatbot] + btn_list) + + textbox.submit( + add_text, + [state, model_selector, textbox], + [state, chatbot, textbox] + btn_list, + ).then(set_invisible_image, [], [image_column]).then( + bot_response, + [state, temperature, top_p, max_output_tokens], + [state, chatbot] + btn_list, + ) + + if random_questions: + random_btn.click( + get_vqa_sample, # First, get the VQA sample + [], # Pass the path to the VQA samples + [textbox, imagebox], # Outputs are textbox and imagebox + ).then(set_visible_image, [textbox], [image_column]).then( + clear_history_example, None, [state, chatbot] + btn_list + ) + + return [state, model_selector] diff --git a/src/serve/gradio_block_arena_vision_anony.py b/src/serve/gradio_block_arena_vision_anony.py new file mode 100644 index 0000000000000000000000000000000000000000..9b1bef363db2d015cabcdb9f5b7dff414cc07a39 --- /dev/null +++ b/src/serve/gradio_block_arena_vision_anony.py @@ -0,0 +1,584 @@ +""" +Chatbot Arena (battle) tab. +Users chat with two anonymous models. +""" + +import json +import time + +import gradio as gr +import numpy as np + +from fastchat.constants import ( + TEXT_MODERATION_MSG, + IMAGE_MODERATION_MSG, + MODERATION_MSG, + CONVERSATION_LIMIT_MSG, + SLOW_MODEL_MSG, + INPUT_CHAR_LEN_LIMIT, + CONVERSATION_TURN_LIMIT, +) +from fastchat.model.model_adapter import get_conversation_template +from fastchat.serve.gradio_block_arena_named import flash_buttons +from fastchat.serve.gradio_web_server import ( + State, + bot_response, + get_conv_log_filename, + no_change_btn, + enable_btn, + disable_btn, + invisible_btn, + acknowledgment_md, + get_ip, + get_model_description_md, + _prepare_text_with_image, +) +from fastchat.serve.gradio_block_arena_anony import ( + flash_buttons, + vote_last_response, + leftvote_last_response, + rightvote_last_response, + tievote_last_response, + bothbad_vote_last_response, + regenerate, + clear_history, + share_click, + add_text, + bot_response_multi, + set_global_vars_anony, + load_demo_side_by_side_anony, + get_sample_weight, + get_battle_pair, +) +from fastchat.serve.gradio_block_arena_vision import ( + get_vqa_sample, + set_invisible_image, + set_visible_image, + add_image, + moderate_input, +) +from fastchat.serve.remote_logger import get_remote_logger +from fastchat.utils import ( + build_logger, + moderation_filter, + image_moderation_filter, +) + +logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log") + +num_sides = 2 +enable_moderation = False +anony_names = ["", ""] +models = [] + +# TODO(chris): fix sampling weights +SAMPLING_WEIGHTS = { + # tier 0 + "gpt-4o": 4, + "gpt-4-turbo": 4, + "gemini-1.5-flash": 4, + "gemini-1.5-pro": 4, + "claude-3-opus-20240229": 4, + "claude-3-haiku-20240307": 4, + "claude-3-sonnet-20240229": 4, + "llava-v1.6-34b": 4, + "llava-v1.6-13b": 4, + "llava-v1.6-7b": 4, + "reka-flash-20240226": 4, +} + +# TODO(chris): Find battle targets that make sense +BATTLE_TARGETS = { + # "gpt-4-turbo": { + # "gemini-1.5-pro-preview-0409", + # "claude-3-opus-20240229", + # "reka-flash-20240226", + # }, + # "gemini-1.5-pro-preview-0409": { + # "gpt-4-turbo", + # "gemini-1.0-pro-vision", + # "reka-flash-20240226", + # }, + # "gemini-1.0-pro-vision": { + # "gpt-4-turbo", + # "gemini-1.5-pro-preview-0409", + # }, + # "claude-3-opus-20240229": { + # "gpt-4-turbo", + # "gemini-1.5-pro-preview-0409", + # "reka-flash-20240226", + # }, + # "claude-3-sonnet-20240229": { + # "claude-3-opus-20240229", + # "gpt-4-turbo", + # "gemini-1.0-pro-vision", + # "gemini-1.5-pro-preview-0409", + # }, + # "claude-3-haiku-20240307": { + # "claude-3-opus-20240229", + # "gpt-4-turbo", + # "gemini-1.0-pro-vision", + # "gemini-1.5-pro-preview-0409", + # }, + # "llava-v1.6-34b": { + # "gpt-4-turbo", + # "gemini-1.5-pro-preview-0409", + # "claude-3-opus-20240229", + # "claude-3-sonnet-20240229", + # "claude-3-haiku-20240307", + # }, + # "llava-v1.6-13b": {"llava-v1.6-7b", "llava-v1.6-34b", "gemini-1.0-pro-vision"}, + # "llava-v1.6-7b": {"llava-v1.6-13b", "gemini-1.0-pro-vision"}, + # "reka-flash-20240226": { + # "gemini-1.0-pro-vision", + # "claude-3-haiku-20240307", + # "claude-3-sonnet-20240229", + # }, +} + +# TODO(chris): Fill out models that require sampling boost +SAMPLING_BOOST_MODELS = [] + +# outage models won't be sampled. +OUTAGE_MODELS = [] + + +def load_demo_side_by_side_vision_anony(models_, url_params): + global models + models = models_ + + states = (None,) * num_sides + selector_updates = ( + gr.Markdown(visible=True), + gr.Markdown(visible=True), + ) + + return states + selector_updates + + +def clear_history_example(request: gr.Request): + logger.info(f"clear_history_example (anony). ip: {get_ip(request)}") + return ( + [None] * num_sides + + [None] * num_sides + + anony_names + + [invisible_btn] * 4 + + [disable_btn] * 2 + ) + + +def vote_last_response(states, vote_type, model_selectors, request: gr.Request): + filename = get_conv_log_filename(states[0].is_vision, states[0].has_csam_image) + + with open(filename, "a") as fout: + data = { + "tstamp": round(time.time(), 4), + "type": vote_type, + "models": [x for x in model_selectors], + "states": [x.dict() for x in states], + "ip": get_ip(request), + } + fout.write(json.dumps(data) + "\n") + get_remote_logger().log(data) + + if ":" not in model_selectors[0]: + for i in range(5): + names = ( + "### Model A: " + states[0].model_name, + "### Model B: " + states[1].model_name, + ) + yield names + (None,) + (disable_btn,) * 4 + time.sleep(0.1) + else: + names = ( + "### Model A: " + states[0].model_name, + "### Model B: " + states[1].model_name, + ) + yield names + (None,) + (disable_btn,) * 4 + + +def leftvote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"leftvote (anony). ip: {get_ip(request)}") + for x in vote_last_response( + [state0, state1], "leftvote", [model_selector0, model_selector1], request + ): + yield x + + +def rightvote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"rightvote (anony). ip: {get_ip(request)}") + for x in vote_last_response( + [state0, state1], "rightvote", [model_selector0, model_selector1], request + ): + yield x + + +def tievote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"tievote (anony). ip: {get_ip(request)}") + for x in vote_last_response( + [state0, state1], "tievote", [model_selector0, model_selector1], request + ): + yield x + + +def bothbad_vote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"bothbad_vote (anony). ip: {get_ip(request)}") + for x in vote_last_response( + [state0, state1], "bothbad_vote", [model_selector0, model_selector1], request + ): + yield x + + +def regenerate(state0, state1, request: gr.Request): + logger.info(f"regenerate (anony). ip: {get_ip(request)}") + states = [state0, state1] + if state0.regen_support and state1.regen_support: + for i in range(num_sides): + states[i].conv.update_last_message(None) + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [None] + + [disable_btn] * 6 + ) + states[0].skip_next = True + states[1].skip_next = True + return ( + states + [x.to_gradio_chatbot() for x in states] + [None] + [no_change_btn] * 6 + ) + + +def clear_history(request: gr.Request): + logger.info(f"clear_history (anony). ip: {get_ip(request)}") + return ( + [None] * num_sides + + [None] * num_sides + + anony_names + + [None] + + [invisible_btn] * 4 + + [disable_btn] * 2 + + [""] + ) + + +def add_text( + state0, state1, model_selector0, model_selector1, chat_input, request: gr.Request +): + text, images = chat_input["text"], chat_input["files"] + ip = get_ip(request) + logger.info(f"add_text (anony). ip: {ip}. len: {len(text)}") + states = [state0, state1] + model_selectors = [model_selector0, model_selector1] + + # Init states if necessary + if states[0] is None: + assert states[1] is None + + model_left, model_right = get_battle_pair( + models, + BATTLE_TARGETS, + OUTAGE_MODELS, + SAMPLING_WEIGHTS, + SAMPLING_BOOST_MODELS, + ) + states = [ + State(model_left, is_vision=True), + State(model_right, is_vision=True), + ] + + if len(text) <= 0: + for i in range(num_sides): + states[i].skip_next = True + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [None] + + [ + no_change_btn, + ] + * 6 + + [""] + ) + + model_list = [states[i].model_name for i in range(num_sides)] + text, image_flagged, csam_flag = moderate_input(text, text, model_list, images, ip) + + conv = states[0].conv + if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT: + logger.info(f"conversation turn limit. ip: {get_ip(request)}. text: {text}") + for i in range(num_sides): + states[i].skip_next = True + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [{"text": CONVERSATION_LIMIT_MSG}] + + [ + no_change_btn, + ] + * 6 + + [""] + ) + + if image_flagged: + logger.info(f"image flagged. ip: {ip}. text: {text}") + for i in range(num_sides): + states[i].skip_next = True + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [{"text": IMAGE_MODERATION_MSG}] + + [ + no_change_btn, + ] + * 6 + + [""] + ) + + text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off + for i in range(num_sides): + post_processed_text = _prepare_text_with_image( + states[i], text, images, csam_flag=csam_flag + ) + states[i].conv.append_message(states[i].conv.roles[0], post_processed_text) + states[i].conv.append_message(states[i].conv.roles[1], None) + states[i].skip_next = False + + hint_msg = "" + for i in range(num_sides): + if "deluxe" in states[i].model_name: + hint_msg = SLOW_MODEL_MSG + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [None] + + [ + disable_btn, + ] + * 6 + + [hint_msg] + ) + + +def build_side_by_side_vision_ui_anony(models, random_questions=None): + notice_markdown = """ +# ⚔️ Vision Arena ⚔️: Benchmarking VLMs in the Wild +| [Blog](https://lmsys.org/blog/2023-05-03-arena/) | [GitHub](https://github.com/lm-sys/FastChat) | [Paper](https://arxiv.org/abs/2306.05685) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) | + +## 📜 Rules +- Ask any question to two anonymous models (e.g., Claude, Gemini, GPT-4-V) and vote for the better one! +- You can continue chatting until you identify a winner. +- Vote won't be counted if model identity is revealed during conversation. +- You can only chat with one image per conversation. You can upload images less than 15MB. Click the "Random Example" button to chat with a random image. + +**❗️ For research purposes, we log user prompts and images, and may release this data to the public in the future. Please do not upload any confidential or personal information.** + +## 👇 Chat now! +""" + + states = [gr.State() for _ in range(num_sides)] + model_selectors = [None] * num_sides + chatbots = [None] * num_sides + + gr.Markdown(notice_markdown, elem_id="notice_markdown") + + with gr.Row(): + with gr.Column(scale=2, visible=False) as image_column: + imagebox = gr.Image( + type="pil", + show_label=False, + interactive=False, + ) + + with gr.Column(scale=5): + with gr.Group(elem_id="share-region-anony"): + with gr.Accordion( + f"🔍 Expand to see the descriptions of {len(models)} models", + open=False, + ): + model_description_md = get_model_description_md(models) + gr.Markdown( + model_description_md, elem_id="model_description_markdown" + ) + + with gr.Row(): + for i in range(num_sides): + label = "Model A" if i == 0 else "Model B" + with gr.Column(): + chatbots[i] = gr.Chatbot( + label=label, + elem_id="chatbot", + height=550, + show_copy_button=True, + ) + + with gr.Row(): + for i in range(num_sides): + with gr.Column(): + model_selectors[i] = gr.Markdown( + anony_names[i], elem_id="model_selector_md" + ) + with gr.Row(): + slow_warning = gr.Markdown("", elem_id="notice_markdown") + + with gr.Row(): + leftvote_btn = gr.Button( + value="👈 A is better", visible=False, interactive=False + ) + rightvote_btn = gr.Button( + value="👉 B is better", visible=False, interactive=False + ) + tie_btn = gr.Button(value="🤝 Tie", visible=False, interactive=False) + bothbad_btn = gr.Button( + value="👎 Both are bad", visible=False, interactive=False + ) + + with gr.Row(): + textbox = gr.MultimodalTextbox( + file_types=["image"], + show_label=False, + container=True, + placeholder="Click add or drop your image here", + elem_id="input_box", + ) + # send_btn = gr.Button(value="Send", variant="primary", scale=0) + + with gr.Row() as button_row: + if random_questions: + global vqa_samples + with open(random_questions, "r") as f: + vqa_samples = json.load(f) + random_btn = gr.Button(value="🎲 Random Example", interactive=True) + clear_btn = gr.Button(value="🎲 New Round", interactive=False) + regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) + share_btn = gr.Button(value="📷 Share") + + with gr.Accordion("Parameters", open=False) as parameter_row: + temperature = gr.Slider( + minimum=0.0, + maximum=1.0, + value=0.7, + step=0.1, + interactive=True, + label="Temperature", + ) + top_p = gr.Slider( + minimum=0.0, + maximum=1.0, + value=1.0, + step=0.1, + interactive=True, + label="Top P", + ) + max_output_tokens = gr.Slider( + minimum=16, + maximum=2048, + value=1024, + step=64, + interactive=True, + label="Max output tokens", + ) + + gr.Markdown(acknowledgment_md, elem_id="ack_markdown") + + # Register listeners + btn_list = [ + leftvote_btn, + rightvote_btn, + tie_btn, + bothbad_btn, + regenerate_btn, + clear_btn, + ] + leftvote_btn.click( + leftvote_last_response, + states + model_selectors, + model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + rightvote_btn.click( + rightvote_last_response, + states + model_selectors, + model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + tie_btn.click( + tievote_last_response, + states + model_selectors, + model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + bothbad_btn.click( + bothbad_vote_last_response, + states + model_selectors, + model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + regenerate_btn.click( + regenerate, states, states + chatbots + [textbox] + btn_list + ).then( + bot_response_multi, + states + [temperature, top_p, max_output_tokens], + states + chatbots + btn_list, + ).then( + flash_buttons, [], btn_list + ) + clear_btn.click( + clear_history, + None, + states + chatbots + model_selectors + [textbox] + btn_list + [slow_warning], + ) + + share_js = """ +function (a, b, c, d) { + const captureElement = document.querySelector('#share-region-anony'); + html2canvas(captureElement) + .then(canvas => { + canvas.style.display = 'none' + document.body.appendChild(canvas) + return canvas + }) + .then(canvas => { + const image = canvas.toDataURL('image/png') + const a = document.createElement('a') + a.setAttribute('download', 'chatbot-arena.png') + a.setAttribute('href', image) + a.click() + canvas.remove() + }); + return [a, b, c, d]; +} +""" + share_btn.click(share_click, states + model_selectors, [], js=share_js) + + textbox.input(add_image, [textbox], [imagebox]).then( + set_visible_image, [textbox], [image_column] + ).then(clear_history_example, None, states + chatbots + model_selectors + btn_list) + + textbox.submit( + add_text, + states + model_selectors + [textbox], + states + chatbots + [textbox] + btn_list + [slow_warning], + ).then(set_invisible_image, [], [image_column]).then( + bot_response_multi, + states + [temperature, top_p, max_output_tokens], + states + chatbots + btn_list, + ).then( + flash_buttons, + [], + btn_list, + ) + + if random_questions: + random_btn.click( + get_vqa_sample, # First, get the VQA sample + [], # Pass the path to the VQA samples + [textbox, imagebox], # Outputs are textbox and imagebox + ).then(set_visible_image, [textbox], [image_column]).then( + clear_history_example, None, states + chatbots + model_selectors + btn_list + ) + + return states + model_selectors diff --git a/src/serve/gradio_block_arena_vision_named.py b/src/serve/gradio_block_arena_vision_named.py new file mode 100644 index 0000000000000000000000000000000000000000..6fb394a70ccbea5f4c30e0c5746027ebf46da9d8 --- /dev/null +++ b/src/serve/gradio_block_arena_vision_named.py @@ -0,0 +1,457 @@ +""" +Multimodal Chatbot Arena (side-by-side) tab. +Users chat with two chosen models. +""" + +import json +import os +import time + +import gradio as gr +import numpy as np + +from src.constants import ( + TEXT_MODERATION_MSG, + IMAGE_MODERATION_MSG, + MODERATION_MSG, + CONVERSATION_LIMIT_MSG, + SLOW_MODEL_MSG, + INPUT_CHAR_LEN_LIMIT, + CONVERSATION_TURN_LIMIT, +) +from src.model.model_adapter import get_conversation_template +from src.serve.gradio_block_arena_named import ( + flash_buttons, + share_click, + bot_response_multi, +) +from src.serve.gradio_block_arena_vision import ( + get_vqa_sample, + set_invisible_image, + set_visible_image, + add_image, + moderate_input, +) +from src.serve.gradio_web_server import ( + State, + bot_response, + get_conv_log_filename, + no_change_btn, + enable_btn, + disable_btn, + invisible_btn, + acknowledgment_md, + get_ip, + get_model_description_md, + _prepare_text_with_image, +) +from fastchat.serve.remote_logger import get_remote_logger +from fastchat.utils import ( + build_logger, + moderation_filter, + image_moderation_filter, +) + + +logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log") + +num_sides = 2 +enable_moderation = False + + +def clear_history_example(request: gr.Request): + logger.info(f"clear_history_example (named). ip: {get_ip(request)}") + return ( + [None] * num_sides + + [None] * num_sides + + [invisible_btn] * 4 + + [disable_btn] * 2 + ) + + +def vote_last_response(states, vote_type, model_selectors, request: gr.Request): + filename = get_conv_log_filename(states[0].is_vision, states[0].has_csam_image) + with open(filename, "a") as fout: + data = { + "tstamp": round(time.time(), 4), + "type": vote_type, + "models": [x for x in model_selectors], + "states": [x.dict() for x in states], + "ip": get_ip(request), + } + fout.write(json.dumps(data) + "\n") + get_remote_logger().log(data) + + +def leftvote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"leftvote (named). ip: {get_ip(request)}") + vote_last_response( + [state0, state1], "leftvote", [model_selector0, model_selector1], request + ) + return (None,) + (disable_btn,) * 4 + + +def rightvote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"rightvote (named). ip: {get_ip(request)}") + vote_last_response( + [state0, state1], "rightvote", [model_selector0, model_selector1], request + ) + return (None,) + (disable_btn,) * 4 + + +def tievote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"tievote (named). ip: {get_ip(request)}") + vote_last_response( + [state0, state1], "tievote", [model_selector0, model_selector1], request + ) + return (None,) + (disable_btn,) * 4 + + +def bothbad_vote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"bothbad_vote (named). ip: {get_ip(request)}") + vote_last_response( + [state0, state1], "bothbad_vote", [model_selector0, model_selector1], request + ) + return (None,) + (disable_btn,) * 4 + + +def regenerate(state0, state1, request: gr.Request): + logger.info(f"regenerate (named). ip: {get_ip(request)}") + states = [state0, state1] + if state0.regen_support and state1.regen_support: + for i in range(num_sides): + states[i].conv.update_last_message(None) + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [None] + + [disable_btn] * 6 + ) + states[0].skip_next = True + states[1].skip_next = True + return ( + states + [x.to_gradio_chatbot() for x in states] + [None] + [no_change_btn] * 6 + ) + + +def clear_history(request: gr.Request): + logger.info(f"clear_history (named). ip: {get_ip(request)}") + return ( + [None] * num_sides + + [None] * num_sides + + [None] + + [invisible_btn] * 4 + + [disable_btn] * 2 + ) + + +def add_text( + state0, state1, model_selector0, model_selector1, chat_input, request: gr.Request +): + text, images = chat_input["text"], chat_input["files"] + ip = get_ip(request) + logger.info(f"add_text (named). ip: {ip}. len: {len(text)}") + states = [state0, state1] + model_selectors = [model_selector0, model_selector1] + + # Init states if necessary + for i in range(num_sides): + if states[i] is None: + states[i] = State(model_selectors[i], is_vision=True) + + if len(text) <= 0: + for i in range(num_sides): + states[i].skip_next = True + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [None] + + [ + no_change_btn, + ] + * 6 + ) + + model_list = [states[i].model_name for i in range(num_sides)] + all_conv_text_left = states[0].conv.get_prompt() + all_conv_text_right = states[0].conv.get_prompt() + all_conv_text = ( + all_conv_text_left[-1000:] + all_conv_text_right[-1000:] + "\nuser: " + text + ) + + text, image_flagged, csam_flag = moderate_input( + text, all_conv_text, model_list, images, ip + ) + + conv = states[0].conv + if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT: + logger.info(f"conversation turn limit. ip: {ip}. text: {text}") + for i in range(num_sides): + states[i].skip_next = True + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [{"text": CONVERSATION_LIMIT_MSG}] + + [ + no_change_btn, + ] + * 6 + ) + + if image_flagged: + logger.info(f"image flagged. ip: {ip}. text: {text}") + for i in range(num_sides): + states[i].skip_next = True + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [{"text": IMAGE_MODERATION_MSG}] + + [ + no_change_btn, + ] + * 6 + ) + + text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off + for i in range(num_sides): + post_processed_text = _prepare_text_with_image( + states[i], text, images, csam_flag=csam_flag + ) + states[i].conv.append_message(states[i].conv.roles[0], post_processed_text) + states[i].conv.append_message(states[i].conv.roles[1], None) + states[i].skip_next = False + + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [None] + + [ + disable_btn, + ] + * 6 + ) + + +def build_side_by_side_vision_ui_named(models, random_questions=None): + notice_markdown = """ +# ⚔️ Vision Arena ⚔️ : Benchmarking VLMs in the Wild +| [Blog](https://lmsys.org/blog/2023-05-03-arena/) | [GitHub](https://github.com/lm-sys/FastChat) | [Paper](https://arxiv.org/abs/2306.05685) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) | + +## 📜 Rules +- Chat with any two models side-by-side and vote! +- You can continue chatting for multiple rounds. +- Click "Clear history" to start a new round. +- You can only chat with one image per conversation. You can upload images less than 15MB. Click the "Random Example" button to chat with a random image. + +**❗️ For research purposes, we log user prompts and images, and may release this data to the public in the future. Please do not upload any confidential or personal information.** + +## 🤖 Choose two models to compare +""" + + states = [gr.State() for _ in range(num_sides)] + model_selectors = [None] * num_sides + chatbots = [None] * num_sides + + notice = gr.Markdown(notice_markdown, elem_id="notice_markdown") + + with gr.Row(): + with gr.Column(scale=2, visible=False) as image_column: + imagebox = gr.Image( + type="pil", + show_label=False, + interactive=False, + ) + + with gr.Column(scale=5): + with gr.Group(elem_id="share-region-anony"): + with gr.Accordion( + f"🔍 Expand to see the descriptions of {len(models)} models", + open=False, + ): + model_description_md = get_model_description_md(models) + gr.Markdown( + model_description_md, elem_id="model_description_markdown" + ) + + with gr.Row(): + for i in range(num_sides): + with gr.Column(): + model_selectors[i] = gr.Dropdown( + choices=models, + value=models[i] if len(models) > i else "", + interactive=True, + show_label=False, + container=False, + ) + + with gr.Row(): + for i in range(num_sides): + label = "Model A" if i == 0 else "Model B" + with gr.Column(): + chatbots[i] = gr.Chatbot( + label=label, + elem_id=f"chatbot", + height=550, + show_copy_button=True, + ) + + with gr.Row(): + leftvote_btn = gr.Button( + value="👈 A is better", visible=False, interactive=False + ) + rightvote_btn = gr.Button( + value="👉 B is better", visible=False, interactive=False + ) + tie_btn = gr.Button(value="🤝 Tie", visible=False, interactive=False) + bothbad_btn = gr.Button( + value="👎 Both are bad", visible=False, interactive=False + ) + + with gr.Row(): + textbox = gr.MultimodalTextbox( + file_types=["image"], + show_label=False, + placeholder="Click add or drop your image here", + container=True, + elem_id="input_box", + ) + + with gr.Row() as button_row: + if random_questions: + global vqa_samples + with open(random_questions, "r") as f: + vqa_samples = json.load(f) + random_btn = gr.Button(value="🎲 Random Example", interactive=True) + clear_btn = gr.Button(value="🗑️ Clear history", interactive=False) + regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) + share_btn = gr.Button(value="📷 Share") + + with gr.Accordion("Parameters", open=False) as parameter_row: + temperature = gr.Slider( + minimum=0.0, + maximum=1.0, + value=0.7, + step=0.1, + interactive=True, + label="Temperature", + ) + top_p = gr.Slider( + minimum=0.0, + maximum=1.0, + value=1.0, + step=0.1, + interactive=True, + label="Top P", + ) + max_output_tokens = gr.Slider( + minimum=16, + maximum=2048, + value=1024, + step=64, + interactive=True, + label="Max output tokens", + ) + + gr.Markdown(acknowledgment_md, elem_id="ack_markdown") + + # Register listeners + btn_list = [ + leftvote_btn, + rightvote_btn, + tie_btn, + bothbad_btn, + regenerate_btn, + clear_btn, + ] + leftvote_btn.click( + leftvote_last_response, + states + model_selectors, + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + rightvote_btn.click( + rightvote_last_response, + states + model_selectors, + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + tie_btn.click( + tievote_last_response, + states + model_selectors, + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + bothbad_btn.click( + bothbad_vote_last_response, + states + model_selectors, + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + regenerate_btn.click( + regenerate, states, states + chatbots + [textbox] + btn_list + ).then( + bot_response_multi, + states + [temperature, top_p, max_output_tokens], + states + chatbots + btn_list, + ).then( + flash_buttons, [], btn_list + ) + clear_btn.click(clear_history, None, states + chatbots + [textbox] + btn_list) + + share_js = """ +function (a, b, c, d) { + const captureElement = document.querySelector('#share-region-named'); + html2canvas(captureElement) + .then(canvas => { + canvas.style.display = 'none' + document.body.appendChild(canvas) + return canvas + }) + .then(canvas => { + const image = canvas.toDataURL('image/png') + const a = document.createElement('a') + a.setAttribute('download', 'chatbot-arena.png') + a.setAttribute('href', image) + a.click() + canvas.remove() + }); + return [a, b, c, d]; +} +""" + share_btn.click(share_click, states + model_selectors, [], js=share_js) + + for i in range(num_sides): + model_selectors[i].change( + clear_history, None, states + chatbots + [textbox] + btn_list + ).then(set_visible_image, [textbox], [image_column]) + + textbox.input(add_image, [textbox], [imagebox]).then( + set_visible_image, [textbox], [image_column] + ).then(clear_history_example, None, states + chatbots + btn_list) + + textbox.submit( + add_text, + states + model_selectors + [textbox], + states + chatbots + [textbox] + btn_list, + ).then(set_invisible_image, [], [image_column]).then( + bot_response_multi, + states + [temperature, top_p, max_output_tokens], + states + chatbots + btn_list, + ).then( + flash_buttons, [], btn_list + ) + + if random_questions: + random_btn.click( + get_vqa_sample, # First, get the VQA sample + [], # Pass the path to the VQA samples + [textbox, imagebox], # Outputs are textbox and imagebox + ).then(set_visible_image, [textbox], [image_column]).then( + clear_history_example, None, states + chatbots + btn_list + ) + + return states + model_selectors diff --git a/src/serve/gradio_web_server.py b/src/serve/gradio_web_server.py new file mode 100644 index 0000000000000000000000000000000000000000..974e1e2c46cdb34d5a2193d571786834b78bf79e --- /dev/null +++ b/src/serve/gradio_web_server.py @@ -0,0 +1,1030 @@ +""" +The gradio demo server for chatting with a single model. +""" + +import argparse +from collections import defaultdict +import datetime +import hashlib +import json +import os +import random +import time +import uuid + +import gradio as gr +import requests + +from src.constants import ( + LOGDIR, + WORKER_API_TIMEOUT, + ErrorCode, + MODERATION_MSG, + CONVERSATION_LIMIT_MSG, + RATE_LIMIT_MSG, + SERVER_ERROR_MSG, + INPUT_CHAR_LEN_LIMIT, + CONVERSATION_TURN_LIMIT, + SESSION_EXPIRATION_TIME, +) +from src.model.model_adapter import ( + get_conversation_template, +) +from src.model.model_registry import get_model_info, model_info +from src.serve.api_provider import get_api_provider_stream_iter +from src.serve.remote_logger import get_remote_logger +from src.utils import ( + build_logger, + get_window_url_params_js, + get_window_url_params_with_tos_js, + moderation_filter, + parse_gradio_auth_creds, + load_image, +) + +logger = build_logger("gradio_web_server", "gradio_web_server.log") + +headers = {"User-Agent": "FastChat Client"} + +no_change_btn = gr.Button() +enable_btn = gr.Button(interactive=True, visible=True) +disable_btn = gr.Button(interactive=False) +invisible_btn = gr.Button(interactive=False, visible=False) + +controller_url = None +enable_moderation = False +use_remote_storage = False + +acknowledgment_md = """ +### Terms of Service + +Users are required to agree to the following terms before using the service: + +The service is a research preview. It only provides limited safety measures and may generate offensive content. +It must not be used for any illegal, harmful, violent, racist, or sexual purposes. +Please do not upload any private information. +The service collects user dialogue data, including both text and images, and reserves the right to distribute it under a Creative Commons Attribution (CC-BY) or a similar license. + +### Acknowledgment +We thank [UC Berkeley SkyLab](https://sky.cs.berkeley.edu/), [Kaggle](https://www.kaggle.com/), [MBZUAI](https://mbzuai.ac.ae/), [a16z](https://www.a16z.com/), [Together AI](https://www.together.ai/), [Hyperbolic](https://hyperbolic.xyz/), [Anyscale](https://www.anyscale.com/), [HuggingFace](https://huggingface.co/) for their generous [sponsorship](https://lmsys.org/donations/). + + +""" + +# JSON file format of API-based models: +# { +# "gpt-3.5-turbo": { +# "model_name": "gpt-3.5-turbo", +# "api_type": "openai", +# "api_base": "https://api.openai.com/v1", +# "api_key": "sk-******", +# "anony_only": false +# } +# } +# +# - "api_type" can be one of the following: openai, anthropic, gemini, or mistral. For custom APIs, add a new type and implement it accordingly. +# - "anony_only" indicates whether to display this model in anonymous mode only. + +api_endpoint_info = {} + + +class State: + def __init__(self, model_name, is_vision=False): + self.conv = get_conversation_template(model_name) + self.conv_id = uuid.uuid4().hex + self.skip_next = False + self.model_name = model_name + self.oai_thread_id = None + self.is_vision = is_vision + + # NOTE(chris): This could be sort of a hack since it assumes the user only uploads one image. If they can upload multiple, we should store a list of image hashes. + self.has_csam_image = False + + self.regen_support = True + if "browsing" in model_name: + self.regen_support = False + self.init_system_prompt(self.conv) + + def init_system_prompt(self, conv): + system_prompt = conv.get_system_message() + if len(system_prompt) == 0: + return + current_date = datetime.datetime.now().strftime("%Y-%m-%d") + system_prompt = system_prompt.replace("{{currentDateTime}}", current_date) + conv.set_system_message(system_prompt) + + def to_gradio_chatbot(self): + return self.conv.to_gradio_chatbot() + + def dict(self): + base = self.conv.dict() + base.update( + { + "conv_id": self.conv_id, + "model_name": self.model_name, + } + ) + + if self.is_vision: + base.update({"has_csam_image": self.has_csam_image}) + return base + + +def set_global_vars(controller_url_, enable_moderation_, use_remote_storage_): + global controller_url, enable_moderation, use_remote_storage + controller_url = controller_url_ + enable_moderation = enable_moderation_ + use_remote_storage = use_remote_storage_ + + +def get_conv_log_filename(is_vision=False, has_csam_image=False): + t = datetime.datetime.now() + conv_log_filename = f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json" + if is_vision and not has_csam_image: + name = os.path.join(LOGDIR, f"vision-tmp-{conv_log_filename}") + elif is_vision and has_csam_image: + name = os.path.join(LOGDIR, f"vision-csam-{conv_log_filename}") + else: + name = os.path.join(LOGDIR, conv_log_filename) + + return name + + +def get_model_list(controller_url, register_api_endpoint_file, vision_arena): + global api_endpoint_info + + # Add models from the controller + if controller_url: + ret = requests.post(controller_url + "/refresh_all_workers") + assert ret.status_code == 200 + + if vision_arena: + ret = requests.post(controller_url + "/list_multimodal_models") + models = ret.json()["models"] + else: + ret = requests.post(controller_url + "/list_language_models") + models = ret.json()["models"] + else: + models = [] + + # Add models from the API providers + if register_api_endpoint_file: + api_endpoint_info = json.load(open(register_api_endpoint_file)) + for mdl, mdl_dict in api_endpoint_info.items(): + mdl_vision = mdl_dict.get("vision-arena", False) + mdl_text = mdl_dict.get("text-arena", True) + if vision_arena and mdl_vision: + models.append(mdl) + if not vision_arena and mdl_text: + models.append(mdl) + + # Remove anonymous models + models = list(set(models)) + visible_models = models.copy() + for mdl in models: + if mdl not in api_endpoint_info: + continue + mdl_dict = api_endpoint_info[mdl] + if mdl_dict["anony_only"]: + visible_models.remove(mdl) + + # Sort models and add descriptions + priority = {k: f"___{i:03d}" for i, k in enumerate(model_info)} + models.sort(key=lambda x: priority.get(x, x)) + visible_models.sort(key=lambda x: priority.get(x, x)) + logger.info(f"All models: {models}") + logger.info(f"Visible models: {visible_models}") + return visible_models, models + + +def load_demo_single(models, url_params): + selected_model = models[0] if len(models) > 0 else "" + if "model" in url_params: + model = url_params["model"] + if model in models: + selected_model = model + + dropdown_update = gr.Dropdown(choices=models, value=selected_model, visible=True) + state = None + return state, dropdown_update + + +def load_demo(url_params, request: gr.Request): + global models + + ip = get_ip(request) + logger.info(f"load_demo. ip: {ip}. params: {url_params}") + + if args.model_list_mode == "reload": + models, all_models = get_model_list( + controller_url, args.register_api_endpoint_file, vision_arena=False + ) + + return load_demo_single(models, url_params) + + +def vote_last_response(state, vote_type, model_selector, request: gr.Request): + filename = get_conv_log_filename() + if "llava" in model_selector: + filename = filename.replace("2024", "vision-tmp-2024") + + with open(filename, "a") as fout: + data = { + "tstamp": round(time.time(), 4), + "type": vote_type, + "model": model_selector, + "state": state.dict(), + "ip": get_ip(request), + } + fout.write(json.dumps(data) + "\n") + get_remote_logger().log(data) + + +def upvote_last_response(state, model_selector, request: gr.Request): + ip = get_ip(request) + logger.info(f"upvote. ip: {ip}") + vote_last_response(state, "upvote", model_selector, request) + return ("",) + (disable_btn,) * 3 + + +def downvote_last_response(state, model_selector, request: gr.Request): + ip = get_ip(request) + logger.info(f"downvote. ip: {ip}") + vote_last_response(state, "downvote", model_selector, request) + return ("",) + (disable_btn,) * 3 + + +def flag_last_response(state, model_selector, request: gr.Request): + ip = get_ip(request) + logger.info(f"flag. ip: {ip}") + vote_last_response(state, "flag", model_selector, request) + return ("",) + (disable_btn,) * 3 + + +def regenerate(state, request: gr.Request): + ip = get_ip(request) + logger.info(f"regenerate. ip: {ip}") + if not state.regen_support: + state.skip_next = True + return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5 + state.conv.update_last_message(None) + return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 + + +def clear_history(request: gr.Request): + ip = get_ip(request) + logger.info(f"clear_history. ip: {ip}") + state = None + return (state, [], "", None) + (disable_btn,) * 5 + + +def get_ip(request: gr.Request): + if "cf-connecting-ip" in request.headers: + ip = request.headers["cf-connecting-ip"] + elif "x-forwarded-for" in request.headers: + ip = request.headers["x-forwarded-for"] + else: + ip = request.client.host + return ip + + +# TODO(Chris): At some point, we would like this to be a live-reporting feature. +def report_csam_image(state, image): + pass + + +def _prepare_text_with_image(state, text, images, csam_flag): + if images is not None and len(images) > 0: + image = images[0] + + if len(state.conv.get_images()) > 0: + # reset convo with new image + state.conv = get_conversation_template(state.model_name) + + image = state.conv.convert_image_to_base64( + image + ) # PIL type is not JSON serializable + + if csam_flag: + state.has_csam_image = True + report_csam_image(state, image) + + text = text, [image] + + return text + + +def add_text(state, model_selector, text, image, request: gr.Request): + ip = get_ip(request) + logger.info(f"add_text. ip: {ip}. len: {len(text)}") + + if state is None: + state = State(model_selector) + + if len(text) <= 0: + state.skip_next = True + return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5 + + all_conv_text = state.conv.get_prompt() + all_conv_text = all_conv_text[-2000:] + "\nuser: " + text + flagged = moderation_filter(all_conv_text, [state.model_name]) + # flagged = moderation_filter(text, [state.model_name]) + if flagged: + logger.info(f"violate moderation. ip: {ip}. text: {text}") + # overwrite the original text + text = MODERATION_MSG + + if (len(state.conv.messages) - state.conv.offset) // 2 >= CONVERSATION_TURN_LIMIT: + logger.info(f"conversation turn limit. ip: {ip}. text: {text}") + state.skip_next = True + return (state, state.to_gradio_chatbot(), CONVERSATION_LIMIT_MSG, None) + ( + no_change_btn, + ) * 5 + + text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off + text = _prepare_text_with_image(state, text, image, csam_flag=False) + state.conv.append_message(state.conv.roles[0], text) + state.conv.append_message(state.conv.roles[1], None) + return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 + + +def model_worker_stream_iter( + conv, + model_name, + worker_addr, + prompt, + temperature, + repetition_penalty, + top_p, + max_new_tokens, + images, +): + # Make requests + gen_params = { + "model": model_name, + "prompt": prompt, + "temperature": temperature, + "repetition_penalty": repetition_penalty, + "top_p": top_p, + "max_new_tokens": max_new_tokens, + "stop": conv.stop_str, + "stop_token_ids": conv.stop_token_ids, + "echo": False, + } + + logger.info(f"==== request ====\n{gen_params}") + + if len(images) > 0: + gen_params["images"] = images + + # Stream output + response = requests.post( + worker_addr + "/worker_generate_stream", + headers=headers, + json=gen_params, + stream=True, + timeout=WORKER_API_TIMEOUT, + ) + for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): + if chunk: + data = json.loads(chunk.decode()) + yield data + + +def is_limit_reached(model_name, ip): + monitor_url = "http://localhost:9090" + try: + ret = requests.get( + f"{monitor_url}/is_limit_reached?model={model_name}&user_id={ip}", timeout=1 + ) + obj = ret.json() + return obj + except Exception as e: + logger.info(f"monitor error: {e}") + return None + + +def bot_response( + state, + temperature, + top_p, + max_new_tokens, + request: gr.Request, + apply_rate_limit=True, + use_recommended_config=False, +): + ip = get_ip(request) + logger.info(f"bot_response. ip: {ip}") + start_tstamp = time.time() + temperature = float(temperature) + top_p = float(top_p) + max_new_tokens = int(max_new_tokens) + + if state.skip_next: + # This generate call is skipped due to invalid inputs + state.skip_next = False + yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 + return + + if apply_rate_limit: + ret = is_limit_reached(state.model_name, ip) + if ret is not None and ret["is_limit_reached"]: + error_msg = RATE_LIMIT_MSG + "\n\n" + ret["reason"] + logger.info(f"rate limit reached. ip: {ip}. error_msg: {ret['reason']}") + state.conv.update_last_message(error_msg) + yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 + return + + conv, model_name = state.conv, state.model_name + model_api_dict = ( + api_endpoint_info[model_name] if model_name in api_endpoint_info else None + ) + images = conv.get_images() + + if model_api_dict is None: + # Query worker address + ret = requests.post( + controller_url + "/get_worker_address", json={"model": model_name} + ) + worker_addr = ret.json()["address"] + logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}") + + # No available worker + if worker_addr == "": + conv.update_last_message(SERVER_ERROR_MSG) + yield ( + state, + state.to_gradio_chatbot(), + disable_btn, + disable_btn, + disable_btn, + enable_btn, + enable_btn, + ) + return + + # Construct prompt. + # We need to call it here, so it will not be affected by "▌". + prompt = conv.get_prompt() + # Set repetition_penalty + if "t5" in model_name: + repetition_penalty = 1.2 + else: + repetition_penalty = 1.0 + + stream_iter = model_worker_stream_iter( + conv, + model_name, + worker_addr, + prompt, + temperature, + repetition_penalty, + top_p, + max_new_tokens, + images, + ) + else: + if use_recommended_config: + recommended_config = model_api_dict.get("recommended_config", None) + if recommended_config is not None: + temperature = recommended_config.get("temperature", temperature) + top_p = recommended_config.get("top_p", top_p) + max_new_tokens = recommended_config.get( + "max_new_tokens", max_new_tokens + ) + + stream_iter = get_api_provider_stream_iter( + conv, + model_name, + model_api_dict, + temperature, + top_p, + max_new_tokens, + state, + ) + + html_code = ' ' + + # conv.update_last_message("▌") + conv.update_last_message(html_code) + yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 + + try: + data = {"text": ""} + for i, data in enumerate(stream_iter): + if data["error_code"] == 0: + output = data["text"].strip() + # conv.update_last_message(output + "▌") + conv.update_last_message(output + html_code) + yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 + else: + output = data["text"] + f"\n\n(error_code: {data['error_code']})" + conv.update_last_message(output) + yield (state, state.to_gradio_chatbot()) + ( + disable_btn, + disable_btn, + disable_btn, + enable_btn, + enable_btn, + ) + return + output = data["text"].strip() + conv.update_last_message(output) + yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 + except requests.exceptions.RequestException as e: + conv.update_last_message( + f"{SERVER_ERROR_MSG}\n\n" + f"(error_code: {ErrorCode.GRADIO_REQUEST_ERROR}, {e})" + ) + yield (state, state.to_gradio_chatbot()) + ( + disable_btn, + disable_btn, + disable_btn, + enable_btn, + enable_btn, + ) + return + except Exception as e: + conv.update_last_message( + f"{SERVER_ERROR_MSG}\n\n" + f"(error_code: {ErrorCode.GRADIO_STREAM_UNKNOWN_ERROR}, {e})" + ) + yield (state, state.to_gradio_chatbot()) + ( + disable_btn, + disable_btn, + disable_btn, + enable_btn, + enable_btn, + ) + return + + finish_tstamp = time.time() + logger.info(f"{output}") + + conv.save_new_images( + has_csam_images=state.has_csam_image, use_remote_storage=use_remote_storage + ) + + filename = get_conv_log_filename( + is_vision=state.is_vision, has_csam_image=state.has_csam_image + ) + + with open(filename, "a") as fout: + data = { + "tstamp": round(finish_tstamp, 4), + "type": "chat", + "model": model_name, + "gen_params": { + "temperature": temperature, + "top_p": top_p, + "max_new_tokens": max_new_tokens, + }, + "start": round(start_tstamp, 4), + "finish": round(finish_tstamp, 4), + "state": state.dict(), + "ip": get_ip(request), + } + fout.write(json.dumps(data) + "\n") + get_remote_logger().log(data) + + +block_css = """ +#notice_markdown .prose { + font-size: 110% !important; +} +#notice_markdown th { + display: none; +} +#notice_markdown td { + padding-top: 6px; + padding-bottom: 6px; +} +#arena_leaderboard_dataframe table { + font-size: 110%; +} +#full_leaderboard_dataframe table { + font-size: 110%; +} +#model_description_markdown { + font-size: 110% !important; +} +#leaderboard_markdown .prose { + font-size: 110% !important; +} +#leaderboard_markdown td { + padding-top: 6px; + padding-bottom: 6px; +} +#leaderboard_dataframe td { + line-height: 0.1em; +} +#about_markdown .prose { + font-size: 110% !important; +} +#ack_markdown .prose { + font-size: 110% !important; +} +#chatbot .prose { + font-size: 105% !important; +} +.sponsor-image-about img { + margin: 0 20px; + margin-top: 20px; + height: 40px; + max-height: 100%; + width: auto; + float: left; +} + +.chatbot h1, h2, h3 { + margin-top: 8px; /* Adjust the value as needed */ + margin-bottom: 0px; /* Adjust the value as needed */ + padding-bottom: 0px; +} + +.chatbot h1 { + font-size: 130%; +} +.chatbot h2 { + font-size: 120%; +} +.chatbot h3 { + font-size: 110%; +} +.chatbot p:not(:first-child) { + margin-top: 8px; +} + +.typing { + display: inline-block; +} + +.cursor { + display: inline-block; + width: 7px; + height: 1em; + background-color: black; + vertical-align: middle; + animation: blink 1s infinite; +} + +.dark .cursor { + display: inline-block; + width: 7px; + height: 1em; + background-color: white; + vertical-align: middle; + animation: blink 1s infinite; +} + +@keyframes blink { + 0%, 50% { opacity: 1; } + 50.1%, 100% { opacity: 0; } +} + +.app { + max-width: 100% !important; + padding: 20px !important; +} + +a { + color: #1976D2; /* Your current link color, a shade of blue */ + text-decoration: none; /* Removes underline from links */ +} +a:hover { + color: #63A4FF; /* This can be any color you choose for hover */ + text-decoration: underline; /* Adds underline on hover */ +} +""" + + +def get_model_description_md(models): + model_description_md = """ +| | | | +| ---- | ---- | ---- | +""" + ct = 0 + visited = set() + for i, name in enumerate(models): + minfo = get_model_info(name) + if minfo.simple_name in visited: + continue + visited.add(minfo.simple_name) + one_model_md = f"[{minfo.simple_name}]({minfo.link}): {minfo.description}" + + if ct % 3 == 0: + model_description_md += "|" + model_description_md += f" {one_model_md} |" + if ct % 3 == 2: + model_description_md += "\n" + ct += 1 + return model_description_md + + +def build_about(): + about_markdown = """ +# About Us +Chatbot Arena is an open-source research project developed by members from [LMSYS](https://lmsys.org) and UC Berkeley [SkyLab](https://sky.cs.berkeley.edu/). Our mission is to build an open platform to evaluate LLMs by human preference in the real-world. +We open-source our [FastChat](https://github.com/lm-sys/FastChat) project at GitHub and release chat and human feedback dataset. We invite everyone to join us! + +## Arena Core Team +- [Lianmin Zheng](https://lmzheng.net/) (co-lead), [Wei-Lin Chiang](https://infwinston.github.io/) (co-lead), [Ying Sheng](https://sites.google.com/view/yingsheng/home), [Joseph E. Gonzalez](https://people.eecs.berkeley.edu/~jegonzal/), [Ion Stoica](http://people.eecs.berkeley.edu/~istoica/) + +## Past Members +- [Siyuan Zhuang](https://scholar.google.com/citations?user=KSZmI5EAAAAJ), [Hao Zhang](https://cseweb.ucsd.edu/~haozhang/) + +## Learn more +- Chatbot Arena [paper](https://arxiv.org/abs/2403.04132), [launch blog](https://lmsys.org/blog/2023-05-03-arena/), [dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md), [policy](https://lmsys.org/blog/2024-03-01-policy/) +- LMSYS-Chat-1M dataset [paper](https://arxiv.org/abs/2309.11998), LLM Judge [paper](https://arxiv.org/abs/2306.05685) + +## Contact Us +- Follow our [X](https://x.com/lmsysorg), [Discord](https://discord.gg/HSWAKCrnFx) or email us at lmsys.org@gmail.com +- File issues on [GitHub](https://github.com/lm-sys/FastChat) +- Download our datasets and models on [HuggingFace](https://huggingface.co/lmsys) + +## Acknowledgment +We thank [SkyPilot](https://github.com/skypilot-org/skypilot) and [Gradio](https://github.com/gradio-app/gradio) team for their system support. +We also thank [UC Berkeley SkyLab](https://sky.cs.berkeley.edu/), [Kaggle](https://www.kaggle.com/), [MBZUAI](https://mbzuai.ac.ae/), [a16z](https://www.a16z.com/), [Together AI](https://www.together.ai/), [Hyperbolic](https://hyperbolic.xyz/), [Anyscale](https://www.anyscale.com/), [HuggingFace](https://huggingface.co/) for their generous sponsorship. Learn more about partnership [here](https://lmsys.org/donations/). + + +""" + gr.Markdown(about_markdown, elem_id="about_markdown") + + +def build_single_model_ui(models, add_promotion_links=False): + promotion = ( + """ +- | [GitHub](https://github.com/lm-sys/FastChat) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) | +- Introducing Llama 2: The Next Generation Open Source Large Language Model. [[Website]](https://ai.meta.com/llama/) +- Vicuna: An Open-Source Chatbot Impressing GPT-4 with 90% ChatGPT Quality. [[Blog]](https://lmsys.org/blog/2023-03-30-vicuna/) + +## 🤖 Choose any model to chat +""" + if add_promotion_links + else "" + ) + + notice_markdown = f""" +# 🏔️ Chat with Open Large Language Models +{promotion} +""" + + state = gr.State() + gr.Markdown(notice_markdown, elem_id="notice_markdown") + + with gr.Group(elem_id="share-region-named"): + with gr.Row(elem_id="model_selector_row"): + model_selector = gr.Dropdown( + choices=models, + value=models[0] if len(models) > 0 else "", + interactive=True, + show_label=False, + container=False, + ) + with gr.Row(): + with gr.Accordion( + f"🔍 Expand to see the descriptions of {len(models)} models", + open=False, + ): + model_description_md = get_model_description_md(models) + gr.Markdown(model_description_md, elem_id="model_description_markdown") + + chatbot = gr.Chatbot( + elem_id="chatbot", + label="Scroll down and start chatting", + height=550, + show_copy_button=True, + ) + with gr.Row(): + textbox = gr.Textbox( + show_label=False, + placeholder="👉 Enter your prompt and press ENTER", + elem_id="input_box", + ) + send_btn = gr.Button(value="Send", variant="primary", scale=0) + + with gr.Row() as button_row: + upvote_btn = gr.Button(value="👍 Upvote", interactive=False) + downvote_btn = gr.Button(value="👎 Downvote", interactive=False) + flag_btn = gr.Button(value="⚠️ Flag", interactive=False) + regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) + clear_btn = gr.Button(value="🗑️ Clear history", interactive=False) + + with gr.Accordion("Parameters", open=False) as parameter_row: + temperature = gr.Slider( + minimum=0.0, + maximum=1.0, + value=0.7, + step=0.1, + interactive=True, + label="Temperature", + ) + top_p = gr.Slider( + minimum=0.0, + maximum=1.0, + value=1.0, + step=0.1, + interactive=True, + label="Top P", + ) + max_output_tokens = gr.Slider( + minimum=16, + maximum=2048, + value=1024, + step=64, + interactive=True, + label="Max output tokens", + ) + + if add_promotion_links: + gr.Markdown(acknowledgment_md, elem_id="ack_markdown") + + # Register listeners + imagebox = gr.State(None) + btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn] + upvote_btn.click( + upvote_last_response, + [state, model_selector], + [textbox, upvote_btn, downvote_btn, flag_btn], + ) + downvote_btn.click( + downvote_last_response, + [state, model_selector], + [textbox, upvote_btn, downvote_btn, flag_btn], + ) + flag_btn.click( + flag_last_response, + [state, model_selector], + [textbox, upvote_btn, downvote_btn, flag_btn], + ) + regenerate_btn.click( + regenerate, state, [state, chatbot, textbox, imagebox] + btn_list + ).then( + bot_response, + [state, temperature, top_p, max_output_tokens], + [state, chatbot] + btn_list, + ) + clear_btn.click(clear_history, None, [state, chatbot, textbox, imagebox] + btn_list) + + model_selector.change( + clear_history, None, [state, chatbot, textbox, imagebox] + btn_list + ) + + textbox.submit( + add_text, + [state, model_selector, textbox, imagebox], + [state, chatbot, textbox, imagebox] + btn_list, + ).then( + bot_response, + [state, temperature, top_p, max_output_tokens], + [state, chatbot] + btn_list, + ) + send_btn.click( + add_text, + [state, model_selector, textbox, imagebox], + [state, chatbot, textbox, imagebox] + btn_list, + ).then( + bot_response, + [state, temperature, top_p, max_output_tokens], + [state, chatbot] + btn_list, + ) + + return [state, model_selector] + + +def build_demo(models): + with gr.Blocks( + title="Chat with Open Large Language Models", + theme=gr.themes.Default(), + css=block_css, + ) as demo: + url_params = gr.JSON(visible=False) + + state, model_selector = build_single_model_ui(models) + + if args.model_list_mode not in ["once", "reload"]: + raise ValueError(f"Unknown model list mode: {args.model_list_mode}") + + if args.show_terms_of_use: + load_js = get_window_url_params_with_tos_js + else: + load_js = get_window_url_params_js + + demo.load( + load_demo, + [url_params], + [ + state, + model_selector, + ], + js=load_js, + ) + + return demo + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="0.0.0.0") + parser.add_argument("--port", type=int) + parser.add_argument( + "--share", + action="store_true", + help="Whether to generate a public, shareable link", + ) + parser.add_argument( + "--controller-url", + type=str, + default="http://localhost:21001", + help="The address of the controller", + ) + parser.add_argument( + "--concurrency-count", + type=int, + default=10, + help="The concurrency count of the gradio queue", + ) + parser.add_argument( + "--model-list-mode", + type=str, + default="once", + choices=["once", "reload"], + help="Whether to load the model list once or reload the model list every time", + ) + parser.add_argument( + "--moderate", + action="store_true", + help="Enable content moderation to block unsafe inputs", + ) + parser.add_argument( + "--show-terms-of-use", + action="store_true", + help="Shows term of use before loading the demo", + ) + parser.add_argument( + "--register-api-endpoint-file", + type=str, + help="Register API-based model endpoints from a JSON file", + ) + parser.add_argument( + "--gradio-auth-path", + type=str, + help='Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3"', + ) + parser.add_argument( + "--gradio-root-path", + type=str, + help="Sets the gradio root path, eg /abc/def. Useful when running behind a reverse-proxy or at a custom URL path prefix", + ) + parser.add_argument( + "--use-remote-storage", + action="store_true", + default=False, + help="Uploads image files to google cloud storage if set to true", + ) + args = parser.parse_args() + logger.info(f"args: {args}") + + # Set global variables + set_global_vars(args.controller_url, args.moderate, args.use_remote_storage) + models, all_models = get_model_list( + args.controller_url, args.register_api_endpoint_file, vision_arena=False + ) + + # Set authorization credentials + auth = None + if args.gradio_auth_path is not None: + auth = parse_gradio_auth_creds(args.gradio_auth_path) + + # Launch the demo + demo = build_demo(models) + demo.queue( + default_concurrency_limit=args.concurrency_count, + status_update_rate=10, + api_open=False, + ).launch( + server_name=args.host, + server_port=args.port, + share=args.share, + max_threads=200, + auth=auth, + root_path=args.gradio_root_path, + ) diff --git a/src/serve/gradio_web_server_multi.py b/src/serve/gradio_web_server_multi.py new file mode 100644 index 0000000000000000000000000000000000000000..d5dad71b0f5df465ad045c382f859d35c8a6d630 --- /dev/null +++ b/src/serve/gradio_web_server_multi.py @@ -0,0 +1,335 @@ +""" +The gradio demo server with multiple tabs. +It supports chatting with a single model or chatting with two models side-by-side. +""" + +import argparse +import pickle +import time + +import gradio as gr + +from fastchat.serve.gradio_block_arena_anony import ( + build_side_by_side_ui_anony, + load_demo_side_by_side_anony, + set_global_vars_anony, +) +from fastchat.serve.gradio_block_arena_named import ( + build_side_by_side_ui_named, + load_demo_side_by_side_named, + set_global_vars_named, +) +from fastchat.serve.gradio_block_arena_vision import ( + build_single_vision_language_model_ui, +) +from fastchat.serve.gradio_block_arena_vision_anony import ( + build_side_by_side_vision_ui_anony, + load_demo_side_by_side_vision_anony, +) +from fastchat.serve.gradio_block_arena_vision_named import ( + build_side_by_side_vision_ui_named, +) + +from fastchat.serve.gradio_web_server import ( + set_global_vars, + block_css, + build_single_model_ui, + build_about, + get_model_list, + load_demo_single, + get_ip, +) +from fastchat.serve.monitor.monitor import build_leaderboard_tab +from fastchat.utils import ( + build_logger, + get_window_url_params_js, + get_window_url_params_with_tos_js, + parse_gradio_auth_creds, +) + +logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log") + + +def load_demo(url_params, request: gr.Request): + global models, all_models, vl_models + + ip = get_ip(request) + logger.info(f"load_demo. ip: {ip}. params: {url_params}") + + selected = 0 + if "arena" in url_params: + selected = 0 + elif "compare" in url_params: + selected = 1 + elif "direct" in url_params or "model" in url_params: + selected = 2 + elif "vision" in url_params: + selected = 3 + elif "leaderboard" in url_params: + selected = 4 + elif "about" in url_params: + selected = 5 + + if args.model_list_mode == "reload": + models, all_models = get_model_list( + args.controller_url, + args.register_api_endpoint_file, + vision_arena=False, + ) + + vl_models, all_vl_models = get_model_list( + args.controller_url, + args.register_api_endpoint_file, + vision_arena=True, + ) + + single_updates = load_demo_single(models, url_params) + side_by_side_anony_updates = load_demo_side_by_side_anony(all_models, url_params) + side_by_side_named_updates = load_demo_side_by_side_named(models, url_params) + + vision_language_updates = load_demo_single(vl_models, url_params) + side_by_side_vision_named_updates = load_demo_side_by_side_named( + vl_models, url_params + ) + side_by_side_vision_anony_updates = load_demo_side_by_side_vision_anony( + vl_models, url_params + ) + + return ( + (gr.Tabs(selected=selected),) + + single_updates + + side_by_side_anony_updates + + side_by_side_named_updates + + side_by_side_vision_anony_updates + + side_by_side_vision_named_updates + + vision_language_updates + ) + + +def build_demo(models, vl_models, elo_results_file, leaderboard_table_file): + text_size = gr.themes.sizes.text_md + if args.show_terms_of_use: + load_js = get_window_url_params_with_tos_js + else: + load_js = get_window_url_params_js + + head_js = """ + +""" + if args.ga_id is not None: + head_js += f""" + + + """ + + with gr.Blocks( + title="Chat with Open Large Language Models", + theme=gr.themes.Default(text_size=text_size), + css=block_css, + head=head_js, + ) as demo: + with gr.Tabs() as tabs: + with gr.Tab("Text Arena", id=0): + with gr.Tab("⚔️ Arena (battle)", id=0): + side_by_side_anony_list = build_side_by_side_ui_anony(models) + + with gr.Tab("⚔️ Arena (side-by-side)", id=1): + side_by_side_named_list = build_side_by_side_ui_named(models) + + with gr.Tab("💬 Direct Chat", id=2): + single_model_list = build_single_model_ui( + models, add_promotion_links=True + ) + + demo_tabs = ( + [tabs] + + single_model_list + + side_by_side_anony_list + + side_by_side_named_list + ) + + if args.vision_arena: + with gr.Tab("Vision Arena", id=3): + with gr.Tab("⚔️ Vision Arena (battle)", id=3): + side_by_side_vision_anony_list = ( + build_side_by_side_vision_ui_anony( + vl_models, + random_questions=args.random_questions, + ) + ) + + with gr.Tab("⚔️ Vision Arena (side-by-side)", id=4): + side_by_side_vision_named_list = ( + build_side_by_side_vision_ui_named( + vl_models, + random_questions=args.random_questions, + ) + ) + + with gr.Tab("👀 Vision Direct Chat", id=5): + single_vision_language_model_list = ( + build_single_vision_language_model_ui( + vl_models, + add_promotion_links=True, + random_questions=args.random_questions, + ) + ) + demo_tabs += ( + side_by_side_vision_anony_list + + side_by_side_vision_named_list + + single_vision_language_model_list + ) + + if elo_results_file: + with gr.Tab("Leaderboard", id=6): + build_leaderboard_tab( + elo_results_file, leaderboard_table_file, show_plot=True + ) + + with gr.Tab("ℹ️ About Us", id=7): + about = build_about() + + url_params = gr.JSON(visible=False) + + if args.model_list_mode not in ["once", "reload"]: + raise ValueError(f"Unknown model list mode: {args.model_list_mode}") + + demo.load( + load_demo, + [url_params], + demo_tabs, + js=load_js, + ) + + return demo + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="0.0.0.0") + parser.add_argument("--port", type=int) + parser.add_argument( + "--share", + action="store_true", + help="Whether to generate a public, shareable link", + ) + parser.add_argument( + "--controller-url", + type=str, + default="http://localhost:21001", + help="The address of the controller", + ) + parser.add_argument( + "--concurrency-count", + type=int, + default=10, + help="The concurrency count of the gradio queue", + ) + parser.add_argument( + "--model-list-mode", + type=str, + default="once", + choices=["once", "reload"], + help="Whether to load the model list once or reload the model list every time.", + ) + parser.add_argument( + "--moderate", + action="store_true", + help="Enable content moderation to block unsafe inputs", + ) + parser.add_argument( + "--show-terms-of-use", + action="store_true", + help="Shows term of use before loading the demo", + ) + parser.add_argument( + "--vision-arena", action="store_true", help="Show tabs for vision arena." + ) + parser.add_argument( + "--random-questions", type=str, help="Load random questions from a JSON file" + ) + parser.add_argument( + "--register-api-endpoint-file", + type=str, + help="Register API-based model endpoints from a JSON file", + ) + parser.add_argument( + "--gradio-auth-path", + type=str, + help='Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3"', + default=None, + ) + parser.add_argument( + "--elo-results-file", type=str, help="Load leaderboard results and plots" + ) + parser.add_argument( + "--leaderboard-table-file", type=str, help="Load leaderboard results and plots" + ) + parser.add_argument( + "--gradio-root-path", + type=str, + help="Sets the gradio root path, eg /abc/def. Useful when running behind a reverse-proxy or at a custom URL path prefix", + ) + parser.add_argument( + "--ga-id", + type=str, + help="the Google Analytics ID", + default=None, + ) + parser.add_argument( + "--use-remote-storage", + action="store_true", + default=False, + help="Uploads image files to google cloud storage if set to true", + ) + args = parser.parse_args() + logger.info(f"args: {args}") + + # Set global variables + set_global_vars(args.controller_url, args.moderate, args.use_remote_storage) + set_global_vars_named(args.moderate) + set_global_vars_anony(args.moderate) + models, all_models = get_model_list( + args.controller_url, + args.register_api_endpoint_file, + vision_arena=False, + ) + + vl_models, all_vl_models = get_model_list( + args.controller_url, + args.register_api_endpoint_file, + vision_arena=True, + ) + + # Set authorization credentials + auth = None + if args.gradio_auth_path is not None: + auth = parse_gradio_auth_creds(args.gradio_auth_path) + + # Launch the demo + demo = build_demo( + models, + vl_models, + args.elo_results_file, + args.leaderboard_table_file, + ) + demo.queue( + default_concurrency_limit=args.concurrency_count, + status_update_rate=10, + api_open=False, + ).launch( + server_name=args.host, + server_port=args.port, + share=args.share, + max_threads=200, + auth=auth, + root_path=args.gradio_root_path, + show_api=False, + ) diff --git a/src/serve/huggingface_api.py b/src/serve/huggingface_api.py new file mode 100644 index 0000000000000000000000000000000000000000..8022fbc93e9f2d4240eb67ff95061928cee81bbd --- /dev/null +++ b/src/serve/huggingface_api.py @@ -0,0 +1,73 @@ +""" +Use FastChat with Hugging Face generation APIs. + +Usage: +python3 -m fastchat.serve.huggingface_api --model lmsys/vicuna-7b-v1.5 +python3 -m fastchat.serve.huggingface_api --model lmsys/fastchat-t5-3b-v1.0 +""" +import argparse + +import torch + +from fastchat.model import load_model, get_conversation_template, add_model_args + + +@torch.inference_mode() +def main(args): + # Load model + model, tokenizer = load_model( + args.model_path, + device=args.device, + num_gpus=args.num_gpus, + max_gpu_memory=args.max_gpu_memory, + load_8bit=args.load_8bit, + cpu_offloading=args.cpu_offloading, + revision=args.revision, + debug=args.debug, + ) + + # Build the prompt with a conversation template + msg = args.message + conv = get_conversation_template(args.model_path) + conv.append_message(conv.roles[0], msg) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + # Run inference + inputs = tokenizer([prompt], return_tensors="pt").to(args.device) + output_ids = model.generate( + **inputs, + do_sample=True if args.temperature > 1e-5 else False, + temperature=args.temperature, + repetition_penalty=args.repetition_penalty, + max_new_tokens=args.max_new_tokens, + ) + + if model.config.is_encoder_decoder: + output_ids = output_ids[0] + else: + output_ids = output_ids[0][len(inputs["input_ids"][0]) :] + outputs = tokenizer.decode( + output_ids, skip_special_tokens=True, spaces_between_special_tokens=False + ) + + # Print results + print(f"{conv.roles[0]}: {msg}") + print(f"{conv.roles[1]}: {outputs}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + add_model_args(parser) + parser.add_argument("--temperature", type=float, default=0.7) + parser.add_argument("--repetition_penalty", type=float, default=1.0) + parser.add_argument("--max-new-tokens", type=int, default=1024) + parser.add_argument("--debug", action="store_true") + parser.add_argument("--message", type=str, default="Hello! Who are you?") + args = parser.parse_args() + + # Reset default repetition penalty for T5 models. + if "t5" in args.model_path and args.repetition_penalty == 1.0: + args.repetition_penalty = 1.2 + + main(args) diff --git a/src/serve/huggingface_api_worker.py b/src/serve/huggingface_api_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..6ed8e6c8cc4a39904927532beab24874b7762a17 --- /dev/null +++ b/src/serve/huggingface_api_worker.py @@ -0,0 +1,415 @@ +""" +A model worker that calls huggingface inference endpoint. + +Register models in a JSON file with the following format: +{ + "falcon-180b-chat": { + "model_name": "falcon-180B-chat", + "api_base": "https://api-inference.huggingface.co/models", + "model_path": "tiiuae/falcon-180B-chat", + "token": "hf_XXX", + "context_length": 2048 + }, + "zephyr-7b-beta": { + "model_name": "zephyr-7b-beta", + "model_path": "", + "api_base": "xxx", + "token": "hf_XXX", + "context_length": 4096 + } +} + +"model_path", "api_base", "token", and "context_length" are necessary, while others are optional. +""" +import argparse +import asyncio +import json +import uuid +import os +from typing import List, Optional + +import requests +import uvicorn +from fastapi import BackgroundTasks, FastAPI, Request +from fastapi.responses import JSONResponse, StreamingResponse +from huggingface_hub import InferenceClient + +from fastchat.constants import SERVER_ERROR_MSG, ErrorCode +from fastchat.serve.base_model_worker import BaseModelWorker +from fastchat.utils import build_logger + +worker_id = str(uuid.uuid4())[:8] +logger = build_logger("model_worker", f"model_worker_{worker_id}.log") + +workers = [] +worker_map = {} +app = FastAPI() + + +# reference to +# https://github.com/philschmid/easyllm/blob/cbd908b3b3f44a97a22cb0fc2c93df3660bacdad/easyllm/clients/huggingface.py#L374-L392 +def get_gen_kwargs( + params, + seed: Optional[int] = None, +): + stop = params.get("stop", None) + if isinstance(stop, list): + stop_sequences = stop + elif isinstance(stop, str): + stop_sequences = [stop] + else: + stop_sequences = [] + gen_kwargs = { + "do_sample": True, + "return_full_text": bool(params.get("echo", False)), + "max_new_tokens": int(params.get("max_new_tokens", 256)), + "top_p": float(params.get("top_p", 1.0)), + "temperature": float(params.get("temperature", 1.0)), + "stop_sequences": stop_sequences, + "repetition_penalty": float(params.get("repetition_penalty", 1.0)), + "top_k": params.get("top_k", None), + "seed": seed, + } + if gen_kwargs["top_p"] == 1: + gen_kwargs["top_p"] = 0.9999999 + if gen_kwargs["top_p"] == 0: + gen_kwargs.pop("top_p") + if gen_kwargs["temperature"] == 0: + gen_kwargs.pop("temperature") + gen_kwargs["do_sample"] = False + return gen_kwargs + + +def could_be_stop(text, stop): + for s in stop: + if any(text.endswith(s[:i]) for i in range(1, len(s) + 1)): + return True + return False + + +class HuggingfaceApiWorker(BaseModelWorker): + def __init__( + self, + controller_addr: str, + worker_addr: str, + worker_id: str, + model_path: str, + api_base: str, + token: str, + context_length: int, + model_names: List[str], + limit_worker_concurrency: int, + no_register: bool, + conv_template: Optional[str] = None, + seed: Optional[int] = None, + **kwargs, + ): + super().__init__( + controller_addr, + worker_addr, + worker_id, + model_path, + model_names, + limit_worker_concurrency, + conv_template=conv_template, + ) + + self.model_path = model_path + self.api_base = api_base + self.token = token + self.context_len = context_length + self.seed = seed + + logger.info( + f"Connecting with huggingface api {self.model_path} as {self.model_names} on worker {worker_id} ..." + ) + + if not no_register: + self.init_heart_beat() + + def count_token(self, params): + # No tokenizer here + ret = { + "count": 0, + "error_code": 0, + } + return ret + + def generate_stream_gate(self, params): + self.call_ct += 1 + + prompt = params["prompt"] + gen_kwargs = get_gen_kwargs(params, seed=self.seed) + stop = gen_kwargs["stop_sequences"] + if "falcon" in self.model_path and "chat" in self.model_path: + stop.extend(["\nUser:", "<|endoftext|>", " User:", "###"]) + stop = list(set(stop)) + gen_kwargs["stop_sequences"] = stop + + logger.info(f"prompt: {prompt}") + logger.info(f"gen_kwargs: {gen_kwargs}") + + try: + if self.model_path == "": + url = f"{self.api_base}" + else: + url = f"{self.api_base}/{self.model_path}" + client = InferenceClient(url, token=self.token) + res = client.text_generation( + prompt, stream=True, details=True, **gen_kwargs + ) + + reason = None + text = "" + for chunk in res: + if chunk.token.special: + continue + text += chunk.token.text + + s = next((x for x in stop if text.endswith(x)), None) + if s is not None: + text = text[: -len(s)] + reason = "stop" + break + if could_be_stop(text, stop): + continue + if ( + chunk.details is not None + and chunk.details.finish_reason is not None + ): + reason = chunk.details.finish_reason + if reason not in ["stop", "length"]: + reason = None + ret = { + "text": text, + "error_code": 0, + "finish_reason": reason, + } + yield json.dumps(ret).encode() + b"\0" + except Exception as e: + ret = { + "text": f"{SERVER_ERROR_MSG}\n\n({e})", + "error_code": ErrorCode.INTERNAL_ERROR, + } + yield json.dumps(ret).encode() + b"\0" + + def generate_gate(self, params): + for x in self.generate_stream_gate(params): + pass + return json.loads(x[:-1].decode()) + + def get_embeddings(self, params): + raise NotImplementedError() + + +def release_worker_semaphore(worker): + worker.semaphore.release() + + +def acquire_worker_semaphore(worker): + if worker.semaphore is None: + worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency) + return worker.semaphore.acquire() + + +def create_background_tasks(worker): + background_tasks = BackgroundTasks() + background_tasks.add_task(lambda: release_worker_semaphore(worker)) + return background_tasks + + +@app.post("/worker_generate_stream") +async def api_generate_stream(request: Request): + params = await request.json() + worker = worker_map[params["model"]] + await acquire_worker_semaphore(worker) + generator = worker.generate_stream_gate(params) + background_tasks = create_background_tasks(worker) + return StreamingResponse(generator, background=background_tasks) + + +@app.post("/worker_generate") +async def api_generate(request: Request): + params = await request.json() + worker = worker_map[params["model"]] + await acquire_worker_semaphore(worker) + output = worker.generate_gate(params) + release_worker_semaphore(worker) + return JSONResponse(output) + + +@app.post("/worker_get_embeddings") +async def api_get_embeddings(request: Request): + params = await request.json() + worker = worker_map[params["model"]] + await acquire_worker_semaphore(worker) + embedding = worker.get_embeddings(params) + release_worker_semaphore(worker) + return JSONResponse(content=embedding) + + +@app.post("/worker_get_status") +async def api_get_status(request: Request): + return { + "model_names": [m for w in workers for m in w.model_names], + "speed": 1, + "queue_length": sum([w.get_queue_length() for w in workers]), + } + + +@app.post("/count_token") +async def api_count_token(request: Request): + params = await request.json() + worker = worker_map[params["model"]] + return worker.count_token(params) + + +@app.post("/worker_get_conv_template") +async def api_get_conv(request: Request): + params = await request.json() + worker = worker_map[params["model"]] + return worker.get_conv_template() + + +@app.post("/model_details") +async def api_model_details(request: Request): + params = await request.json() + worker = worker_map[params["model"]] + return {"context_length": worker.context_len} + + +def create_huggingface_api_worker(): + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=21002) + parser.add_argument("--worker-address", type=str, default="http://localhost:21002") + parser.add_argument( + "--controller-address", type=str, default="http://localhost:21001" + ) + # all model-related parameters are listed in --model-info-file + parser.add_argument( + "--model-info-file", + type=str, + required=True, + help="Huggingface API model's info file path", + ) + + parser.add_argument( + "--limit-worker-concurrency", + type=int, + default=5, + help="Limit the model concurrency to prevent OOM.", + ) + parser.add_argument("--no-register", action="store_true") + parser.add_argument( + "--seed", + type=int, + default=None, + help="Overwrite the random seed for each generation.", + ) + parser.add_argument( + "--ssl", + action="store_true", + required=False, + default=False, + help="Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.", + ) + args = parser.parse_args() + + with open(args.model_info_file, "r", encoding="UTF-8") as f: + model_info = json.load(f) + + logger.info(f"args: {args}") + + model_path_list = [] + api_base_list = [] + token_list = [] + context_length_list = [] + model_names_list = [] + conv_template_list = [] + + for m in model_info: + model_path_list.append(model_info[m]["model_path"]) + api_base_list.append(model_info[m]["api_base"]) + token_list.append(model_info[m]["token"]) + + context_length = model_info[m]["context_length"] + model_names = model_info[m].get("model_names", [m.split("/")[-1]]) + if isinstance(model_names, str): + model_names = [model_names] + conv_template = model_info[m].get("conv_template", None) + + context_length_list.append(context_length) + model_names_list.append(model_names) + conv_template_list.append(conv_template) + + logger.info(f"Model paths: {model_path_list}") + logger.info(f"API bases: {api_base_list}") + logger.info(f"Tokens: {token_list}") + logger.info(f"Context lengths: {context_length_list}") + logger.info(f"Model names: {model_names_list}") + logger.info(f"Conv templates: {conv_template_list}") + + for ( + model_names, + conv_template, + model_path, + api_base, + token, + context_length, + ) in zip( + model_names_list, + conv_template_list, + model_path_list, + api_base_list, + token_list, + context_length_list, + ): + m = HuggingfaceApiWorker( + args.controller_address, + args.worker_address, + worker_id, + model_path, + api_base, + token, + context_length, + model_names, + args.limit_worker_concurrency, + no_register=args.no_register, + conv_template=conv_template, + seed=args.seed, + ) + workers.append(m) + for name in model_names: + worker_map[name] = m + + # register all the models + url = args.controller_address + "/register_worker" + data = { + "worker_name": workers[0].worker_addr, + "check_heart_beat": not args.no_register, + "worker_status": { + "model_names": [m for w in workers for m in w.model_names], + "speed": 1, + "queue_length": sum([w.get_queue_length() for w in workers]), + }, + } + r = requests.post(url, json=data) + assert r.status_code == 200 + + return args, workers + + +if __name__ == "__main__": + args, workers = create_huggingface_api_worker() + if args.ssl: + uvicorn.run( + app, + host=args.host, + port=args.port, + log_level="info", + ssl_keyfile=os.environ["SSL_KEYFILE"], + ssl_certfile=os.environ["SSL_CERTFILE"], + ) + else: + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/src/serve/inference.py b/src/serve/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..6d155aab7c6cb32ca5fdb10d5661f1b298908f93 --- /dev/null +++ b/src/serve/inference.py @@ -0,0 +1,555 @@ +"""Inference for FastChat models.""" +import abc +import gc +import json +import math +import os +import sys +import time +from typing import Iterable, Optional, Dict +import warnings + +import psutil +import torch +from transformers import ( + AutoTokenizer, + AutoModelForCausalLM, + LlamaTokenizer, + LlamaForCausalLM, + AutoModel, + AutoModelForSeq2SeqLM, + T5Tokenizer, + AutoConfig, +) +from transformers.generation.logits_process import ( + LogitsProcessorList, + RepetitionPenaltyLogitsProcessor, + TemperatureLogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, +) + +from fastchat.conversation import get_conv_template, SeparatorStyle +from fastchat.model.model_adapter import ( + load_model, + get_conversation_template, + get_generate_stream_function, +) +from fastchat.modules.awq import AWQConfig +from fastchat.modules.gptq import GptqConfig +from fastchat.modules.exllama import ExllamaConfig +from fastchat.modules.xfastertransformer import XftConfig +from fastchat.utils import is_partial_stop, is_sentence_complete, get_context_length + + +def prepare_logits_processor( + temperature: float, repetition_penalty: float, top_p: float, top_k: int +) -> LogitsProcessorList: + processor_list = LogitsProcessorList() + # TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op so we skip two cases. + if temperature >= 1e-5 and temperature != 1.0: + processor_list.append(TemperatureLogitsWarper(temperature)) + if repetition_penalty > 1.0: + processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty)) + if 1e-8 <= top_p < 1.0: + processor_list.append(TopPLogitsWarper(top_p)) + if top_k > 0: + processor_list.append(TopKLogitsWarper(top_k)) + return processor_list + + +@torch.inference_mode() +def generate_stream( + model, + tokenizer, + params: Dict, + device: str, + context_len: int, + stream_interval: int = 2, + judge_sent_end: bool = False, +): + if hasattr(model, "device"): + device = model.device + + # Read parameters + prompt = params["prompt"] + len_prompt = len(prompt) + temperature = float(params.get("temperature", 1.0)) + repetition_penalty = float(params.get("repetition_penalty", 1.0)) + top_p = float(params.get("top_p", 1.0)) + top_k = int(params.get("top_k", -1)) # -1 means disable + max_new_tokens = int(params.get("max_new_tokens", 256)) + logprobs = params.get("logprobs", None) # FIXME: Support logprobs>1. + echo = bool(params.get("echo", True)) + stop_str = params.get("stop", None) + stop_token_ids = params.get("stop_token_ids", None) or [] + if tokenizer.eos_token_id not in stop_token_ids: + stop_token_ids.append(tokenizer.eos_token_id) + + logits_processor = prepare_logits_processor( + temperature, repetition_penalty, top_p, top_k + ) + input_ids = tokenizer(prompt).input_ids + + if model.config.is_encoder_decoder: + max_src_len = context_len + else: # truncate + max_src_len = context_len - max_new_tokens - 1 + + input_ids = input_ids[-max_src_len:] + output_ids = list(input_ids) + input_echo_len = len(input_ids) + + if model.config.is_encoder_decoder: + if logprobs is not None: # FIXME: Support logprobs for encoder-decoder models. + raise NotImplementedError + encoder_output = model.encoder( + input_ids=torch.as_tensor([input_ids], device=device) + )[0] + start_ids = torch.as_tensor( + [[model.generation_config.decoder_start_token_id]], + dtype=torch.int64, + device=device, + ) + else: + start_ids = torch.as_tensor([input_ids], device=device) + + past_key_values = out = None + token_logprobs = [None] # The first token has no logprobs. + sent_interrupt = False + finish_reason = None + stopped = False + for i in range(max_new_tokens): + if i == 0: # prefill + if model.config.is_encoder_decoder: + out = model.decoder( + input_ids=start_ids, + encoder_hidden_states=encoder_output, + use_cache=True, + ) + logits = model.lm_head(out[0]) + else: + out = model(input_ids=start_ids, use_cache=True) + logits = out.logits + past_key_values = out.past_key_values + + if logprobs is not None: + # Prefull logprobs for the prompt. + shift_input_ids = start_ids[..., 1:].contiguous() + shift_logits = logits[..., :-1, :].contiguous() + shift_logits = torch.log_softmax(shift_logits, dim=-1).tolist() + for label_id, logit in zip( + shift_input_ids[0].tolist(), shift_logits[0] + ): + token_logprobs.append(logit[label_id]) + else: # decoding + if model.config.is_encoder_decoder: + out = model.decoder( + input_ids=torch.as_tensor( + [[token] if not sent_interrupt else output_ids], + device=device, + ), + encoder_hidden_states=encoder_output, + use_cache=True, + past_key_values=past_key_values if not sent_interrupt else None, + ) + sent_interrupt = False + + logits = model.lm_head(out[0]) + else: + out = model( + input_ids=torch.as_tensor( + [[token] if not sent_interrupt else output_ids], + device=device, + ), + use_cache=True, + past_key_values=past_key_values if not sent_interrupt else None, + ) + sent_interrupt = False + logits = out.logits + past_key_values = out.past_key_values + + if logits_processor: + if repetition_penalty > 1.0: + tmp_output_ids = torch.as_tensor([output_ids], device=logits.device) + else: + tmp_output_ids = None + last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0] + else: + last_token_logits = logits[0, -1, :] + + if device == "mps": + # Switch to CPU by avoiding some bugs in mps backend. + last_token_logits = last_token_logits.float().to("cpu") + + if temperature < 1e-5 or top_p < 1e-8: # greedy + _, indices = torch.topk(last_token_logits, 2) + tokens = [int(index) for index in indices.tolist()] + else: + probs = torch.softmax(last_token_logits, dim=-1) + indices = torch.multinomial(probs, num_samples=2) + tokens = [int(token) for token in indices.tolist()] + token = tokens[0] + output_ids.append(token) + if logprobs is not None: + # Cannot use last_token_logits because logprobs is based on raw logits. + token_logprobs.append( + torch.log_softmax(logits[0, -1, :], dim=-1)[token].tolist() + ) + + if token in stop_token_ids: + stopped = True + else: + stopped = False + + # Yield the output tokens + if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped: + if echo: + tmp_output_ids = output_ids + rfind_start = len_prompt + else: + tmp_output_ids = output_ids[input_echo_len:] + rfind_start = 0 + + output = tokenizer.decode( + tmp_output_ids, + skip_special_tokens=True, + spaces_between_special_tokens=False, + clean_up_tokenization_spaces=True, + ) + ret_logprobs = None + if logprobs is not None: + ret_logprobs = { + "text_offset": [], + "tokens": [ + tokenizer.decode(token) + for token in ( + output_ids if echo else output_ids[input_echo_len:] + ) + ], + "token_logprobs": token_logprobs + if echo + else token_logprobs[input_echo_len:], + "top_logprobs": [{}] + * len(token_logprobs if echo else token_logprobs[input_echo_len:]), + } + # Compute text_offset + curr_pos = 0 + for text in ret_logprobs["tokens"]: + ret_logprobs["text_offset"].append(curr_pos) + curr_pos += len(text) + + # TODO: For the issue of incomplete sentences interrupting output, apply a patch and others can also modify it to a more elegant way + if judge_sent_end and stopped and not is_sentence_complete(output): + if len(tokens) > 1: + token = tokens[1] + output_ids[-1] = token + else: + output_ids.pop() + stopped = False + sent_interrupt = True + + partially_stopped = False + if stop_str: + if isinstance(stop_str, str): + pos = output.rfind(stop_str, rfind_start) + if pos != -1: + output = output[:pos] + stopped = True + else: + partially_stopped = is_partial_stop(output, stop_str) + elif isinstance(stop_str, Iterable): + for each_stop in stop_str: + pos = output.rfind(each_stop, rfind_start) + if pos != -1: + output = output[:pos] + stopped = True + break + else: + partially_stopped = is_partial_stop(output, each_stop) + if partially_stopped: + break + else: + raise ValueError("Invalid stop field type.") + + # Prevent yielding partial stop sequence + if not partially_stopped: + yield { + "text": output, + "logprobs": ret_logprobs, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": i, + "total_tokens": input_echo_len + i, + }, + "finish_reason": None, + } + + if stopped: + break + + # Finish stream event, which contains finish reason + else: + finish_reason = "length" + + if stopped: + finish_reason = "stop" + + yield { + "text": output, + "logprobs": ret_logprobs, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": i, + "total_tokens": input_echo_len + i, + }, + "finish_reason": finish_reason, + } + + # Clean + del past_key_values, out + gc.collect() + torch.cuda.empty_cache() + if device == "xpu": + torch.xpu.empty_cache() + if device == "npu": + torch.npu.empty_cache() + + +class ChatIO(abc.ABC): + @abc.abstractmethod + def prompt_for_input(self, role: str) -> str: + """Prompt for input from a role.""" + + @abc.abstractmethod + def prompt_for_output(self, role: str): + """Prompt for output from a role.""" + + @abc.abstractmethod + def stream_output(self, output_stream): + """Stream output.""" + + @abc.abstractmethod + def print_output(self, text: str): + """Print output.""" + + +def chat_loop( + model_path: str, + device: str, + num_gpus: int, + max_gpu_memory: str, + dtype: Optional[torch.dtype], + load_8bit: bool, + cpu_offloading: bool, + conv_template: Optional[str], + conv_system_msg: Optional[str], + temperature: float, + repetition_penalty: float, + max_new_tokens: int, + chatio: ChatIO, + gptq_config: Optional[GptqConfig] = None, + awq_config: Optional[AWQConfig] = None, + exllama_config: Optional[ExllamaConfig] = None, + xft_config: Optional[XftConfig] = None, + revision: str = "main", + judge_sent_end: bool = True, + debug: bool = True, + history: bool = True, +): + # Model + model, tokenizer = load_model( + model_path, + device=device, + num_gpus=num_gpus, + max_gpu_memory=max_gpu_memory, + dtype=dtype, + load_8bit=load_8bit, + cpu_offloading=cpu_offloading, + gptq_config=gptq_config, + awq_config=awq_config, + exllama_config=exllama_config, + xft_config=xft_config, + revision=revision, + debug=debug, + ) + generate_stream_func = get_generate_stream_function(model, model_path) + + model_type = str(type(model)).lower() + is_t5 = "t5" in model_type + is_codet5p = "codet5p" in model_type + is_xft = "xft" in model_type + + # Hardcode T5's default repetition penalty to be 1.2 + if is_t5 and repetition_penalty == 1.0: + repetition_penalty = 1.2 + + # Set context length + context_len = get_context_length(model.config) + + # Chat + def new_chat(): + if conv_template: + conv = get_conv_template(conv_template) + else: + conv = get_conversation_template(model_path) + if conv_system_msg is not None: + conv.set_system_message(conv_system_msg) + return conv + + def reload_conv(conv): + """ + Reprints the conversation from the start. + """ + for message in conv.messages[conv.offset :]: + chatio.prompt_for_output(message[0]) + chatio.print_output(message[1]) + + conv = None + + while True: + if not history or not conv: + conv = new_chat() + + try: + inp = chatio.prompt_for_input(conv.roles[0]) + except EOFError: + inp = "" + + if inp == "!!exit" or not inp: + print("exit...") + break + elif inp == "!!reset": + print("resetting...") + conv = new_chat() + continue + elif inp == "!!remove": + print("removing last message...") + if len(conv.messages) > conv.offset: + # Assistant + if conv.messages[-1][0] == conv.roles[1]: + conv.messages.pop() + # User + if conv.messages[-1][0] == conv.roles[0]: + conv.messages.pop() + reload_conv(conv) + else: + print("No messages to remove.") + continue + elif inp == "!!regen": + print("regenerating last message...") + if len(conv.messages) > conv.offset: + # Assistant + if conv.messages[-1][0] == conv.roles[1]: + conv.messages.pop() + # User + if conv.messages[-1][0] == conv.roles[0]: + reload_conv(conv) + # Set inp to previous message + inp = conv.messages.pop()[1] + else: + # Shouldn't happen in normal circumstances + print("No user message to regenerate from.") + continue + else: + print("No messages to regenerate.") + continue + elif inp.startswith("!!save"): + args = inp.split(" ", 1) + + if len(args) != 2: + print("usage: !!save ") + continue + else: + filename = args[1] + + # Add .json if extension not present + if not "." in filename: + filename += ".json" + + print("saving...", filename) + with open(filename, "w") as outfile: + json.dump(conv.dict(), outfile) + continue + elif inp.startswith("!!load"): + args = inp.split(" ", 1) + + if len(args) != 2: + print("usage: !!load ") + continue + else: + filename = args[1] + + # Check if file exists and add .json if needed + if not os.path.exists(filename): + if (not filename.endswith(".json")) and os.path.exists( + filename + ".json" + ): + filename += ".json" + else: + print("file not found:", filename) + continue + + print("loading...", filename) + with open(filename, "r") as infile: + new_conv = json.load(infile) + + conv = get_conv_template(new_conv["template_name"]) + conv.set_system_message(new_conv["system_message"]) + conv.messages = new_conv["messages"] + reload_conv(conv) + continue + + conv.append_message(conv.roles[0], inp) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + if is_codet5p: # codet5p is a code completion model. + prompt = inp + + gen_params = { + "model": model_path, + "prompt": prompt, + "temperature": temperature, + "repetition_penalty": repetition_penalty, + "max_new_tokens": max_new_tokens, + "stop": conv.stop_str, + "stop_token_ids": conv.stop_token_ids, + "echo": False, + } + + try: + chatio.prompt_for_output(conv.roles[1]) + output_stream = generate_stream_func( + model, + tokenizer, + gen_params, + device, + context_len=context_len, + judge_sent_end=judge_sent_end, + ) + t = time.time() + outputs = chatio.stream_output(output_stream) + duration = time.time() - t + conv.update_last_message(outputs.strip()) + + if debug: + num_tokens = len(tokenizer.encode(outputs)) + msg = { + "conv_template": conv.name, + "prompt": prompt, + "outputs": outputs, + "speed (token/s)": round(num_tokens / duration, 2), + } + print(f"\n{msg}\n") + + except KeyboardInterrupt: + print("stopped generation.") + # If generation didn't finish + if conv.messages[-1][1] is None: + conv.messages.pop() + # Remove last user message, so there isn't a double up + if conv.messages[-1][0] == conv.roles[0]: + conv.messages.pop() + + reload_conv(conv) diff --git a/src/serve/launch_all_serve.py b/src/serve/launch_all_serve.py new file mode 100644 index 0000000000000000000000000000000000000000..2f4ad7b0b134d1699ff8ba0d95d8039ec3c1f204 --- /dev/null +++ b/src/serve/launch_all_serve.py @@ -0,0 +1,284 @@ +""" +Usage: python launch_all_serve_by_shell.py --model-path-address "THUDM/chatglm2-6b@localhost@2021" "huggyllama/llama-7b@localhost@2022" + +Workers are listed in format of `model-path`@`host`@`port` + +The key mechanism behind this scripts is: + 1, execute shell cmd to launch the controller/worker/openai-api-server; + 2, check the log of controller/worker/openai-api-server to ensure that the serve is launched properly. +Note that a few of non-critical `fastchat.serve` cmd options are not supported currently. +""" +import sys +import os + +sys.path.append(os.path.dirname(os.path.dirname(__file__))) + +import subprocess +import re +import argparse + +LOGDIR = "./logs/" + +if not os.path.exists(LOGDIR): + os.makedirs(LOGDIR) + +parser = argparse.ArgumentParser() +# ------multi worker----------------- +parser.add_argument( + "--model-path-address", + default="THUDM/chatglm2-6b@localhost@20002", + nargs="+", + type=str, + help="model path, host, and port, formatted as model-path@host@port", +) +# ---------------controller------------------------- + +parser.add_argument("--controller-host", type=str, default="localhost") +parser.add_argument("--controller-port", type=int, default=21001) +parser.add_argument( + "--dispatch-method", + type=str, + choices=["lottery", "shortest_queue"], + default="shortest_queue", +) +controller_args = ["controller-host", "controller-port", "dispatch-method"] + +# ----------------------worker------------------------------------------ + +parser.add_argument("--worker-host", type=str, default="localhost") +parser.add_argument("--worker-port", type=int, default=21002) +# parser.add_argument("--worker-address", type=str, default="http://localhost:21002") +# parser.add_argument( +# "--controller-address", type=str, default="http://localhost:21001" +# ) +parser.add_argument( + "--model-path", + type=str, + default="lmsys/vicuna-7b-v1.5", + help="The path to the weights. This can be a local folder or a Hugging Face repo ID.", +) +parser.add_argument( + "--revision", + type=str, + default="main", + help="Hugging Face Hub model revision identifier", +) +parser.add_argument( + "--device", + type=str, + choices=["cpu", "cuda", "mps", "xpu", "npu"], + default="cuda", + help="The device type", +) +parser.add_argument( + "--gpus", + type=str, + default="0", + help="A single GPU like 1 or multiple GPUs like 0,2", +) +parser.add_argument("--num-gpus", type=int, default=1) +parser.add_argument( + "--max-gpu-memory", + type=str, + help="The maximum memory per gpu. Use a string like '13Gib'", +) +parser.add_argument("--load-8bit", action="store_true", help="Use 8-bit quantization") +parser.add_argument( + "--cpu-offloading", + action="store_true", + help="Only when using 8-bit quantization: Offload excess weights to the CPU that don't fit on the GPU", +) +parser.add_argument( + "--gptq-ckpt", + type=str, + default=None, + help="Load quantized model. The path to the local GPTQ checkpoint.", +) +parser.add_argument( + "--gptq-wbits", + type=int, + default=16, + choices=[2, 3, 4, 8, 16], + help="#bits to use for quantization", +) +parser.add_argument( + "--gptq-groupsize", + type=int, + default=-1, + help="Groupsize to use for quantization; default uses full row.", +) +parser.add_argument( + "--gptq-act-order", + action="store_true", + help="Whether to apply the activation order GPTQ heuristic", +) +parser.add_argument( + "--model-names", + type=lambda s: s.split(","), + help="Optional display comma separated names", +) +parser.add_argument( + "--limit-worker-concurrency", + type=int, + default=5, + help="Limit the model concurrency to prevent OOM.", +) +parser.add_argument("--stream-interval", type=int, default=2) +parser.add_argument("--no-register", action="store_true") + +worker_args = [ + "worker-host", + "worker-port", + "model-path", + "revision", + "device", + "gpus", + "num-gpus", + "max-gpu-memory", + "load-8bit", + "cpu-offloading", + "gptq-ckpt", + "gptq-wbits", + "gptq-groupsize", + "gptq-act-order", + "model-names", + "limit-worker-concurrency", + "stream-interval", + "no-register", + "controller-address", +] +# -----------------openai server--------------------------- + +parser.add_argument("--server-host", type=str, default="localhost", help="host name") +parser.add_argument("--server-port", type=int, default=8001, help="port number") +parser.add_argument( + "--allow-credentials", action="store_true", help="allow credentials" +) +# parser.add_argument( +# "--allowed-origins", type=json.loads, default=["*"], help="allowed origins" +# ) +# parser.add_argument( +# "--allowed-methods", type=json.loads, default=["*"], help="allowed methods" +# ) +# parser.add_argument( +# "--allowed-headers", type=json.loads, default=["*"], help="allowed headers" +# ) +parser.add_argument( + "--api-keys", + type=lambda s: s.split(","), + help="Optional list of comma separated API keys", +) +server_args = [ + "server-host", + "server-port", + "allow-credentials", + "api-keys", + "controller-address", +] + +args = parser.parse_args() + +args = argparse.Namespace( + **vars(args), + **{"controller-address": f"http://{args.controller_host}:{args.controller_port}"}, +) + +if args.gpus: + if len(args.gpus.split(",")) < args.num_gpus: + raise ValueError( + f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!" + ) + os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus + +# 0,controller, model_worker, openai_api_server +# 1, cmd options +# 2,LOGDIR +# 3, log file name +base_launch_sh = "nohup python3 -m fastchat.serve.{0} {1} >{2}/{3}.log 2>&1 &" + +# 0 LOGDIR +#! 1 log file name +# 2 controller, worker, openai_api_server +base_check_sh = """while [ `grep -c "Uvicorn running on" {0}/{1}.log` -eq '0' ];do + sleep 1s; + echo "wait {2} running" + done + echo '{2} running' """ + + +def string_args(args, args_list): + args_str = "" + for key, value in args._get_kwargs(): + key = key.replace("_", "-") + if key not in args_list: + continue + + key = key.split("-")[-1] if re.search("port|host", key) else key + if not value: + pass + # 1==True -> True + elif isinstance(value, bool) and value == True: + args_str += f" --{key} " + elif ( + isinstance(value, list) + or isinstance(value, tuple) + or isinstance(value, set) + ): + value = " ".join(value) + args_str += f" --{key} {value} " + else: + args_str += f" --{key} {value} " + + return args_str + + +def launch_worker(item): + log_name = ( + item.split("/")[-1] + .split("\\")[-1] + .replace("-", "_") + .replace("@", "_") + .replace(".", "_") + ) + + args.model_path, args.worker_host, args.worker_port = item.split("@") + print("*" * 80) + worker_str_args = string_args(args, worker_args) + print(worker_str_args) + worker_sh = base_launch_sh.format( + "model_worker", worker_str_args, LOGDIR, f"worker_{log_name}" + ) + worker_check_sh = base_check_sh.format(LOGDIR, f"worker_{log_name}", "model_worker") + subprocess.run(worker_sh, shell=True, check=True) + subprocess.run(worker_check_sh, shell=True, check=True) + + +def launch_all(): + controller_str_args = string_args(args, controller_args) + controller_sh = base_launch_sh.format( + "controller", controller_str_args, LOGDIR, "controller" + ) + controller_check_sh = base_check_sh.format(LOGDIR, "controller", "controller") + subprocess.run(controller_sh, shell=True, check=True) + subprocess.run(controller_check_sh, shell=True, check=True) + + if isinstance(args.model_path_address, str): + launch_worker(args.model_path_address) + else: + for idx, item in enumerate(args.model_path_address): + print(f"loading {idx}th model:{item}") + launch_worker(item) + + server_str_args = string_args(args, server_args) + server_sh = base_launch_sh.format( + "openai_api_server", server_str_args, LOGDIR, "openai_api_server" + ) + server_check_sh = base_check_sh.format( + LOGDIR, "openai_api_server", "openai_api_server" + ) + subprocess.run(server_sh, shell=True, check=True) + subprocess.run(server_check_sh, shell=True, check=True) + + +if __name__ == "__main__": + launch_all() diff --git a/src/serve/lightllm_worker.py b/src/serve/lightllm_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..ed0e21b68e3a9c03556937987910b590344b452f --- /dev/null +++ b/src/serve/lightllm_worker.py @@ -0,0 +1,512 @@ +""" +A model worker that executes the model based on LightLLM. + +See documentations at docs/lightllm_integration.md +""" + +import argparse +import asyncio +import json +import os +import torch +import uvicorn + +from transformers import AutoConfig + +from typing import List + +from fastapi import FastAPI, Request, BackgroundTasks +from fastapi.responses import StreamingResponse, JSONResponse + +from fastchat.serve.base_model_worker import BaseModelWorker +from fastchat.serve.model_worker import ( + logger, + worker_id, +) + +from lightllm.server.sampling_params import SamplingParams +from lightllm.server.multimodal_params import MultimodalParams +from lightllm.server.httpserver.manager import HttpServerManager +from lightllm.server.detokenization.manager import start_detokenization_process +from lightllm.server.router.manager import start_router_process +from lightllm.server.req_id_generator import ReqIDGenerator + +from lightllm.utils.net_utils import alloc_can_use_network_port +from lightllm.utils.start_utils import start_submodule_processes +from fastchat.utils import get_context_length, is_partial_stop + +app = FastAPI() +g_id_gen = ReqIDGenerator() + + +class LightLLMWorker(BaseModelWorker): + def __init__( + self, + controller_addr: str, + worker_addr: str, + worker_id: str, + model_path: str, + model_names: List[str], + limit_worker_concurrency: int, + no_register: bool, + conv_template: str, + tokenizer, + context_len, + ): + super().__init__( + controller_addr, + worker_addr, + worker_id, + model_path, + model_names, + limit_worker_concurrency, + conv_template, + ) + + logger.info( + f"Loading the model {self.model_names} on worker {worker_id}, worker type: LightLLM worker..." + ) + self.tokenizer = tokenizer + self.context_len = context_len + + self.is_first = True + + if not no_register: + self.init_heart_beat() + + async def generate_stream(self, params): + self.call_ct += 1 + + prompt = params.pop("prompt") + request_id = params.pop("request_id") + temperature = float(params.get("temperature", 1.0)) + top_p = float(params.get("top_p", 1.0)) + top_k = params.get("top_k", -1.0) + presence_penalty = float(params.get("presence_penalty", 0.0)) + frequency_penalty = float(params.get("frequency_penalty", 0.0)) + repetition_penalty = float(params.get("repetition_penalty", 1.0)) + max_new_tokens = params.get("max_new_tokens", 256) + echo = params.get("echo", True) + stop_str = params.get("stop", None) + stop_token_ids = params.get("stop_token_ids", None) or [] + if self.tokenizer.eos_token_id is not None: + stop_token_ids.append(self.tokenizer.eos_token_id) + + request = params.get("request", None) + + # Handle stop_str + stop = set() + if isinstance(stop_str, str) and stop_str != "": + stop.add(stop_str) + elif isinstance(stop_str, list) and stop_str != []: + stop.update(stop_str) + + for tid in stop_token_ids: + if tid is not None: + s = self.tokenizer.decode(tid) + if s != "": + stop.add(s) + + if self.is_first: + loop = asyncio.get_event_loop() + loop.create_task(httpserver_manager.handle_loop()) + self.is_first = False + + # make sampling params in vllm + top_p = max(top_p, 1e-5) + if temperature <= 1e-5: + top_p = 1.0 + + sampling_params = SamplingParams( + do_sample=temperature > 0.0, + temperature=temperature, + top_p=top_p, + top_k=top_k, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + repetition_penalty=repetition_penalty, + max_new_tokens=max_new_tokens, + stop_sequences=list(stop), + ) + sampling_params.verify() + + results_generator = httpserver_manager.generate( + prompt, sampling_params, request_id, MultimodalParams() + ) + + completion_tokens = 0 + text_outputs = "" + cumulative_logprob = 0.0 + + async for request_output, metadata, finish_status in results_generator: + text_outputs += request_output + completion_tokens += 1 + + partial_stop = any(is_partial_stop(text_outputs, i) for i in stop) + # prevent yielding partial stop sequence + if partial_stop: + continue + + if type(finish_status) is bool: # compatibility with old version + finish_reason = "stop" if finish_status else None + else: + finish_reason = finish_status.get_finish_reason() + + if request and await request.is_disconnected(): + await httpserver_manager.abort(request_id) + finish_reason = "abort" + + logprob = metadata.get("logprob", None) + if logprob is not None: + cumulative_logprob += logprob + + prompt_tokens = metadata["prompt_tokens"] + ret = { + "text": prompt + text_outputs if echo else text_outputs, + "error_code": 0, + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + }, + "cumulative_logprob": cumulative_logprob, + } + + if finish_reason is not None: + yield ( + json.dumps({**ret, "finish_reason": None}, ensure_ascii=False) + + "\0" + ).encode("utf-8") + yield ( + json.dumps({**ret, "finish_reason": finish_reason}, ensure_ascii=False) + + "\0" + ).encode("utf-8") + + if finish_reason is not None: # In case of abort, we need to break the loop + break + + async def generate(self, params): + async for x in self.generate_stream(params): + pass + return json.loads(x[:-1].decode()) + + +def release_worker_semaphore(): + worker.semaphore.release() + + +def acquire_worker_semaphore(): + if worker.semaphore is None: + worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency) + return worker.semaphore.acquire() + + +def create_background_tasks(request_id): + async def abort_request() -> None: + await httpserver_manager.abort(request_id) + + background_tasks = BackgroundTasks() + background_tasks.add_task(release_worker_semaphore) + background_tasks.add_task(abort_request) + return background_tasks + + +@app.post("/worker_generate_stream") +async def api_generate_stream(request: Request): + params = await request.json() + await acquire_worker_semaphore() + request_id = g_id_gen.generate_id() + params["request_id"] = request_id + params["request"] = request + generator = worker.generate_stream(params) + background_tasks = create_background_tasks(request_id) + return StreamingResponse(generator, background=background_tasks) + + +@app.post("/worker_generate") +async def api_generate(request: Request): + params = await request.json() + await acquire_worker_semaphore() + request_id = g_id_gen.generate_id() + params["request_id"] = request_id + params["request"] = request + output = await worker.generate(params) + release_worker_semaphore() + await httpserver_manager.abort(request_id) + return JSONResponse(output) + + +@app.post("/worker_get_status") +async def api_get_status(request: Request): + return worker.get_status() + + +@app.post("/count_token") +async def api_count_token(request: Request): + params = await request.json() + return worker.count_token(params) + + +@app.post("/worker_get_conv_template") +async def api_get_conv(request: Request): + return worker.get_conv_template() + + +@app.post("/model_details") +async def api_model_details(request: Request): + return {"context_length": worker.context_len} + + +if __name__ == "__main__": + torch.multiprocessing.set_start_method("spawn") + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="127.0.0.1") + parser.add_argument("--port", type=int, default=8000) + + parser.add_argument( + "--model-path", + dest="model_dir", + type=str, + default=None, + help="the model weight dir path, the app will load config, weights and tokenizer from this dir", + ) + parser.add_argument("--worker-address", type=str, default="http://localhost:21002") + parser.add_argument( + "--controller-address", type=str, default="http://localhost:21001" + ) + parser.add_argument( + "--conv-template", type=str, default=None, help="Conversation prompt template." + ) + parser.add_argument( + "--model-names", + type=lambda s: s.split(","), + help="Optional display comma separated names", + ) + parser.add_argument("--limit-worker-concurrency", type=int, default=1024) + parser.add_argument("--no-register", action="store_true") + + parser.add_argument( + "--tokenizer_mode", + type=str, + default="slow", + help="""tokenizer load mode, can be slow or auto, slow mode load fast but run slow, slow mode is good for debug and test, + when you want to get best performance, try auto mode""", + ) + parser.add_argument( + "--load_way", + type=str, + default="HF", + help="the way of loading model weights, the default is HF(Huggingface format), llama also supports DS(Deepspeed)", + ) + parser.add_argument( + "--max_total_token_num", + type=int, + default=6000, + help="the total token nums the gpu and model can support, equals = max_batch * (input_len + output_len)", + ) + parser.add_argument( + "--batch_max_tokens", + type=int, + default=None, + help="max tokens num for new cat batch, it control prefill batch size to Preventing OOM", + ) + parser.add_argument("--eos_id", type=int, default=2, help="eos stop token id") + parser.add_argument( + "--running_max_req_size", + type=int, + default=1000, + help="the max size for forward requests in the same time", + ) + parser.add_argument( + "--tp", type=int, default=1, help="model tp parral size, the default is 1" + ) + parser.add_argument( + "--max_req_input_len", + type=int, + default=None, + help="the max value for req input tokens num. If None, it will be derived from the config.", + ) + parser.add_argument( + "--max_req_total_len", + type=int, + default=None, + help="the max value for req_input_len + req_output_len. If None, it will be derived from the config.", + ) + parser.add_argument( + "--mode", + type=str, + default=[], + nargs="+", + help="""Model mode: [triton_int8kv | ppl_int8kv | ppl_fp16 | triton_flashdecoding + | triton_gqa_attention | triton_gqa_flashdecoding] + [triton_int8weight | triton_int4weight | lmdeploy_int4weight | ppl_int4weight], + triton_flashdecoding mode is for long context, current support llama llama2 qwen; + triton_gqa_attention and triton_gqa_flashdecoding is fast kernel for model which use GQA; + triton_int8kv mode use int8 to store kv cache, can increase token capacity, use triton kernel; + ppl_int8kv mode use int8 to store kv cache, and use ppl fast kernel; + ppl_fp16 mode use ppl fast fp16 decode attention kernel; + triton_int8weight and triton_int4weight and lmdeploy_int4weight or ppl_int4weight mode use int8 and int4 to store weights; + you need to read source code to make sure the supported detail mode for all models""", + ) + parser.add_argument( + "--trust_remote_code", + action="store_true", + help="Whether or not to allow for custom models defined on the Hub in their own modeling files.", + ) + parser.add_argument( + "--disable_log_stats", + action="store_true", + help="disable logging throughput stats.", + ) + parser.add_argument( + "--log_stats_interval", + type=int, + default=10, + help="log stats interval in second.", + ) + + parser.add_argument( + "--router_token_ratio", + type=float, + default=0.0, + help="token ratio to control router dispatch", + ) + parser.add_argument( + "--router_max_new_token_len", + type=int, + default=1024, + help="the request max new token len for router", + ) + + parser.add_argument( + "--no_skipping_special_tokens", + action="store_true", + help="whether to skip special tokens when decoding", + ) + parser.add_argument( + "--no_spaces_between_special_tokens", + action="store_true", + help="whether to add spaces between special tokens when decoding", + ) + + parser.add_argument( + "--splitfuse_mode", action="store_true", help="use splitfuse mode" + ) + parser.add_argument( + "--splitfuse_block_size", type=int, default=256, help="splitfuse block size" + ) + parser.add_argument( + "--prompt_cache_strs", + type=str, + default=[], + nargs="+", + help="""prompt cache strs""", + ) + parser.add_argument( + "--cache_capacity", + type=int, + default=200, + help="cache server capacity for multimodal resources", + ) + parser.add_argument( + "--cache_reserved_ratio", + type=float, + default=0.5, + help="cache server reserved capacity ratio after clear", + ) + parser.add_argument( + "--return_all_prompt_logprobs", + action="store_true", + help="return all prompt tokens logprobs", + ) + parser.add_argument( + "--long_truncation_mode", + type=str, + choices=[None, "head", "center"], + default=None, + help="""use to select the handle way when input token len > max_req_input_len. + None : raise Exception + head : remove some head tokens to make input token len <= max_req_input_len + center : remove some tokens in center loc to make input token len <= max_req_input_len""", + ) + + args = parser.parse_args() + + # 非splitfuse 模式,不支持 prompt cache 特性 + if not args.splitfuse_mode: + assert len(args.prompt_cache_strs) == 0 + + model_config = AutoConfig.from_pretrained(args.model_dir) + context_length = get_context_length(model_config) + + if args.max_req_input_len is None: + args.max_req_input_len = context_length - 1 + if args.max_req_total_len is None: + args.max_req_total_len = context_length + + assert args.max_req_input_len < args.max_req_total_len + assert args.max_req_total_len <= args.max_total_token_num + + if not args.splitfuse_mode: + # 普通模式下 + if args.batch_max_tokens is None: + batch_max_tokens = int(1 / 6 * args.max_total_token_num) + batch_max_tokens = max(batch_max_tokens, args.max_req_total_len) + args.batch_max_tokens = batch_max_tokens + else: + assert ( + args.batch_max_tokens >= args.max_req_total_len + ), "batch_max_tokens must >= max_req_total_len" + else: + # splitfuse 模式下 + # assert args.batch_max_tokens is not None, "need to set by yourself" + if args.batch_max_tokens is None: + batch_max_tokens = int(1 / 6 * args.max_total_token_num) + batch_max_tokens = max(batch_max_tokens, args.splitfuse_block_size) + args.batch_max_tokens = batch_max_tokens + + can_use_ports = alloc_can_use_network_port(num=6 + args.tp) + + assert can_use_ports is not None, "Can not alloc enough free ports." + ( + router_port, + detokenization_port, + httpserver_port, + visual_port, + cache_port, + nccl_port, + ) = can_use_ports[0:6] + args.nccl_port = nccl_port + model_rpc_ports = can_use_ports[6:] + + global httpserver_manager + httpserver_manager = HttpServerManager( + args, + router_port=router_port, + cache_port=cache_port, + visual_port=visual_port, + httpserver_port=httpserver_port, + enable_multimodal=False, + ) + + start_submodule_processes( + start_funcs=[start_router_process, start_detokenization_process], + start_args=[ + (args, router_port, detokenization_port, model_rpc_ports), + (args, detokenization_port, httpserver_port), + ], + ) + worker = LightLLMWorker( + args.controller_address, + args.worker_address, + worker_id, + args.model_dir, + args.model_names, + args.limit_worker_concurrency, + args.no_register, + args.conv_template, + httpserver_manager.tokenizer, + context_length, + ) + + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/src/serve/mlx_worker.py b/src/serve/mlx_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..a7e85f848eadbf2e492e0298f5a3609e50f12e59 --- /dev/null +++ b/src/serve/mlx_worker.py @@ -0,0 +1,288 @@ +""" +A model worker using Apple MLX + +https://github.com/ml-explore/mlx-examples/tree/main/llms + +Code based on vllm_worker https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/vllm_worker.py + +You must install MLX python: + +pip install mlx-lm +""" + +import argparse +import asyncio +import atexit +import json +from typing import List +import uuid + +from fastapi import FastAPI, Request, BackgroundTasks +from fastapi.concurrency import run_in_threadpool +from fastapi.responses import StreamingResponse, JSONResponse +import uvicorn + +from fastchat.serve.base_model_worker import BaseModelWorker +from fastchat.serve.model_worker import ( + logger, + worker_id, +) +from fastchat.utils import get_context_length, is_partial_stop + +import mlx.core as mx +from mlx_lm import load, generate +from mlx_lm.utils import generate_step + +app = FastAPI() + + +class MLXWorker(BaseModelWorker): + def __init__( + self, + controller_addr: str, + worker_addr: str, + worker_id: str, + model_path: str, + model_names: List[str], + limit_worker_concurrency: int, + no_register: bool, + llm_engine: "MLX", + conv_template: str, + ): + super().__init__( + controller_addr, + worker_addr, + worker_id, + model_path, + model_names, + limit_worker_concurrency, + conv_template, + ) + + logger.info( + f"Loading the model {self.model_names} on worker {worker_id}, worker type: MLX worker..." + ) + + self.model_name = model_path + self.mlx_model, self.mlx_tokenizer = load(model_path) + + self.tokenizer = self.mlx_tokenizer + # self.context_len = get_context_length( + # llm_engine.engine.model_config.hf_config) + self.context_len = 2048 # hard code for now -- not sure how to get in MLX + + if not no_register: + self.init_heart_beat() + + async def generate_stream(self, params): + self.call_ct += 1 + + context = params.pop("prompt") + request_id = params.pop("request_id") + temperature = float(params.get("temperature", 1.0)) + top_p = float(params.get("top_p", 1.0)) + top_k = params.get("top_k", -1.0) + presence_penalty = float(params.get("presence_penalty", 0.0)) + frequency_penalty = float(params.get("frequency_penalty", 0.0)) + max_new_tokens = params.get("max_new_tokens", 256) + stop_str = params.get("stop", None) + stop_token_ids = params.get("stop_token_ids", None) or [] + if self.tokenizer.eos_token_id is not None: + stop_token_ids.append(self.tokenizer.eos_token_id) + echo = params.get("echo", True) + use_beam_search = params.get("use_beam_search", False) + best_of = params.get("best_of", None) + + # Handle stop_str + stop = set() + if isinstance(stop_str, str) and stop_str != "": + stop.add(stop_str) + elif isinstance(stop_str, list) and stop_str != []: + stop.update(stop_str) + + for tid in stop_token_ids: + if tid is not None: + s = self.tokenizer.decode(tid) + if s != "": + stop.add(s) + + print("Stop patterns: ", stop) + + top_p = max(top_p, 1e-5) + if temperature <= 1e-5: + top_p = 1.0 + + tokens = [] + skip = 0 + + context_mlx = mx.array(self.tokenizer.encode(context)) + + finish_reason = "length" + + iterator = await run_in_threadpool( + generate_step, context_mlx, self.mlx_model, temperature + ) + + for i in range(max_new_tokens): + (token, _) = await run_in_threadpool(next, iterator) + if token == self.mlx_tokenizer.eos_token_id: + finish_reason = "stop" + break + tokens.append(token.item()) + tokens_decoded = self.mlx_tokenizer.decode(tokens) + last_token_decoded = self.mlx_tokenizer.decode([token.item()]) + skip = len(tokens_decoded) + + partial_stop = any(is_partial_stop(tokens_decoded, i) for i in stop) + + if partial_stop: + finish_reason = "stop" + break + + ret = { + "text": tokens_decoded, + "error_code": 0, + "usage": { + "prompt_tokens": len(context), + "completion_tokens": len(tokens), + "total_tokens": len(context) + len(tokens), + }, + "cumulative_logprob": [], + "finish_reason": None, # hard code for now + } + # print(ret) + yield (json.dumps(ret) + "\0").encode() + ret = { + "text": self.mlx_tokenizer.decode(tokens), + "error_code": 0, + "usage": {}, + "cumulative_logprob": [], + "finish_reason": finish_reason, + } + yield (json.dumps(obj={**ret, **{"finish_reason": None}}) + "\0").encode() + yield (json.dumps(ret) + "\0").encode() + + async def generate(self, params): + async for x in self.generate_stream(params): + pass + return json.loads(x[:-1].decode()) + + +def release_worker_semaphore(): + worker.semaphore.release() + + +def acquire_worker_semaphore(): + if worker.semaphore is None: + worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency) + return worker.semaphore.acquire() + + +def create_background_tasks(request_id): + async def abort_request() -> None: + print("trying to abort but not implemented") + + background_tasks = BackgroundTasks() + background_tasks.add_task(release_worker_semaphore) + background_tasks.add_task(abort_request) + return background_tasks + + +@app.post("/worker_generate_stream") +async def api_generate_stream(request: Request): + params = await request.json() + await acquire_worker_semaphore() + request_id = uuid.uuid4() + params["request_id"] = str(request_id) + generator = worker.generate_stream(params) + background_tasks = create_background_tasks(request_id) + return StreamingResponse(generator, background=background_tasks) + + +@app.post("/worker_generate") +async def api_generate(request: Request): + params = await request.json() + await acquire_worker_semaphore() + request_id = uuid.uuid4() + params["request_id"] = str(request_id) + output = await worker.generate(params) + release_worker_semaphore() + # await engine.abort(request_id) + print("Trying to abort but not implemented") + return JSONResponse(output) + + +@app.post("/worker_get_status") +async def api_get_status(request: Request): + return worker.get_status() + + +@app.post("/count_token") +async def api_count_token(request: Request): + params = await request.json() + return worker.count_token(params) + + +@app.post("/worker_get_conv_template") +async def api_get_conv(request: Request): + return worker.get_conv_template() + + +@app.post("/model_details") +async def api_model_details(request: Request): + return {"context_length": worker.context_len} + + +worker = None + + +def cleanup_at_exit(): + global worker + print("Cleaning up...") + del worker + + +atexit.register(cleanup_at_exit) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=21002) + parser.add_argument("--worker-address", type=str, default="http://localhost:21002") + parser.add_argument( + "--controller-address", type=str, default="http://localhost:21001" + ) + parser.add_argument("--model-path", type=str, default="microsoft/phi-2") + parser.add_argument( + "--model-names", + type=lambda s: s.split(","), + help="Optional display comma separated names", + ) + parser.add_argument( + "--conv-template", type=str, default=None, help="Conversation prompt template." + ) + parser.add_argument( + "--trust_remote_code", + action="store_false", + default=True, + help="Trust remote code (e.g., from HuggingFace) when" + "downloading the model and tokenizer.", + ) + + args, unknown = parser.parse_known_args() + + if args.model_path: + args.model = args.model_path + + worker = MLXWorker( + args.controller_address, + args.worker_address, + worker_id, + args.model_path, + args.model_names, + 1024, + False, + "MLX", + args.conv_template, + ) + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/src/serve/model_worker.py b/src/serve/model_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..683a78556dd6062e395d23effa3faa77a422cf58 --- /dev/null +++ b/src/serve/model_worker.py @@ -0,0 +1,425 @@ +""" +A model worker that executes the model. +""" +import argparse +import base64 +import gc +import json +import os +from typing import List, Optional +import uuid + +import torch +import torch.nn.functional as F +from transformers import set_seed +import uvicorn + +from fastchat.constants import ErrorCode, SERVER_ERROR_MSG +from fastchat.model.model_adapter import ( + load_model, + add_model_args, + get_generate_stream_function, +) +from fastchat.modules.awq import AWQConfig +from fastchat.modules.exllama import ExllamaConfig +from fastchat.modules.xfastertransformer import XftConfig +from fastchat.modules.gptq import GptqConfig +from fastchat.serve.base_model_worker import BaseModelWorker, app +from fastchat.utils import ( + build_logger, + get_context_length, + str_to_torch_dtype, +) + +worker_id = str(uuid.uuid4())[:8] +logger = build_logger("model_worker", f"model_worker_{worker_id}.log") + + +class ModelWorker(BaseModelWorker): + def __init__( + self, + controller_addr: str, + worker_addr: str, + worker_id: str, + model_path: str, + model_names: List[str], + limit_worker_concurrency: int, + no_register: bool, + device: str, + num_gpus: int, + max_gpu_memory: str, + revision: str = None, + dtype: Optional[torch.dtype] = None, + load_8bit: bool = False, + cpu_offloading: bool = False, + gptq_config: Optional[GptqConfig] = None, + awq_config: Optional[AWQConfig] = None, + exllama_config: Optional[ExllamaConfig] = None, + xft_config: Optional[XftConfig] = None, + stream_interval: int = 2, + conv_template: Optional[str] = None, + embed_in_truncate: bool = False, + seed: Optional[int] = None, + debug: bool = False, + **kwargs, + ): + super().__init__( + controller_addr, + worker_addr, + worker_id, + model_path, + model_names, + limit_worker_concurrency, + conv_template=conv_template, + ) + + logger.info(f"Loading the model {self.model_names} on worker {worker_id} ...") + self.model, self.tokenizer = load_model( + model_path, + revision=revision, + device=device, + num_gpus=num_gpus, + max_gpu_memory=max_gpu_memory, + dtype=dtype, + load_8bit=load_8bit, + cpu_offloading=cpu_offloading, + gptq_config=gptq_config, + awq_config=awq_config, + exllama_config=exllama_config, + xft_config=xft_config, + debug=debug, + ) + self.device = device + if self.tokenizer.pad_token == None: + self.tokenizer.pad_token = self.tokenizer.eos_token + self.context_len = get_context_length(self.model.config) + self.generate_stream_func = get_generate_stream_function(self.model, model_path) + self.stream_interval = stream_interval + self.embed_in_truncate = embed_in_truncate + self.seed = seed + + if not no_register: + self.init_heart_beat() + + def generate_stream_gate(self, params): + if self.device == "npu": + import torch_npu + + torch_npu.npu.set_device("npu:0") + self.call_ct += 1 + + try: + if self.seed is not None: + set_seed(self.seed) + for output in self.generate_stream_func( + self.model, + self.tokenizer, + params, + self.device, + self.context_len, + self.stream_interval, + ): + ret = { + "text": output["text"], + "error_code": 0, + } + if "usage" in output: + ret["usage"] = output["usage"] + if "finish_reason" in output: + ret["finish_reason"] = output["finish_reason"] + if "logprobs" in output: + ret["logprobs"] = output["logprobs"] + yield json.dumps(ret).encode() + b"\0" + except torch.cuda.OutOfMemoryError as e: + ret = { + "text": f"{SERVER_ERROR_MSG}\n\n({e})", + "error_code": ErrorCode.CUDA_OUT_OF_MEMORY, + } + yield json.dumps(ret).encode() + b"\0" + except (ValueError, RuntimeError) as e: + ret = { + "text": f"{SERVER_ERROR_MSG}\n\n({e})", + "error_code": ErrorCode.INTERNAL_ERROR, + } + yield json.dumps(ret).encode() + b"\0" + + def generate_gate(self, params): + for x in self.generate_stream_gate(params): + pass + return json.loads(x[:-1].decode()) + + def __process_embed_chunk(self, input_ids, attention_mask, **model_type_dict): + if model_type_dict.get("is_bert"): + model_output = self.model(input_ids) + if model_type_dict.get("is_robert"): + data = model_output.last_hidden_state + else: + data = model_output[0] + elif model_type_dict.get("is_t5"): + model_output = self.model(input_ids, decoder_input_ids=input_ids) + data = model_output.encoder_last_hidden_state + else: + model_output = self.model(input_ids, output_hidden_states=True) + if model_type_dict.get("is_chatglm"): + data = model_output.hidden_states[-1].transpose(0, 1) + else: + data = model_output.hidden_states[-1] + + if hasattr(self.model, "use_cls_pooling") and self.model.use_cls_pooling: + sum_embeddings = data[:, 0] + else: + mask = attention_mask.unsqueeze(-1).expand(data.size()).float() + masked_embeddings = data * mask + sum_embeddings = torch.sum(masked_embeddings, dim=1) + token_num = torch.sum(attention_mask).item() + + return sum_embeddings, token_num + + def __encode_base64(self, embeddings: torch.Tensor) -> List[str]: + embeddings = embeddings.cpu() + return [ + base64.b64encode(e.numpy().tobytes()).decode("utf-8") for e in embeddings + ] + + @torch.inference_mode() + def get_embeddings(self, params): + self.call_ct += 1 + + try: + tokenizer = self.tokenizer + ret = {"embedding": [], "token_num": 0} + + model_type_dict = { + "is_llama": "llama" in str(type(self.model)), + "is_t5": "t5" in str(type(self.model)), + "is_chatglm": "chatglm" in str(type(self.model)), + "is_bert": "bert" in str(type(self.model)), + "is_robert": "robert" in str(type(self.model)), + } + + if self.embed_in_truncate: + encoding = tokenizer.batch_encode_plus( + params["input"], + padding=True, + truncation="longest_first", + return_tensors="pt", + max_length=self.context_len, + ) + else: + encoding = tokenizer.batch_encode_plus( + params["input"], padding=True, return_tensors="pt" + ) + input_ids = encoding["input_ids"].to(self.device) + attention_mask = input_ids != tokenizer.pad_token_id + + base64_encode = params.get("encoding_format", None) + + if self.embed_in_truncate: + embedding, token_num = self.__process_embed_chunk( + input_ids, attention_mask, **model_type_dict + ) + if ( + not hasattr(self.model, "use_cls_pooling") + or not self.model.use_cls_pooling + ): + embedding = embedding / token_num + normalized_embeddings = F.normalize(embedding, p=2, dim=1) + ret["token_num"] = token_num + else: + all_embeddings = [] + all_token_num = 0 + for i in range(0, input_ids.size(1), self.context_len): + chunk_input_ids = input_ids[:, i : i + self.context_len] + chunk_attention_mask = attention_mask[:, i : i + self.context_len] + + # add cls token and mask to get cls embedding + if ( + hasattr(self.model, "use_cls_pooling") + and self.model.use_cls_pooling + ): + cls_tokens = ( + torch.zeros( + (chunk_input_ids.size(0), 1), + dtype=chunk_input_ids.dtype, + device=chunk_input_ids.device, + ) + + tokenizer.cls_token_id + ) + chunk_input_ids = torch.cat( + [cls_tokens, chunk_input_ids], dim=-1 + ) + mask = torch.ones( + (chunk_attention_mask.size(0), 1), + dtype=chunk_attention_mask.dtype, + device=chunk_attention_mask.device, + ) + chunk_attention_mask = torch.cat( + [mask, chunk_attention_mask], dim=-1 + ) + + chunk_embeddings, token_num = self.__process_embed_chunk( + chunk_input_ids, chunk_attention_mask, **model_type_dict + ) + if ( + hasattr(self.model, "use_cls_pooling") + and self.model.use_cls_pooling + ): + all_embeddings.append(chunk_embeddings * token_num) + else: + all_embeddings.append(chunk_embeddings) + all_token_num += token_num + + all_embeddings_tensor = torch.stack(all_embeddings) + embedding = torch.sum(all_embeddings_tensor, dim=0) / all_token_num + normalized_embeddings = F.normalize(embedding, p=2, dim=1) + + ret["token_num"] = all_token_num + + if base64_encode == "base64": + out_embeddings = self.__encode_base64(normalized_embeddings) + else: + out_embeddings = normalized_embeddings.tolist() + ret["embedding"] = out_embeddings + + gc.collect() + torch.cuda.empty_cache() + if self.device == "xpu": + torch.xpu.empty_cache() + if self.device == "npu": + torch.npu.empty_cache() + except torch.cuda.OutOfMemoryError as e: + ret = { + "text": f"{SERVER_ERROR_MSG}\n\n({e})", + "error_code": ErrorCode.CUDA_OUT_OF_MEMORY, + } + except (ValueError, RuntimeError) as e: + ret = { + "text": f"{SERVER_ERROR_MSG}\n\n({e})", + "error_code": ErrorCode.INTERNAL_ERROR, + } + return ret + + +def create_model_worker(): + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=21002) + parser.add_argument("--worker-address", type=str, default="http://localhost:21002") + parser.add_argument( + "--controller-address", type=str, default="http://localhost:21001" + ) + add_model_args(parser) + parser.add_argument( + "--model-names", + type=lambda s: s.split(","), + help="Optional display comma separated names", + ) + parser.add_argument( + "--conv-template", type=str, default=None, help="Conversation prompt template." + ) + parser.add_argument("--embed-in-truncate", action="store_true") + parser.add_argument( + "--limit-worker-concurrency", + type=int, + default=5, + help="Limit the model concurrency to prevent OOM.", + ) + parser.add_argument("--stream-interval", type=int, default=2) + parser.add_argument("--no-register", action="store_true") + parser.add_argument( + "--seed", + type=int, + default=None, + help="Overwrite the random seed for each generation.", + ) + parser.add_argument( + "--debug", type=bool, default=False, help="Print debugging messages" + ) + parser.add_argument( + "--ssl", + action="store_true", + required=False, + default=False, + help="Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.", + ) + args = parser.parse_args() + logger.info(f"args: {args}") + + if args.gpus: + if len(args.gpus.split(",")) < args.num_gpus: + raise ValueError( + f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!" + ) + os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus + + gptq_config = GptqConfig( + ckpt=args.gptq_ckpt or args.model_path, + wbits=args.gptq_wbits, + groupsize=args.gptq_groupsize, + act_order=args.gptq_act_order, + ) + awq_config = AWQConfig( + ckpt=args.awq_ckpt or args.model_path, + wbits=args.awq_wbits, + groupsize=args.awq_groupsize, + ) + if args.enable_exllama: + exllama_config = ExllamaConfig( + max_seq_len=args.exllama_max_seq_len, + gpu_split=args.exllama_gpu_split, + cache_8bit=args.exllama_cache_8bit, + ) + else: + exllama_config = None + if args.enable_xft: + xft_config = XftConfig( + max_seq_len=args.xft_max_seq_len, + data_type=args.xft_dtype, + ) + if args.device != "cpu": + print("xFasterTransformer now is only support CPUs. Reset device to CPU") + args.device = "cpu" + else: + xft_config = None + + worker = ModelWorker( + args.controller_address, + args.worker_address, + worker_id, + args.model_path, + args.model_names, + args.limit_worker_concurrency, + revision=args.revision, + no_register=args.no_register, + device=args.device, + num_gpus=args.num_gpus, + max_gpu_memory=args.max_gpu_memory, + dtype=str_to_torch_dtype(args.dtype), + load_8bit=args.load_8bit, + cpu_offloading=args.cpu_offloading, + gptq_config=gptq_config, + awq_config=awq_config, + exllama_config=exllama_config, + xft_config=xft_config, + stream_interval=args.stream_interval, + conv_template=args.conv_template, + embed_in_truncate=args.embed_in_truncate, + seed=args.seed, + debug=args.debug, + ) + return args, worker + + +if __name__ == "__main__": + args, worker = create_model_worker() + if args.ssl: + uvicorn.run( + app, + host=args.host, + port=args.port, + log_level="info", + ssl_keyfile=os.environ["SSL_KEYFILE"], + ssl_certfile=os.environ["SSL_CERTFILE"], + ) + else: + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/src/serve/monitor/basic_stats.py b/src/serve/monitor/basic_stats.py new file mode 100644 index 0000000000000000000000000000000000000000..3c1a8793d00ae2cd40da085e09448428f9ef5cff --- /dev/null +++ b/src/serve/monitor/basic_stats.py @@ -0,0 +1,220 @@ +import argparse +import code +import datetime +import json +import os +from pytz import timezone +import time + +import pandas as pd # pandas>=2.0.3 +import plotly.express as px +import plotly.graph_objects as go +from tqdm import tqdm + + +NUM_SERVERS = 14 +LOG_ROOT_DIR = "~/fastchat_logs" + + +def get_log_files(max_num_files=None): + log_root = os.path.expanduser(LOG_ROOT_DIR) + filenames = [] + for i in range(NUM_SERVERS): + for filename in os.listdir(f"{log_root}/server{i}"): + if filename.endswith("-conv.json"): + filepath = f"{log_root}/server{i}/{filename}" + name_tstamp_tuple = (filepath, os.path.getmtime(filepath)) + filenames.append(name_tstamp_tuple) + # sort by tstamp + filenames = sorted(filenames, key=lambda x: x[1]) + filenames = [x[0] for x in filenames] + + max_num_files = max_num_files or len(filenames) + filenames = filenames[-max_num_files:] + return filenames + + +def load_log_files(filename): + data = [] + for retry in range(5): + try: + lines = open(filename).readlines() + break + except FileNotFoundError: + time.sleep(2) + + for l in lines: + row = json.loads(l) + data.append( + dict( + type=row["type"], + tstamp=row["tstamp"], + model=row.get("model", ""), + models=row.get("models", ["", ""]), + ) + ) + return data + + +def load_log_files_parallel(log_files, num_threads=16): + data_all = [] + from multiprocessing import Pool + + with Pool(num_threads) as p: + ret_all = list(tqdm(p.imap(load_log_files, log_files), total=len(log_files))) + for ret in ret_all: + data_all.extend(ret) + return data_all + + +def get_anony_vote_df(df): + anony_vote_df = df[ + df["type"].isin(["leftvote", "rightvote", "tievote", "bothbad_vote"]) + ] + anony_vote_df = anony_vote_df[anony_vote_df["models"].apply(lambda x: x[0] == "")] + return anony_vote_df + + +def merge_counts(series, on, names): + ret = pd.merge(series[0], series[1], on=on) + for i in range(2, len(series)): + ret = pd.merge(ret, series[i], on=on) + ret = ret.reset_index() + old_names = list(ret.columns)[-len(series) :] + rename = {old_name: new_name for old_name, new_name in zip(old_names, names)} + ret = ret.rename(columns=rename) + return ret + + +def report_basic_stats(log_files): + df_all = load_log_files_parallel(log_files) + df_all = pd.DataFrame(df_all) + now_t = df_all["tstamp"].max() + df_1_hour = df_all[df_all["tstamp"] > (now_t - 3600)] + df_1_day = df_all[df_all["tstamp"] > (now_t - 3600 * 24)] + anony_vote_df_all = get_anony_vote_df(df_all) + + # Chat trends + chat_dates = [ + datetime.datetime.fromtimestamp(x, tz=timezone("US/Pacific")).strftime( + "%Y-%m-%d" + ) + for x in df_all[df_all["type"] == "chat"]["tstamp"] + ] + chat_dates_counts = pd.value_counts(chat_dates) + vote_dates = [ + datetime.datetime.fromtimestamp(x, tz=timezone("US/Pacific")).strftime( + "%Y-%m-%d" + ) + for x in anony_vote_df_all["tstamp"] + ] + vote_dates_counts = pd.value_counts(vote_dates) + chat_dates_bar = go.Figure( + data=[ + go.Bar( + name="Anony. Vote", + x=vote_dates_counts.index, + y=vote_dates_counts, + text=[f"{val:.0f}" for val in vote_dates_counts], + textposition="auto", + ), + go.Bar( + name="Chat", + x=chat_dates_counts.index, + y=chat_dates_counts, + text=[f"{val:.0f}" for val in chat_dates_counts], + textposition="auto", + ), + ] + ) + chat_dates_bar.update_layout( + barmode="stack", + xaxis_title="Dates", + yaxis_title="Count", + height=300, + width=1200, + ) + + # Model call counts + model_hist_all = df_all[df_all["type"] == "chat"]["model"].value_counts() + model_hist_1_day = df_1_day[df_1_day["type"] == "chat"]["model"].value_counts() + model_hist_1_hour = df_1_hour[df_1_hour["type"] == "chat"]["model"].value_counts() + model_hist = merge_counts( + [model_hist_all, model_hist_1_day, model_hist_1_hour], + on="model", + names=["All", "Last Day", "Last Hour"], + ) + model_hist_md = model_hist.to_markdown(index=False, tablefmt="github") + + # Action counts + action_hist_all = df_all["type"].value_counts() + action_hist_1_day = df_1_day["type"].value_counts() + action_hist_1_hour = df_1_hour["type"].value_counts() + action_hist = merge_counts( + [action_hist_all, action_hist_1_day, action_hist_1_hour], + on="type", + names=["All", "Last Day", "Last Hour"], + ) + action_hist_md = action_hist.to_markdown(index=False, tablefmt="github") + + # Anony vote counts + anony_vote_hist_all = anony_vote_df_all["type"].value_counts() + anony_vote_df_1_day = get_anony_vote_df(df_1_day) + anony_vote_hist_1_day = anony_vote_df_1_day["type"].value_counts() + # anony_vote_df_1_hour = get_anony_vote_df(df_1_hour) + # anony_vote_hist_1_hour = anony_vote_df_1_hour["type"].value_counts() + anony_vote_hist = merge_counts( + [anony_vote_hist_all, anony_vote_hist_1_day], + on="type", + names=["All", "Last Day"], + ) + anony_vote_hist_md = anony_vote_hist.to_markdown(index=False, tablefmt="github") + + # Last 24 hours + chat_1_day = df_1_day[df_1_day["type"] == "chat"] + num_chats_last_24_hours = [] + base = df_1_day["tstamp"].min() + for i in range(24, 0, -1): + left = base + (i - 1) * 3600 + right = base + i * 3600 + num = ((chat_1_day["tstamp"] >= left) & (chat_1_day["tstamp"] < right)).sum() + num_chats_last_24_hours.append(num) + times = [ + datetime.datetime.fromtimestamp( + base + i * 3600, tz=timezone("US/Pacific") + ).strftime("%Y-%m-%d %H:%M:%S %Z") + for i in range(24, 0, -1) + ] + last_24_hours_df = pd.DataFrame({"time": times, "value": num_chats_last_24_hours}) + last_24_hours_md = last_24_hours_df.to_markdown(index=False, tablefmt="github") + + # Last update datetime + last_updated_tstamp = now_t + last_updated_datetime = datetime.datetime.fromtimestamp( + last_updated_tstamp, tz=timezone("US/Pacific") + ).strftime("%Y-%m-%d %H:%M:%S %Z") + + # code.interact(local=locals()) + + return { + "chat_dates_bar": chat_dates_bar, + "model_hist_md": model_hist_md, + "action_hist_md": action_hist_md, + "anony_vote_hist_md": anony_vote_hist_md, + "num_chats_last_24_hours": last_24_hours_md, + "last_updated_datetime": last_updated_datetime, + } + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--max-num-files", type=int) + args = parser.parse_args() + + log_files = get_log_files(args.max_num_files) + basic_stats = report_basic_stats(log_files) + + print(basic_stats["action_hist_md"] + "\n") + print(basic_stats["model_hist_md"] + "\n") + print(basic_stats["anony_vote_hist_md"] + "\n") + print(basic_stats["num_chats_last_24_hours"] + "\n") diff --git a/src/serve/monitor/clean_battle_data.py b/src/serve/monitor/clean_battle_data.py new file mode 100644 index 0000000000000000000000000000000000000000..270f981ccf62eb822d823ed299d864f6533fad17 --- /dev/null +++ b/src/serve/monitor/clean_battle_data.py @@ -0,0 +1,423 @@ +""" +Clean chatbot arena battle log. + +Usage: +python3 clean_battle_data.py --mode conv_release +""" +import argparse +import datetime +import json +import os +from pytz import timezone +import time + +from tqdm import tqdm +from multiprocessing import Pool +import tiktoken +from collections import Counter +import shortuuid + +from fastchat.serve.monitor.basic_stats import get_log_files, NUM_SERVERS +from fastchat.utils import detect_language + + +VOTES = ["tievote", "leftvote", "rightvote", "bothbad_vote"] +IDENTITY_WORDS = [ + "vicuna", + "lmsys", + "koala", + "uc berkeley", + "open assistant", + "laion", + "chatglm", + "chatgpt", + "gpt-4", + "openai", + "anthropic", + "claude", + "bard", + "palm", + "lamda", + "google", + "gemini", + "llama", + "qianwan", + "qwen", + "alibaba", + "mistral", + "zhipu", + "KEG lab", + "01.AI", + "AI2", + "Tülu", + "Tulu", + "deepseek", + "hermes", + "cohere", + "DBRX", + "databricks", +] + +ERROR_WORDS = [ + "NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.", + "$MODERATION$ YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES.", + "API REQUEST ERROR. Please increase the number of max tokens.", + "**API REQUEST ERROR** Reason: The response was blocked.", + "**API REQUEST ERROR**", +] + +UNFINISHED_WORDS = [ + "▌", + '', +] + +for i in range(len(IDENTITY_WORDS)): + IDENTITY_WORDS[i] = IDENTITY_WORDS[i].lower() + +for i in range(len(ERROR_WORDS)): + ERROR_WORDS[i] = ERROR_WORDS[i].lower() + + +def remove_html(raw): + if isinstance(raw, str) and raw.startswith("

"): + return raw[raw.find(": ") + 2 : -len("

\n")] + return raw + + +def to_openai_format(messages): + roles = ["user", "assistant"] + ret = [] + for i, x in enumerate(messages): + ret.append({"role": roles[i % 2], "content": x[1]}) + return ret + + +def replace_model_name(old_name, tstamp): + replace_dict = { + "bard": "palm-2", + "claude-v1": "claude-1", + "claude-instant-v1": "claude-instant-1", + "oasst-sft-1-pythia-12b": "oasst-pythia-12b", + "claude-2": "claude-2.0", + "StripedHyena-Nous-7B": "stripedhyena-nous-7b", + "gpt-4-turbo": "gpt-4-1106-preview", + "gpt-4-0125-assistants-api": "gpt-4-turbo-browsing", + } + if old_name in ["gpt-4", "gpt-3.5-turbo"]: + if tstamp > 1687849200: + return old_name + "-0613" + else: + return old_name + "-0314" + if old_name in replace_dict: + return replace_dict[old_name] + return old_name + + +def read_file(filename): + data = [] + for retry in range(5): + try: + # lines = open(filename).readlines() + for l in open(filename): + row = json.loads(l) + if row["type"] in VOTES: + data.append(row) + break + except FileNotFoundError: + time.sleep(2) + return data + + +def read_file_parallel(log_files, num_threads=16): + data_all = [] + with Pool(num_threads) as p: + ret_all = list(tqdm(p.imap(read_file, log_files), total=len(log_files))) + for ret in ret_all: + data_all.extend(ret) + return data_all + + +def process_data( + data, + exclude_model_names, + sanitize_ip, + ban_ip_list, +): + encoding = tiktoken.encoding_for_model("gpt-3.5-turbo") + convert_type = { + "leftvote": "model_a", + "rightvote": "model_b", + "tievote": "tie", + "bothbad_vote": "tie (bothbad)", + } + + all_ips = dict() + + count_dict = { + "anony": 0, + "invalid": 0, + "leaked_identity": 0, + "banned": 0, + "error": 0, + "unfinished": 0, + "none_msg": 0, + "exclude_model": 0, + } + count_leak = {} + + battles = [] + for row in data: + flag_anony = False + flag_leaked_identity = False + flag_error = False + flag_unfinished = False + flag_none_msg = False + + if row["models"][0] is None or row["models"][1] is None: + continue + + # Resolve model names + models_public = [remove_html(row["models"][0]), remove_html(row["models"][1])] + if "model_name" in row["states"][0]: + models_hidden = [ + row["states"][0]["model_name"], + row["states"][1]["model_name"], + ] + if models_hidden[0] is None: + models_hidden = models_public + else: + models_hidden = models_public + + if (models_public[0] == "" and models_public[1] != "") or ( + models_public[1] == "" and models_public[0] != "" + ): + count_dict["invalid"] += 1 + continue + + if models_public[0] == "" or models_public[0] == "Model A": + flag_anony = True + models = models_hidden + else: + flag_anony = False + models = models_public + if ( + models_hidden[0] not in models_public[0] + or models_hidden[1] not in models_public[1] + ): + count_dict["invalid"] += 1 + continue + + # Detect langauge + state = row["states"][0] + if state["offset"] >= len(state["messages"]): + count_dict["invalid"] += 1 + continue + lang_code = detect_language(state["messages"][state["offset"]][1]) + + # Drop conversations if the model names are leaked + messages = "" + for i in range(2): + state = row["states"][i] + for _, (role, msg) in enumerate(state["messages"][state["offset"] :]): + if msg: + messages += msg.lower() + else: + flag_none_msg = True + + for word in IDENTITY_WORDS: + if word in messages: + if word not in count_leak: + count_leak[word] = 0 + count_leak[word] += 1 + flag_leaked_identity = True + break + + for word in ERROR_WORDS: + if word in messages: + flag_error = True + break + + for word in UNFINISHED_WORDS: + if word in messages: + flag_unfinished = True + break + + if flag_none_msg: + count_dict["none_msg"] += 1 + continue + if flag_leaked_identity: + count_dict["leaked_identity"] += 1 + continue + if flag_error: + count_dict["error"] += 1 + continue + if flag_unfinished: + count_dict["unfinished"] += 1 + continue + + # Replace bard with palm + models = [replace_model_name(m, row["tstamp"]) for m in models] + # Exclude certain models + if exclude_model_names and any(x in exclude_model_names for x in models): + count_dict["exclude_model"] += 1 + continue + + question_id = row["states"][0]["conv_id"] + conversation_a = to_openai_format( + row["states"][0]["messages"][row["states"][0]["offset"] :] + ) + conversation_b = to_openai_format( + row["states"][1]["messages"][row["states"][1]["offset"] :] + ) + + ip = row["ip"] + if ip not in all_ips: + all_ips[ip] = {"ip": ip, "count": 0, "sanitized_id": shortuuid.uuid()} + all_ips[ip]["count"] += 1 + if sanitize_ip: + user_id = f"{all_ips[ip]['sanitized_id']}" + else: + user_id = f"{all_ips[ip]['ip']}" + + if ban_ip_list is not None and ip in ban_ip_list: + count_dict["banned"] += 1 + continue + + if flag_anony: + count_dict["anony"] += 1 + + for conv in conversation_a: + conv["num_tokens"] = len( + encoding.encode(conv["content"], allowed_special="all") + ) + for conv in conversation_b: + conv["num_tokens"] = len( + encoding.encode(conv["content"], allowed_special="all") + ) + + # Save the results + battles.append( + dict( + question_id=question_id, + model_a=models[0], + model_b=models[1], + winner=convert_type[row["type"]], + judge=f"arena_user_{user_id}", + conversation_a=conversation_a, + conversation_b=conversation_b, + turn=len(conversation_a) // 2, + anony=flag_anony, + language=lang_code, + tstamp=row["tstamp"], + ) + ) + return battles, count_dict, count_leak, all_ips + + +def clean_battle_data( + log_files, + exclude_model_names, + ban_ip_list=None, + sanitize_ip=False, + anony_only=False, + num_threads=16, +): + data = read_file_parallel(log_files, num_threads=16) + + battles = [] + count_dict = {} + count_leak = {} + all_ips = {} + with Pool(num_threads) as p: + # split data into chunks + chunk_size = len(data) // min(100, len(data)) + data_chunks = [ + data[i : i + chunk_size] for i in range(0, len(data), chunk_size) + ] + + args_list = [ + (data_chunk, exclude_model_names, sanitize_ip, ban_ip_list) + for data_chunk in data_chunks + ] + ret_all = list(tqdm(p.starmap(process_data, args_list), total=len(data_chunks))) + + for ret in ret_all: + sub_battles, sub_count_dict, sub_count_leak, sub_all_ips = ret + battles.extend(sub_battles) + count_dict = dict(Counter(count_dict) + Counter(sub_count_dict)) + count_leak = dict(Counter(count_leak) + Counter(sub_count_leak)) + for ip in sub_all_ips: + if ip not in all_ips: + all_ips[ip] = sub_all_ips[ip] + else: + all_ips[ip]["count"] += sub_all_ips[ip]["count"] + battles.sort(key=lambda x: x["tstamp"]) + last_updated_tstamp = battles[-1]["tstamp"] + + last_updated_datetime = datetime.datetime.fromtimestamp( + last_updated_tstamp, tz=timezone("US/Pacific") + ).strftime("%Y-%m-%d %H:%M:%S %Z") + + print(f"#votes: {len(data)}") + print(count_dict) + print(f"#battles: {len(battles)}, #anony: {count_dict['anony']}") + print(f"last-updated: {last_updated_datetime}") + print(f"leaked_identity: {count_leak}") + + if ban_ip_list is not None: + for ban_ip in ban_ip_list: + if ban_ip in all_ips: + del all_ips[ban_ip] + print("Top 30 IPs:") + print(sorted(all_ips.values(), key=lambda x: x["count"], reverse=True)[:30]) + return battles + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--max-num-files", type=int) + parser.add_argument( + "--mode", type=str, choices=["simple", "conv_release"], default="simple" + ) + parser.add_argument("--exclude-model-names", type=str, nargs="+") + parser.add_argument("--ban-ip-file", type=str) + parser.add_argument("--sanitize-ip", action="store_true", default=False) + args = parser.parse_args() + + log_files = get_log_files(args.max_num_files) + ban_ip_list = json.load(open(args.ban_ip_file)) if args.ban_ip_file else None + + battles = clean_battle_data( + log_files, args.exclude_model_names or [], ban_ip_list, args.sanitize_ip + ) + last_updated_tstamp = battles[-1]["tstamp"] + cutoff_date = datetime.datetime.fromtimestamp( + last_updated_tstamp, tz=timezone("US/Pacific") + ).strftime("%Y%m%d") + + if args.mode == "simple": + for x in battles: + for key in [ + "conversation_a", + "conversation_b", + "question_id", + ]: + del x[key] + print("Samples:") + for i in range(4): + print(battles[i]) + output = f"clean_battle_{cutoff_date}.json" + elif args.mode == "conv_release": + new_battles = [] + for x in battles: + if not x["anony"]: + continue + for key in []: + del x[key] + new_battles.append(x) + battles = new_battles + output = f"clean_battle_conv_{cutoff_date}.json" + + with open(output, "w", encoding="utf-8", errors="replace") as fout: + json.dump(battles, fout, indent=2, ensure_ascii=False) + print(f"Write cleaned data to {output}") diff --git a/src/serve/monitor/clean_chat_data.py b/src/serve/monitor/clean_chat_data.py new file mode 100644 index 0000000000000000000000000000000000000000..2bda0e2c3a5242f65fab00e05b1199afdf864cc9 --- /dev/null +++ b/src/serve/monitor/clean_chat_data.py @@ -0,0 +1,171 @@ +""" +Clean chatbot arena chat log. + +Usage: +python3 clean_chat_data.py +""" +import argparse +import datetime +import json +import os +from pytz import timezone +import time + +from tqdm import tqdm + +from fastchat.serve.monitor.basic_stats import NUM_SERVERS +from fastchat.serve.monitor.clean_battle_data import ( + to_openai_format, + replace_model_name, +) +from fastchat.utils import detect_language + + +NETWORK_ERROR_MSG = ( + "NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.".lower() +) + + +def get_log_files(max_num_files=None): + dates = [] + for month in range(4, 12): + for day in range(1, 33): + dates.append(f"2023-{month:02d}-{day:02d}") + + filenames = [] + for d in dates: + for i in range(NUM_SERVERS): + name = os.path.expanduser(f"~/fastchat_logs/server{i}/{d}-conv.json") + if os.path.exists(name): + filenames.append(name) + max_num_files = max_num_files or len(filenames) + # filenames = list(reversed(filenames)) + filenames = filenames[-max_num_files:] + return filenames + + +def clean_chat_data(log_files, action_type): + raw_data = [] + for filename in tqdm(log_files, desc="read files"): + for retry in range(5): + try: + lines = open(filename).readlines() + break + except FileNotFoundError: + time.sleep(2) + + for l in lines: + row = json.loads(l) + if row["type"] == action_type: + raw_data.append(row) + + all_models = set() + all_ips = dict() + chats = [] + ct_invalid_conv_id = 0 + ct_invalid = 0 + ct_network_error = 0 + for row in raw_data: + try: + if action_type in ["chat", "upvote", "downvote"]: + state = row["state"] + model = row["model"] + elif action_type == "leftvote": + state = row["states"][0] + model = row["states"][0]["model_name"] + elif action_type == "rightvote": + state = row["states"][1] + model = row["states"][1]["model_name"] + conversation_id = state["conv_id"] + except KeyError: + ct_invalid_conv_id += 1 + continue + + if conversation_id is None: + ct_invalid_conv_id += 1 + continue + + conversation = to_openai_format(state["messages"][state["offset"] :]) + if not isinstance(model, str): + ct_invalid += 1 + continue + model = replace_model_name(model, row["tstamp"]) + + try: + lang_code = detect_language(state["messages"][state["offset"]][1]) + except IndexError: + ct_invalid += 1 + continue + + if not all(isinstance(x["content"], str) for x in conversation): + ct_invalid += 1 + continue + + messages = "".join([x["content"] for x in conversation]).lower() + if NETWORK_ERROR_MSG in messages: + ct_network_error += 1 + continue + + ip = row["ip"] + if ip not in all_ips: + all_ips[ip] = len(all_ips) + user_id = all_ips[ip] + + chats.append( + dict( + conversation_id=conversation_id, + model=model, + conversation=conversation, + turn=len(conversation) // 2, + language=lang_code, + user_id=user_id, + tstamp=row["tstamp"], + ) + ) + + all_models.update([model]) + + chats.sort(key=lambda x: x["tstamp"]) + last_updated_tstamp = chats[-1]["tstamp"] + last_updated_datetime = datetime.datetime.fromtimestamp( + last_updated_tstamp, tz=timezone("US/Pacific") + ).strftime("%Y-%m-%d %H:%M:%S %Z") + + # Deduplication + dedup_chats = [] + visited_conv_ids = set() + for i in reversed(range(len(chats))): + if chats[i]["conversation_id"] in visited_conv_ids: + continue + visited_conv_ids.add(chats[i]["conversation_id"]) + dedup_chats.append(chats[i]) + + print( + f"#raw: {len(raw_data)}, #chat: {len(chats)}, #dedup_chat: {len(dedup_chats)}" + ) + print( + f"#invalid_conv_id: {ct_invalid_conv_id}, #network_error: {ct_network_error}, #invalid: {ct_invalid}" + ) + print(f"#models: {len(all_models)}, {all_models}") + print(f"last-updated: {last_updated_datetime}") + + return list(reversed(dedup_chats)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--action-type", type=str, default="chat") + parser.add_argument("--max-num-files", type=int) + args = parser.parse_args() + + log_files = get_log_files(args.max_num_files) + chats = clean_chat_data(log_files, args.action_type) + last_updated_tstamp = chats[-1]["tstamp"] + cutoff_date = datetime.datetime.fromtimestamp( + last_updated_tstamp, tz=timezone("US/Pacific") + ).strftime("%Y%m%d") + + output = f"clean_{args.action_type}_conv_{cutoff_date}.json" + with open(output, "w") as fout: + json.dump(chats, fout, indent=2, ensure_ascii=False) + print(f"Write cleaned data to {output}") diff --git a/src/serve/monitor/code_tagger.py b/src/serve/monitor/code_tagger.py new file mode 100644 index 0000000000000000000000000000000000000000..12eeaed4b25e0f7b4b6b31d562de905d4ac6a8d2 --- /dev/null +++ b/src/serve/monitor/code_tagger.py @@ -0,0 +1,180 @@ +import re +import json +import argparse +import multiprocessing as mp + +import nltk +from tqdm import tqdm +from nltk.tokenize import word_tokenize + + +def is_code_conversation(text: str) -> tuple[bool, list[str]]: + """Check if the text is a code conversation""" + + if "```plaintext" in text: + lines = text.split("\n") + line1_idx = [idx for idx, line in enumerate(lines) if "```plaintext" in line][0] + line2_idx = [ + line1_idx + 1 + idx + for idx, line in enumerate(lines) + if "```" in line[line1_idx + 1 :] + ] + if line2_idx: + line2_idx = line2_idx[0] + text = "\n".join(lines[:line1_idx]) + "\n".join(lines[line2_idx + 1 :]) + else: + text = "\n".join(lines[:line1_idx]) + return is_code_conversation(text) + + if "```markdown" in text: + otext = text + lines = text.split("\n") + line1_idx = [idx for idx, line in enumerate(lines) if "```markdown" in line][0] + line2_idx = [ + line1_idx + 1 + idx + for idx, line in enumerate(lines) + if "```" in line[line1_idx + 1 :] + ] + if line2_idx: + line2_idx = line2_idx[0] + text = "\n".join(lines[:line1_idx]) + "\n".join(lines[line2_idx + 1 :]) + else: + text = "\n".join(lines[:line1_idx]) + return is_code_conversation(text) + + if "ascii art" in text.lower(): + return False, [] + + # 1. Check for code formatting + if re.search(r"```", text): + return True, ["backticks"] + + # Tokenize the text + tokens = word_tokenize(text) + tokens = [token.lower() for token in tokens] + + # 2. Check for programming concepts + concepts = ["git", "github", "pull request", "dataframe", "nginx", "pip"] + if any(concept in tokens for concept in concepts): + matched_concepts = list(set(tokens).intersection(set(concepts))) + return True, matched_concepts + + # 3. Check for programming language name + languages = [ + "python", + "c++", + "cpp", + "java", + "javascript", + "typescript", + "html", + "css", + "sql", + "bash", + "powershell", + "matlab", + "golang", + "linux", + "ubuntu", + ] + if any(language in tokens for language in languages): + matched_languages = list(set(tokens).intersection(set(languages))) + return True, matched_languages + + # 4. Programming concept substrings + strings = [ + "import pandas", + "import numpy", + "import torch", + "jax", + "tensorflow", + "pytorch", + "keras", + "scikit-learn", + "sklearn", + " apt-get ", + ] + found_array = [string in text for string in strings] + if any(found_array): + matched_strings = [ + string for string, found in zip(strings, found_array) if found + ] + return True, matched_strings + + # 5. Programming concept regexes + regexes = [ + r"from \w+ import \w+", + r"conda install \w+", + r"pip install -r \w+", + r"conda install -c \w+ \w+", + r"#include <\w+>", + r"import \w+ as \w+", + r"#include \"\w+\.h\"", + ] + found_array = [re.search(regex, text) for regex in regexes] + if any(found_array): + matched_regexes = [regex for regex, found in zip(regexes, found_array) if found] + return True, matched_regexes + + return False, [] + + +def check_code_conv(conv) -> tuple[bool, list[str]]: + """Check if the conversation is a code conversation""" + for _, msg in enumerate(conv): + content = msg["content"] + if not isinstance(content, str): + continue + is_code_conv_res = is_code_conversation(content) + if is_code_conv_res[0]: + return is_code_conv_res + return False, [] + + +def check_conv_row(conv_row): + check_a, code_a = check_code_conv(conv_row["conversation_a"]) + check_b, code_b = check_code_conv(conv_row["conversation_b"]) + + return check_a or check_b, code_a + code_b + + +def process_battle_file(battle_file_path: str, n_cpus: int): + with open(battle_file_path, "r") as f: + data = json.load(f) + + with mp.Pool(n_cpus) as pool: + tagged_data = list(tqdm(pool.imap(check_conv_row, data), total=len(data))) + + output_data = [row for row, (is_code, _) in zip(data, tagged_data) if is_code] + + return output_data + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--clean-battle-file", type=str) + parser.add_argument("--output-clean-battle-file", type=str, default=None) + parser.add_argument("--n-cpus", type=int, default=-1) + + args = parser.parse_args() + + if args.output_clean_battle_file is None: + args.output_clean_battle_file = args.clean_battle_file + + if args.n_cpus == -1: + args.n_cpus = mp.cpu_count() + + print( + f"Processing {args.clean_battle_file} and saving to {args.output_clean_battle_file} with {args.n_cpus} cpus" + ) + + output_data = process_battle_file(args.clean_battle_file, args.n_cpus) + + with open(args.output_clean_battle_file, "w") as f: + json.dump(output_data, f, indent=4) + + print(f"Total code conversations: {len(output_data)}") + print("Done!") + + with open(args.output_clean_battle_file, "r") as f: + data = json.load(f) diff --git a/src/serve/monitor/criteria_labeling.py b/src/serve/monitor/criteria_labeling.py new file mode 100644 index 0000000000000000000000000000000000000000..b08b030d45bc7faff2a8c85e5923f701a94d4cc4 --- /dev/null +++ b/src/serve/monitor/criteria_labeling.py @@ -0,0 +1,214 @@ +import argparse +import json +import pandas as pd +import os +import re +import ast +import time +import concurrent.futures +import tqdm +import random +import threading + +LOCK = threading.RLock() + +## Configs +SYSTEM_PROMPT = "Your task is to evaluate how well the following input prompts can assess the capabilities of advanced AI assistants.\n\nFor the input prompt, please analyze it based on the following 7 criteria.\n1. Specificity: Does the prompt ask for a specific output, such as code, a mathematical solution, a logical simplification, a problem-solving strategy, or a hardware setup recommendation? This specificity allows the AI to demonstrate its ability to understand and generate precise responses.\n2. Domain Knowledge: Does the prompt cover a specific domain, such as programming, mathematics, logic, problem-solving, or hardware setup? Prompts spanning a range of topics test the AI's breadth of knowledge and its ability to apply that knowledge to different domains.\n3. Complexity: Does the prompt vary in complexity, from straightforward tasks to more complex, multi-step problems? This allows evaluators to assess the AI's capability to handle problems of varying difficulty.\n4. Problem-Solving Skills: Does the prompt directly involves the AI to demonstrate active problem-solving skills, such systemically coming up with a solution for a specific setup instead of regurgitating an existing fact? This tests the AI's ability to apply logical reasoning and provide practical solutions.\n5. Creativity: Does the prompt involve a level of creativity in approaching the problem? This criterion tests the AI's ability to provide tailored solutions that take into account the user's specific needs and limitations.\n6. Technical Accuracy: Does the prompt require technical accuracy in the response? This allows evaluators to assess the AI's precision and correctness in technical fields.\n7. Real-world Application: Does the prompt relate to real-world applications, such as setting up a functional system or writing code for a practical use case? This tests the AI's ability to provide practical and actionable information that could be implemented in real-life scenarios.\n\nYou must list the criteria numbers that the prompt satisfies in the format of a Python array. For example, \"[...]\". Do not explain your choice." + +ENDPOINT_INFO = { + "model_name": "META-LLAMA/LLAMA-3-70B-CHAT-HF", + "name": "llama-3-70b-instruct", + "endpoints": [{"api_base": "-", "api_key": "-"}], + "parallel": 8, + "temperature": 0.0, + "max_token": 512, +} # Modify this + +TAGS = { + 1: "specificity", + 2: "domain_knowledge", + 3: "complexity", + 4: "problem_solving", + 5: "creativity", + 6: "technical_accuracy", + 7: "real_world", +} + +# API setting constants +API_MAX_RETRY = 3 +API_RETRY_SLEEP = 10 +API_ERROR_OUTPUT = "$ERROR$" + + +def get_endpoint(endpoint_list): + if endpoint_list is None: + return None + assert endpoint_list is not None + # randomly pick one + api_dict = random.choices(endpoint_list)[0] + return api_dict + + +pattern = re.compile(r"(\[\d(?:\,\s\d)*\])") + + +def get_score(judgment): + matches = pattern.findall(judgment) + matches = [m for m in matches if m != ""] + if len(set(matches)) == 0: + return [] + elif len(set(matches)) == 1: + try: + return ast.literal_eval(matches[0]) + except SyntaxError: + print(matches[0]) + return [] + else: + return [] + + +def chat_completion_openai(model, messages, temperature, max_tokens, api_dict=None): + import openai + + if api_dict: + client = openai.OpenAI( + base_url=api_dict["api_base"], + api_key=api_dict["api_key"], + ) + else: + client = openai.OpenAI() + + output = API_ERROR_OUTPUT + for _ in range(API_MAX_RETRY): + try: + # print(messages) + completion = client.chat.completions.create( + model=model, + messages=messages, + temperature=temperature, + max_tokens=max_tokens, + # extra_body={"guided_choice": GUIDED_CHOICES} if GUIDED_CHOICES else None, + ) + output = completion.choices[0].message.content + break + except openai.RateLimitError as e: + print(type(e), e) + time.sleep(API_RETRY_SLEEP) + except openai.BadRequestError as e: + print(messages) + print(type(e), e) + break + except openai.APIConnectionError as e: + print(messages) + print(type(e), e) + time.sleep(API_RETRY_SLEEP) + except openai.InternalServerError as e: + print(messages) + print(type(e), e) + time.sleep(1) + except KeyError: + print(type(e), e) + break + + return output + + +def get_answer( + question: dict, + max_tokens: int, + temperature: float, + answer_file: str, + api_dict: dict, +): + conv = [] + conv.append({"role": "system", "content": SYSTEM_PROMPT}) + + conv.append({"role": "user", "content": question["prompt"]}) + output = chat_completion_openai( + model=ENDPOINT_INFO["model_name"], + messages=conv, + temperature=temperature, + max_tokens=max_tokens, + api_dict=api_dict, + ) + + criteria = get_score(output) + + # Dump answers + question["criteria_tag"] = {name: bool(i in criteria) for i, name in TAGS.items()} + question.drop("prompt") + + with LOCK: + with open(answer_file, "a") as fout: + fout.write(json.dumps(question.to_dict()) + "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input-file", type=str, required=True) + parser.add_argument("--cache-file", type=str, default=None) + parser.add_argument("--output-file", type=str, required=True) + parser.add_argument("--convert-to-json", action="store_true") + args = parser.parse_args() + + print("loading input data (might take min)") + input_data = pd.read_json(args.input_file) + print(f"{len(input_data)}# of input data just loaded") + if args.cache_file: + print("loading cache data") + cache_data = pd.read_json(args.cache_file) + print(f"{len(cache_data)}# of cache data just loaded") + + assert "criteria_tag" in cache_data.columns and len( + cache_data["criteria_tag"].dropna() + ) == len(cache_data) + + not_labeled = input_data[ + ~input_data["question_id"].isin(cache_data["question_id"]) + ].copy() + else: + not_labeled = input_data.copy() + + if os.path.isfile(args.output_file): + print("loading existing output") + output_data = pd.read_json(args.output_file, lines=True) + print(f"{len(output_data)}# of existing output just loaded") + + assert "criteria_tag" in output_data.columns and len( + output_data["criteria_tag"].dropna() + ) == len(output_data) + + not_labeled = not_labeled[ + ~not_labeled["question_id"].isin(output_data["question_id"]) + ] + + print(f"{len(not_labeled)} needs to be labeled") + + not_labeled["prompt"] = not_labeled.conversation_a.map( + lambda convo: "\n".join([convo[i]["content"] for i in range(0, len(convo), 2)]) + ) + + with concurrent.futures.ThreadPoolExecutor( + max_workers=ENDPOINT_INFO["parallel"] + ) as executor: + futures = [] + for index, row in tqdm.tqdm(not_labeled.iterrows()): + future = executor.submit( + get_answer, + row, + ENDPOINT_INFO["max_token"], + ENDPOINT_INFO["temperature"], + args.output_file, + get_endpoint(ENDPOINT_INFO["endpoints"]), + ) + futures.append(future) + for future in tqdm.tqdm( + concurrent.futures.as_completed(futures), total=len(futures) + ): + future.result() + + if args.convert_to_json: + temp = pd.read_json(args.output_file, lines=True) + temp.to_json( + args.output_file[:-1], orient="records", indent=4, force_ascii=False + ) diff --git a/src/serve/monitor/dataset_release_scripts/arena_33k/count_unique_users.py b/src/serve/monitor/dataset_release_scripts/arena_33k/count_unique_users.py new file mode 100644 index 0000000000000000000000000000000000000000..8e94cf2756203f207e82cc7f31ff544ecdcc80f0 --- /dev/null +++ b/src/serve/monitor/dataset_release_scripts/arena_33k/count_unique_users.py @@ -0,0 +1,25 @@ +"""Count the unique users in a battle log file.""" + +import argparse +import json + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input", type=str) + args = parser.parse_args() + + lines = json.load(open(args.input)) + ct_anony_votes = 0 + all_users = set() + all_models = set() + for l in lines: + if not l["anony"]: + continue + all_users.add(l["judge"]) + all_models.add(l["model_a"]) + all_models.add(l["model_b"]) + ct_anony_votes += 1 + + print(f"#anony_vote: {ct_anony_votes}, #user: {len(all_users)}") + print(f"#model: {len(all_models)}") diff --git a/src/serve/monitor/dataset_release_scripts/arena_33k/filter_bad_conv.py b/src/serve/monitor/dataset_release_scripts/arena_33k/filter_bad_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..6d12d7c652bc02bb7b5c9f65bce0e1644f739c1b --- /dev/null +++ b/src/serve/monitor/dataset_release_scripts/arena_33k/filter_bad_conv.py @@ -0,0 +1,155 @@ +""" +Filter conversations for release. + +Usage: python3 filter_bad_conv.py --in clean_battle_conv_20230630_tagged_v1_pii.json +""" +import argparse +from collections import defaultdict +from enum import Enum, auto +import json +import os +import random + +from tqdm import tqdm + +BLOCKED_WORDS_FILENAME = "blocked_words.json" +blocked_words = [] +frequency = defaultdict(lambda: 0) + + +class TypeCode(Enum): + CORRECT = auto() + ANONYMIZED = auto() + REDACTED = auto() + BAD_FORMAT = auto() + BLOCKED_WORD = auto() + BLOCKED_MODEL = auto() + TOO_SHORT = auto() + TOO_FREQUENT = auto() + + +def detect_type(conv): + for key in ["conversation_a", "conversation_b"]: + messages = [row["content"] for row in conv[key]] + for msg in messages: + if not isinstance(msg, str): + return TypeCode.BAD_FORMAT + + user_prompts = [ + row["content"].lower().strip() for row in conv[key] if row["role"] == "user" + ] + if len(messages) <= 2 and all(len(x) < 16 for x in user_prompts): + return TypeCode.TOO_SHORT + + if all(x in frequent_prompts for x in user_prompts): + return TypeCode.TOO_FREQUENT + + for msg in messages: + msg = msg.lower() + if "" in msg: + return TypeCode.ANONYMIZED + if "" in msg: + return TypeCode.REDACTED + + for w in blocked_words: + if w in msg: + return TypeCode.BLOCKED_WORD + + for key in ["model_a", "model_b"]: + if conv[key] in ["vicuna-33b", "mpt-30b-chat"]: + return TypeCode.BLOCKED_MODEL + + return TypeCode.CORRECT + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--in-file", type=str, required=True) + parser.add_argument("--sample", type=int) + args = parser.parse_args() + + # Read conversations + convs = json.load(open(args.in_file)) + print(f"#conv: {len(convs)}") + + # Read blocked words + if os.path.exists(BLOCKED_WORDS_FILENAME): + blocked_words = json.load(open(BLOCKED_WORDS_FILENAME)) + + # Count frequency + for conv in convs: + for key in ["conversation_a", "conversation_b"]: + messages = [row["content"] for row in conv[key] if row["role"] == "user"] + for msg in messages: + if not isinstance(msg, str): + continue + msg = msg.lower().strip() + frequency[msg] += 1 + + keys = list(frequency.keys()) + keys.sort(key=lambda x: -frequency[x]) + frequent_prompts = keys[:10] + frequent_prompts = set(frequent_prompts) + frequent_prompts.add("") + + # Start filter + ct_bad_format = 0 + ct_anonymized = 0 + ct_redacted = 0 + ct_error = 0 + ct_lang_filter = 0 + ct_flagged = 0 + ct_blocked_word = 0 + ct_blocked_model = 0 + ct_too_short = 0 + ct_too_frequent = 0 + + new_convs = [] + for conv in tqdm(convs): + type_code = detect_type(conv) + + if type_code == TypeCode.BAD_FORMAT: + ct_bad_format += 1 + continue + + if type_code == TypeCode.ANONYMIZED: + ct_anonymized += 1 + continue + elif type_code == TypeCode.REDACTED: + ct_redacted += 1 + continue + elif type_code == TypeCode.BLOCKED_WORD: + ct_blocked_word += 1 + continue + elif type_code == TypeCode.BLOCKED_MODEL: + ct_blocked_model += 1 + continue + elif type_code == TypeCode.TOO_SHORT: + ct_too_short += 1 + continue + elif type_code == TypeCode.TOO_FREQUENT: + ct_too_frequent += 1 + continue + + if conv["openai_moderation"]["flagged"]: + ct_flagged += 1 + continue + + if type_code in [TypeCode.CORRECT]: + new_convs.append(conv) + + if args.sample: + # random.seed(0) + # random.shuffle(new_convs) + new_convs = new_convs[: args.sample] + + print(f"ct_anonymized: {ct_anonymized}, ct_redacted: {ct_redacted}") + print(f"ct_bad_format: {ct_bad_format}, ct_flagged: {ct_flagged}") + print(f"ct_blocked_word: {ct_blocked_word}, ct_blocked_model: {ct_blocked_model}") + print(f"ct_too_short: {ct_too_short}, ct_too_frequent: {ct_anonymized}") + print(f"new_conv: {len(new_convs)}") + + out_file = args.in_file.replace(".json", ".out.json") + print(f"Output to {out_file}") + with open(out_file, "w") as fout: + json.dump(new_convs, fout, indent=2, ensure_ascii=False) diff --git a/src/serve/monitor/dataset_release_scripts/arena_33k/merge_field.py b/src/serve/monitor/dataset_release_scripts/arena_33k/merge_field.py new file mode 100644 index 0000000000000000000000000000000000000000..5a88209bfcb58cb2131ce94d6eba03c899e74a0a --- /dev/null +++ b/src/serve/monitor/dataset_release_scripts/arena_33k/merge_field.py @@ -0,0 +1,25 @@ +"""Count the unique users in a battle log file.""" + +import argparse +import json + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input", type=str) + parser.add_argument("--tag-file", type=str) + args = parser.parse_args() + + # build index + objs = json.load(open(args.tag_file)) + new_field_dict = {} + for obj in objs: + new_field_dict[obj["question_id"]] = obj["toxic_chat"] + + objs = json.load(open(args.input)) + for obj in objs: + obj["toxic_chat_tag"] = new_field_dict[obj["question_id"]] + + output = args.input.replace(".json", "_added.json") + with open(output, "w") as fout: + json.dump(objs, fout, indent=2, ensure_ascii=False) diff --git a/src/serve/monitor/dataset_release_scripts/arena_33k/sample.py b/src/serve/monitor/dataset_release_scripts/arena_33k/sample.py new file mode 100644 index 0000000000000000000000000000000000000000..0cd78b71e95a3034bf3440aee3557a38426d0244 --- /dev/null +++ b/src/serve/monitor/dataset_release_scripts/arena_33k/sample.py @@ -0,0 +1,32 @@ +""" +Count the unique users in a battle log file. + +Usage: +python3 -input in.json --number 1000 +""" + +import argparse +import json +import random + +K = 1000 + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input", type=str) + parser.add_argument("--number", type=int, nargs="+") + args = parser.parse_args() + + convs = json.load(open(args.input)) + random.seed(0) + random.shuffle(convs) + + for number in args.number: + new_convs = convs[:number] + + output = args.input.replace(".json", f"_{number//K}k.json") + with open(output, "w") as fout: + json.dump(new_convs, fout, indent=2, ensure_ascii=False) + + print(f"#in: {len(convs)}, #out: {len(new_convs)}") + print(f"Write to file: {output}") diff --git a/src/serve/monitor/dataset_release_scripts/arena_33k/upload_hf_dataset.py b/src/serve/monitor/dataset_release_scripts/arena_33k/upload_hf_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e37aadcea65df7ca605369b88c068aa57c8f35f2 --- /dev/null +++ b/src/serve/monitor/dataset_release_scripts/arena_33k/upload_hf_dataset.py @@ -0,0 +1,9 @@ +""" +Upload to huggingface. +""" +import json +from datasets import Dataset, DatasetDict, load_dataset + +objs = json.load(open("clean_battle_conv_20230630_tagged_v3_pii_33k_added.json")) +data = Dataset.from_list(objs) +data.push_to_hub("lmsys/chatbot_arena_conversations", private=True) diff --git a/src/serve/monitor/dataset_release_scripts/lmsys_chat_1m/approve_all.py b/src/serve/monitor/dataset_release_scripts/lmsys_chat_1m/approve_all.py new file mode 100644 index 0000000000000000000000000000000000000000..a7084207309907dcb8fa37eccf55fd2a6b62ca48 --- /dev/null +++ b/src/serve/monitor/dataset_release_scripts/lmsys_chat_1m/approve_all.py @@ -0,0 +1,13 @@ +import requests + +headers = {"authorization": "Bearer hf_XXX"} + +url = "https://huggingface.co/api/datasets/lmsys/lmsys-chat-1m/user-access-request/pending" +a = requests.get(url, headers=headers) + +for u in a.json(): + user = u["user"]["user"] + url = "https://huggingface.co/api/datasets/lmsys/lmsys-chat-1m/user-access-request/grant" + ret = requests.post(url, headers=headers, json={"user": user}) + print(user, ret.status_code) + assert ret.status_code == 200 diff --git a/src/serve/monitor/dataset_release_scripts/lmsys_chat_1m/compute_stats.py b/src/serve/monitor/dataset_release_scripts/lmsys_chat_1m/compute_stats.py new file mode 100644 index 0000000000000000000000000000000000000000..97abaaa0df053c93c3adb655f1b5c41af0aab00d --- /dev/null +++ b/src/serve/monitor/dataset_release_scripts/lmsys_chat_1m/compute_stats.py @@ -0,0 +1,119 @@ +""" +From colab: +https://colab.research.google.com/drive/1oMdw_Lqgmd6DletSOLHsyD-Rc96cRShs?usp=sharing +""" +import argparse +import datetime +import json +import os +from pytz import timezone +import time + +import kaleido +import numpy as np +import pandas as pd +import plotly.express as px +import plotly.graph_objects as go +from tqdm import tqdm + +import plotly.io as pio + +pio.kaleido.scope.mathjax = None + +parser = argparse.ArgumentParser() +parser.add_argument("--in-file", type=str, required=True) +parser.add_argument("--scale", type=int, required=True) +args = parser.parse_args() + +filename = args.in_file +scale = args.scale +convs = json.load(open(filename)) +df = pd.DataFrame(convs) +df + +print(f"#ips: {df['user_id'].nunique() * scale}") +print(f"#models: {df['model'].nunique()}") +print(f"#language: {df['language'].nunique()}") +print(f"#turns: {df['turn'].mean()}") + +model_counts = df["model"].value_counts() * scale +# print("model counts", model_counts) +fig = px.bar(x=model_counts.index, y=model_counts) +fig.update_layout( + xaxis_title=None, + yaxis_title="Count", + height=200, + width=950, + margin=dict(l=0, r=0, t=0, b=0), +) +fig.show() +fig.write_image("model_count.pdf") + + +model_counts = df["language"].value_counts().head(25) * scale +fig = px.bar(x=model_counts.index, y=model_counts) +fig.update_layout( + xaxis_title=None, + yaxis_title="Count", + height=200, + width=950, + margin=dict(l=0, r=0, t=0, b=0), +) +fig.show() +fig.write_image("language_count.pdf") + +chat_dates = [ + datetime.datetime.fromtimestamp(x, tz=timezone("US/Pacific")).strftime("%Y-%m-%d") + for x in df["tstamp"] +] + + +def to_remove(x): + for d in ["08-09", "08-08", "08-07", "08-06", "08-05", "08-04"]: + if d in x: + return True + return False + + +chat_dates = [x for x in chat_dates if not to_remove(x)] + +chat_dates_counts = pd.value_counts(chat_dates) * scale +print(f"mean #chat per day: {np.mean(chat_dates_counts):.2f}") + +fig = px.bar(x=chat_dates_counts.index, y=chat_dates_counts) +fig.update_layout( + xaxis_title="Dates", + yaxis_title="Count", + height=200, + width=950, + margin=dict(l=0, r=0, t=0, b=0), +) +fig.show() +fig.write_image("daily_conversation_count.pdf") + +import transformers + +tokenizer = transformers.AutoTokenizer.from_pretrained( + "lmsys/vicuna-7b-v1.5", use_fast=False +) + +prompts = [] +responses = [] +for conv in df["conversation"]: + for row in conv: + if row["role"] == "user": + prompts.append(row["content"]) + else: + responses.append(row["content"]) + +print(f"#prompts: {len(prompts)}") +print(f"#responses: {len(responses)}") + + +prompt_lens = [len(tokenizer(x).input_ids) for x in tqdm(prompts)] +print() +print(f"mean prompt len: {np.mean(prompt_lens):.2f}") + +response_lens = [len(tokenizer(x).input_ids) if x else 0 for x in tqdm(responses)] +print() +print(f"mean response len: {np.mean(response_lens):.2f}") diff --git a/src/serve/monitor/dataset_release_scripts/lmsys_chat_1m/filter_bad_conv.py b/src/serve/monitor/dataset_release_scripts/lmsys_chat_1m/filter_bad_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..3ccde1ca57546acf5d1131cae14a499f1228a02c --- /dev/null +++ b/src/serve/monitor/dataset_release_scripts/lmsys_chat_1m/filter_bad_conv.py @@ -0,0 +1,148 @@ +""" +Filter conversations for release. + +Dependency: +pip install opencc-python-reimplementedpip install opencc-python-reimplemented + +Usage: +python3 filter_bad_conv_lmsys_chat_1m.py --in clean_battle_conv_20230630_tagged_v1_pii.json +""" +import argparse +from concurrent.futures import ProcessPoolExecutor +from collections import defaultdict +from enum import Enum, auto +import json +import os +import random + +from tqdm import tqdm +import opencc + +BLOCKED_WORDS_FILENAME = "blocked_words.json" +blocked_words = [] +frequency = defaultdict(lambda: 0) + +cc_converter = opencc.OpenCC("t2s") + + +class TypeCode(Enum): + CORRECT = auto() + ANONYMIZED = auto() + REDACTED = auto() + BAD_FORMAT = auto() + BLOCKED_WORD = auto() + BLOCKED_MODEL = auto() + TOO_SHORT = auto() + TOO_FREQUENT = auto() + + +def detect_type(conv): + for key in ["conversation_a", "conversation_b", "conversation"]: + if key not in conv: + continue + + messages = [row["content"] for row in conv[key]] + for msg in messages: + if not isinstance(msg, str): + return TypeCode.BAD_FORMAT + + if len(messages) == 0: + return TypeCode.BAD_FORMAT + + user_prompts = [ + row["content"].lower().strip() for row in conv[key] if row["role"] == "user" + ] + + for msg in messages: + msg = cc_converter.convert(msg.lower()) + if "" in msg: + return TypeCode.ANONYMIZED + if "" in msg: + return TypeCode.REDACTED + + for w in blocked_words: + if w in msg: + return TypeCode.BLOCKED_WORD + + return TypeCode.CORRECT + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--in-file", type=str, required=True) + parser.add_argument("--sample", type=int) + args = parser.parse_args() + + # Read conversations + convs = json.load(open(args.in_file)) + print(f"#conv: {len(convs)}") + + # Read blocked words + if os.path.exists(BLOCKED_WORDS_FILENAME): + blocked_words = json.load(open(BLOCKED_WORDS_FILENAME)) + blocked_words = [cc_converter.convert(w) for w in blocked_words] + + # Start filter + ct_bad_format = 0 + ct_anonymized = 0 + ct_redacted = 0 + ct_error = 0 + ct_lang_filter = 0 + ct_flagged = 0 + ct_blocked_word = 0 + ct_blocked_model = 0 + ct_too_short = 0 + ct_too_frequent = 0 + + type_codes = [] + with ProcessPoolExecutor() as executor: + for result in tqdm(executor.map(detect_type, convs), total=len(convs)): + type_codes.append(result) + + new_convs = [] + for conv, type_code in zip(convs, type_codes): + if type_code == TypeCode.BAD_FORMAT: + ct_bad_format += 1 + continue + + if type_code == TypeCode.ANONYMIZED: + ct_anonymized += 1 + continue + elif type_code == TypeCode.REDACTED: + ct_redacted += 1 + continue + elif type_code == TypeCode.BLOCKED_WORD: + ct_blocked_word += 1 + continue + elif type_code == TypeCode.BLOCKED_MODEL: + ct_blocked_model += 1 + continue + elif type_code == TypeCode.TOO_SHORT: + ct_too_short += 1 + continue + elif type_code == TypeCode.TOO_FREQUENT: + ct_too_frequent += 1 + continue + + if "openai_moderation" in conv and conv["openai_moderation"]["flagged"]: + ct_flagged += 1 + continue + + if type_code in [TypeCode.CORRECT]: + new_convs.append(conv) + + if args.sample: + random.seed(42) + random.shuffle(new_convs) + new_convs = new_convs[: args.sample] + + print(f"ct_anonymized: {ct_anonymized}, ct_redacted: {ct_redacted}") + print(f"ct_bad_format: {ct_bad_format}, ct_flagged: {ct_flagged}") + print(f"ct_blocked_word: {ct_blocked_word}, ct_blocked_model: {ct_blocked_model}") + print(f"ct_too_short: {ct_too_short}, ct_too_frequent: {ct_too_frequent}") + print(f"new_conv: {len(new_convs)}") + + out_file = args.in_file.replace(".json", ".s1.json") + print(f"Output to {out_file}") + with open(out_file, "w") as fout: + json.dump(new_convs, fout, indent=2, ensure_ascii=False) diff --git a/src/serve/monitor/dataset_release_scripts/lmsys_chat_1m/final_post_processing.py b/src/serve/monitor/dataset_release_scripts/lmsys_chat_1m/final_post_processing.py new file mode 100644 index 0000000000000000000000000000000000000000..e368e92a1dcf260ecb5b175b77e85c6971809a3c --- /dev/null +++ b/src/serve/monitor/dataset_release_scripts/lmsys_chat_1m/final_post_processing.py @@ -0,0 +1,27 @@ +import argparse +import json + +from tqdm import tqdm +import numpy as np + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--in-file", type=str, required=True) + args = parser.parse_args() + + # Read conversations + convs = json.load(open(args.in_file)) + print(f"#conv: {len(convs)}") + + # Delete some fileds + for c in convs: + del c["tstamp"] + del c["user_id"] + + # Write + print(f"#out conv: {len(convs)}") + out_file = args.in_file.replace(".json", ".s2.json") + print(f"Output to {out_file}") + with open(out_file, "w") as fout: + json.dump(convs, fout, indent=2, ensure_ascii=False) diff --git a/src/serve/monitor/dataset_release_scripts/lmsys_chat_1m/instructions.md b/src/serve/monitor/dataset_release_scripts/lmsys_chat_1m/instructions.md new file mode 100644 index 0000000000000000000000000000000000000000..4c439731f6aee43bd29e1a65576c5ae04ff59cfa --- /dev/null +++ b/src/serve/monitor/dataset_release_scripts/lmsys_chat_1m/instructions.md @@ -0,0 +1,23 @@ +``` +export BASE=clean_conv_20230809_100k_pii +export SCALE=10 + +# filter words +python3 filter_bad_conv.py --in $BASE.json + +# Clean up some fileds (e.g., timestamps) +python3 final_post_processing.py --in $BASE.s1.json + +# upload to hf +python3 upload_hf_dataset.py --in $BASE.s1.s2.json + +# Make another version with openai moderation tag +python3 merge_oai_tag.py --in $BASE.s1.s2.json + +# Make visualizations +python3 compute_stats.py --in $BASE.s1.json --scale $SCALE + +# Copy figures +scp "atlas:/data/lmzheng/FastChat/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/*.pdf" . +``` + diff --git a/src/serve/monitor/dataset_release_scripts/lmsys_chat_1m/merge_oai_tag.py b/src/serve/monitor/dataset_release_scripts/lmsys_chat_1m/merge_oai_tag.py new file mode 100644 index 0000000000000000000000000000000000000000..18bef5f1962384d80f174aa22a7b6dcc867fe7c0 --- /dev/null +++ b/src/serve/monitor/dataset_release_scripts/lmsys_chat_1m/merge_oai_tag.py @@ -0,0 +1,45 @@ +import argparse +import json +import time + +from tqdm import tqdm + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--in-file", type=str, required=True) + parser.add_argument("--sample", type=int) + args = parser.parse_args() + + tag_file = "clean_conv_20230809_1.5M_oai_filter_v2.json" + # tag_file = "clean_conv_20230809_1.5M_oai_filter_v2_100k.json" + in_file = args.in_file + tic = time.time() + + # Load tags + print("Load tags...") + tag_data = json.load(open(tag_file)) + tag_dict = {} + for c in tqdm(tag_data): + tag_dict[c["conversation_id"]] = [x["oai_filter"] for x in c["conversation"]] + print(f"elapsed: {time.time() - tic:.2f} s") + + # Append to input_file + print("Load inputs...") + input_data = json.load(open(in_file)) + for c in tqdm(input_data): + cid = c["conversation_id"] + if cid in tag_dict: + c["openai_moderation"] = tag_dict[cid] + else: + print(f"missing tag for conv {cid}") + exit() + print(f"elapsed: {time.time() - tic:.2f} s") + + # Write output + print("Write outputs...") + out_file = in_file.replace(".json", ".with_tag.json") + print(f"Output to {out_file}") + with open(out_file, "w") as fout: + json.dump(input_data, fout, indent=2, ensure_ascii=False) + print(f"elapsed: {time.time() - tic:.2f} s") diff --git a/src/serve/monitor/dataset_release_scripts/lmsys_chat_1m/process_all.sh b/src/serve/monitor/dataset_release_scripts/lmsys_chat_1m/process_all.sh new file mode 100644 index 0000000000000000000000000000000000000000..5bae9fbad221c57eba8f2cf5b7eb2779a6f040a8 --- /dev/null +++ b/src/serve/monitor/dataset_release_scripts/lmsys_chat_1m/process_all.sh @@ -0,0 +1,18 @@ +export BASE=clean_conv_20230809_1.5M_pii +#export BASE=clean_conv_20230809_100k_pii +export SCALE=1 + +# Filter words +python3 filter_bad_conv.py --in $BASE.json --sample 1000000 + +# Clean up some fileds (e.g., timestamps) +python3 final_post_processing.py --in $BASE.s1.json + +# Upload to hf +python3 upload_hf_dataset.py --in $BASE.s1.s2.json + +# Make another version with openai moderation tag +python3 merge_oai_tag.py --in $BASE.s1.s2.json + +# Make visualizations +python3 compute_stats.py --in $BASE.s1.json --scale $SCALE diff --git a/src/serve/monitor/dataset_release_scripts/lmsys_chat_1m/sample.py b/src/serve/monitor/dataset_release_scripts/lmsys_chat_1m/sample.py new file mode 100644 index 0000000000000000000000000000000000000000..3b6da455fc7bf8af1ce473f80440bff280c9366e --- /dev/null +++ b/src/serve/monitor/dataset_release_scripts/lmsys_chat_1m/sample.py @@ -0,0 +1,32 @@ +""" +Count the unique users in a battle log file. + +Usage: +python3 -input in.json --number 1000 +""" + +import argparse +import json +import random + +K = 1000 + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input", type=str) + parser.add_argument("--number", type=int, nargs="+") + args = parser.parse_args() + + convs = json.load(open(args.input)) + random.seed(42) + random.shuffle(convs) + + for number in args.number: + new_convs = convs[:number] + + output = args.input.replace(".json", f"_{number//K}k.json") + with open(output, "w") as fout: + json.dump(new_convs, fout, indent=2, ensure_ascii=False) + + print(f"#in: {len(convs)}, #out: {len(new_convs)}") + print(f"Write to file: {output}") diff --git a/src/serve/monitor/dataset_release_scripts/lmsys_chat_1m/upload_hf_dataset.py b/src/serve/monitor/dataset_release_scripts/lmsys_chat_1m/upload_hf_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..41d0fbdb59b4c7dc8385bef87a1bf0c8ea6e7401 --- /dev/null +++ b/src/serve/monitor/dataset_release_scripts/lmsys_chat_1m/upload_hf_dataset.py @@ -0,0 +1,17 @@ +""" +Upload to huggingface. +""" +import argparse +import json +from datasets import Dataset, DatasetDict, load_dataset + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--in-file", type=str, required=True) + args = parser.parse_args() + + objs = json.load(open(args.in_file)) + print(f"#convs: {len(objs)}") + data = Dataset.from_list(objs) + data.push_to_hub("lmsys/lmsys-chat-1m", private=True) diff --git a/src/serve/monitor/deduplication.py b/src/serve/monitor/deduplication.py new file mode 100644 index 0000000000000000000000000000000000000000..5a4b5765d23549371d92c32b73d951ca58533844 --- /dev/null +++ b/src/serve/monitor/deduplication.py @@ -0,0 +1,85 @@ +import os +import json +import pandas as pd +import ast + +import matplotlib.pyplot as plt +from matplotlib import rcParams + +import argparse +import seaborn as sns +from tqdm import tqdm +import matplotlib.pyplot as plt + +import numpy as np + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--output_dir", type=str, default="output") + parser.add_argument("--model", type=str, default=None) + parser.add_argument("--input_file", type=str, required=True) + parser.add_argument("--percentile", type=float, default=0.9999) + args = parser.parse_args() + output_dir = args.output_dir + input_file = args.input_file + + with open(input_file) as f: + data = json.load(f) + + os.makedirs(output_dir, exist_ok=True) + + # Preprocessing + all_convs_new = [] + convs = [] + for row in data: + conv = "" + for turns in row["conversation_a"]: + if turns["role"] == "user": + conv += f"{turns['content']}\n" + + convs.append(conv[:10000]) + row["post_process_conv"] = conv[:10000] + all_convs_new.append(row) + + df = pd.DataFrame(all_convs_new) + print("Number of conversations: ", len(df)) + + prompt_counts = df["post_process_conv"].value_counts() + # Select the top 20 most frequent prompts + top_prompts = prompt_counts.head(20) + print(top_prompts) + + # Determine the percentile count + percentile_cutoff = prompt_counts.quantile(args.percentile) + print(f"{args.percentile*100} percentile count: {percentile_cutoff}") + + # prompts that are more common than the percentile cutoff + high_frequency_prompts = prompt_counts[prompt_counts > percentile_cutoff].index + print( + f"Number of high frequency prompts: {len(high_frequency_prompts)}/{len(prompt_counts)}" + ) + + # initialize a new column dedup_tag + dedup_tags = np.array( + [{"high_freq": False, "sampled": True} for _ in range(len(df))] + ) + high_freq_groups = df.groupby("post_process_conv") + for prompt in tqdm(high_frequency_prompts): + df_high_freq = high_freq_groups.get_group(prompt) + sampled_indices = df_high_freq.sample( + n=int(percentile_cutoff), random_state=42 + ).index + dedup_tags[df_high_freq.index] = {"high_freq": True, "sampled": False} + dedup_tags[sampled_indices] = {"high_freq": True, "sampled": True} + + df["dedup_tag"] = dedup_tags + + # drop intermediate columns (post_process_conv) + df = df.drop(columns=["post_process_conv"]) + + df.to_json( + os.path.join(output_dir, "dedup.json"), + orient="records", + indent=4, + force_ascii=False, + ) diff --git a/src/serve/monitor/elo_analysis.py b/src/serve/monitor/elo_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..4982b2f0db15a781ff2f2a73c9e22505d1752ce5 --- /dev/null +++ b/src/serve/monitor/elo_analysis.py @@ -0,0 +1,622 @@ +import argparse +import ast +from collections import defaultdict +import datetime +import json +import math +import pickle +from pytz import timezone +from functools import partial + +import numpy as np +import pandas as pd +import plotly.express as px +from tqdm import tqdm +from transformers import AutoTokenizer + +from fastchat.model.model_registry import get_model_info +from fastchat.serve.monitor.basic_stats import get_log_files +from fastchat.serve.monitor.clean_battle_data import clean_battle_data + +pd.options.display.float_format = "{:.2f}".format + + +def compute_elo(battles, K=4, SCALE=400, BASE=10, INIT_RATING=1000): + rating = defaultdict(lambda: INIT_RATING) + + for rd, model_a, model_b, winner in battles[ + ["model_a", "model_b", "winner"] + ].itertuples(): + ra = rating[model_a] + rb = rating[model_b] + ea = 1 / (1 + BASE ** ((rb - ra) / SCALE)) + eb = 1 / (1 + BASE ** ((ra - rb) / SCALE)) + if winner == "model_a": + sa = 1 + elif winner == "model_b": + sa = 0 + elif winner == "tie" or winner == "tie (bothbad)": + sa = 0.5 + else: + raise Exception(f"unexpected vote {winner}") + rating[model_a] += K * (sa - ea) + rating[model_b] += K * (1 - sa - eb) + + return dict(rating) + + +def get_bootstrap_result(battles, func_compute_elo, num_round=1000): + rows = [] + for i in tqdm(range(num_round), desc="bootstrap"): + tmp_battles = battles.sample(frac=1.0, replace=True) + rows.append(func_compute_elo(tmp_battles)) + df = pd.DataFrame(rows) + return df[df.median().sort_values(ascending=False).index] + + +def compute_elo_mle_with_tie( + df, SCALE=400, BASE=10, INIT_RATING=1000, sample_weight=None +): + from sklearn.linear_model import LogisticRegression + + ptbl_a_win = pd.pivot_table( + df[df["winner"] == "model_a"], + index="model_a", + columns="model_b", + aggfunc="size", + fill_value=0, + ) + ptbl_tie = pd.pivot_table( + df[df["winner"].isin(["tie", "tie (bothbad)"])], + index="model_a", + columns="model_b", + aggfunc="size", + fill_value=0, + ) + ptbl_tie = ptbl_tie + ptbl_tie.T + ptbl_b_win = pd.pivot_table( + df[df["winner"] == "model_b"], + index="model_a", + columns="model_b", + aggfunc="size", + fill_value=0, + ) + ptbl_win = ptbl_a_win * 2 + ptbl_b_win.T * 2 + ptbl_tie + + models = pd.Series(np.arange(len(ptbl_win.index)), index=ptbl_win.index) + + p = len(models) + X = np.zeros([p * (p - 1) * 2, p]) + Y = np.zeros(p * (p - 1) * 2) + + cur_row = 0 + sample_weights = [] + for m_a in ptbl_win.index: + for m_b in ptbl_win.columns: + if m_a == m_b: + continue + # if nan skip + if math.isnan(ptbl_win.loc[m_a, m_b]) or math.isnan(ptbl_win.loc[m_b, m_a]): + continue + X[cur_row, models[m_a]] = +math.log(BASE) + X[cur_row, models[m_b]] = -math.log(BASE) + Y[cur_row] = 1.0 + sample_weights.append(ptbl_win.loc[m_a, m_b]) + + X[cur_row + 1, models[m_a]] = math.log(BASE) + X[cur_row + 1, models[m_b]] = -math.log(BASE) + Y[cur_row + 1] = 0.0 + sample_weights.append(ptbl_win.loc[m_b, m_a]) + cur_row += 2 + X = X[:cur_row] + Y = Y[:cur_row] + + lr = LogisticRegression(fit_intercept=False, penalty=None) + lr.fit(X, Y, sample_weight=sample_weights) + elo_scores = SCALE * lr.coef_[0] + INIT_RATING + if "mixtral-8x7b-instruct-v0.1" in models.index: + elo_scores += 1114 - elo_scores[models["mixtral-8x7b-instruct-v0.1"]] + return pd.Series(elo_scores, index=models.index).sort_values(ascending=False) + + +def get_median_elo_from_bootstrap(bootstrap_df): + median = dict(bootstrap_df.quantile(0.5)) + median = {k: int(v + 0.5) for k, v in median.items()} + return median + + +def compute_pairwise_win_fraction(battles, model_order, limit_show_number=None): + # Times each model wins as Model A + a_win_ptbl = pd.pivot_table( + battles[battles["winner"] == "model_a"], + index="model_a", + columns="model_b", + aggfunc="size", + fill_value=0, + ) + + # Table counting times each model wins as Model B + b_win_ptbl = pd.pivot_table( + battles[battles["winner"] == "model_b"], + index="model_a", + columns="model_b", + aggfunc="size", + fill_value=0, + ) + + # Table counting number of A-B pairs + num_battles_ptbl = pd.pivot_table( + battles, index="model_a", columns="model_b", aggfunc="size", fill_value=0 + ) + + # Computing the proportion of wins for each model as A and as B + # against all other models + row_beats_col_freq = (a_win_ptbl + b_win_ptbl.T) / ( + num_battles_ptbl + num_battles_ptbl.T + ) + + if model_order is None: + prop_wins = row_beats_col_freq.mean(axis=1).sort_values(ascending=False) + model_order = list(prop_wins.keys()) + + if limit_show_number is not None: + model_order = model_order[:limit_show_number] + + # Arrange ordering according to proprition of wins + row_beats_col = row_beats_col_freq.loc[model_order, model_order] + return row_beats_col + + +def visualize_leaderboard_table(rating): + models = list(rating.keys()) + models.sort(key=lambda k: -rating[k]) + + emoji_dict = { + 1: "🥇", + 2: "🥈", + 3: "🥉", + } + + md = "" + md += "| Rank | Model | Elo Rating | Description |\n" + md += "| --- | --- | --- | --- |\n" + for i, model in enumerate(models): + rank = i + 1 + minfo = get_model_info(model) + emoji = emoji_dict.get(rank, "") + md += f"| {rank} | {emoji} [{model}]({minfo.link}) | {rating[model]:.0f} | {minfo.description} |\n" + + return md + + +def visualize_pairwise_win_fraction(battles, model_order, scale=1): + row_beats_col = compute_pairwise_win_fraction(battles, model_order) + fig = px.imshow( + row_beats_col, + color_continuous_scale="RdBu", + text_auto=".2f", + height=700 * scale, + width=700 * scale, + ) + fig.update_layout( + xaxis_title="Model B", + yaxis_title="Model A", + xaxis_side="top", + title_y=0.07, + title_x=0.5, + ) + fig.update_traces( + hovertemplate="Model A: %{y}
Model B: %{x}
Fraction of A Wins: %{z}" + ) + + return fig + + +def visualize_battle_count(battles, model_order, scale=1): + ptbl = pd.pivot_table( + battles, index="model_a", columns="model_b", aggfunc="size", fill_value=0 + ) + battle_counts = ptbl + ptbl.T + fig = px.imshow( + battle_counts.loc[model_order, model_order], + text_auto=True, + height=700 * scale, + width=700 * scale, + ) + fig.update_layout( + xaxis_title="Model B", + yaxis_title="Model A", + xaxis_side="top", + title_y=0.07, + title_x=0.5, + ) + fig.update_traces( + hovertemplate="Model A: %{y}
Model B: %{x}
Count: %{z}" + ) + return fig + + +def visualize_average_win_rate(battles, limit_show_number, scale=1): + row_beats_col_freq = compute_pairwise_win_fraction( + battles, None, limit_show_number=limit_show_number + ) + fig = px.bar( + row_beats_col_freq.mean(axis=1).sort_values(ascending=False), + text_auto=".2f", + height=500 * scale, + width=700 * scale, + ) + fig.update_layout( + yaxis_title="Average Win Rate", xaxis_title="Model", showlegend=False + ) + return fig + + +def visualize_bootstrap_elo_rating(df, df_final, limit_show_number, scale=1): + bars = ( + pd.DataFrame( + dict( + lower=df.quantile(0.025), + rating=df_final, + upper=df.quantile(0.975), + ) + ) + .reset_index(names="model") + .sort_values("rating", ascending=False) + ) + bars = bars[:limit_show_number] + bars["error_y"] = bars["upper"] - bars["rating"] + bars["error_y_minus"] = bars["rating"] - bars["lower"] + bars["rating_rounded"] = np.round(bars["rating"]) + fig = px.scatter( + bars, + x="model", + y="rating", + error_y="error_y", + error_y_minus="error_y_minus", + text="rating_rounded", + height=700, + width=700 * scale, + ) + fig.update_layout(xaxis_title="Model", yaxis_title="Rating") + return fig + + +def limit_user_votes(battles, daily_vote_per_user): + from datetime import datetime + + print("Before limiting user votes: ", len(battles)) + # add date + battles["date"] = battles["tstamp"].apply( + lambda x: datetime.fromtimestamp(x).strftime("%Y-%m-%d") + ) + + battles_new = pd.DataFrame() + for date in battles["date"].unique(): + # only take the first daily_vote_per_user votes per judge per day + df_today = battles[battles["date"] == date] + df_sub = df_today.groupby("judge").head(daily_vote_per_user) + + # add df_sub to a new dataframe + battles_new = pd.concat([battles_new, df_sub]) + print("After limiting user votes: ", len(battles_new)) + return battles_new + + +def get_model_pair_stats(battles): + battles["ordered_pair"] = battles.apply( + lambda x: tuple(sorted([x["model_a"], x["model_b"]])), axis=1 + ) + + model_pair_stats = {} + + for index, row in battles.iterrows(): + pair = row["ordered_pair"] + if pair not in model_pair_stats: + model_pair_stats[pair] = {"win": 0, "loss": 0, "tie": 0} + + if row["winner"] in ["tie", "tie (bothbad)"]: + model_pair_stats[pair]["tie"] += 1 + elif row["winner"] == "model_a" and row["model_a"] == min(pair): + model_pair_stats[pair]["win"] += 1 + elif row["winner"] == "model_b" and row["model_b"] == min(pair): + model_pair_stats[pair]["win"] += 1 + else: + model_pair_stats[pair]["loss"] += 1 + + return model_pair_stats + + +def outlier_detect( + model_pair_stats, + battles, + max_vote=100, + randomized=False, + alpha=0.05, + c_param=0.5, + user_list=None, +): + if user_list is None: + # only check user who has >= 5 votes to save compute + user_vote_cnt = battles["judge"].value_counts() + user_list = user_vote_cnt[user_vote_cnt >= 5].index.tolist() + print("#User to be checked: ", len(user_list)) + + bad_user_list = [] + for user in user_list: + flag = False + p_upper = [] + p_lower = [] + df_2 = battles[battles["judge"] == user] + for row in df_2.iterrows(): + if len(p_upper) >= max_vote: + break + + model_pair = tuple(sorted([row[1]["model_a"], row[1]["model_b"]])) + + if row[1]["winner"] in ["tie", "tie (bothbad)"]: + vote = 0.5 + elif row[1]["winner"] == "model_a" and row[1]["model_a"] == model_pair[0]: + vote = 1 + elif row[1]["winner"] == "model_b" and row[1]["model_b"] == model_pair[0]: + vote = 1 + else: + vote = 0 + + stats = model_pair_stats[model_pair] + # count all votes + # ratings = np.array( + # [1] * stats["win"] + [0.5] * stats["tie"] + [0] * stats["loss"] + # ) + + # only count win and loss + ratings = np.array([1] * stats["win"] + [0] * stats["loss"]) + if randomized: + noise = np.random.uniform(-1e-5, 1e-5, len(ratings)) + ratings += noise + vote += np.random.uniform(-1e-5, 1e-5) + + p_upper += [(ratings <= vote).mean()] + p_lower += [(ratings >= vote).mean()] + + M_upper = np.prod(1 / (2 * np.array(p_upper))) + M_lower = np.prod(1 / (2 * np.array(p_lower))) + + # M_upper = np.prod((1 - c_param) / (c_param * np.array(p_upper) ** c_param)) + # M_lower = np.prod((1 - c_param) / (c_param * np.array(p_lower) ** c_param)) + if (M_upper > 1 / alpha) or (M_lower > 1 / alpha): + print(f"Identify bad user with {len(p_upper)} votes") + flag = True + break + if flag: + bad_user_list.append({"user_id": user, "votes": len(p_upper)}) + print("Bad user length: ", len(bad_user_list)) + print(bad_user_list) + + bad_user_id_list = [x["user_id"] for x in bad_user_list] + # remove bad users + battles = battles[~battles["judge"].isin(bad_user_id_list)] + return battles + + +def filter_long_conv(row): + threshold = 768 + for conversation_type in ["conversation_a", "conversation_b"]: + cur_conv = row[conversation_type] + num_tokens_all = sum([turn["num_tokens"] for turn in cur_conv]) + if num_tokens_all >= threshold: + return True + return False + + +def report_elo_analysis_results( + battles_json, + rating_system="bt", + num_bootstrap=100, + exclude_models=[], + langs=[], + exclude_tie=False, + exclude_unknown_lang=False, + daily_vote_per_user=None, + run_outlier_detect=False, + scale=1, + filter_func=lambda x: True, +): + battles = pd.DataFrame(battles_json) + + tqdm.pandas(desc=f"Processing using {filter_func.__name__}") + filtered_indices = battles.progress_apply(filter_func, axis=1) + battles = battles[filtered_indices] + + battles = battles.sort_values(ascending=True, by=["tstamp"]) + + if len(langs) > 0: + battles = battles[battles["language"].isin(langs)] + if exclude_unknown_lang: + battles = battles[~battles["language"].str.contains("unknown")] + + # remove excluded models + battles = battles[ + ~( + battles["model_a"].isin(exclude_models) + | battles["model_b"].isin(exclude_models) + ) + ] + + # Only use anonymous votes + battles = battles[battles["anony"]].reset_index(drop=True) + battles_no_ties = battles[~battles["winner"].str.contains("tie")] + if exclude_tie: + battles = battles_no_ties + + if daily_vote_per_user is not None: + battles = limit_user_votes(battles, daily_vote_per_user) + + if run_outlier_detect: + model_pair_stats = get_model_pair_stats(battles) + battles = outlier_detect(model_pair_stats, battles) + + print(f"Number of battles: {len(battles)}") + # Online update + elo_rating_online = compute_elo(battles) + + if rating_system == "bt": + bootstrap_df = get_bootstrap_result( + battles, compute_elo_mle_with_tie, num_round=num_bootstrap + ) + elo_rating_final = compute_elo_mle_with_tie(battles) + elif rating_system == "elo": + bootstrap_df = get_bootstrap_result( + battles, compute_elo, num_round=num_bootstrap + ) + elo_rating_median = get_median_elo_from_bootstrap(bootstrap_df) + elo_rating_final = elo_rating_median + + model_order = list(elo_rating_final.keys()) + + model_rating_q025 = bootstrap_df.quantile(0.025) + model_rating_q975 = bootstrap_df.quantile(0.975) + + # compute ranking based on CI + ranking = {} + for i, model_a in enumerate(model_order): + ranking[model_a] = 1 + for j, model_b in enumerate(model_order): + if i == j: + continue + if model_rating_q025[model_b] > model_rating_q975[model_a]: + ranking[model_a] += 1 + + # leaderboard_table_df: elo rating, variance, 95% interval, number of battles + leaderboard_table_df = pd.DataFrame( + { + "rating": elo_rating_final, + "variance": bootstrap_df.var(), + "rating_q975": bootstrap_df.quantile(0.975), + "rating_q025": bootstrap_df.quantile(0.025), + "num_battles": battles["model_a"] + .value_counts() + .add(battles["model_b"].value_counts(), fill_value=0), + "final_ranking": pd.Series(ranking), + } + ) + + model_order.sort(key=lambda k: -elo_rating_final[k]) + limit_show_number = int(25 * scale) + model_order = model_order[:limit_show_number] + + # Plots + leaderboard_table = visualize_leaderboard_table(elo_rating_final) + win_fraction_heatmap = visualize_pairwise_win_fraction( + battles_no_ties, model_order, scale=scale + ) + battle_count_heatmap = visualize_battle_count( + battles_no_ties, model_order, scale=scale + ) + average_win_rate_bar = visualize_average_win_rate( + battles_no_ties, limit_show_number, scale=scale + ) + bootstrap_elo_rating = visualize_bootstrap_elo_rating( + bootstrap_df, elo_rating_final, limit_show_number, scale=scale + ) + + last_updated_tstamp = battles["tstamp"].max() + last_updated_datetime = datetime.datetime.fromtimestamp( + last_updated_tstamp, tz=timezone("US/Pacific") + ).strftime("%Y-%m-%d %H:%M:%S %Z") + + return { + "rating_system": rating_system, + "elo_rating_online": elo_rating_online, + "elo_rating_final": elo_rating_final, + "leaderboard_table": leaderboard_table, + "win_fraction_heatmap": win_fraction_heatmap, + "battle_count_heatmap": battle_count_heatmap, + "average_win_rate_bar": average_win_rate_bar, + "bootstrap_elo_rating": bootstrap_elo_rating, + "last_updated_datetime": last_updated_datetime, + "last_updated_tstamp": last_updated_tstamp, + "bootstrap_df": bootstrap_df, + "leaderboard_table_df": leaderboard_table_df, + } + + +def pretty_print_elo_rating(rating): + model_order = list(rating.keys()) + model_order.sort(key=lambda k: -rating[k]) + for i, model in enumerate(model_order): + print(f"{i+1:2d}, {model:25s}, {rating[model]:.0f}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--clean-battle-file", type=str) + parser.add_argument("--max-num-files", type=int) + parser.add_argument("--num-bootstrap", type=int, default=100) + parser.add_argument( + "--rating-system", type=str, choices=["bt", "elo"], default="bt" + ) + parser.add_argument("--exclude-models", type=str, nargs="+", default=[]) + parser.add_argument("--exclude-tie", action="store_true", default=False) + parser.add_argument("--exclude-unknown-lang", action="store_true", default=False) + parser.add_argument("--exclude-url", action="store_true", default=False) + parser.add_argument("--langs", type=str, nargs="+", default=[]) + parser.add_argument("--daily-vote-per-user", type=int, default=None) + parser.add_argument("--run-outlier-detect", action="store_true", default=False) + parser.add_argument("--category", nargs="+", default=["full"]) + parser.add_argument("--scale", type=float, default=1) + args = parser.parse_args() + + np.random.seed(42) + + if args.clean_battle_file: + # Read data from a cleaned battle files + battles = pd.read_json(args.clean_battle_file) + else: + # Read data from all log files + log_files = get_log_files(args.max_num_files) + battles = clean_battle_data(log_files) + + filter_func_map = { + "full": lambda x: True, + "long": filter_long_conv, + "chinese": lambda x: x["language"] == "Chinese", + "english": lambda x: x["language"] == "English", + } + assert all( + [cat in filter_func_map for cat in args.category] + ), f"Invalid category: {args.category}" + + results = {} + for cat in args.category: + filter_func = filter_func_map[cat] + results[cat] = report_elo_analysis_results( + battles, + rating_system=args.rating_system, + num_bootstrap=args.num_bootstrap, + exclude_models=args.exclude_models, + langs=args.langs, + exclude_tie=args.exclude_tie, + exclude_unknown_lang=args.exclude_unknown_lang, + daily_vote_per_user=args.daily_vote_per_user, + run_outlier_detect=args.run_outlier_detect, + scale=args.scale, + filter_func=filter_func, + ) + + for cat in args.category: + print(f"# Results for {cat} conversations") + print("# Online Elo") + pretty_print_elo_rating(results[cat]["elo_rating_online"]) + print("# Median") + pretty_print_elo_rating(results[cat]["elo_rating_final"]) + print(f"last update : {results[cat]['last_updated_datetime']}") + + last_updated_tstamp = results[cat]["last_updated_tstamp"] + cutoff_date = datetime.datetime.fromtimestamp( + last_updated_tstamp, tz=timezone("US/Pacific") + ).strftime("%Y%m%d") + print(f"last update : {cutoff_date}") + + with open(f"elo_results_{cutoff_date}.pkl", "wb") as fout: + pickle.dump(results, fout) diff --git a/src/serve/monitor/inspect_conv.py b/src/serve/monitor/inspect_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..a680a419bd9d11d0db85afbc21c0063a2ae36df7 --- /dev/null +++ b/src/serve/monitor/inspect_conv.py @@ -0,0 +1,87 @@ +import argparse +import code +import datetime +import json +import os +from pytz import timezone +import time + +import pandas as pd +from tqdm import tqdm + + +def get_log_files(max_num_files=None): + dates = [] + for month in [4, 5]: + for day in range(1, 32): + dates.append(f"2023-{month:02d}-{day:02d}") + + num_servers = 14 + filenames = [] + for d in dates: + for i in range(num_servers): + name = os.path.expanduser(f"~/fastchat_logs/server{i}/{d}-conv.json") + if os.path.exists(name): + filenames.append(name) + max_num_files = max_num_files or len(filenames) + filenames = filenames[-max_num_files:] + return filenames + + +def pretty_print_conversation(messages): + for role, msg in messages: + print(f"[[{role}]]: {msg}") + + +def inspect_convs(log_files): + data = [] + for filename in tqdm(log_files, desc="read files"): + for retry in range(5): + try: + lines = open(filename).readlines() + break + except FileNotFoundError: + time.sleep(2) + + for l in lines: + row = json.loads(l) + + if "states" not in row: + continue + if row["type"] not in ["leftvote", "rightvote", "bothbad_vote"]: + continue + + model_names = row["states"][0]["model_name"], row["states"][1]["model_name"] + if row["type"] == "leftvote": + winner, loser = model_names[0], model_names[1] + winner_conv, loser_conv = row["states"][0], row["states"][1] + elif row["type"] == "rightvote": + loser, winner = model_names[0], model_names[1] + loser_conv, winner_conv = row["states"][0], row["states"][1] + + if loser == "bard" and winner == "vicuna-13b": + print("=" * 20) + print(f"Winner: {winner}") + pretty_print_conversation(winner_conv["messages"]) + print(f"Loser: {loser}") + pretty_print_conversation(loser_conv["messages"]) + print("=" * 20) + input() + + # if row["type"] == "bothbad_vote" and "gpt-4" in model_names: + # print("=" * 20) + # print(f"Model A: {model_names[0]}") + # pretty_print_conversation(row["states"][0]["messages"]) + # print(f"Model B: {model_names[1]}") + # pretty_print_conversation(row["states"][1]["messages"]) + # print("=" * 20) + # input() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--max-num-files", type=int) + args = parser.parse_args() + + log_files = get_log_files(args.max_num_files) + inspect_convs(log_files) diff --git a/src/serve/monitor/intersect_conv_file.py b/src/serve/monitor/intersect_conv_file.py new file mode 100644 index 0000000000000000000000000000000000000000..9eadd7cd57510ecbbd23798d55b079c69aac1a12 --- /dev/null +++ b/src/serve/monitor/intersect_conv_file.py @@ -0,0 +1,25 @@ +""" +Take the intersection of two conversation files. + +Usage: python3 -m fastchat.data.merge --input input.json --conv-id conv_id_file.json --out intersect.json +""" + +import argparse +import json + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input", type=str, required=True) + parser.add_argument("--conv-id", type=str, required=True) + parser.add_argument("--out-file", type=str, default="intersect.json") + args = parser.parse_args() + + conv_id_objs = json.load(open(args.conv_id, "r")) + conv_ids = set(x["conversation_id"] for x in conv_id_objs) + + objs = json.load(open(args.input, "r")) + after_objs = [x for x in objs if x["conversation_id"] in conv_ids] + + print(f"#in: {len(objs)}, #out: {len(after_objs)}") + json.dump(after_objs, open(args.out_file, "w"), indent=2, ensure_ascii=False) diff --git a/src/serve/monitor/leaderboard_csv_to_html.py b/src/serve/monitor/leaderboard_csv_to_html.py new file mode 100644 index 0000000000000000000000000000000000000000..ad52e7b2b6e234ed33a51d516e9d682addd1e0eb --- /dev/null +++ b/src/serve/monitor/leaderboard_csv_to_html.py @@ -0,0 +1,51 @@ +""" +Convert a leaderboard csv file to html table used in the blog. + +Usage: +python3 leaderboard_csv_to_html.py --in leaderboard_table_20230619.csv +""" +import argparse + +import numpy as np + +from fastchat.serve.monitor.monitor import load_leaderboard_table_csv + + +def model_hyperlink(model_name, link): + return f' {model_name} ' + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input", type=str, required=True) + args = parser.parse_args() + + data = load_leaderboard_table_csv(args.input, add_hyperlink=False) + headers = [ + "Model", + "MT-bench (score)", + "Arena Elo rating", + "MMLU", + "License", + ] + values = [] + for item in data: + row = [] + for key in headers: + value = item[key] + row.append(value) + row[0] = model_hyperlink(item["Model"], item["Link"]) + values.append(row) + values.sort(key=lambda x: -x[1] if not np.isnan(x[1]) else 1e9) + + for value in values: + row = "" + for x in value: + try: + if np.isnan(x): + x = "-" + except TypeError: + pass + row += f" {x} " + row += "" + print(row) diff --git a/src/serve/monitor/monitor.py b/src/serve/monitor/monitor.py new file mode 100644 index 0000000000000000000000000000000000000000..ee31a8718bac66079ff602cb75bb0a448067f5ea --- /dev/null +++ b/src/serve/monitor/monitor.py @@ -0,0 +1,886 @@ +""" +Live monitor of the website statistics and leaderboard. + +Dependency: +sudo apt install pkg-config libicu-dev +pip install pytz gradio gdown plotly polyglot pyicu pycld2 tabulate +""" + +import argparse +import ast +import json +import pickle +import os +import threading +import time + +import pandas as pd +import gradio as gr +import numpy as np + +from fastchat.serve.monitor.basic_stats import report_basic_stats, get_log_files +from fastchat.serve.monitor.clean_battle_data import clean_battle_data +from fastchat.serve.monitor.elo_analysis import report_elo_analysis_results +from fastchat.utils import build_logger, get_window_url_params_js + + +notebook_url = ( + "https://colab.research.google.com/drive/1KdwokPjirkTmpO_P1WByFNFiqxWQquwH" +) + +basic_component_values = [None] * 6 +leader_component_values = [None] * 5 + + +def make_default_md_1(arena_df, elo_results, mirror=False): + link_color = "#1976D2" # This color should be clear in both light and dark mode + leaderboard_md = f""" + # 🏆 LMSYS Chatbot Arena Leaderboard + Blog | + Paper | + GitHub | + Dataset | + Twitter | + Discord + """ + + return leaderboard_md + + +def make_default_md_2(arena_df, elo_results, mirror=False): + mirror_str = "This is a mirror of the live leaderboard created and maintained by the LMSYS Organization. Please link to leaderboard.lmsys.org for citation purposes." + leaderboard_md = f""" + {mirror_str if mirror else ""} + + LMSYS Chatbot Arena is a crowdsourced open platform for LLM evals. We've collected over 800,000 human pairwise comparisons to rank LLMs with the Bradley-Terry model and display the model ratings in Elo-scale. + You can find more details in our paper. **Chatbot arena is dependent on community participation, please contribute by casting your vote!** + """ + + return leaderboard_md + + +def make_arena_leaderboard_md(arena_df, last_updated_time): + total_votes = sum(arena_df["num_battles"]) // 2 + total_models = len(arena_df) + space = "   " + + leaderboard_md = f""" +Total #models: **{total_models}**.{space} Total #votes: **{"{:,}".format(total_votes)}**.{space} Last updated: {last_updated_time}. + +📣 **NEW!** View leaderboard for different categories (e.g., coding, long user query)! This is still in preview and subject to change. + +Code to recreate leaderboard tables and plots in this [notebook]({notebook_url}). You can contribute your vote at [chat.lmsys.org](https://chat.lmsys.org)! + +***Rank (UB)**: model's ranking (upper-bound), defined by one + the number of models that are statistically better than the target model. +Model A is statistically better than model B when A's lower-bound score is greater than B's upper-bound score (in 95% confidence interval). +See Figure 1 below for visualization of the confidence intervals of model scores. +""" + return leaderboard_md + + +def make_category_arena_leaderboard_md(arena_df, arena_subset_df, name="Overall"): + total_votes = sum(arena_df["num_battles"]) // 2 + total_models = len(arena_df) + space = "   " + total_subset_votes = sum(arena_subset_df["num_battles"]) // 2 + total_subset_models = len(arena_subset_df) + leaderboard_md = f"""### {cat_name_to_explanation[name]} +#### {space} #models: **{total_subset_models} ({round(total_subset_models/total_models *100)}%)** {space} #votes: **{"{:,}".format(total_subset_votes)} ({round(total_subset_votes/total_votes * 100)}%)**{space} +""" + return leaderboard_md + + +def make_full_leaderboard_md(elo_results): + leaderboard_md = """ +Three benchmarks are displayed: **Arena Elo**, **MT-Bench** and **MMLU**. +- [Chatbot Arena](https://chat.lmsys.org/?arena) - a crowdsourced, randomized battle platform. We use 500K+ user votes to compute model strength. +- [MT-Bench](https://arxiv.org/abs/2306.05685): a set of challenging multi-turn questions. We use GPT-4 to grade the model responses. +- [MMLU](https://arxiv.org/abs/2009.03300) (5-shot): a test to measure a model's multitask accuracy on 57 tasks. + +💻 Code: The MT-bench scores (single-answer grading on a scale of 10) are computed by [fastchat.llm_judge](https://github.com/lm-sys/FastChat/tree/main/fastchat/llm_judge). +The MMLU scores are mostly computed by [InstructEval](https://github.com/declare-lab/instruct-eval). +Higher values are better for all benchmarks. Empty cells mean not available. +""" + return leaderboard_md + + +def make_leaderboard_md_live(elo_results): + leaderboard_md = f""" +# Leaderboard +Last updated: {elo_results["last_updated_datetime"]} +{elo_results["leaderboard_table"]} +""" + return leaderboard_md + + +def update_elo_components( + max_num_files, elo_results_file, ban_ip_file, exclude_model_names +): + log_files = get_log_files(max_num_files) + + # Leaderboard + if elo_results_file is None: # Do live update + ban_ip_list = json.load(open(ban_ip_file)) if ban_ip_file else None + battles = clean_battle_data( + log_files, exclude_model_names, ban_ip_list=ban_ip_list + ) + elo_results = report_elo_analysis_results(battles, scale=2) + + leader_component_values[0] = make_leaderboard_md_live(elo_results) + leader_component_values[1] = elo_results["win_fraction_heatmap"] + leader_component_values[2] = elo_results["battle_count_heatmap"] + leader_component_values[3] = elo_results["bootstrap_elo_rating"] + leader_component_values[4] = elo_results["average_win_rate_bar"] + + # Basic stats + basic_stats = report_basic_stats(log_files) + md0 = f"Last updated: {basic_stats['last_updated_datetime']}" + + md1 = "### Action Histogram\n" + md1 += basic_stats["action_hist_md"] + "\n" + + md2 = "### Anony. Vote Histogram\n" + md2 += basic_stats["anony_vote_hist_md"] + "\n" + + md3 = "### Model Call Histogram\n" + md3 += basic_stats["model_hist_md"] + "\n" + + md4 = "### Model Call (Last 24 Hours)\n" + md4 += basic_stats["num_chats_last_24_hours"] + "\n" + + basic_component_values[0] = md0 + basic_component_values[1] = basic_stats["chat_dates_bar"] + basic_component_values[2] = md1 + basic_component_values[3] = md2 + basic_component_values[4] = md3 + basic_component_values[5] = md4 + + +def update_worker( + max_num_files, interval, elo_results_file, ban_ip_file, exclude_model_names +): + while True: + tic = time.time() + update_elo_components( + max_num_files, elo_results_file, ban_ip_file, exclude_model_names + ) + durtaion = time.time() - tic + print(f"update duration: {durtaion:.2f} s") + time.sleep(max(interval - durtaion, 0)) + + +def load_demo(url_params, request: gr.Request): + logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}") + return basic_component_values + leader_component_values + + +def model_hyperlink(model_name, link): + return f'{model_name}' + + +def load_leaderboard_table_csv(filename, add_hyperlink=True): + lines = open(filename).readlines() + heads = [v.strip() for v in lines[0].split(",")] + rows = [] + for i in range(1, len(lines)): + row = [v.strip() for v in lines[i].split(",")] + for j in range(len(heads)): + item = {} + for h, v in zip(heads, row): + if h == "Arena Elo rating": + if v != "-": + v = int(ast.literal_eval(v)) + else: + v = np.nan + elif h == "MMLU": + if v != "-": + v = round(ast.literal_eval(v) * 100, 1) + else: + v = np.nan + elif h == "MT-bench (win rate %)": + if v != "-": + v = round(ast.literal_eval(v[:-1]), 1) + else: + v = np.nan + elif h == "MT-bench (score)": + if v != "-": + v = round(ast.literal_eval(v), 2) + else: + v = np.nan + item[h] = v + if add_hyperlink: + item["Model"] = model_hyperlink(item["Model"], item["Link"]) + rows.append(item) + + return rows + + +def build_basic_stats_tab(): + empty = "Loading ..." + basic_component_values[:] = [empty, None, empty, empty, empty, empty] + + md0 = gr.Markdown(empty) + gr.Markdown("#### Figure 1: Number of model calls and votes") + plot_1 = gr.Plot(show_label=False) + with gr.Row(): + with gr.Column(): + md1 = gr.Markdown(empty) + with gr.Column(): + md2 = gr.Markdown(empty) + with gr.Row(): + with gr.Column(): + md3 = gr.Markdown(empty) + with gr.Column(): + md4 = gr.Markdown(empty) + return [md0, plot_1, md1, md2, md3, md4] + + +def get_full_table(arena_df, model_table_df): + values = [] + for i in range(len(model_table_df)): + row = [] + model_key = model_table_df.iloc[i]["key"] + model_name = model_table_df.iloc[i]["Model"] + # model display name + row.append(model_name) + if model_key in arena_df.index: + idx = arena_df.index.get_loc(model_key) + row.append(round(arena_df.iloc[idx]["rating"])) + else: + row.append(np.nan) + row.append(model_table_df.iloc[i]["MT-bench (score)"]) + row.append(model_table_df.iloc[i]["MMLU"]) + # Organization + row.append(model_table_df.iloc[i]["Organization"]) + # license + row.append(model_table_df.iloc[i]["License"]) + + values.append(row) + values.sort(key=lambda x: -x[1] if not np.isnan(x[1]) else 1e9) + return values + + +def create_ranking_str(ranking, ranking_difference): + if ranking_difference > 0: + return f"{int(ranking)} \u2191" + elif ranking_difference < 0: + return f"{int(ranking)} \u2193" + else: + return f"{int(ranking)}" + + +def recompute_final_ranking(arena_df): + # compute ranking based on CI + ranking = {} + for i, model_a in enumerate(arena_df.index): + ranking[model_a] = 1 + for j, model_b in enumerate(arena_df.index): + if i == j: + continue + if ( + arena_df.loc[model_b]["rating_q025"] + > arena_df.loc[model_a]["rating_q975"] + ): + ranking[model_a] += 1 + return list(ranking.values()) + + +def highlight_top_models(df): + def highlight_max_rank(s): + # Pastel Yellow with transparency, rgba(red, green, blue, alpha) + highlight_color = "rgba(255, 255, 128, 0.2)" # 50% transparent + if int(s["Rank* (UB)"].replace("↑", "").replace("↓", "")) == 1: + return [f"background-color: {highlight_color}" for _ in s] + else: + return ["" for _ in s] + + # Apply and return the styled DataFrame + return df.apply(highlight_max_rank, axis=1) + + +def get_arena_table(arena_df, model_table_df, arena_subset_df=None): + arena_df = arena_df.sort_values( + by=["final_ranking", "rating"], ascending=[True, False] + ) + arena_df["final_ranking"] = recompute_final_ranking(arena_df) + arena_df = arena_df.sort_values( + by=["final_ranking", "rating"], ascending=[True, False] + ) + + # sort by rating + if arena_subset_df is not None: + # filter out models not in the arena_df + arena_subset_df = arena_subset_df[arena_subset_df.index.isin(arena_df.index)] + arena_subset_df = arena_subset_df.sort_values(by=["rating"], ascending=False) + arena_subset_df["final_ranking"] = recompute_final_ranking(arena_subset_df) + # keep only the models in the subset in arena_df and recompute final_ranking + arena_df = arena_df[arena_df.index.isin(arena_subset_df.index)] + # recompute final ranking + arena_df["final_ranking"] = recompute_final_ranking(arena_df) + + # assign ranking by the order + arena_subset_df["final_ranking_no_tie"] = range(1, len(arena_subset_df) + 1) + arena_df["final_ranking_no_tie"] = range(1, len(arena_df) + 1) + # join arena_df and arena_subset_df on index + arena_df = arena_subset_df.join( + arena_df["final_ranking"], rsuffix="_global", how="inner" + ) + arena_df["ranking_difference"] = ( + arena_df["final_ranking_global"] - arena_df["final_ranking"] + ) + + arena_df = arena_df.sort_values( + by=["final_ranking", "rating"], ascending=[True, False] + ) + arena_df["final_ranking"] = arena_df.apply( + lambda x: create_ranking_str(x["final_ranking"], x["ranking_difference"]), + axis=1, + ) + + arena_df["final_ranking"] = arena_df["final_ranking"].astype(str) + + values = [] + for i in range(len(arena_df)): + row = [] + model_key = arena_df.index[i] + try: # this is a janky fix for where the model key is not in the model table (model table and arena table dont contain all the same models) + model_name = model_table_df[model_table_df["key"] == model_key][ + "Model" + ].values[0] + # rank + ranking = arena_df.iloc[i].get("final_ranking") or i + 1 + row.append(ranking) + if arena_subset_df is not None: + row.append(arena_df.iloc[i].get("ranking_difference") or 0) + # model display name + row.append(model_name) + # elo rating + row.append(round(arena_df.iloc[i]["rating"])) + upper_diff = round( + arena_df.iloc[i]["rating_q975"] - arena_df.iloc[i]["rating"] + ) + lower_diff = round( + arena_df.iloc[i]["rating"] - arena_df.iloc[i]["rating_q025"] + ) + row.append(f"+{upper_diff}/-{lower_diff}") + # num battles + row.append(round(arena_df.iloc[i]["num_battles"])) + # Organization + row.append( + model_table_df[model_table_df["key"] == model_key][ + "Organization" + ].values[0] + ) + # license + row.append( + model_table_df[model_table_df["key"] == model_key]["License"].values[0] + ) + cutoff_date = model_table_df[model_table_df["key"] == model_key][ + "Knowledge cutoff date" + ].values[0] + if cutoff_date == "-": + row.append("Unknown") + else: + row.append(cutoff_date) + values.append(row) + except Exception as e: + print(f"{model_key} - {e}") + return values + + +key_to_category_name = { + "full": "Overall", + "dedup": "De-duplicate Top Redundant Queries (soon to be default)", + "coding": "Coding", + "hard_6": "Hard Prompts (Overall)", + "hard_english_6": "Hard Prompts (English)", + "long_user": "Longer Query", + "english": "English", + "chinese": "Chinese", + "french": "French", + "german": "German", + "spanish": "Spanish", + "russian": "Russian", + "japanese": "Japanese", + "no_tie": "Exclude Ties", + "no_short": "Exclude Short Query (< 5 tokens)", + "no_refusal": "Exclude Refusal", + "overall_limit_5_user_vote": "overall_limit_5_user_vote", + "full_old": "Overall (Deprecated)", +} +cat_name_to_explanation = { + "Overall": "Overall Questions", + "De-duplicate Top Redundant Queries (soon to be default)": "De-duplicate top redundant queries (top 0.1%). See details in [blog post](https://lmsys.org/blog/2024-05-17-category-hard/#note-enhancing-quality-through-de-duplication).", + "Coding": "Coding: whether conversation contains code snippets", + "Hard Prompts (Overall)": "Hard Prompts (Overall): details in [blog post](https://lmsys.org/blog/2024-05-17-category-hard/)", + "Hard Prompts (English)": "Hard Prompts (English), note: the delta is to English Category. details in [blog post](https://lmsys.org/blog/2024-05-17-category-hard/)", + "Longer Query": "Longer Query (>= 500 tokens)", + "English": "English Prompts", + "Chinese": "Chinese Prompts", + "French": "French Prompts", + "German": "German Prompts", + "Spanish": "Spanish Prompts", + "Russian": "Russian Prompts", + "Japanese": "Japanese Prompts", + "Exclude Ties": "Exclude Ties and Bothbad", + "Exclude Short Query (< 5 tokens)": "Exclude Short User Query (< 5 tokens)", + "Exclude Refusal": 'Exclude model responses with refusal (e.g., "I cannot answer")', + "overall_limit_5_user_vote": "overall_limit_5_user_vote", + "Overall (Deprecated)": "Overall without De-duplicating Top Redundant Queries (top 0.1%). See details in [blog post](https://lmsys.org/blog/2024-05-17-category-hard/#note-enhancing-quality-through-de-duplication).", +} +cat_name_to_baseline = { + "Hard Prompts (English)": "English", +} + + +def build_leaderboard_tab( + elo_results_file, leaderboard_table_file, show_plot=False, mirror=False +): + arena_dfs = {} + category_elo_results = {} + if elo_results_file is None: # Do live update + default_md = "Loading ..." + p1 = p2 = p3 = p4 = None + else: + with open(elo_results_file, "rb") as fin: + elo_results = pickle.load(fin) + last_updated_time = None + if "full" in elo_results: + last_updated_time = elo_results["full"]["last_updated_datetime"].split( + " " + )[0] + for k in key_to_category_name.keys(): + if k not in elo_results: + continue + arena_dfs[key_to_category_name[k]] = elo_results[k][ + "leaderboard_table_df" + ] + category_elo_results[key_to_category_name[k]] = elo_results[k] + + p1 = category_elo_results["Overall"]["win_fraction_heatmap"] + p2 = category_elo_results["Overall"]["battle_count_heatmap"] + p3 = category_elo_results["Overall"]["bootstrap_elo_rating"] + p4 = category_elo_results["Overall"]["average_win_rate_bar"] + arena_df = arena_dfs["Overall"] + default_md = make_default_md_1( + arena_df, category_elo_results["Overall"], mirror=mirror + ) + default_md_2 = make_default_md_2( + arena_df, category_elo_results["Overall"], mirror=mirror + ) + + with gr.Row(): + with gr.Column(scale=4): + md_1 = gr.Markdown(default_md, elem_id="leaderboard_markdown") + with gr.Column(scale=1): + vote_button = gr.Button("Vote!", link="https://chat.lmsys.org") + md2 = gr.Markdown(default_md_2, elem_id="leaderboard_markdown") + if leaderboard_table_file: + data = load_leaderboard_table_csv(leaderboard_table_file) + model_table_df = pd.DataFrame(data) + + with gr.Tabs() as tabs: + # arena table + arena_table_vals = get_arena_table(arena_df, model_table_df) + with gr.Tab("Arena", id=0): + md = make_arena_leaderboard_md(arena_df, last_updated_time) + gr.Markdown(md, elem_id="leaderboard_markdown") + with gr.Row(): + with gr.Column(scale=2): + category_dropdown = gr.Dropdown( + choices=list(arena_dfs.keys()), + label="Category", + value="Overall", + ) + default_category_details = make_category_arena_leaderboard_md( + arena_df, arena_df, name="Overall" + ) + with gr.Column(scale=4, variant="panel"): + category_deets = gr.Markdown( + default_category_details, elem_id="category_deets" + ) + + arena_vals = pd.DataFrame( + arena_table_vals, + columns=[ + "Rank* (UB)", + "Model", + "Arena Elo", + "95% CI", + "Votes", + "Organization", + "License", + "Knowledge Cutoff", + ], + ) + elo_display_df = gr.Dataframe( + headers=[ + "Rank* (UB)", + "Model", + "Arena Elo", + "95% CI", + "Votes", + "Organization", + "License", + "Knowledge Cutoff", + ], + datatype=[ + "str", + "markdown", + "number", + "str", + "number", + "str", + "str", + "str", + ], + # value=highlight_top_models(arena_vals.style), + value=arena_vals.style, + elem_id="arena_leaderboard_dataframe", + height=700, + column_widths=[70, 190, 100, 100, 90, 130, 150, 100], + wrap=True, + ) + + gr.Markdown( + f"""Note: in each category, we exclude models with fewer than 300 votes as their confidence intervals can be large.""", + elem_id="leaderboard_markdown", + ) + + leader_component_values[:] = [default_md, p1, p2, p3, p4] + + if show_plot: + more_stats_md = gr.Markdown( + f"""## More Statistics for Chatbot Arena (Overall)""", + elem_id="leaderboard_header_markdown", + ) + with gr.Row(): + with gr.Column(): + gr.Markdown( + "#### Figure 1: Confidence Intervals on Model Strength (via Bootstrapping)", + elem_id="plot-title", + ) + plot_3 = gr.Plot(p3, show_label=False) + with gr.Column(): + gr.Markdown( + "#### Figure 2: Average Win Rate Against All Other Models (Assuming Uniform Sampling and No Ties)", + elem_id="plot-title", + ) + plot_4 = gr.Plot(p4, show_label=False) + with gr.Row(): + with gr.Column(): + gr.Markdown( + "#### Figure 3: Fraction of Model A Wins for All Non-tied A vs. B Battles", + elem_id="plot-title", + ) + plot_1 = gr.Plot( + p1, show_label=False, elem_id="plot-container" + ) + with gr.Column(): + gr.Markdown( + "#### Figure 4: Battle Count for Each Combination of Models (without Ties)", + elem_id="plot-title", + ) + plot_2 = gr.Plot(p2, show_label=False) + with gr.Tab("Full Leaderboard", id=1): + md = make_full_leaderboard_md(elo_results) + gr.Markdown(md, elem_id="leaderboard_markdown") + full_table_vals = get_full_table(arena_df, model_table_df) + gr.Dataframe( + headers=[ + "Model", + "Arena Elo", + "MT-bench", + "MMLU", + "Organization", + "License", + ], + datatype=["markdown", "number", "number", "number", "str", "str"], + value=full_table_vals, + elem_id="full_leaderboard_dataframe", + column_widths=[200, 100, 100, 100, 150, 150], + height=700, + wrap=True, + ) + if not show_plot: + gr.Markdown( + """ ## Visit our [HF space](https://huggingface.co/spaces/lmsys/chatbot-arena-leaderboard) for more analysis! + If you want to see more models, please help us [add them](https://github.com/lm-sys/FastChat/blob/main/docs/arena.md#how-to-add-a-new-model). + """, + elem_id="leaderboard_markdown", + ) + else: + pass + + def update_leaderboard_df(arena_table_vals): + elo_datarame = pd.DataFrame( + arena_table_vals, + columns=[ + "Rank* (UB)", + "Delta", + "Model", + "Arena Elo", + "95% CI", + "Votes", + "Organization", + "License", + "Knowledge Cutoff", + ], + ) + + # goal: color the rows based on the rank with styler + def highlight_max(s): + # all items in S which contain up arrow should be green, down arrow should be red, otherwise black + return [ + "color: green; font-weight: bold" + if "\u2191" in v + else "color: red; font-weight: bold" + if "\u2193" in v + else "" + for v in s + ] + + def highlight_rank_max(s): + return [ + "color: green; font-weight: bold" + if v > 0 + else "color: red; font-weight: bold" + if v < 0 + else "" + for v in s + ] + + return elo_datarame.style.apply(highlight_max, subset=["Rank* (UB)"]).apply( + highlight_rank_max, subset=["Delta"] + ) + + def update_leaderboard_and_plots(category): + arena_subset_df = arena_dfs[category] + arena_subset_df = arena_subset_df[arena_subset_df["num_battles"] > 300] + elo_subset_results = category_elo_results[category] + + baseline_category = cat_name_to_baseline.get(category, "Overall") + arena_df = arena_dfs[baseline_category] + arena_values = get_arena_table( + arena_df, + model_table_df, + arena_subset_df=arena_subset_df if category != "Overall" else None, + ) + if category != "Overall": + arena_values = update_leaderboard_df(arena_values) + # arena_values = highlight_top_models(arena_values) + arena_values = gr.Dataframe( + headers=[ + "Rank* (UB)", + "Delta", + "Model", + "Arena Elo", + "95% CI", + "Votes", + "Organization", + "License", + "Knowledge Cutoff", + ], + datatype=[ + "str", + "number", + "markdown", + "number", + "str", + "number", + "str", + "str", + "str", + ], + value=arena_values, + elem_id="arena_leaderboard_dataframe", + height=700, + column_widths=[70, 70, 200, 90, 100, 90, 120, 150, 100], + wrap=True, + ) + else: + # not_arena_values = pd.DataFrame(arena_values, columns=["Rank* (UB)", + # "Model", + # "Arena Elo", + # "95% CI", + # "Votes", + # "Organization", + # "License", + # "Knowledge Cutoff",], + # ) + # arena_values = highlight_top_models(not_arena_values.style) + arena_values = gr.Dataframe( + headers=[ + "Rank* (UB)", + "Model", + "Arena Elo", + "95% CI", + "Votes", + "Organization", + "License", + "Knowledge Cutoff", + ], + datatype=[ + "str", + "markdown", + "number", + "str", + "number", + "str", + "str", + "str", + ], + value=arena_values, + elem_id="arena_leaderboard_dataframe", + height=700, + column_widths=[70, 190, 100, 100, 90, 140, 150, 100], + wrap=True, + ) + + p1 = elo_subset_results["win_fraction_heatmap"] + p2 = elo_subset_results["battle_count_heatmap"] + p3 = elo_subset_results["bootstrap_elo_rating"] + p4 = elo_subset_results["average_win_rate_bar"] + more_stats_md = f"""## More Statistics for Chatbot Arena - {category} + """ + leaderboard_md = make_category_arena_leaderboard_md( + arena_df, arena_subset_df, name=category + ) + return arena_values, p1, p2, p3, p4, more_stats_md, leaderboard_md + + category_dropdown.change( + update_leaderboard_and_plots, + inputs=[category_dropdown], + outputs=[ + elo_display_df, + plot_1, + plot_2, + plot_3, + plot_4, + more_stats_md, + category_deets, + ], + ) + + from fastchat.serve.gradio_web_server import acknowledgment_md + + with gr.Accordion( + "Citation", + open=True, + ): + citation_md = """ + ### Citation + Please cite the following paper if you find our leaderboard or dataset helpful. + ``` + @misc{chiang2024chatbot, + title={Chatbot Arena: An Open Platform for Evaluating LLMs by Human Preference}, + author={Wei-Lin Chiang and Lianmin Zheng and Ying Sheng and Anastasios Nikolas Angelopoulos and Tianle Li and Dacheng Li and Hao Zhang and Banghua Zhu and Michael Jordan and Joseph E. Gonzalez and Ion Stoica}, + year={2024}, + eprint={2403.04132}, + archivePrefix={arXiv}, + primaryClass={cs.AI} + } + """ + gr.Markdown(citation_md, elem_id="leaderboard_markdown") + gr.Markdown(acknowledgment_md, elem_id="ack_markdown") + + if show_plot: + return [md_1, plot_1, plot_2, plot_3, plot_4] + return [md_1] + + +def build_demo(elo_results_file, leaderboard_table_file): + from fastchat.serve.gradio_web_server import block_css + + text_size = gr.themes.sizes.text_lg + # load theme from theme.json + theme = gr.themes.Default.load("theme.json") + # set text size to large + theme.text_size = text_size + theme.set( + button_large_text_size="40px", + button_small_text_size="40px", + button_large_text_weight="1000", + button_small_text_weight="1000", + button_shadow="*shadow_drop_lg", + button_shadow_hover="*shadow_drop_lg", + checkbox_label_shadow="*shadow_drop_lg", + button_shadow_active="*shadow_inset", + button_secondary_background_fill="*primary_300", + button_secondary_background_fill_dark="*primary_700", + button_secondary_background_fill_hover="*primary_200", + button_secondary_background_fill_hover_dark="*primary_500", + button_secondary_text_color="*primary_800", + button_secondary_text_color_dark="white", + ) + + with gr.Blocks( + title="Chatbot Arena Leaderboard", + # theme=gr.themes.Default(text_size=text_size), + theme=theme, + css=block_css, + ) as demo: + with gr.Tabs() as tabs: + with gr.Tab("Leaderboard", id=0): + leader_components = build_leaderboard_tab( + elo_results_file, + leaderboard_table_file, + show_plot=True, + mirror=False, + ) + + with gr.Tab("Basic Stats", id=1): + basic_components = build_basic_stats_tab() + + url_params = gr.JSON(visible=False) + demo.load( + load_demo, + [url_params], + basic_components + leader_components, + js=get_window_url_params_js, + ) + + return demo + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="0.0.0.0") + parser.add_argument("--port", type=int) + parser.add_argument("--share", action="store_true") + parser.add_argument("--concurrency-count", type=int, default=10) + parser.add_argument("--update-interval", type=int, default=300) + parser.add_argument("--max-num-files", type=int) + parser.add_argument("--elo-results-file", type=str) + parser.add_argument("--leaderboard-table-file", type=str) + parser.add_argument("--ban-ip-file", type=str) + parser.add_argument("--exclude-model-names", type=str, nargs="+") + args = parser.parse_args() + + logger = build_logger("monitor", "monitor.log") + logger.info(f"args: {args}") + + if args.elo_results_file is None: # Do live update + update_thread = threading.Thread( + target=update_worker, + args=( + args.max_num_files, + args.update_interval, + args.elo_results_file, + args.ban_ip_file, + args.exclude_model_names, + ), + ) + update_thread.start() + + demo = build_demo(args.elo_results_file, args.leaderboard_table_file) + demo.queue( + default_concurrency_limit=args.concurrency_count, + status_update_rate=10, + api_open=False, + ).launch( + server_name=args.host, + server_port=args.port, + share=args.share, + max_threads=200, + ) diff --git a/src/serve/monitor/summarize_cluster.py b/src/serve/monitor/summarize_cluster.py new file mode 100644 index 0000000000000000000000000000000000000000..b461a68b2bfeeaf1a660103b491edf7f0b255a21 --- /dev/null +++ b/src/serve/monitor/summarize_cluster.py @@ -0,0 +1,85 @@ +""" +Usage: +python3 summarize_cluster.py --in results_c20_kmeans_cluster.pkl --model gpt-4 --num-prompts 100 +python3 summarize_cluster.py --in results_c20_kmeans_cluster.pkl --model azure-gpt-4-32k --num-prompts 200 +""" +import argparse +import pickle + +import pandas as pd + +from fastchat.llm_judge.common import ( + chat_completion_openai, + chat_completion_openai_azure, + chat_completion_anthropic, +) +from fastchat.conversation import get_conv_template + + +def truncate_string(s, l): + half = int(l // 2) + return s[:half] + s[-half:] if len(s) > l else s + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input-file", type=str, required=True) + parser.add_argument("--model", type=str, default="gpt-3.5-turbo") + parser.add_argument("--num-prompts", type=int, default=100) + args = parser.parse_args() + + model = args.model + + cluster_infos = pickle.load(open(args.input_file, "rb")) + num_total_prompts = sum([x[0] for x in cluster_infos]) + + topics = [] + percentages = [] + for i, info in enumerate(cluster_infos): + num_samples, topk_prompts, random_prompts = info + percentage = num_samples / num_total_prompts + print( + f"cluster {i}, #prompts {num_samples}, percentage: {percentage * 100:.2f}%" + ) + instruct = "Given a list of user messages, use less than 8 words to summarize a central topic for all messages in English. Your output should only include a single line. Try to be specific." + split = int(args.num_prompts * 0.8) + prompt = "\n".join( + [truncate_string(x, l=200) for x in topk_prompts[:split]] + + [ + truncate_string(x, l=200) + for x in random_prompts[: args.num_prompts - split] + ] + ) + prompt = "BEGIN OF THE MESSAGE LIST\n" + prompt + "\nEND OF THE MESSAGE LIST." + + if "azure-" in model: + template_name = "chatgpt" + completion_func = chat_completion_openai_azure + elif "gpt" in model: + template_name = "chatgpt" + completion_func = chat_completion_openai + elif "claude" in model: + template_name = "claude" + completion_func = chat_completion_anthropic + + conv = get_conv_template(template_name) + conv.set_system_message(instruct) + conv.append_message(conv.roles[0], prompt) + conv.append_message(conv.roles[1], None) + + topic = completion_func(model, conv, temperature=0, max_tokens=256) + print(topic) + + topics.append(topic) + percentages.append(round(percentage, 6)) + + print() + print(f"topics: {topics}") + print(f"percentages: {percentages}") + + # save the informations + df = pd.DataFrame() + df["topic"] = topics + df["percentage"] = percentages + + df.to_json(f"cluster_summary_{len(df)}.jsonl", lines=True, orient="records") diff --git a/src/serve/monitor/tag_openai_moderation.py b/src/serve/monitor/tag_openai_moderation.py new file mode 100644 index 0000000000000000000000000000000000000000..b80703388b2a47bf372a09bbed81d7bede2bd412 --- /dev/null +++ b/src/serve/monitor/tag_openai_moderation.py @@ -0,0 +1,63 @@ +""" +Add OpenAI moderation API results to all conversations. +""" +import argparse +from concurrent.futures import ThreadPoolExecutor +import json +import os +import time + +import openai +import requests +from tqdm import tqdm + + +API_MAX_RETRY = 16 +API_RETRY_SLEEP = 10 +API_ERROR_OUTPUT = "$ERROR$" + + +def tag_moderation(text): + result = API_ERROR_OUTPUT + for _ in range(API_MAX_RETRY): + try: + result = openai.Moderation.create(input=text)["results"][0] + break + except openai.error.OpenAIError as e: + print(type(e), e) + time.sleep(API_RETRY_SLEEP) + + return result + + +def tag_openai_moderation(x): + conv = x["conversation_a"] + user_prompts = "\n".join([x["content"] for x in conv if x["role"] == "user"]) + result = tag_moderation(user_prompts) + x["openai_moderation"] = result + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input", type=str, required=True) + parser.add_argument( + "--parallel", type=int, default=1, help="The number of concurrent API calls." + ) + parser.add_argument("--first-n", type=int) + args = parser.parse_args() + + battles = json.load(open(args.input)) + + if args.first_n: + battles = battles[: args.first_n] + + with ThreadPoolExecutor(args.parallel) as executor: + for line in tqdm( + executor.map(tag_openai_moderation, battles), total=len(battles) + ): + pass + + output = args.input.replace(".json", "_tagged.json") + with open(output, "w") as fout: + json.dump(battles, fout, indent=2, ensure_ascii=False) + print(f"Write cleaned data to {output}") diff --git a/src/serve/monitor/topic_clustering.py b/src/serve/monitor/topic_clustering.py new file mode 100644 index 0000000000000000000000000000000000000000..3d58e56bf3a749fe13480ea0eb7f7d2d412accc7 --- /dev/null +++ b/src/serve/monitor/topic_clustering.py @@ -0,0 +1,292 @@ +""" + +Usage: +python3 topic_clustering.py --in arena.json --english-only --min-length 32 +python3 topic_clustering.py --in clean_conv_20230809_100k.json --english-only --min-length 32 --max-length 1536 +""" +import argparse +import json +import pickle +import string +import time + +import numpy as np +from sentence_transformers import SentenceTransformer +from sentence_transformers.util import cos_sim +from sklearn.cluster import KMeans, AgglomerativeClustering +import torch +from tqdm import tqdm +from openai import OpenAI + +from fastchat.utils import detect_language + + +def remove_punctuation(input_string): + # Make a translator object to remove all punctuation + translator = str.maketrans("", "", string.punctuation) + + # Use the translator object to remove the punctuation + no_punct = input_string.translate(translator) + return no_punct + + +def read_texts(input_file, min_length, max_length, english_only): + visited = set() + texts = [] + + lines = json.load(open(input_file, "r")) + + for l in tqdm(lines): + if "text" in l: + line_texts = [l["text"]] + elif "conversation_a" in l: + line_texts = [ + x["content"] for x in l["conversation_a"] if x["role"] == "user" + ] + elif "conversation" in l: + line_texts = [ + x["content"] for x in l["conversation"] if x["role"] == "user" + ] + elif "turns" in l: + line_texts = l["turns"] + + for text in line_texts: + text = text.strip() + + # Filter language + if english_only: + lang = detect_language(text) + if lang != "English": + continue + + # Filter short or long prompts + if min_length: + if len(text) < min_length: + continue + + if max_length: + if len(text) > max_length: + continue + + # De-duplication + words = sorted([x.lower() for x in remove_punctuation(text).split(" ")]) + words = "".join(words) + if words in visited: + continue + + visited.add(words) + texts.append(text) + return np.array(texts) + + +def get_embeddings(texts, model_name, batch_size): + if model_name == "text-embedding-ada-002": + client = OpenAI() + texts = texts.tolist() + + embeddings = [] + for i in tqdm(range(0, len(texts), batch_size)): + text = texts[i : i + batch_size] + responses = client.embeddings.create(input=text, model=model_name).data + embeddings.extend([data.embedding for data in responses]) + embeddings = torch.tensor(embeddings) + else: + model = SentenceTransformer(model_name) + embeddings = model.encode( + texts, + batch_size=batch_size, + show_progress_bar=True, + device="cuda", + convert_to_tensor=True, + ) + + embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) + return embeddings.cpu() + + +def run_k_means(embeddings, num_clusters): + np.random.seed(42) + clustering_model = KMeans(n_clusters=num_clusters, n_init="auto") + clustering_model.fit(embeddings.numpy()) + centers = torch.from_numpy(clustering_model.cluster_centers_) + labels = torch.from_numpy(clustering_model.labels_) + + # Sort labels + classes, counts = np.unique(labels, return_counts=True) + indices = np.argsort(counts)[::-1] + classes = [classes[i] for i in indices] + new_labels = torch.empty_like(labels) + new_centers = torch.empty_like(centers) + for i, c in enumerate(classes): + new_labels[labels == c] = i + new_centers[i] = centers[c] + return new_centers, new_labels + + +def run_agg_cluster(embeddings, num_clusters): + np.random.seed(42) + clustering_model = AgglomerativeClustering(n_clusters=num_clusters) + clustering_model.fit(embeddings) + labels = torch.from_numpy(clustering_model.labels_) + + # Sort labels + classes, counts = np.unique(labels, return_counts=True) + indices = np.argsort(counts)[::-1] + classes = [classes[i] for i in indices] + new_labels = torch.empty_like(labels) + for i, c in enumerate(classes): + new_labels[labels == c] = i + + # Compute centers + centers = [] + for i in range(len(classes)): + centers.append(embeddings[new_labels == i].mean(axis=0, keepdim=True)) + centers = torch.cat(centers) + return centers, new_labels + + +def run_hdbscan_cluster(embeddings): + import hdbscan + + np.random.seed(42) + clusterer = hdbscan.HDBSCAN(min_cluster_size=10) + labels = torch.from_numpy(clusterer.fit_predict(embeddings)) + + # Sort labels + classes, counts = np.unique(labels, return_counts=True) + indices = np.argsort(counts)[::-1] + classes = [classes[i] for i in indices] + new_labels = torch.empty_like(labels) + for i, c in enumerate(classes): + new_labels[labels == c] = i + + # Compute centers + centers = [] + for i in range(len(classes)): + centers.append(embeddings[new_labels == i].mean(axis=0, keepdim=True)) + centers = torch.cat(centers) + return centers, new_labels + + +def get_topk_indices(centers, labels, embeddings, topk): + indices = [] + arange = torch.arange(len(labels)) + counts = torch.unique(labels, return_counts=True)[1] + topk = min(topk, counts.min().item()) + for i in range(len(centers)): + tmp_indices = labels == i + tmp_arange = arange[tmp_indices] + tmp_embeddings = embeddings[tmp_indices] + + scores = cos_sim(centers[i].unsqueeze(0), tmp_embeddings)[0] + sorted_indices = torch.flip(torch.argsort(scores), dims=[0]) + indices.append(tmp_arange[sorted_indices[:topk]].unsqueeze(0)) + return torch.cat(indices) + + +def print_topk(texts, labels, topk_indices, show_cut_off): + ret = "" + for k in range(len(topk_indices)): + num_samples = torch.sum(labels == k).item() + + ret += "=" * 20 + f" cluster {k}, #samples: {num_samples} " + "=" * 20 + "\n" + for idx in topk_indices[k]: + ret += "PROMPT: " + texts[idx][:show_cut_off] + "\n" + ret += "=" * 40 + "\n\n" + + return ret + + +def get_cluster_info(texts, labels, topk_indices): + np.random.seed(42) + + cluster_info = [] + for k in range(len(topk_indices)): + num_samples = torch.sum(labels == k).item() + topk_prompts = [] + for idx in topk_indices[k]: + topk_prompts.append(texts[idx]) + random_prompts = [] + for idx in range(len(topk_indices)): + random_prompts.append(np.random.choice(texts)) + cluster_info.append((num_samples, topk_prompts, random_prompts)) + + return cluster_info + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input-file", type=str, required=True) + parser.add_argument("--model", type=str, default="all-mpnet-base-v2") + # default="all-MiniLM-L12-v2") + # default="multi-qa-distilbert-cos-v1") + parser.add_argument("--batch-size", type=int, default=256) + parser.add_argument("--min-length", type=int) + parser.add_argument("--max-length", type=int) + parser.add_argument("--english-only", action="store_true") + parser.add_argument("--num-clusters", type=int, default=20) + parser.add_argument( + "--cluster-alg", + type=str, + choices=["kmeans", "aggcls", "HDBSCAN"], + default="kmeans", + ) + parser.add_argument("--show-top-k", type=int, default=200) + parser.add_argument("--show-cut-off", type=int, default=512) + parser.add_argument("--save-embeddings", action="store_true") + parser.add_argument("--embeddings-file", type=str, default=None) + args = parser.parse_args() + + num_clusters = args.num_clusters + show_top_k = args.show_top_k + show_cut_off = args.show_cut_off + + texts = read_texts( + args.input_file, args.min_length, args.max_length, args.english_only + ) + print(f"#text: {len(texts)}") + + if args.embeddings_file is None: + embeddings = get_embeddings(texts, args.model, args.batch_size) + if args.save_embeddings: + # allow saving embedding to save time and money + torch.save(embeddings, "embeddings.pt") + else: + embeddings = torch.load(args.embeddings_file) + print(f"embeddings shape: {embeddings.shape}") + + if args.cluster_alg == "kmeans": + centers, labels = run_k_means(embeddings, num_clusters) + elif args.cluster_alg == "aggcls": + centers, labels = run_agg_cluster(embeddings, num_clusters) + elif args.cluster_alg == "HDBSCAN": + centers, labels = run_hdbscan_cluster(embeddings) + else: + raise ValueError(f"Invalid clustering algorithm: {args.cluster_alg}") + + topk_indices = get_topk_indices(centers, labels, embeddings, args.show_top_k) + topk_str = print_topk(texts, labels, topk_indices, args.show_cut_off) + num_clusters = len(centers) + + # Dump results + filename_prefix = f"results_c{num_clusters}_{args.cluster_alg}" + print(topk_str) + with open(filename_prefix + "_topk.txt", "w") as fout: + fout.write(topk_str) + + with open(filename_prefix + "_all.jsonl", "w") as fout: + for i in range(len(centers)): + tmp_indices = labels == i + tmp_embeddings = embeddings[tmp_indices] + tmp_texts = texts[tmp_indices] + + scores = cos_sim(centers[i].unsqueeze(0), tmp_embeddings)[0] + sorted_indices = torch.flip(torch.argsort(scores), dims=[0]) + + for text, score in zip(tmp_texts[sorted_indices], scores[sorted_indices]): + obj = {"cluster": i, "text": text, "sim": score.item()} + fout.write(json.dumps(obj, ensure_ascii=False) + "\n") + + cluster_info = get_cluster_info(texts, labels, topk_indices) + with open(filename_prefix + "_cluster.pkl", "wb") as fout: + pickle.dump(cluster_info, fout) diff --git a/src/serve/monitor/vote_time_stats/README.md b/src/serve/monitor/vote_time_stats/README.md new file mode 100644 index 0000000000000000000000000000000000000000..7404839b98db34ded6a5fd9c8749fef27766b8a8 --- /dev/null +++ b/src/serve/monitor/vote_time_stats/README.md @@ -0,0 +1,5 @@ +# Instructions + +First run `analyze_data.py` to collect metadata of all votes. + +Then run `plot.py` to get the plot. You need to edit these files for proper input or output filename diff --git a/src/serve/monitor/vote_time_stats/analyze_data.py b/src/serve/monitor/vote_time_stats/analyze_data.py new file mode 100644 index 0000000000000000000000000000000000000000..4bdd18694f55fb2e29aed357dd2637e0477966e3 --- /dev/null +++ b/src/serve/monitor/vote_time_stats/analyze_data.py @@ -0,0 +1,120 @@ +import datetime +import glob +import json +from collections import deque +import tqdm + + +def _serialize_json(data): + # Serialize JSON with sorted keys and no whitespace + return json.dumps(data, sort_keys=True, separators=(",", ":")).encode("utf-8") + + +types = { + "share", + "chat", + "flag", + "bothbad_vote", + "downvote", + "leftvote", + "rightvote", + "upvote", + "tievote", +} + +chat_dict = {} +cache_queue = deque() + + +def process_record(r): + ip = r.pop("ip", None) + tstamp = r.pop("tstamp") + mtype = r.pop("type") + start = r.pop("start", None) + finish = r.pop("finish", None) + + # gabagge collect to save memory + while len(cache_queue) > 100000: + outdated = cache_queue.popleft() + poped_item = chat_dict.pop(outdated["key"], None) + if poped_item is None: + # TODO: this sometimes happens, need to investigate what happens. in theory the chat dict should be synced with the queue, unless there are duplicated items + print("Error: Key to GC does not exist.") + + assert mtype in types + if mtype == "chat": + key = _serialize_json(r["state"]) + # TODO: add the string length of the last reply for analyzing voting time per character. + chat_dict[key] = { + "timestamp": tstamp, + "start": start, + "finish": finish, + "conv_id": r["state"]["conv_id"], + } + cache_queue.append({"key": key, "timestamp": tstamp}) + elif mtype in ("leftvote", "rightvote", "bothbad_vote", "tievote"): + left_key = _serialize_json(r["states"][0]) + right_key = _serialize_json(r["states"][1]) + if left_key not in chat_dict: + # TODO: this sometimes happens, it means we have the vote but we cannot find previous chat, need to investigate what happens + print( + f'WARNING: Cannot find vote context for conversation {r["states"][0]["conv_id"]}' + ) + return + if right_key not in chat_dict: + print( + f'WARNING: Cannot find vote context for conversation {r["states"][1]["conv_id"]}' + ) + return + vote_time_data = { + "timestamp": tstamp, + "type": mtype, + "left": chat_dict[left_key], + "right": chat_dict[right_key], + "ip": ip, + } + return vote_time_data + + return None + + +def process_file(infile: str, outfile: str): + with open(infile) as f: + records = [] + for l in f.readlines(): + l = l.strip() + if l: + try: + r = json.loads(l) + if r.get("tstamp") is not None: + records.append(r) + except Exception: + pass + # sort the record in case there are out-of-order records + records.sort(key=lambda x: x["tstamp"]) + + with open(outfile, "a") as outfile: + for r in records: + try: + output = process_record(r) + if output is not None: + outfile.write(json.dumps(output) + "\n") + except Exception as e: + import traceback + + print("Error:", e) + traceback.print_exc() + + +today = datetime.datetime.today().isoformat().split("T", 1)[0] +# sort it to make sure the date is continuous for each server +filelist = sorted(glob.glob("/mnt/disks/data/fastchat_logs/server*/202*-*-*-conv.json")) +filelist = [ + f for f in filelist if today not in f +] # skip today because date could be partial + +# TODO: change this to select different range of data +filelist = [f for f in filelist if "2024-03-" in f] + +for f in tqdm.tqdm(filelist): + process_file(f, "output.jsonl") diff --git a/src/serve/monitor/vote_time_stats/plot.py b/src/serve/monitor/vote_time_stats/plot.py new file mode 100644 index 0000000000000000000000000000000000000000..ba6fd5e37a4524dda3c7ae7a05313fefbe030a86 --- /dev/null +++ b/src/serve/monitor/vote_time_stats/plot.py @@ -0,0 +1,66 @@ +import json +import matplotlib.pyplot as plt +import seaborn as sns +import numpy as np + + +infile = "output.jsonl" +date = "2024-03" # used in the plot + +durations = [] + +with open(infile) as f: + for line in f: + data = json.loads(line) + l = data["left"]["finish"] + r = data["right"]["finish"] + v = data["timestamp"] + durations.append(v - max(l, r)) + +print( + f"Avg: {np.mean(durations)}, Median: {np.median(durations)}, Max: {np.max(durations)}" +) + +# Define the new cutoff and number of bins +cutoff = 200.0 # New cutoff value +num_bins_inside_cutoff = 20 # Number of bins from 0 to cutoff + +for i, n in enumerate(durations): + if n > cutoff: + durations[i] = cutoff + 0.5 * cutoff / num_bins_inside_cutoff + +# Create bin edges from 0 to cutoff, with the specified number of bins +bin_edges = np.linspace(0, cutoff, num_bins_inside_cutoff + 1) + +# Adjusting the overflow bin to end at 110 +overflow_cap = ( + cutoff + cutoff / num_bins_inside_cutoff +) # Adjust as needed based on distribution +bin_edges = np.append(bin_edges, overflow_cap) + +# Create the plot with custom bins +sns.histplot( + durations, bins=bin_edges, kde=False +) # Turn off KDE for clearer bar visibility +plt.title(f'Distribution of "time to vote" {date}') +plt.xlabel("Duration (seconds)") +plt.ylabel("Frequency") + +# Highlight the overflow bin +plt.axvline(x=cutoff, color="red", linestyle="--") +plt.text( + cutoff + 1, plt.ylim()[1] * 0.9, "Overflow", color="red", ha="left" +) # Adjust text alignment + +# Customizing x-axis labels to hide the "110" +ax = plt.gca() # Get current axis +labels = [item.get_text() for item in ax.get_xticklabels()] +if "110" in labels: + labels[labels.index("110")] = "" # Replace "110" with an empty string +ax.set_xticklabels(labels) + +# Ensure nothing is cut off in the plot +plt.tight_layout() + +# Save the plot to a file with high resolution +plt.savefig(f"duration_distribution_time_to_vote_{date}.png", dpi=300) diff --git a/src/serve/multi_model_worker.py b/src/serve/multi_model_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..5e6266fe0449b59529825c7e8982cccf7d871e0e --- /dev/null +++ b/src/serve/multi_model_worker.py @@ -0,0 +1,300 @@ +""" +A multi-model worker that contains multiple sub-works one for each model. This +supports running a list of models on the same machine so that they can +(potentially) share the same background weights. + +Each model can have one or more model names. + +This multi-model worker assumes the models shares some underlying weights and +thus reports the combined queue lengths for health checks. + +We recommend using this with multiple Peft models (with `peft` in the name) +where all Peft models are trained on the exact same base model. +""" +import argparse +import asyncio +import dataclasses +import logging +import json +import os +import time +from typing import List, Union +import threading +import uuid + +from fastapi import FastAPI, Request, BackgroundTasks +from fastapi.responses import StreamingResponse, JSONResponse +import requests + +try: + from transformers import ( + AutoTokenizer, + AutoModelForCausalLM, + LlamaTokenizer, + AutoModel, + ) +except ImportError: + from transformers import ( + AutoTokenizer, + AutoModelForCausalLM, + LLaMATokenizer, + AutoModel, + ) +import torch +import torch.nn.functional as F +import uvicorn + +from fastchat.constants import WORKER_HEART_BEAT_INTERVAL, ErrorCode, SERVER_ERROR_MSG +from fastchat.model.model_adapter import ( + load_model, + add_model_args, + get_conversation_template, +) +from fastchat.model.model_chatglm import generate_stream_chatglm +from fastchat.model.model_falcon import generate_stream_falcon +from fastchat.model.model_codet5p import generate_stream_codet5p +from fastchat.modules.gptq import GptqConfig +from fastchat.modules.exllama import ExllamaConfig +from fastchat.modules.xfastertransformer import XftConfig +from fastchat.serve.inference import generate_stream +from fastchat.serve.model_worker import ModelWorker, worker_id, logger +from fastchat.utils import build_logger, pretty_print_semaphore, get_context_length + + +# We store both the underlying workers and a mapping from their model names to +# the worker instance. This makes it easy to fetch the appropriate worker for +# each API call. +workers = [] +worker_map = {} +app = FastAPI() + + +def release_worker_semaphore(): + workers[0].semaphore.release() + + +def acquire_worker_semaphore(): + if workers[0].semaphore is None: + # Share the same semaphore for all workers because + # all workers share the same GPU. + semaphore = asyncio.Semaphore(workers[0].limit_worker_concurrency) + for w in workers: + w.semaphore = semaphore + return workers[0].semaphore.acquire() + + +def create_background_tasks(): + background_tasks = BackgroundTasks() + background_tasks.add_task(release_worker_semaphore) + return background_tasks + + +# Note: for all the calls below, we make a hard assumption that the caller +# includes the model name in the payload, otherwise we can't figure out which +# underlying sub-worker to call. + + +@app.post("/worker_generate_stream") +async def api_generate_stream(request: Request): + params = await request.json() + await acquire_worker_semaphore() + worker = worker_map[params["model"]] + generator = worker.generate_stream_gate(params) + background_tasks = create_background_tasks() + return StreamingResponse(generator, background=background_tasks) + + +@app.post("/worker_generate") +async def api_generate(request: Request): + params = await request.json() + await acquire_worker_semaphore() + worker = worker_map[params["model"]] + output = worker.generate_gate(params) + release_worker_semaphore() + return JSONResponse(output) + + +@app.post("/worker_get_embeddings") +async def api_get_embeddings(request: Request): + params = await request.json() + await acquire_worker_semaphore() + worker = worker_map[params["model"]] + embedding = worker.get_embeddings(params) + background_tasks = create_background_tasks() + return JSONResponse(content=embedding, background=background_tasks) + + +@app.post("/worker_get_status") +async def api_get_status(request: Request): + return { + "model_names": [m for w in workers for m in w.model_names], + "speed": 1, + "queue_length": sum([w.get_queue_length() for w in workers]), + } + + +@app.post("/count_token") +async def api_count_token(request: Request): + params = await request.json() + worker = worker_map[params["model"]] + return worker.count_token(params) + + +@app.post("/worker_get_conv_template") +async def api_get_conv(request: Request): + params = await request.json() + worker = worker_map[params["model"]] + return worker.get_conv_template() + + +@app.post("/model_details") +async def api_model_details(request: Request): + params = await request.json() + worker = worker_map[params["model"]] + return {"context_length": worker.context_len} + + +def create_multi_model_worker(): + # Note: Ensure we resolve arg conflicts. We let `add_model_args` add MOST + # of the model args but we'll override one to have an append action that + # supports multiple values. + parser = argparse.ArgumentParser(conflict_handler="resolve") + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=21002) + parser.add_argument("--worker-address", type=str, default="http://localhost:21002") + parser.add_argument( + "--controller-address", type=str, default="http://localhost:21001" + ) + add_model_args(parser) + # Override the model path to be repeated and align it with model names. + parser.add_argument( + "--model-path", + type=str, + default=[], + action="append", + help="One or more paths to model weights to load. This can be a local folder or a Hugging Face repo ID.", + ) + parser.add_argument( + "--model-names", + type=lambda s: s.split(","), + action="append", + help="One or more model names. Values must be aligned with `--model-path` values.", + ) + parser.add_argument( + "--conv-template", + type=str, + default=None, + action="append", + help="Conversation prompt template. Values must be aligned with `--model-path` values. If only one value is provided, it will be repeated for all models.", + ) + parser.add_argument("--limit-worker-concurrency", type=int, default=5) + parser.add_argument("--stream-interval", type=int, default=2) + parser.add_argument("--no-register", action="store_true") + parser.add_argument( + "--ssl", + action="store_true", + required=False, + default=False, + help="Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.", + ) + args = parser.parse_args() + logger.info(f"args: {args}") + + if args.gpus: + if len(args.gpus.split(",")) < args.num_gpus: + raise ValueError( + f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!" + ) + os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus + + gptq_config = GptqConfig( + ckpt=args.gptq_ckpt or args.model_path, + wbits=args.gptq_wbits, + groupsize=args.gptq_groupsize, + act_order=args.gptq_act_order, + ) + if args.enable_exllama: + exllama_config = ExllamaConfig( + max_seq_len=args.exllama_max_seq_len, + gpu_split=args.exllama_gpu_split, + cache_8bit=args.exllama_cache_8bit, + ) + else: + exllama_config = None + if args.enable_xft: + xft_config = XftConfig( + max_seq_len=args.xft_max_seq_len, + data_type=args.xft_dtype, + ) + if args.device != "cpu": + print("xFasterTransformer now is only support CPUs. Reset device to CPU") + args.device = "cpu" + else: + xft_config = None + + if args.model_names is None: + args.model_names = [[x.split("/")[-1]] for x in args.model_path] + + if args.conv_template is None: + args.conv_template = [None] * len(args.model_path) + elif len(args.conv_template) == 1: # Repeat the same template + args.conv_template = args.conv_template * len(args.model_path) + + # Launch all workers + workers = [] + for conv_template, model_path, model_names in zip( + args.conv_template, args.model_path, args.model_names + ): + w = ModelWorker( + args.controller_address, + args.worker_address, + worker_id, + model_path, + model_names, + args.limit_worker_concurrency, + args.no_register, + device=args.device, + num_gpus=args.num_gpus, + max_gpu_memory=args.max_gpu_memory, + load_8bit=args.load_8bit, + cpu_offloading=args.cpu_offloading, + gptq_config=gptq_config, + exllama_config=exllama_config, + xft_config=xft_config, + stream_interval=args.stream_interval, + conv_template=conv_template, + ) + workers.append(w) + for model_name in model_names: + worker_map[model_name] = w + + # Register all models + url = args.controller_address + "/register_worker" + data = { + "worker_name": workers[0].worker_addr, + "check_heart_beat": not args.no_register, + "worker_status": { + "model_names": [m for w in workers for m in w.model_names], + "speed": 1, + "queue_length": sum([w.get_queue_length() for w in workers]), + }, + } + r = requests.post(url, json=data) + assert r.status_code == 200 + + return args, workers + + +if __name__ == "__main__": + args, workers = create_multi_model_worker() + if args.ssl: + uvicorn.run( + app, + host=args.host, + port=args.port, + log_level="info", + ssl_keyfile=os.environ["SSL_KEYFILE"], + ssl_certfile=os.environ["SSL_CERTFILE"], + ) + else: + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/src/serve/openai_api_server.py b/src/serve/openai_api_server.py new file mode 100644 index 0000000000000000000000000000000000000000..a6ffee96bf4f9ea39f49e9a309cab92a651fdfb9 --- /dev/null +++ b/src/serve/openai_api_server.py @@ -0,0 +1,939 @@ +"""A server that provides OpenAI-compatible RESTful APIs. It supports: + +- Chat Completions. (Reference: https://platform.openai.com/docs/api-reference/chat) +- Completions. (Reference: https://platform.openai.com/docs/api-reference/completions) +- Embeddings. (Reference: https://platform.openai.com/docs/api-reference/embeddings) + +Usage: +python3 -m fastchat.serve.openai_api_server +""" +import asyncio +import argparse +import json +import os +from typing import Generator, Optional, Union, Dict, List, Any + +import aiohttp +import fastapi +from fastapi import Depends, HTTPException +from fastapi.exceptions import RequestValidationError +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import StreamingResponse, JSONResponse +from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer +import httpx + +from pydantic_settings import BaseSettings +import shortuuid +import tiktoken +import uvicorn + +from fastchat.constants import ( + WORKER_API_TIMEOUT, + WORKER_API_EMBEDDING_BATCH_SIZE, + ErrorCode, +) +from fastchat.conversation import Conversation, SeparatorStyle +from fastchat.protocol.openai_api_protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseStreamChoice, + ChatCompletionStreamResponse, + ChatMessage, + ChatCompletionResponseChoice, + CompletionRequest, + CompletionResponse, + CompletionResponseChoice, + DeltaMessage, + CompletionResponseStreamChoice, + CompletionStreamResponse, + EmbeddingsRequest, + EmbeddingsResponse, + ErrorResponse, + LogProbs, + ModelCard, + ModelList, + ModelPermission, + UsageInfo, +) +from fastchat.protocol.api_protocol import ( + APIChatCompletionRequest, + APITokenCheckRequest, + APITokenCheckResponse, + APITokenCheckResponseItem, +) +from fastchat.utils import build_logger + +logger = build_logger("openai_api_server", "openai_api_server.log") + +conv_template_map = {} + +fetch_timeout = aiohttp.ClientTimeout(total=3 * 3600) + + +async def fetch_remote(url, pload=None, name=None): + async with aiohttp.ClientSession(timeout=fetch_timeout) as session: + async with session.post(url, json=pload) as response: + chunks = [] + if response.status != 200: + ret = { + "text": f"{response.reason}", + "error_code": ErrorCode.INTERNAL_ERROR, + } + return json.dumps(ret) + + async for chunk, _ in response.content.iter_chunks(): + chunks.append(chunk) + output = b"".join(chunks) + + if name is not None: + res = json.loads(output) + if name != "": + res = res[name] + return res + + return output + + +class AppSettings(BaseSettings): + # The address of the model controller. + controller_address: str = "http://localhost:21001" + api_keys: Optional[List[str]] = None + + +app_settings = AppSettings() +app = fastapi.FastAPI() +headers = {"User-Agent": "FastChat API Server"} +get_bearer_token = HTTPBearer(auto_error=False) + + +async def check_api_key( + auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token), +) -> str: + if app_settings.api_keys: + if auth is None or (token := auth.credentials) not in app_settings.api_keys: + raise HTTPException( + status_code=401, + detail={ + "error": { + "message": "", + "type": "invalid_request_error", + "param": None, + "code": "invalid_api_key", + } + }, + ) + return token + else: + # api_keys not set; allow all + return None + + +def create_error_response(code: int, message: str) -> JSONResponse: + return JSONResponse( + ErrorResponse(message=message, code=code).model_dump(), status_code=400 + ) + + +@app.exception_handler(RequestValidationError) +async def validation_exception_handler(request, exc): + return create_error_response(ErrorCode.VALIDATION_TYPE_ERROR, str(exc)) + + +async def check_model(request) -> Optional[JSONResponse]: + controller_address = app_settings.controller_address + ret = None + + models = await fetch_remote(controller_address + "/list_models", None, "models") + if request.model not in models: + ret = create_error_response( + ErrorCode.INVALID_MODEL, + f"Only {'&&'.join(models)} allowed now, your model {request.model}", + ) + return ret + + +async def check_length(request, prompt, max_tokens, worker_addr): + if ( + not isinstance(max_tokens, int) or max_tokens <= 0 + ): # model worker not support max_tokens=None + max_tokens = 1024 * 1024 + + context_len = await fetch_remote( + worker_addr + "/model_details", {"model": request.model}, "context_length" + ) + token_num = await fetch_remote( + worker_addr + "/count_token", + {"model": request.model, "prompt": prompt}, + "count", + ) + length = min(max_tokens, context_len - token_num) + + if length <= 0: + return None, create_error_response( + ErrorCode.CONTEXT_OVERFLOW, + f"This model's maximum context length is {context_len} tokens. However, your messages resulted in {token_num} tokens. Please reduce the length of the messages.", + ) + + return length, None + + +def check_requests(request) -> Optional[JSONResponse]: + # Check all params + if request.max_tokens is not None and request.max_tokens <= 0: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.max_tokens} is less than the minimum of 1 - 'max_tokens'", + ) + if request.n is not None and request.n <= 0: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.n} is less than the minimum of 1 - 'n'", + ) + if request.temperature is not None and request.temperature < 0: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.temperature} is less than the minimum of 0 - 'temperature'", + ) + if request.temperature is not None and request.temperature > 2: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.temperature} is greater than the maximum of 2 - 'temperature'", + ) + if request.top_p is not None and request.top_p < 0: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.top_p} is less than the minimum of 0 - 'top_p'", + ) + if request.top_p is not None and request.top_p > 1: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.top_p} is greater than the maximum of 1 - 'top_p'", + ) + if request.top_k is not None and (request.top_k > -1 and request.top_k < 1): + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.top_k} is out of Range. Either set top_k to -1 or >=1.", + ) + if request.stop is not None and ( + not isinstance(request.stop, str) and not isinstance(request.stop, list) + ): + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.stop} is not valid under any of the given schemas - 'stop'", + ) + + return None + + +def process_input(model_name, inp): + if isinstance(inp, str): + inp = [inp] + elif isinstance(inp, list): + if isinstance(inp[0], int): + try: + decoding = tiktoken.model.encoding_for_model(model_name) + except KeyError: + logger.warning("Warning: model not found. Using cl100k_base encoding.") + model = "cl100k_base" + decoding = tiktoken.get_encoding(model) + inp = [decoding.decode(inp)] + elif isinstance(inp[0], list): + try: + decoding = tiktoken.model.encoding_for_model(model_name) + except KeyError: + logger.warning("Warning: model not found. Using cl100k_base encoding.") + model = "cl100k_base" + decoding = tiktoken.get_encoding(model) + inp = [decoding.decode(text) for text in inp] + + return inp + + +def create_openai_logprobs(logprob_dict): + """Create OpenAI-style logprobs.""" + return LogProbs(**logprob_dict) if logprob_dict is not None else None + + +def _add_to_set(s, new_stop): + if not s: + return + if isinstance(s, str): + new_stop.add(s) + else: + new_stop.update(s) + + +async def get_gen_params( + model_name: str, + worker_addr: str, + messages: Union[str, List[Dict[str, str]]], + *, + temperature: float, + top_p: float, + top_k: Optional[int], + presence_penalty: Optional[float], + frequency_penalty: Optional[float], + max_tokens: Optional[int], + echo: Optional[bool], + logprobs: Optional[int] = None, + stop: Optional[Union[str, List[str]]], + best_of: Optional[int] = None, + use_beam_search: Optional[bool] = None, +) -> Dict[str, Any]: + conv = await get_conv(model_name, worker_addr) + conv = Conversation( + name=conv["name"], + system_template=conv["system_template"], + system_message=conv["system_message"], + roles=conv["roles"], + messages=list(conv["messages"]), # prevent in-place modification + offset=conv["offset"], + sep_style=SeparatorStyle(conv["sep_style"]), + sep=conv["sep"], + sep2=conv["sep2"], + stop_str=conv["stop_str"], + stop_token_ids=conv["stop_token_ids"], + ) + + if isinstance(messages, str): + prompt = messages + images = [] + else: + for message in messages: + msg_role = message["role"] + if msg_role == "system": + conv.set_system_message(message["content"]) + elif msg_role == "user": + if type(message["content"]) == list: + image_list = [ + item["image_url"]["url"] + for item in message["content"] + if item["type"] == "image_url" + ] + text_list = [ + item["text"] + for item in message["content"] + if item["type"] == "text" + ] + + # TODO(chris): This only applies to LLaVA model. Implement an image_token string in the conv template. + text = "\n" * len(image_list) + text += "\n".join(text_list) + conv.append_message(conv.roles[0], (text, image_list)) + else: + conv.append_message(conv.roles[0], message["content"]) + elif msg_role == "assistant": + conv.append_message(conv.roles[1], message["content"]) + else: + raise ValueError(f"Unknown role: {msg_role}") + + # Add a blank message for the assistant. + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + images = conv.get_images() + + gen_params = { + "model": model_name, + "prompt": prompt, + "temperature": temperature, + "logprobs": logprobs, + "top_p": top_p, + "top_k": top_k, + "presence_penalty": presence_penalty, + "frequency_penalty": frequency_penalty, + "max_new_tokens": max_tokens, + "echo": echo, + "stop_token_ids": conv.stop_token_ids, + } + + if len(images) > 0: + gen_params["images"] = images + + if best_of is not None: + gen_params.update({"best_of": best_of}) + if use_beam_search is not None: + gen_params.update({"use_beam_search": use_beam_search}) + + new_stop = set() + _add_to_set(stop, new_stop) + _add_to_set(conv.stop_str, new_stop) + + gen_params["stop"] = list(new_stop) + + logger.debug(f"==== request ====\n{gen_params}") + return gen_params + + +async def get_worker_address(model_name: str) -> str: + """ + Get worker address based on the requested model + + :param model_name: The worker's model name + :return: Worker address from the controller + :raises: :class:`ValueError`: No available worker for requested model + """ + controller_address = app_settings.controller_address + worker_addr = await fetch_remote( + controller_address + "/get_worker_address", {"model": model_name}, "address" + ) + + # No available worker + if worker_addr == "": + raise ValueError(f"No available worker for {model_name}") + logger.debug(f"model_name: {model_name}, worker_addr: {worker_addr}") + return worker_addr + + +async def get_conv(model_name: str, worker_addr: str): + conv_template = conv_template_map.get((worker_addr, model_name)) + if conv_template is None: + conv_template = await fetch_remote( + worker_addr + "/worker_get_conv_template", {"model": model_name}, "conv" + ) + conv_template_map[(worker_addr, model_name)] = conv_template + return conv_template + + +@app.get("/v1/models", dependencies=[Depends(check_api_key)]) +async def show_available_models(): + controller_address = app_settings.controller_address + ret = await fetch_remote(controller_address + "/refresh_all_workers") + models = await fetch_remote(controller_address + "/list_models", None, "models") + + models.sort() + # TODO: return real model permission details + model_cards = [] + for m in models: + model_cards.append(ModelCard(id=m, root=m, permission=[ModelPermission()])) + return ModelList(data=model_cards) + + +@app.post("/v1/chat/completions", dependencies=[Depends(check_api_key)]) +async def create_chat_completion(request: ChatCompletionRequest): + """Creates a completion for the chat message""" + error_check_ret = await check_model(request) + if error_check_ret is not None: + return error_check_ret + error_check_ret = check_requests(request) + if error_check_ret is not None: + return error_check_ret + + worker_addr = await get_worker_address(request.model) + + gen_params = await get_gen_params( + request.model, + worker_addr, + request.messages, + temperature=request.temperature, + top_p=request.top_p, + top_k=request.top_k, + presence_penalty=request.presence_penalty, + frequency_penalty=request.frequency_penalty, + max_tokens=request.max_tokens, + echo=False, + stop=request.stop, + ) + + max_new_tokens, error_check_ret = await check_length( + request, + gen_params["prompt"], + gen_params["max_new_tokens"], + worker_addr, + ) + + if error_check_ret is not None: + return error_check_ret + + gen_params["max_new_tokens"] = max_new_tokens + + if request.stream: + generator = chat_completion_stream_generator( + request.model, gen_params, request.n, worker_addr + ) + return StreamingResponse(generator, media_type="text/event-stream") + + choices = [] + chat_completions = [] + for i in range(request.n): + content = asyncio.create_task(generate_completion(gen_params, worker_addr)) + chat_completions.append(content) + try: + all_tasks = await asyncio.gather(*chat_completions) + except Exception as e: + return create_error_response(ErrorCode.INTERNAL_ERROR, str(e)) + usage = UsageInfo() + for i, content in enumerate(all_tasks): + if isinstance(content, str): + content = json.loads(content) + + if content["error_code"] != 0: + return create_error_response(content["error_code"], content["text"]) + choices.append( + ChatCompletionResponseChoice( + index=i, + message=ChatMessage(role="assistant", content=content["text"]), + finish_reason=content.get("finish_reason", "stop"), + ) + ) + if "usage" in content: + task_usage = UsageInfo.model_validate(content["usage"]) + for usage_key, usage_value in task_usage.model_dump().items(): + setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) + + return ChatCompletionResponse(model=request.model, choices=choices, usage=usage) + + +async def chat_completion_stream_generator( + model_name: str, gen_params: Dict[str, Any], n: int, worker_addr: str +) -> Generator[str, Any, None]: + """ + Event stream format: + https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format + """ + id = f"chatcmpl-{shortuuid.random()}" + finish_stream_events = [] + for i in range(n): + # First chunk with role + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage(role="assistant"), + finish_reason=None, + ) + chunk = ChatCompletionStreamResponse( + id=id, choices=[choice_data], model=model_name + ) + yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n" + + previous_text = "" + async for content in generate_completion_stream(gen_params, worker_addr): + if content["error_code"] != 0: + yield f"data: {json.dumps(content, ensure_ascii=False)}\n\n" + yield "data: [DONE]\n\n" + return + decoded_unicode = content["text"].replace("\ufffd", "") + delta_text = decoded_unicode[len(previous_text) :] + previous_text = ( + decoded_unicode + if len(decoded_unicode) > len(previous_text) + else previous_text + ) + + if len(delta_text) == 0: + delta_text = None + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage(content=delta_text), + finish_reason=content.get("finish_reason", None), + ) + chunk = ChatCompletionStreamResponse( + id=id, choices=[choice_data], model=model_name + ) + if delta_text is None: + if content.get("finish_reason", None) is not None: + finish_stream_events.append(chunk) + continue + yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n" + # There is not "content" field in the last delta message, so exclude_none to exclude field "content". + for finish_chunk in finish_stream_events: + yield f"data: {finish_chunk.model_dump_json(exclude_none=True)}\n\n" + yield "data: [DONE]\n\n" + + +@app.post("/v1/completions", dependencies=[Depends(check_api_key)]) +async def create_completion(request: CompletionRequest): + error_check_ret = await check_model(request) + if error_check_ret is not None: + return error_check_ret + error_check_ret = check_requests(request) + if error_check_ret is not None: + return error_check_ret + + request.prompt = process_input(request.model, request.prompt) + + worker_addr = await get_worker_address(request.model) + for text in request.prompt: + max_tokens, error_check_ret = await check_length( + request, text, request.max_tokens, worker_addr + ) + if error_check_ret is not None: + return error_check_ret + + if isinstance(max_tokens, int) and max_tokens < request.max_tokens: + request.max_tokens = max_tokens + + if request.stream: + generator = generate_completion_stream_generator( + request, request.n, worker_addr + ) + return StreamingResponse(generator, media_type="text/event-stream") + else: + text_completions = [] + for text in request.prompt: + gen_params = await get_gen_params( + request.model, + worker_addr, + text, + temperature=request.temperature, + top_p=request.top_p, + top_k=request.top_k, + frequency_penalty=request.frequency_penalty, + presence_penalty=request.presence_penalty, + max_tokens=request.max_tokens, + logprobs=request.logprobs, + echo=request.echo, + stop=request.stop, + best_of=request.best_of, + use_beam_search=request.use_beam_search, + ) + for i in range(request.n): + content = asyncio.create_task( + generate_completion(gen_params, worker_addr) + ) + text_completions.append(content) + + try: + all_tasks = await asyncio.gather(*text_completions) + except Exception as e: + return create_error_response(ErrorCode.INTERNAL_ERROR, str(e)) + + choices = [] + usage = UsageInfo() + for i, content in enumerate(all_tasks): + if content["error_code"] != 0: + return create_error_response(content["error_code"], content["text"]) + choices.append( + CompletionResponseChoice( + index=i, + text=content["text"], + logprobs=create_openai_logprobs(content.get("logprobs", None)), + finish_reason=content.get("finish_reason", "stop"), + ) + ) + task_usage = UsageInfo.model_validate(content["usage"]) + for usage_key, usage_value in task_usage.model_dump().items(): + setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) + + return CompletionResponse( + model=request.model, choices=choices, usage=UsageInfo.model_validate(usage) + ) + + +async def generate_completion_stream_generator( + request: CompletionRequest, n: int, worker_addr: str +): + model_name = request.model + id = f"cmpl-{shortuuid.random()}" + finish_stream_events = [] + for text in request.prompt: + for i in range(n): + previous_text = "" + gen_params = await get_gen_params( + request.model, + worker_addr, + text, + temperature=request.temperature, + top_p=request.top_p, + top_k=request.top_k, + presence_penalty=request.presence_penalty, + frequency_penalty=request.frequency_penalty, + max_tokens=request.max_tokens, + logprobs=request.logprobs, + echo=request.echo, + stop=request.stop, + ) + async for content in generate_completion_stream(gen_params, worker_addr): + if content["error_code"] != 0: + yield f"data: {json.dumps(content, ensure_ascii=False)}\n\n" + yield "data: [DONE]\n\n" + return + decoded_unicode = content["text"].replace("\ufffd", "") + delta_text = decoded_unicode[len(previous_text) :] + previous_text = ( + decoded_unicode + if len(decoded_unicode) > len(previous_text) + else previous_text + ) + # todo: index is not apparent + choice_data = CompletionResponseStreamChoice( + index=i, + text=delta_text, + logprobs=create_openai_logprobs(content.get("logprobs", None)), + finish_reason=content.get("finish_reason", None), + ) + chunk = CompletionStreamResponse( + id=id, + object="text_completion", + choices=[choice_data], + model=model_name, + ) + if len(delta_text) == 0: + if content.get("finish_reason", None) is not None: + finish_stream_events.append(chunk) + continue + yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n" + # There is not "content" field in the last delta message, so exclude_none to exclude field "content". + for finish_chunk in finish_stream_events: + yield f"data: {finish_chunk.model_dump_json(exclude_unset=True)}\n\n" + yield "data: [DONE]\n\n" + + +async def generate_completion_stream(payload: Dict[str, Any], worker_addr: str): + controller_address = app_settings.controller_address + async with httpx.AsyncClient() as client: + delimiter = b"\0" + async with client.stream( + "POST", + worker_addr + "/worker_generate_stream", + headers=headers, + json=payload, + timeout=WORKER_API_TIMEOUT, + ) as response: + # content = await response.aread() + buffer = b"" + async for raw_chunk in response.aiter_raw(): + buffer += raw_chunk + while (chunk_end := buffer.find(delimiter)) >= 0: + chunk, buffer = buffer[:chunk_end], buffer[chunk_end + 1 :] + if not chunk: + continue + yield json.loads(chunk.decode()) + + +async def generate_completion(payload: Dict[str, Any], worker_addr: str): + return await fetch_remote(worker_addr + "/worker_generate", payload, "") + + +@app.post("/v1/embeddings", dependencies=[Depends(check_api_key)]) +@app.post("/v1/engines/{model_name}/embeddings", dependencies=[Depends(check_api_key)]) +async def create_embeddings(request: EmbeddingsRequest, model_name: str = None): + """Creates embeddings for the text""" + if request.model is None: + request.model = model_name + error_check_ret = await check_model(request) + if error_check_ret is not None: + return error_check_ret + + request.input = process_input(request.model, request.input) + + data = [] + token_num = 0 + batch_size = WORKER_API_EMBEDDING_BATCH_SIZE + batches = [ + request.input[i : min(i + batch_size, len(request.input))] + for i in range(0, len(request.input), batch_size) + ] + for num_batch, batch in enumerate(batches): + payload = { + "model": request.model, + "input": batch, + "encoding_format": request.encoding_format, + } + embedding = await get_embedding(payload) + if "error_code" in embedding and embedding["error_code"] != 0: + return create_error_response(embedding["error_code"], embedding["text"]) + data += [ + { + "object": "embedding", + "embedding": emb, + "index": num_batch * batch_size + i, + } + for i, emb in enumerate(embedding["embedding"]) + ] + token_num += embedding["token_num"] + return EmbeddingsResponse( + data=data, + model=request.model, + usage=UsageInfo( + prompt_tokens=token_num, + total_tokens=token_num, + completion_tokens=None, + ), + ).model_dump(exclude_none=True) + + +async def get_embedding(payload: Dict[str, Any]): + controller_address = app_settings.controller_address + model_name = payload["model"] + worker_addr = await get_worker_address(model_name) + + embedding = await fetch_remote(worker_addr + "/worker_get_embeddings", payload) + return json.loads(embedding) + + +### GENERAL API - NOT OPENAI COMPATIBLE ### + + +@app.post("/api/v1/token_check") +async def count_tokens(request: APITokenCheckRequest): + """ + Checks the token count for each message in your list + This is not part of the OpenAI API spec. + """ + checkedList = [] + for item in request.prompts: + worker_addr = await get_worker_address(item.model) + + context_len = await fetch_remote( + worker_addr + "/model_details", + {"prompt": item.prompt, "model": item.model}, + "context_length", + ) + + token_num = await fetch_remote( + worker_addr + "/count_token", + {"prompt": item.prompt, "model": item.model}, + "count", + ) + + can_fit = True + if token_num + item.max_tokens > context_len: + can_fit = False + + checkedList.append( + APITokenCheckResponseItem( + fits=can_fit, contextLength=context_len, tokenCount=token_num + ) + ) + + return APITokenCheckResponse(prompts=checkedList) + + +@app.post("/api/v1/chat/completions") +async def create_chat_completion(request: APIChatCompletionRequest): + """Creates a completion for the chat message""" + error_check_ret = await check_model(request) + if error_check_ret is not None: + return error_check_ret + error_check_ret = check_requests(request) + if error_check_ret is not None: + return error_check_ret + + worker_addr = await get_worker_address(request.model) + + gen_params = await get_gen_params( + request.model, + worker_addr, + request.messages, + temperature=request.temperature, + top_p=request.top_p, + top_k=request.top_k, + presence_penalty=request.presence_penalty, + frequency_penalty=request.frequency_penalty, + max_tokens=request.max_tokens, + echo=False, + stop=request.stop, + ) + + if request.repetition_penalty is not None: + gen_params["repetition_penalty"] = request.repetition_penalty + + max_new_tokens, error_check_ret = await check_length( + request, + gen_params["prompt"], + gen_params["max_new_tokens"], + worker_addr, + ) + + if error_check_ret is not None: + return error_check_ret + + gen_params["max_new_tokens"] = max_new_tokens + + if request.stream: + generator = chat_completion_stream_generator( + request.model, gen_params, request.n, worker_addr + ) + return StreamingResponse(generator, media_type="text/event-stream") + + choices = [] + chat_completions = [] + for i in range(request.n): + content = asyncio.create_task(generate_completion(gen_params, worker_addr)) + chat_completions.append(content) + try: + all_tasks = await asyncio.gather(*chat_completions) + except Exception as e: + return create_error_response(ErrorCode.INTERNAL_ERROR, str(e)) + usage = UsageInfo() + for i, content in enumerate(all_tasks): + if content["error_code"] != 0: + return create_error_response(content["error_code"], content["text"]) + choices.append( + ChatCompletionResponseChoice( + index=i, + message=ChatMessage(role="assistant", content=content["text"]), + finish_reason=content.get("finish_reason", "stop"), + ) + ) + task_usage = UsageInfo.model_validate(content["usage"]) + for usage_key, usage_value in task_usage.model_dump().items(): + setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) + + return ChatCompletionResponse(model=request.model, choices=choices, usage=usage) + + +### END GENERAL API - NOT OPENAI COMPATIBLE ### + + +def create_openai_api_server(): + parser = argparse.ArgumentParser( + description="FastChat ChatGPT-Compatible RESTful API server." + ) + parser.add_argument("--host", type=str, default="localhost", help="host name") + parser.add_argument("--port", type=int, default=8000, help="port number") + parser.add_argument( + "--controller-address", type=str, default="http://localhost:21001" + ) + parser.add_argument( + "--allow-credentials", action="store_true", help="allow credentials" + ) + parser.add_argument( + "--allowed-origins", type=json.loads, default=["*"], help="allowed origins" + ) + parser.add_argument( + "--allowed-methods", type=json.loads, default=["*"], help="allowed methods" + ) + parser.add_argument( + "--allowed-headers", type=json.loads, default=["*"], help="allowed headers" + ) + parser.add_argument( + "--api-keys", + type=lambda s: s.split(","), + help="Optional list of comma separated API keys", + ) + parser.add_argument( + "--ssl", + action="store_true", + required=False, + default=False, + help="Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.", + ) + args = parser.parse_args() + + app.add_middleware( + CORSMiddleware, + allow_origins=args.allowed_origins, + allow_credentials=args.allow_credentials, + allow_methods=args.allowed_methods, + allow_headers=args.allowed_headers, + ) + app_settings.controller_address = args.controller_address + app_settings.api_keys = args.api_keys + + logger.info(f"args: {args}") + return args + + +if __name__ == "__main__": + args = create_openai_api_server() + if args.ssl: + uvicorn.run( + app, + host=args.host, + port=args.port, + log_level="info", + ssl_keyfile=os.environ["SSL_KEYFILE"], + ssl_certfile=os.environ["SSL_CERTFILE"], + ) + else: + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/src/serve/register_worker.py b/src/serve/register_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..aa57117b9106730b0731df10dbfb0a2b1bbe381b --- /dev/null +++ b/src/serve/register_worker.py @@ -0,0 +1,28 @@ +""" +Manually register workers. + +Usage: +python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002 +""" + +import argparse + +import requests + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--controller-address", type=str) + parser.add_argument("--worker-name", type=str) + parser.add_argument("--check-heart-beat", action="store_true") + parser.add_argument("--multimodal", action="store_true") + args = parser.parse_args() + + url = args.controller_address + "/register_worker" + data = { + "worker_name": args.worker_name, + "check_heart_beat": args.check_heart_beat, + "worker_status": None, + "multimodal": args.multimodal, + } + r = requests.post(url, json=data) + assert r.status_code == 200 diff --git a/src/serve/remote_logger.py b/src/serve/remote_logger.py new file mode 100644 index 0000000000000000000000000000000000000000..549d49811848d83da27eb117d63b2e9e60e59c3f --- /dev/null +++ b/src/serve/remote_logger.py @@ -0,0 +1,59 @@ +# A JSON logger that sends data to remote endpoint. +# Architecturally, it hosts a background thread that sends logs to a remote endpoint. +import os +import json +import requests +import threading +import queue +import logging + +_global_logger = None + + +def get_remote_logger(): + global _global_logger + if _global_logger is None: + if url := os.environ.get("REMOTE_LOGGER_URL"): + logging.info(f"Remote logger enabled, sending data to {url}") + _global_logger = RemoteLogger(url=url) + else: + _global_logger = EmptyLogger() + return _global_logger + + +class EmptyLogger: + """Dummy logger that does nothing.""" + + def __init__(self): + pass + + def log(self, _data: dict): + pass + + +class RemoteLogger: + """A JSON logger that sends data to remote endpoint.""" + + def __init__(self, url: str): + self.url = url + + self.logs = queue.Queue() + self.thread = threading.Thread(target=self._send_logs, daemon=True) + self.thread.start() + + def log(self, data: dict): + self.logs.put_nowait(data) + + def _send_logs(self): + while True: + data = self.logs.get() + + # process the data by keep only the top level fields, and turn any nested dict into a string + for key, value in data.items(): + if isinstance(value, (dict, list, tuple)): + data[key] = json.dumps(value, ensure_ascii=False) + + try: + requests.post(self.url, json=data) + except Exception: + logging.exception("Failed to send logs to remote endpoint") diff --git a/src/serve/sglang_worker.py b/src/serve/sglang_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..b306684332ce06013ec3951eba042897f2ac861f --- /dev/null +++ b/src/serve/sglang_worker.py @@ -0,0 +1,313 @@ +""" +A model worker that executes the model based on SGLang. + +Usage: +python3 -m fastchat.serve.sglang_worker --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000 --worker-address http://localhost:30000 +""" + +import argparse +import asyncio +import json +import multiprocessing +from typing import List + +from fastapi import FastAPI, Request, BackgroundTasks +from fastapi.responses import StreamingResponse, JSONResponse +import uvicorn +import sglang as sgl +from sglang.srt.hf_transformers_utils import get_tokenizer, get_config +from sglang.srt.utils import load_image, is_multimodal_model + +from fastchat.conversation import IMAGE_PLACEHOLDER_STR +from fastchat.constants import ErrorCode, SERVER_ERROR_MSG +from fastchat.serve.base_model_worker import BaseModelWorker +from fastchat.serve.model_worker import ( + logger, + worker_id, +) +from fastchat.utils import get_context_length, is_partial_stop + +app = FastAPI() + + +@sgl.function +def pipeline(s, prompt, max_tokens): + for p in prompt: + if isinstance(p, str): + s += p + else: + s += sgl.image(p) + s += sgl.gen("response", max_tokens=max_tokens) + + +class SGLWorker(BaseModelWorker): + def __init__( + self, + controller_addr: str, + worker_addr: str, + worker_id: str, + model_path: str, + tokenizer_path: str, + model_names: List[str], + limit_worker_concurrency: int, + no_register: bool, + conv_template: str, + runtime: sgl.Runtime, + trust_remote_code: bool, + ): + super().__init__( + controller_addr, + worker_addr, + worker_id, + model_path, + model_names, + limit_worker_concurrency, + conv_template, + is_multimodal_model(model_path), + ) + + logger.info( + f"Loading the model {self.model_names} on worker {worker_id}, worker type: SGLang worker..." + ) + + self.tokenizer = get_tokenizer(tokenizer_path) + self.context_len = get_context_length( + get_config(model_path, trust_remote_code=trust_remote_code) + ) + + if not no_register: + self.init_heart_beat() + + async def generate_stream(self, params): + self.call_ct += 1 + + prompt = params.pop("prompt") + images = params.get("images", []) + temperature = float(params.get("temperature", 1.0)) + top_p = float(params.get("top_p", 1.0)) + top_k = params.get("top_k", -1.0) + frequency_penalty = float(params.get("frequency_penalty", 0.0)) + presence_penalty = float(params.get("presence_penalty", 0.0)) + max_new_tokens = params.get("max_new_tokens", 256) + stop_str = params.get("stop", None) + stop_token_ids = params.get("stop_token_ids", None) or [] + echo = params.get("echo", True) + + # Handle stop_str + stop = [] + if isinstance(stop_str, str) and stop_str != "": + stop.append(stop_str) + elif isinstance(stop_str, list) and stop_str != []: + stop.extend(stop_str) + + for tid in stop_token_ids: + if tid is not None: + s = self.tokenizer.decode(tid) + if s != "": + stop.append(s) + + # make sampling params for sgl.gen + top_p = max(top_p, 1e-5) + if temperature <= 1e-5: + top_p = 1.0 + + # split prompt by image token + split_prompt = prompt.split(IMAGE_PLACEHOLDER_STR) + if prompt.count(IMAGE_PLACEHOLDER_STR) != len(images): + raise ValueError( + "The number of images passed in does not match the number of tokens in the prompt!" + ) + prompt = [] + for i in range(len(split_prompt)): + prompt.append(split_prompt[i]) + if i < len(images): + prompt[-1] = prompt[-1].strip() + prompt.append(load_image(images[i])) + + state = pipeline.run( + prompt, + max_new_tokens, + stop=stop, + temperature=temperature, + top_p=top_p, + top_k=top_k, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + stream=True, + ) + + entire_output = prompt if echo else "" + async for out, meta_info in state.text_async_iter( + var_name="response", return_meta_data=True + ): + partial_stop = any(is_partial_stop(out, i) for i in stop) + + # prevent yielding partial stop sequence + if partial_stop: + continue + + entire_output += out + prompt_tokens = meta_info["prompt_tokens"] + completion_tokens = meta_info["completion_tokens"] + + ret = { + "text": entire_output, + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + }, + "error_code": 0, + } + yield ret + + async def generate_stream_gate(self, params): + try: + async for ret in self.generate_stream(params): + yield json.dumps(ret).encode() + b"\0" + except (ValueError, RuntimeError) as e: + ret = { + "text": f"{SERVER_ERROR_MSG}\n\n({e})", + "error_code": ErrorCode.INTERNAL_ERROR, + } + yield json.dumps(ret).encode() + b"\0" + + async def generate_gate(self, params): + async for x in self.generate_stream_gate(params): + pass + return json.loads(x[:-1].decode()) + + +def release_worker_semaphore(): + worker.semaphore.release() + + +def acquire_worker_semaphore(): + if worker.semaphore is None: + worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency) + return worker.semaphore.acquire() + + +def create_background_tasks(): + background_tasks = BackgroundTasks() + background_tasks.add_task(release_worker_semaphore) + return background_tasks + + +@app.post("/worker_generate_stream") +async def api_generate_stream(request: Request): + params = await request.json() + await acquire_worker_semaphore() + generator = worker.generate_stream_gate(params) + background_tasks = create_background_tasks() + return StreamingResponse(generator, background=background_tasks) + + +@app.post("/worker_generate") +async def api_generate(request: Request): + params = await request.json() + await acquire_worker_semaphore() + output = await worker.generate_gate(params) + release_worker_semaphore() + return JSONResponse(output) + + +@app.post("/worker_get_status") +async def api_get_status(request: Request): + return worker.get_status() + + +@app.post("/count_token") +async def api_count_token(request: Request): + params = await request.json() + return worker.count_token(params) + + +@app.post("/worker_get_conv_template") +async def api_get_conv(request: Request): + return worker.get_conv_template() + + +@app.post("/model_details") +async def api_model_details(request: Request): + return {"context_length": worker.context_len} + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=21002) + parser.add_argument("--worker-address", type=str, default="http://localhost:21002") + parser.add_argument( + "--controller-address", type=str, default="http://localhost:21001" + ) + parser.add_argument("--model-path", type=str, default="lmsys/vicuna-7b-v1.5") + parser.add_argument("--tokenizer-path", type=str, default="") + parser.add_argument( + "--model-names", + type=lambda s: s.split(","), + help="Optional display comma separated names", + ) + parser.add_argument("--limit-worker-concurrency", type=int, default=1024) + parser.add_argument("--no-register", action="store_true") + parser.add_argument("--num-gpus", type=int, default=1) + parser.add_argument( + "--conv-template", type=str, default=None, help="Conversation prompt template." + ) + parser.add_argument( + "--trust-remote-code", + action="store_false", + default=True, + help="Trust remote code (e.g., from HuggingFace) when" + "downloading the model and tokenizer.", + ) + parser.add_argument( + "--mem-fraction-static", + type=float, + default=0.9, + help="The ratio (between 0 and 1) of GPU memory to" + "reserve for the model weights, activations, and KV cache. Higher" + "values will increase the KV cache size and thus improve the model's" + "throughput. However, if the value is too high, it may cause out-of-" + "memory (OOM) errors.", + ) + parser.add_argument( + "--multimodal", + action="store_true", + required=False, + default=False, + help="Register this worker as serving a multimodal model.", + ) + + args = parser.parse_args() + + args.tp_size = args.num_gpus if args.num_gpus > 1 else 1 + args.tokenizer_path = ( + args.model_path if args.tokenizer_path == "" else args.tokenizer_path + ) + + multiprocessing.set_start_method("spawn", force=True) + runtime = sgl.Runtime( + model_path=args.model_path, + tokenizer_path=args.tokenizer_path, + trust_remote_code=args.trust_remote_code, + mem_fraction_static=args.mem_fraction_static, + tp_size=args.tp_size, + log_level="info", + ) + sgl.set_default_backend(runtime) + + worker = SGLWorker( + args.controller_address, + args.worker_address, + worker_id, + args.model_path, + args.tokenizer_path, + args.model_names, + args.limit_worker_concurrency, + args.no_register, + args.conv_template, + runtime, + args.trust_remote_code, + ) + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/src/serve/shutdown_serve.py b/src/serve/shutdown_serve.py new file mode 100644 index 0000000000000000000000000000000000000000..95e2b704f0b65584c5be15ce14b40bc150bd6009 --- /dev/null +++ b/src/serve/shutdown_serve.py @@ -0,0 +1,24 @@ +""" +Usage: +python shutdown_serve.py --down all +options: "all","controller","model_worker","openai_api_server", `all` means to stop all related servers +""" + +import argparse +import os +import subprocess + +parser = argparse.ArgumentParser() +parser.add_argument( + "--down", choices=["all", "controller", "model_worker", "openai_api_server"] +) +args = parser.parse_args() +base_shell = "ps -eo user,pid,cmd|grep fastchat.serve{}|grep -v grep|awk '{{print $2}}'|xargs kill -9" +if args.down == "all": + shell_script = base_shell.format("") +else: + serve = f".{args.down}" + shell_script = base_shell.format(serve) +print(f"execute shell cmd: {shell_script}") +subprocess.run(shell_script, shell=True, check=True) +print(f"{args.down} has been shutdown!") diff --git a/src/serve/test_message.py b/src/serve/test_message.py new file mode 100644 index 0000000000000000000000000000000000000000..203a44901c10c5526f198c8e9dbb4e32d15ed7aa --- /dev/null +++ b/src/serve/test_message.py @@ -0,0 +1,81 @@ +"""Send a test message.""" +import argparse +import json + +import requests + +from fastchat.model.model_adapter import get_conversation_template + + +def main(): + model_name = args.model_name + + if args.worker_address: + worker_addr = args.worker_address + else: + controller_addr = args.controller_address + ret = requests.post(controller_addr + "/refresh_all_workers") + ret = requests.post(controller_addr + "/list_models") + models = ret.json()["models"] + models.sort() + print(f"Models: {models}") + + ret = requests.post( + controller_addr + "/get_worker_address", json={"model": model_name} + ) + worker_addr = ret.json()["address"] + print(f"worker_addr: {worker_addr}") + + if worker_addr == "": + print(f"No available workers for {model_name}") + return + + conv = get_conversation_template(model_name) + conv.append_message(conv.roles[0], args.message) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + headers = {"User-Agent": "FastChat Client"} + gen_params = { + "model": model_name, + "prompt": prompt, + "temperature": args.temperature, + "max_new_tokens": args.max_new_tokens, + "stop": conv.stop_str, + "stop_token_ids": conv.stop_token_ids, + "echo": False, + } + response = requests.post( + worker_addr + "/worker_generate_stream", + headers=headers, + json=gen_params, + stream=True, + ) + + print(f"{conv.roles[0]}: {args.message}") + print(f"{conv.roles[1]}: ", end="") + prev = 0 + for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): + if chunk: + data = json.loads(chunk.decode()) + output = data["text"].strip() + print(output[prev:], end="", flush=True) + prev = len(output) + print("") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--controller-address", type=str, default="http://localhost:21001" + ) + parser.add_argument("--worker-address", type=str) + parser.add_argument("--model-name", type=str, required=True) + parser.add_argument("--temperature", type=float, default=0.0) + parser.add_argument("--max-new-tokens", type=int, default=32) + parser.add_argument( + "--message", type=str, default="Tell me a story with more than 1000 words." + ) + args = parser.parse_args() + + main() diff --git a/src/serve/test_throughput.py b/src/serve/test_throughput.py new file mode 100644 index 0000000000000000000000000000000000000000..3796a6e2a7cb53dc6921674fc4c488246e0b93c7 --- /dev/null +++ b/src/serve/test_throughput.py @@ -0,0 +1,115 @@ +"""Benchmarking script to test the throughput of serving workers.""" +import argparse +import json + +import requests +import threading +import time + +from fastchat.conversation import get_conv_template + + +def main(): + if args.worker_address: + worker_addr = args.worker_address + else: + controller_addr = args.controller_address + ret = requests.post(controller_addr + "/refresh_all_workers") + ret = requests.post(controller_addr + "/list_models") + models = ret.json()["models"] + models.sort() + print(f"Models: {models}") + + ret = requests.post( + controller_addr + "/get_worker_address", json={"model": args.model_name} + ) + worker_addr = ret.json()["address"] + print(f"worker_addr: {worker_addr}") + + if worker_addr == "": + return + + conv = get_conv_template("vicuna_v1.1") + conv.append_message(conv.roles[0], "Tell me a story with more than 1000 words") + prompt_template = conv.get_prompt() + prompts = [prompt_template for _ in range(args.n_thread)] + + headers = {"User-Agent": "fastchat Client"} + ploads = [ + { + "model": args.model_name, + "prompt": prompts[i], + "max_new_tokens": args.max_new_tokens, + "temperature": 0.0, + # "stop": conv.sep, + } + for i in range(len(prompts)) + ] + + def send_request(results, i): + if args.test_dispatch: + ret = requests.post( + controller_addr + "/get_worker_address", json={"model": args.model_name} + ) + thread_worker_addr = ret.json()["address"] + else: + thread_worker_addr = worker_addr + print(f"thread {i} goes to {thread_worker_addr}") + response = requests.post( + thread_worker_addr + "/worker_generate_stream", + headers=headers, + json=ploads[i], + stream=False, + ) + k = list( + response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0") + ) + # print(k) + response_new_words = json.loads(k[-2].decode("utf-8"))["text"] + error_code = json.loads(k[-2].decode("utf-8"))["error_code"] + # print(f"=== Thread {i} ===, words: {1}, error code: {error_code}") + results[i] = len(response_new_words.split(" ")) - len(prompts[i].split(" ")) + + # use N threads to prompt the backend + tik = time.time() + threads = [] + results = [None] * args.n_thread + for i in range(args.n_thread): + t = threading.Thread(target=send_request, args=(results, i)) + t.start() + # time.sleep(0.5) + threads.append(t) + + for t in threads: + t.join() + + print(f"Time (POST): {time.time() - tik} s") + # n_words = 0 + # for i, response in enumerate(results): + # # print(prompt[i].replace(conv.sep, "\n"), end="") + # # make sure the streaming finishes at EOS or stopping criteria + # k = list(response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0")) + # response_new_words = json.loads(k[-2].decode("utf-8"))["text"] + # # print(response_new_words) + # n_words += len(response_new_words.split(" ")) - len(prompts[i].split(" ")) + n_words = sum(results) + time_seconds = time.time() - tik + print( + f"Time (Completion): {time_seconds}, n threads: {args.n_thread}, " + f"throughput: {n_words / time_seconds} words/s." + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--controller-address", type=str, default="http://localhost:21001" + ) + parser.add_argument("--worker-address", type=str) + parser.add_argument("--model-name", type=str, default="vicuna") + parser.add_argument("--max-new-tokens", type=int, default=2048) + parser.add_argument("--n-thread", type=int, default=8) + parser.add_argument("--test-dispatch", action="store_true") + args = parser.parse_args() + + main() diff --git a/src/serve/vision/create_vqa_examples_dir.py b/src/serve/vision/create_vqa_examples_dir.py new file mode 100644 index 0000000000000000000000000000000000000000..2f2630bd985cba2c02f08ec59d1fb237d5d47545 --- /dev/null +++ b/src/serve/vision/create_vqa_examples_dir.py @@ -0,0 +1,127 @@ +import datasets +from datasets import load_dataset +from PIL import Image +from pathlib import Path +import pandas as pd +import os +import json +import tqdm +import argparse +import shutil +import numpy as np + +np.random.seed(0) + +""" +Creates a directory with images and JSON files for VQA examples. Final json is located in metadata_sampled.json +""" + + +def download_images_and_create_json( + dataset_info, cache_dir="~/vqa_examples_cache", base_dir="./vqa_examples" +): + for dataset_name, info in dataset_info.items(): + dataset_cache_dir = os.path.join(cache_dir, dataset_name) + os.makedirs(dataset_cache_dir, exist_ok=True) + + if info["subset"]: + dataset = load_dataset( + info["path"], + info["subset"], + cache_dir=dataset_cache_dir, + split=info["split"], + ) + else: + dataset = load_dataset( + info["path"], cache_dir=dataset_cache_dir, split=info["split"] + ) + dataset_dir = os.path.join(base_dir, dataset_name) + os.makedirs(dataset_dir, exist_ok=True) + + json_data = [] + for i, item in enumerate(tqdm.tqdm(dataset)): + id_key = i if info["id_key"] == "index" else item[info["id_key"]] + image_pil = item[info["image_key"]].convert("RGB") + image_path = os.path.join(dataset_dir, f"{id_key}.jpg") + image_pil.save(image_path) + json_entry = { + "dataset": dataset_name, + "question": item[info["question_key"]], + "path": image_path, + } + json_data.append(json_entry) + + with open(os.path.join(dataset_dir, "data.json"), "w") as json_file: + json.dump(json_data, json_file, indent=4) + # Delete the cache directory for the dataset + shutil.rmtree(dataset_cache_dir, ignore_errors=True) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data_dir", type=str, default="~/.cache") + parser.add_argument("--output_dir", type=str, default="./vqa_examples") + args = parser.parse_args() + + datasets_info = { + "DocVQA": { + "path": "lmms-lab/DocVQA", + "image_key": "image", + "question_key": "question", + "id_key": "questionId", + "subset": "DocVQA", + "split": "test", + }, + "ChartQA": { + "path": "HuggingFaceM4/ChartQA", + "image_key": "image", + "question_key": "query", + "id_key": "index", + "subset": False, + "split": "test", + }, + "realworldqa": { + "path": "visheratin/realworldqa", + "image_key": "image", + "question_key": "question", + "id_key": "index", + "subset": False, + "split": "test", + }, + "NewYorker": { + "path": "jmhessel/newyorker_caption_contest", + "image_key": "image", + "question_key": "questions", + "id_key": "index", + "subset": "explanation", + "split": "train", + }, + "WikiArt": { + "path": "huggan/wikiart", + "image_key": "image", + "question_key": "artist", + "id_key": "index", + "subset": False, + "split": "train", + }, + "TextVQA": { + "path": "facebook/textvqa", + "image_key": "image", + "question_key": "question", + "id_key": "question_id", + "subset": False, + "split": "train", + }, + } + + download_images_and_create_json( + datasets_info, cache_dir=args.data_dir, base_dir=args.output_dir + ) + dataset_json = [] + for dataset_name in datasets_info.keys(): + with open(f"{args.output_dir}/{dataset_name}/data.json") as f: + data = json.load(f) + dataset_json.extend(np.random.choice(data, 500)) + + with open(f"{args.output_dir}/metadata_sampled.json", "w") as f: + json.dump(dataset_json, f, indent=4) diff --git a/src/serve/vllm_worker.py b/src/serve/vllm_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..0af680bb5f1cabcf77b9a67bd9d542b53550f89f --- /dev/null +++ b/src/serve/vllm_worker.py @@ -0,0 +1,302 @@ +""" +A model worker that executes the model based on vLLM. + +See documentations at docs/vllm_integration.md +""" + +import argparse +import asyncio +import json +from typing import List + +from fastapi import FastAPI, Request, BackgroundTasks +from fastapi.responses import StreamingResponse, JSONResponse +import uvicorn +from vllm import AsyncLLMEngine +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.sampling_params import SamplingParams +from vllm.utils import random_uuid + +from fastchat.serve.base_model_worker import BaseModelWorker +from fastchat.serve.model_worker import ( + logger, + worker_id, +) +from fastchat.utils import get_context_length, is_partial_stop + + +app = FastAPI() + + +class VLLMWorker(BaseModelWorker): + def __init__( + self, + controller_addr: str, + worker_addr: str, + worker_id: str, + model_path: str, + model_names: List[str], + limit_worker_concurrency: int, + no_register: bool, + llm_engine: AsyncLLMEngine, + conv_template: str, + ): + super().__init__( + controller_addr, + worker_addr, + worker_id, + model_path, + model_names, + limit_worker_concurrency, + conv_template, + ) + + logger.info( + f"Loading the model {self.model_names} on worker {worker_id}, worker type: vLLM worker..." + ) + self.tokenizer = llm_engine.engine.tokenizer + # This is to support vllm >= 0.2.7 where TokenizerGroup was introduced + # and llm_engine.engine.tokenizer was no longer a raw tokenizer + if hasattr(self.tokenizer, "tokenizer"): + self.tokenizer = llm_engine.engine.tokenizer.tokenizer + self.context_len = get_context_length(llm_engine.engine.model_config.hf_config) + + if not no_register: + self.init_heart_beat() + + async def generate_stream(self, params): + self.call_ct += 1 + + context = params.pop("prompt") + request_id = params.pop("request_id") + temperature = float(params.get("temperature", 1.0)) + top_p = float(params.get("top_p", 1.0)) + top_k = params.get("top_k", -1.0) + presence_penalty = float(params.get("presence_penalty", 0.0)) + frequency_penalty = float(params.get("frequency_penalty", 0.0)) + max_new_tokens = params.get("max_new_tokens", 256) + stop_str = params.get("stop", None) + stop_token_ids = params.get("stop_token_ids", None) or [] + if self.tokenizer.eos_token_id is not None: + stop_token_ids.append(self.tokenizer.eos_token_id) + echo = params.get("echo", True) + use_beam_search = params.get("use_beam_search", False) + best_of = params.get("best_of", None) + + request = params.get("request", None) + + # Handle stop_str + stop = set() + if isinstance(stop_str, str) and stop_str != "": + stop.add(stop_str) + elif isinstance(stop_str, list) and stop_str != []: + stop.update(stop_str) + + for tid in stop_token_ids: + if tid is not None: + s = self.tokenizer.decode(tid) + if s != "": + stop.add(s) + + # make sampling params in vllm + top_p = max(top_p, 1e-5) + if temperature <= 1e-5: + top_p = 1.0 + + sampling_params = SamplingParams( + n=1, + temperature=temperature, + top_p=top_p, + use_beam_search=use_beam_search, + stop=list(stop), + stop_token_ids=stop_token_ids, + max_tokens=max_new_tokens, + top_k=top_k, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + best_of=best_of, + ) + results_generator = engine.generate(context, sampling_params, request_id) + + async for request_output in results_generator: + prompt = request_output.prompt + if echo: + text_outputs = [ + prompt + output.text for output in request_output.outputs + ] + else: + text_outputs = [output.text for output in request_output.outputs] + text_outputs = " ".join(text_outputs) + + partial_stop = any(is_partial_stop(text_outputs, i) for i in stop) + # prevent yielding partial stop sequence + if partial_stop: + continue + + aborted = False + if request and await request.is_disconnected(): + await engine.abort(request_id) + request_output.finished = True + aborted = True + for output in request_output.outputs: + output.finish_reason = "abort" + + prompt_tokens = len(request_output.prompt_token_ids) + completion_tokens = sum( + len(output.token_ids) for output in request_output.outputs + ) + ret = { + "text": text_outputs, + "error_code": 0, + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + }, + "cumulative_logprob": [ + output.cumulative_logprob for output in request_output.outputs + ], + "finish_reason": request_output.outputs[0].finish_reason + if len(request_output.outputs) == 1 + else [output.finish_reason for output in request_output.outputs], + } + # Emit twice here to ensure a 'finish_reason' with empty content in the OpenAI API response. + # This aligns with the behavior of model_worker. + if request_output.finished: + yield (json.dumps({**ret, **{"finish_reason": None}}) + "\0").encode() + yield (json.dumps(ret) + "\0").encode() + + if aborted: + break + + async def generate(self, params): + async for x in self.generate_stream(params): + pass + return json.loads(x[:-1].decode()) + + +def release_worker_semaphore(): + worker.semaphore.release() + + +def acquire_worker_semaphore(): + if worker.semaphore is None: + worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency) + return worker.semaphore.acquire() + + +def create_background_tasks(request_id): + async def abort_request() -> None: + await engine.abort(request_id) + + background_tasks = BackgroundTasks() + background_tasks.add_task(release_worker_semaphore) + background_tasks.add_task(abort_request) + return background_tasks + + +@app.post("/worker_generate_stream") +async def api_generate_stream(request: Request): + params = await request.json() + await acquire_worker_semaphore() + request_id = random_uuid() + params["request_id"] = request_id + params["request"] = request + generator = worker.generate_stream(params) + background_tasks = create_background_tasks(request_id) + return StreamingResponse(generator, background=background_tasks) + + +@app.post("/worker_generate") +async def api_generate(request: Request): + params = await request.json() + await acquire_worker_semaphore() + request_id = random_uuid() + params["request_id"] = request_id + params["request"] = request + output = await worker.generate(params) + release_worker_semaphore() + await engine.abort(request_id) + return JSONResponse(output) + + +@app.post("/worker_get_status") +async def api_get_status(request: Request): + return worker.get_status() + + +@app.post("/count_token") +async def api_count_token(request: Request): + params = await request.json() + return worker.count_token(params) + + +@app.post("/worker_get_conv_template") +async def api_get_conv(request: Request): + return worker.get_conv_template() + + +@app.post("/model_details") +async def api_model_details(request: Request): + return {"context_length": worker.context_len} + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=21002) + parser.add_argument("--worker-address", type=str, default="http://localhost:21002") + parser.add_argument( + "--controller-address", type=str, default="http://localhost:21001" + ) + parser.add_argument("--model-path", type=str, default="lmsys/vicuna-7b-v1.5") + parser.add_argument( + "--model-names", + type=lambda s: s.split(","), + help="Optional display comma separated names", + ) + parser.add_argument("--limit-worker-concurrency", type=int, default=1024) + parser.add_argument("--no-register", action="store_true") + parser.add_argument("--num-gpus", type=int, default=1) + parser.add_argument( + "--conv-template", type=str, default=None, help="Conversation prompt template." + ) + parser.add_argument( + "--trust_remote_code", + action="store_false", + default=True, + help="Trust remote code (e.g., from HuggingFace) when" + "downloading the model and tokenizer.", + ) + parser.add_argument( + "--gpu_memory_utilization", + type=float, + default=0.9, + help="The ratio (between 0 and 1) of GPU memory to" + "reserve for the model weights, activations, and KV cache. Higher" + "values will increase the KV cache size and thus improve the model's" + "throughput. However, if the value is too high, it may cause out-of-" + "memory (OOM) errors.", + ) + + parser = AsyncEngineArgs.add_cli_args(parser) + args = parser.parse_args() + if args.model_path: + args.model = args.model_path + if args.num_gpus > 1: + args.tensor_parallel_size = args.num_gpus + + engine_args = AsyncEngineArgs.from_cli_args(args) + engine = AsyncLLMEngine.from_engine_args(engine_args) + worker = VLLMWorker( + args.controller_address, + args.worker_address, + worker_id, + args.model_path, + args.model_names, + args.limit_worker_concurrency, + args.no_register, + engine, + args.conv_template, + ) + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/src/utils.py b/src/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..58a10efa19ae13cf348735b6da1dfd0cb5165ee9 --- /dev/null +++ b/src/utils.py @@ -0,0 +1,502 @@ +""" +Common utilities. +""" +from asyncio import AbstractEventLoop +from io import BytesIO +import base64 +import json +import logging +import logging.handlers +import os +import platform +import sys +import time +from typing import AsyncGenerator, Generator +import warnings + +import requests + +from src.constants import LOGDIR + + +handler = None +visited_loggers = set() + + +def build_logger(logger_name, logger_filename): + global handler + + formatter = logging.Formatter( + fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + + # Set the format of root handlers + if not logging.getLogger().handlers: + if sys.version_info[1] >= 9: + # This is for windows + logging.basicConfig(level=logging.INFO, encoding="utf-8") + else: + if platform.system() == "Windows": + warnings.warn( + "If you are running on Windows, " + "we recommend you use Python >= 3.9 for UTF-8 encoding." + ) + logging.basicConfig(level=logging.INFO) + logging.getLogger().handlers[0].setFormatter(formatter) + + # Redirect stdout and stderr to loggers + stdout_logger = logging.getLogger("stdout") + stdout_logger.setLevel(logging.INFO) + sl = StreamToLogger(stdout_logger, logging.INFO) + sys.stdout = sl + + stderr_logger = logging.getLogger("stderr") + stderr_logger.setLevel(logging.ERROR) + sl = StreamToLogger(stderr_logger, logging.ERROR) + sys.stderr = sl + + # Get logger + logger = logging.getLogger(logger_name) + logger.setLevel(logging.INFO) + + # Avoid httpx flooding POST logs + logging.getLogger("httpx").setLevel(logging.WARNING) + + # if LOGDIR is empty, then don't try output log to local file + if LOGDIR != "": + os.makedirs(LOGDIR, exist_ok=True) + filename = os.path.join(LOGDIR, logger_filename) + handler = logging.handlers.TimedRotatingFileHandler( + filename, when="D", utc=True, encoding="utf-8" + ) + handler.setFormatter(formatter) + + for l in [stdout_logger, stderr_logger, logger]: + if l in visited_loggers: + continue + visited_loggers.add(l) + l.addHandler(handler) + + return logger + + +class StreamToLogger(object): + """ + Fake file-like stream object that redirects writes to a logger instance. + """ + + def __init__(self, logger, log_level=logging.INFO): + self.terminal = sys.stdout + self.logger = logger + self.log_level = log_level + self.linebuf = "" + + def __getattr__(self, attr): + return getattr(self.terminal, attr) + + def write(self, buf): + temp_linebuf = self.linebuf + buf + self.linebuf = "" + for line in temp_linebuf.splitlines(True): + # From the io.TextIOWrapper docs: + # On output, if newline is None, any '\n' characters written + # are translated to the system default line separator. + # By default sys.stdout.write() expects '\n' newlines and then + # translates them so this is still cross platform. + if line[-1] == "\n": + encoded_message = line.encode("utf-8", "ignore").decode("utf-8") + self.logger.log(self.log_level, encoded_message.rstrip()) + else: + self.linebuf += line + + def flush(self): + if self.linebuf != "": + encoded_message = self.linebuf.encode("utf-8", "ignore").decode("utf-8") + self.logger.log(self.log_level, encoded_message.rstrip()) + self.linebuf = "" + + +def disable_torch_init(): + """ + Disable the redundant torch default initialization to accelerate model creation. + """ + import torch + + setattr(torch.nn.Linear, "reset_parameters", lambda self: None) + setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) + + +def get_gpu_memory(max_gpus=None): + """Get available memory for each GPU.""" + import torch + + gpu_memory = [] + num_gpus = ( + torch.cuda.device_count() + if max_gpus is None + else min(max_gpus, torch.cuda.device_count()) + ) + + for gpu_id in range(num_gpus): + with torch.cuda.device(gpu_id): + device = torch.cuda.current_device() + gpu_properties = torch.cuda.get_device_properties(device) + total_memory = gpu_properties.total_memory / (1024**3) + allocated_memory = torch.cuda.memory_allocated() / (1024**3) + available_memory = total_memory - allocated_memory + gpu_memory.append(available_memory) + return gpu_memory + + +def oai_moderation(text, custom_thresholds=None): + """ + Check whether the text violates OpenAI moderation API. + """ + import openai + + client = openai.OpenAI(api_key=os.environ["OPENAI_API_KEY"]) + + # default to true to be conservative + flagged = True + MAX_RETRY = 3 + for _ in range(MAX_RETRY): + try: + res = client.moderations.create(input=text) + flagged = res.results[0].flagged + if custom_thresholds is not None: + for category, threshold in custom_thresholds.items(): + if getattr(res.results[0].category_scores, category) > threshold: + flagged = True + break + except (openai.OpenAIError, KeyError, IndexError) as e: + print(f"MODERATION ERROR: {e}\nInput: {text}") + return flagged + + +def moderation_filter(text, model_list, do_moderation=False): + # Apply moderation for below models + MODEL_KEYWORDS = ["claude", "gpt", "bard", "mistral-large", "command-r", "dbrx"] + + custom_thresholds = {"sexual": 0.3} + # set a stricter threshold for claude + for model in model_list: + if "claude" in model: + custom_thresholds = {"sexual": 0.2} + + for keyword in MODEL_KEYWORDS: + for model in model_list: + if keyword in model: + do_moderation = True + break + + if do_moderation: + return oai_moderation(text, custom_thresholds) + return False + + +def clean_flant5_ckpt(ckpt_path): + """ + Flan-t5 trained with HF+FSDP saves corrupted weights for shared embeddings, + Use this function to make sure it can be correctly loaded. + """ + import torch + + index_file = os.path.join(ckpt_path, "pytorch_model.bin.index.json") + index_json = json.load(open(index_file, "r")) + + weightmap = index_json["weight_map"] + + share_weight_file = weightmap["shared.weight"] + share_weight = torch.load(os.path.join(ckpt_path, share_weight_file))[ + "shared.weight" + ] + + for weight_name in ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]: + weight_file = weightmap[weight_name] + weight = torch.load(os.path.join(ckpt_path, weight_file)) + weight[weight_name] = share_weight + torch.save(weight, os.path.join(ckpt_path, weight_file)) + + +def pretty_print_semaphore(semaphore): + """Print a semaphore in better format.""" + if semaphore is None: + return "None" + return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" + + +"""A javascript function to get url parameters for the gradio web server.""" +get_window_url_params_js = """ +function() { + const params = new URLSearchParams(window.location.search); + url_params = Object.fromEntries(params); + console.log("url_params", url_params); + return url_params; + } +""" + + +get_window_url_params_with_tos_js = """ +function() { + const params = new URLSearchParams(window.location.search); + url_params = Object.fromEntries(params); + console.log("url_params", url_params); + + msg = "Users of this website are required to agree to the following terms:\\n\\nThe service is a research preview. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes.\\nPlease do not upload any private information.\\nThe service collects user dialogue data, including both text and images, and reserves the right to distribute it under a Creative Commons Attribution (CC-BY) or a similar license." + alert(msg); + return url_params; + } +""" + + +def iter_over_async( + async_gen: AsyncGenerator, event_loop: AbstractEventLoop +) -> Generator: + """ + Convert async generator to sync generator + + :param async_gen: the AsyncGenerator to convert + :param event_loop: the event loop to run on + :returns: Sync generator + """ + ait = async_gen.__aiter__() + + async def get_next(): + try: + obj = await ait.__anext__() + return False, obj + except StopAsyncIteration: + return True, None + + while True: + done, obj = event_loop.run_until_complete(get_next()) + if done: + break + yield obj + + +def detect_language(text: str) -> str: + """Detect the langauge of a string.""" + import polyglot # pip3 install polyglot pyicu pycld2 + from polyglot.detect import Detector + from polyglot.detect.base import logger as polyglot_logger + import pycld2 + + polyglot_logger.setLevel("ERROR") + + try: + lang_code = Detector(text).language.name + except (pycld2.error, polyglot.detect.base.UnknownLanguage): + lang_code = "unknown" + return lang_code + + +def parse_gradio_auth_creds(filename: str): + """Parse a username:password file for gradio authorization.""" + gradio_auth_creds = [] + with open(filename, "r", encoding="utf8") as file: + for line in file.readlines(): + gradio_auth_creds += [x.strip() for x in line.split(",") if x.strip()] + if gradio_auth_creds: + auth = [tuple(cred.split(":")) for cred in gradio_auth_creds] + else: + auth = None + return auth + + +def is_partial_stop(output: str, stop_str: str): + """Check whether the output contains a partial stop str.""" + for i in range(0, min(len(output), len(stop_str))): + if stop_str.startswith(output[-i:]): + return True + return False + + +def run_cmd(cmd: str): + """Run a bash command.""" + print(cmd) + return os.system(cmd) + + +def is_sentence_complete(output: str): + """Check whether the output is a complete sentence.""" + end_symbols = (".", "?", "!", "...", "。", "?", "!", "…", '"', "'", "”") + return output.endswith(end_symbols) + + +# Models don't use the same configuration key for determining the maximum +# sequence length. Store them here so we can sanely check them. +# NOTE: The ordering here is important. Some models have two of these and we +# have a preference for which value gets used. +SEQUENCE_LENGTH_KEYS = [ + "max_position_embeddings", + "max_sequence_length", + "seq_length", + "max_seq_len", + "model_max_length", +] + + +def get_context_length(config): + """Get the context length of a model from a huggingface model config.""" + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling: + rope_scaling_factor = config.rope_scaling["factor"] + else: + rope_scaling_factor = 1 + + for key in SEQUENCE_LENGTH_KEYS: + val = getattr(config, key, None) + if val is not None: + return int(rope_scaling_factor * val) + return 2048 + + +def str_to_torch_dtype(dtype: str): + import torch + + if dtype is None: + return None + elif dtype == "float32": + return torch.float32 + elif dtype == "float16": + return torch.float16 + elif dtype == "bfloat16": + return torch.bfloat16 + else: + raise ValueError(f"Unrecognized dtype: {dtype}") + + +def load_image(image_file): + from PIL import Image + import requests + + image = None + + if image_file.startswith("http://") or image_file.startswith("https://"): + timeout = int(os.getenv("REQUEST_TIMEOUT", "3")) + response = requests.get(image_file, timeout=timeout) + image = Image.open(BytesIO(response.content)) + elif image_file.lower().endswith(("png", "jpg", "jpeg", "webp", "gif")): + image = Image.open(image_file) + elif image_file.startswith("data:"): + image_file = image_file.split(",")[1] + image = Image.open(BytesIO(base64.b64decode(image_file))) + else: + image = Image.open(BytesIO(base64.b64decode(image_file))) + + return image + + +def upload_image_file_to_gcs(image, filename): + from google.cloud import storage + import io + + storage_client = storage.Client() + # upload file to GCS + bucket = storage_client.get_bucket("arena_user_content") + + blob = bucket.blob(f"{filename}") + if not blob.exists(): + buffer = io.BytesIO() + image.save(buffer, format="PNG") + buffer.seek(0) + blob.upload_from_file(buffer, content_type="image/png") + + return blob.public_url + + +def get_image_file_from_gcs(filename): + from google.cloud import storage + + storage_client = storage.Client() + bucket = storage_client.get_bucket("arena_user_content") + blob = bucket.blob(f"{filename}") + contents = blob.download_as_bytes() + + return contents + + +def resize_image_and_return_image_in_bytes(image, max_image_size_mb): + from PIL import Image + import math + + image_bytes = BytesIO() + if not max_image_size_mb is None: + image.save(image_bytes, format="PNG") + target_size_bytes = max_image_size_mb * 1024 * 1024 + current_size_bytes = image_bytes.tell() + + if current_size_bytes > target_size_bytes: + resize_factor = (target_size_bytes / current_size_bytes) ** 0.5 + new_width = math.floor(image.width * resize_factor) + new_height = math.floor(image.height * resize_factor) + resized_image = image.resize((new_width, new_height)) + + image_bytes = BytesIO() + resized_image.save(image_bytes, format="PNG") + + image_bytes.seek(0) + else: + image.save(image_bytes, format="PNG") + + return image_bytes + + +def convert_image_to_byte_array(image, max_image_size_mb): + from PIL import Image + + if type(image) == str: + pil_image = Image.open(image).convert("RGB") + image_bytes = resize_image_and_return_image_in_bytes( + pil_image, max_image_size_mb + ) + else: + image_bytes = resize_image_and_return_image_in_bytes(image, max_image_size_mb) + + image_byte_array = image_bytes.getvalue() + return image_byte_array + + +def image_moderation_request(image_bytes, endpoint, api_key): + headers = {"Content-Type": "image/jpeg", "Ocp-Apim-Subscription-Key": api_key} + + MAX_RETRIES = 3 + for _ in range(MAX_RETRIES): + response = requests.post(endpoint, headers=headers, data=image_bytes).json() + try: + if response["Status"]["Code"] == 3000: + break + except: + time.sleep(0.5) + return response + + +def image_moderation_provider(image, api_type): + if api_type == "nsfw": + endpoint = os.environ["AZURE_IMG_MODERATION_ENDPOINT"] + api_key = os.environ["AZURE_IMG_MODERATION_API_KEY"] + response = image_moderation_request(image, endpoint, api_key) + return response["IsImageAdultClassified"] + elif api_type == "csam": + endpoint = ( + "https://api.microsoftmoderator.com/photodna/v1.0/Match?enhance=false" + ) + api_key = os.environ["PHOTODNA_API_KEY"] + response = image_moderation_request(image, endpoint, api_key) + return response["IsMatch"] + + +def image_moderation_filter(image): + print(f"moderating image: {image}") + MAX_NSFW_ENDPOINT_IMAGE_SIZE_IN_MB = 4 + image_bytes = convert_image_to_byte_array(image, MAX_NSFW_ENDPOINT_IMAGE_SIZE_IN_MB) + + nsfw_flagged = image_moderation_provider(image_bytes, "nsfw") + csam_flagged = False + + if nsfw_flagged: + csam_flagged = image_moderation_provider(image_bytes, "csam") + + return nsfw_flagged, csam_flagged