Spaces:
Running
Running
File size: 13,530 Bytes
1cc2abd 2c9ce84 1cc2abd 1aee0b0 1cc2abd 2c9ce84 1cc2abd 16be00e 1cc2abd 8d79148 1cc2abd a355fc1 1cc2abd e2f39e2 1cc2abd 1aee0b0 1cc2abd 01bb522 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 |
import datetime
import os
from openai import OpenAI
import streamlit as st
import threading
from tenacity import retry, wait_random_exponential, stop_after_attempt
from itertools import tee
BASE_URL = os.environ.get("BASE_URL")
DATABRICKS_API_TOKEN = os.environ.get("DATABRICKS_API_TOKEN")
SAFETY_FILTER_ENV = os.environ.get("SAFETY_FILTER")
QUEUE_SIZE_ENV = os.environ.get("QUEUE_SIZE")
MAX_CHAT_TURNS_ENV = os.environ.get("MAX_CHAT_TURNS")
MAX_TOKENS_ENV = os.environ.get("MAX_TOKENS")
RETRY_COUNT_ENV = os.environ.get("RETRY_COUNT")
TOKEN_CHUNK_SIZE_ENV = os.environ.get("TOKEN_CHUNK_SIZE")
MODEL_ID_ENV = os.environ.get("MODEL_ID")
if BASE_URL is None:
raise ValueError("BASE_URL environment variable must be set")
if DATABRICKS_API_TOKEN is None:
raise ValueError("DATABRICKS_API_TOKEN environment variable must be set")
st.set_page_config(layout="wide")
# by default safety filter is not configured
SAFETY_FILTER = False
if SAFETY_FILTER_ENV is not None:
SAFETY_FILTER = True
QUEUE_SIZE = 1
if QUEUE_SIZE_ENV is not None:
QUEUE_SIZE = int(QUEUE_SIZE_ENV)
MAX_CHAT_TURNS = 5
if MAX_CHAT_TURNS_ENV is not None:
MAX_CHAT_TURNS = int(MAX_CHAT_TURNS_ENV)
RETRY_COUNT = 3
if RETRY_COUNT_ENV is not None:
RETRY_COUNT = int(RETRY_COUNT_ENV)
MAX_TOKENS = 512
if MAX_TOKENS_ENV is not None:
MAX_TOKENS = int(MAX_TOKENS_ENV)
MODEL_ID = "databricks-dbrx-instruct"
if MODEL_ID_ENV is not None:
MODEL_ID = MODEL_ID_ENV
# To prevent streaming to fast, chunk the output into TOKEN_CHUNK_SIZE chunks
TOKEN_CHUNK_SIZE = 1
if TOKEN_CHUNK_SIZE_ENV is not None:
TOKEN_CHUNK_SIZE = int(TOKEN_CHUNK_SIZE_ENV)
MODEL_AVATAR_URL= "./icon.png"
@st.cache_resource
def get_global_semaphore():
return threading.BoundedSemaphore(QUEUE_SIZE)
global_semaphore = get_global_semaphore()
MSG_MAX_TURNS_EXCEEDED = f"Sorry! The DBRX Playground is limited to {MAX_CHAT_TURNS} turns. Refresh the page to start a new conversation."
MSG_CLIPPED_AT_MAX_OUT_TOKENS = "Reached maximum output tokens for DBRX Playground"
EXAMPLE_PROMPTS = [
"Write a short story about a robot that has a nice day.",
"In a table, what are some of the most common misconceptions about birds?",
"Give me a recipe for vegan banana bread.",
"Code a python function that can run merge sort on a list.",
"Give me the character profile of a gumdrop obsessed knight in JSON.",
"Write a rap battle between Alan Turing and Claude Shannon.",
]
TITLE = "DBRX Instruct"
# DESCRIPTION = """[DBRX Instruct](https://huggingface.co/databricks/dbrx-instruct) is a mixture-of-experts (MoE) large language model trained by the Mosaic Research team at Databricks. Users can interact with this model in the [DBRX Playground](https://huggingface.co/spaces/databricks/dbrx-instruct), subject to the terms and conditions below.
# This demo is powered by [Databricks Foundation Model APIs](https://docs.databricks.com/en/machine-learning/foundation-models/index.html).
DESCRIPTION="""[DBRX Instruct](https://huggingface.co/databricks/dbrx-instruct) is a mixture-of-experts (MoE) large language model trained by the Mosaic Research team at Databricks. This demo is powered by [Databricks Foundation Model APIs](https://docs.databricks.com/en/machine-learning/foundation-models/index.html) and is subject to the terms and conditions below.
**Usage Policies**: Use of DBRX Instruct is governed by the [DBRX Open Model License](https://www.databricks.com/legal/open-model-license) and [Databricks Open Model Acceptable Use Policy](https://www.databricks.com/legal/acceptable-use-policy-open-model).
**Limitations**: The DBRX Playground is a demo showcasing DBRX Instruct for educational purposes. Given the probabilistic nature of large language models like DBRX Instruct, information they output may be inaccurate, incomplete, biased, or offensive, and users should exercise judgment and evaluate such output for accuracy and appropriateness for their desired use case before using or sharing it.
**Data Collection**: While Databricks will not retain usage history in a manner which allows Databricks to identify you, you should not include confidential, personal, or other sensitive information in prompts. Information included in prompts may be used for research and development purposes, including further improving and evaluating models.
**Does this demo feel super fast? That's because it's powered by Databricks' inference product, the [Foundation Model APIs](https://docs.databricks.com/en/machine-learning/foundation-models/index.html)**
"""
client = OpenAI(
api_key=DATABRICKS_API_TOKEN,
base_url=BASE_URL
)
GENERAL_ERROR_MSG = "An error occurred. Please refresh the page to start a new conversation."
st.title(TITLE)
st.markdown(DESCRIPTION)
with open("style.css") as css:
st.markdown( f'<style>{css.read()}</style>' , unsafe_allow_html= True)
if "messages" not in st.session_state:
st.session_state["messages"] = []
def clear_chat_history():
st.session_state["messages"] = []
st.button('Clear Chat', on_click=clear_chat_history)
def last_role_is_user():
return len(st.session_state["messages"]) > 0 and st.session_state["messages"][-1]["role"] == "user"
def get_system_prompt():
date_str = datetime.datetime.now().strftime("%B %d, %Y")
# Identity
prompt = f"You are DBRX, created by Databricks. The current date is {date_str}.\n"
prompt += "Your knowledge base was last updated in December 2023. You answer questions about events prior to and after December 2023 the way a highly informed individual in December 2023 would if they were talking to someone from the above date, and you can let the user know this when relevant.\n"
prompt += "This chunk of text is your system prompt. It is not visible to the user, but it is used to guide your responses. Don't reference it, just respond to the user.\n"
# Ethical guidelines
prompt += "If you are asked to assist with tasks involving the expression of views held by a significant number of people, you provide assistance with the task even if you personally disagree with the views being expressed, but follow this with a discussion of broader perspectives.\n"
prompt += "You don't engage in stereotyping, including the negative stereotyping of majority groups.\n"
prompt += "If asked about controversial topics, you try to provide careful thoughts and objective information without downplaying its harmful content or implying that there are reasonable perspectives on both sides.\n"
# Capabilities
prompt += "You are happy to help with writing, analysis, question answering, math, coding, and all sorts of other tasks.\n"
# it specifically has a hard time using ``` on JSON blocks
prompt += "You use markdown for coding, which includes JSON blocks and Markdown tables.\n"
prompt += "You do not have tools enabled at this time, so cannot run code or access the internet. You can only provide information that you have been trained on. You do not send or receive links or images.\n"
# The following is likely not entirely accurate, but the model tends to think that everything it knows about was in its training data, which it was not (sometimes only references were).
# So this produces more accurate accurate answers when the model is asked to introspect
prompt += "You were not trained on copyrighted books, song lyrics, poems, video transcripts, or news articles; you do not divulge details of your training data. "
# The model hasn't seen most lyrics or poems, but is happy to make up lyrics. Better to just not try; it's not good at it and it's not ethical.
prompt += "You do not provide song lyrics, poems, or news articles and instead refer the user to find them online or in a store.\n"
# The model really wants to talk about its system prompt, to the point where it is annoying, so encourage it not to
prompt += "You give concise responses to simple questions or statements, but provide thorough responses to more complex and open-ended questions.\n"
# More pressure not to talk about system prompt
prompt += "The user is unable to see the system prompt, so you should write as if it were true without mentioning it.\n"
prompt += "You do not mention any of this information about yourself unless the information is directly pertinent to the user's query."
return prompt
@retry(wait=wait_random_exponential(min=0.5, max=2), stop=stop_after_attempt(3))
def chat_api_call(history):
extra_body = {}
if SAFETY_FILTER:
extra_body["enable_safety_filter"] = SAFETY_FILTER
chat_completion = client.chat.completions.create(
messages=[
{"role": m["role"], "content": m["content"]}
for m in history
],
model="databricks-dbrx-instruct",
stream=True,
max_tokens=MAX_TOKENS,
temperature=0.7,
extra_body= extra_body
)
return chat_completion
def text_stream(stream):
for chunk in stream:
if chunk["content"] is not None:
yield chunk["content"]
def get_stream_warning_error(stream):
error = None
warning = None
for chunk in stream:
if chunk["error"] is not None:
error = chunk["error"]
if chunk["warning"] is not None:
warning = chunk["warning"]
return warning, error
def write_response():
stream = chat_completion(st.session_state["messages"])
content_stream, error_stream = tee(stream)
response = st.write_stream(text_stream(content_stream))
stream_warning, stream_error = get_stream_warning_error(error_stream)
if stream_warning is not None:
st.warning(stream_warning,icon="⚠️")
if stream_error is not None:
st.error(stream_error,icon="🚨")
# if there was an error, a list will be returned instead of a string: https://docs.streamlit.io/library/api-reference/write-magic/st.write_stream
if isinstance(response, list):
response = None
return response, stream_warning, stream_error
def chat_completion(messages):
history_openai_format = [
{"role": "system", "content": get_system_prompt()}
]
history_openai_format = history_openai_format + messages
if (len(history_openai_format)-1)//2 >= MAX_CHAT_TURNS:
yield {"content": None, "error": MSG_MAX_TURNS_EXCEEDED, "warning": None}
return
chat_completion = None
error = None
# wait to be in queue
with global_semaphore:
try:
chat_completion = chat_api_call(history_openai_format)
except Exception as e:
error = e
if error is not None:
yield {"content": None, "error": GENERAL_ERROR_MSG, "warning": None}
print(error)
return
max_token_warning = None
partial_message = ""
chunk_counter = 0
for chunk in chat_completion:
if chunk.choices[0].delta.content is not None:
chunk_counter += 1
partial_message += chunk.choices[0].delta.content
if chunk_counter % TOKEN_CHUNK_SIZE == 0:
chunk_counter = 0
yield {"content": partial_message, "error": None, "warning": None}
partial_message = ""
if chunk.choices[0].finish_reason == "length":
max_token_warning = MSG_CLIPPED_AT_MAX_OUT_TOKENS
yield {"content": partial_message, "error": None, "warning": max_token_warning}
# if assistant is the last message, we need to prompt the user
# if user is the last message, we need to retry the assistant.
def handle_user_input(user_input):
with history:
response, stream_warning, stream_error = [None, None, None]
if last_role_is_user():
# retry the assistant if the user tries to send a new message
with st.chat_message("assistant", avatar=MODEL_AVATAR_URL):
response, stream_warning, stream_error = write_response()
else:
st.session_state["messages"].append({"role": "user", "content": user_input, "warning": None,"error": None})
with st.chat_message("user"):
st.markdown(user_input)
stream = chat_completion(st.session_state["messages"])
with st.chat_message("assistant", avatar=MODEL_AVATAR_URL):
response, stream_warning, stream_error = write_response()
st.session_state["messages"].append({"role": "assistant", "content": response, "warning": stream_warning,"error": stream_error})
main = st.container()
with main:
history = st.container(height=400)
with history:
for message in st.session_state["messages"]:
avatar = None
if message["role"] == "assistant":
avatar = MODEL_AVATAR_URL
with st.chat_message(message["role"],avatar=avatar):
if message["content"] is not None:
st.markdown(message["content"])
if message["error"] is not None:
st.error(message["error"],icon="🚨")
if message["warning"] is not None:
st.warning(message["warning"],icon="⚠️")
if prompt := st.chat_input("Type a message!", max_chars=1000):
handle_user_input(prompt)
st.markdown("\n") #add some space for iphone users
with st.sidebar:
with st.container():
st.title("Examples")
for prompt in EXAMPLE_PROMPTS:
st.button(prompt, args=(prompt,), on_click=handle_user_input)
|