Spaces:
Sleeping
Sleeping
File size: 13,231 Bytes
ffa493c d47ab3c ffa493c d47ab3c ffa493c d47ab3c ffa493c d47ab3c ffa493c d47ab3c ffa493c d47ab3c ffa493c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 |
import datetime
from google.protobuf import message
import torch
import time
import threading
import streamlit as st
import random
from typing import Iterable
from unsloth import FastLanguageModel
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, PreTrainedTokenizerFast
from datetime import datetime
from threading import Thread
fine_tuned_model_name = "jed-tiotuico/twitter-llama"
sota_model_name = "unsloth/mistral-7b-instruct-v0.2-bnb-4bit"
# fine_tuned_model_name = "MBZUAI/LaMini-GPT-124M"
# sota_model_name = "MBZUAI/LaMini-GPT-124M"
alpaca_input_text_format = "### Instruction:\n{}\n\n### Response:\n"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# if device is cpu try mps?
if device == "cpu":
# check if mps is available
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
def get_model_tokenizer(sota_model_name):
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "jed-tiotuico/twitter-llama",
max_seq_length = 200,
dtype = None,
load_in_4bit = True,
cache_dir = "/data/.cache/hf-models",
token=st.secrets["HF_TOKEN"]
)
FastLanguageModel.for_inference(model)
return model, tokenizer
def write_user_chat_message(user_chat, customer_msg):
if customer_msg:
if user_chat == None:
user_chat = st.chat_message("user")
user_chat.write(customer_msg)
def write_stream_user_chat_message(user_chat, model, token, prompt):
if prompt:
if user_chat == None:
user_chat = st.chat_message("user")
new_customer_msg = user_chat.write_stream(
stream_generation(
prompt,
show_prompt=False,
tokenizer=tokenizer,
model=model,
)
)
return new_customer_msg
def get_mistral_model_tokenizer(sota_model_name):
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "unsloth/mistral-7b-instruct-v0.2-bnb-4bit",
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
cache_dir = "/data/.cache/hf-models",
)
FastLanguageModel.for_inference(model)
return model, tokenizer
class DeckPicker:
def __init__(self, items):
self.items = items[:] # Make a copy of the items to shuffle
self.original_items = items[:] # Keep the original order
random.shuffle(self.items) # Shuffle the items
self.index = -1 # Initialize the index
def pick(self):
"""Pick the next item from the deck. If all items have been picked, reshuffle."""
self.index += 1
if self.index >= len(self.items):
self.index = 0
random.shuffle(self.items) # Reshuffle if at the end
return self.items[self.index]
def get_state(self):
"""Return the current state of the deck and the last picked index."""
return self.items, self.index
# Example of usage
nouns = [
"service", "issue", "account", "support", "problem", "help", "team",
"request", "response", "email", "ticket", "update", "error", "system",
"connection", "downtime", "billing", "charge", "refund", "password",
"outage", "agent", "feature", "access", "status", "interface", "network",
"subscription", "upgrade", "notification", "data", "server", "log", "message",
"renewal", "setup", "security", "feedback", "confirmation", "printer"
]
verbs = [
"have", "print", "need", "help", "update", "resolve", "access", "contact",
"receive", "reset", "support", "experience", "report", "request", "process",
"check", "confirm", "explain", "manage", "handle", "disconnect", "renew",
"change", "fix", "cancel", "complete", "notify", "respond", "fail", "restore",
"review", "escalate", "submit", "configure", "troubleshoot", "log", "operate",
"suspend", "pay", "adjust"
]
adjectives = [
"quick", "immediate", "urgent", "unable", "detailed", "frequent", "technical",
"possible", "slow", "helpful", "unresponsive", "secure", "successful", "necessary",
"available", "scheduled", "regular", "interrupted", "automatic", "manual", "last",
"online", "offline", "new", "current", "prior", "due", "related", "temporary",
"permanent", "next", "previous", "complicated", "easy", "difficult", "major",
"minor", "alternative", "additional", "expired"
]
def create_few_shots(noun_picker, verb_picker, adjective_picker):
noun = noun_picker.pick()
verb = verb_picker.pick()
adjective = adjective_picker.pick()
context = f"""
Write a short realistic customer support tweet message by a customer for another company.
Avoid adding hashtags or mentions in the message.
Ensure that the sentiment is negative.
Ensure that the word count is around 15 to 25 words.
Ensure the message contains the noun: {noun}, verb: {verb}, and adjective: {adjective}.
Example of return messages 5/5:
1/5: your website is straight up garbage. how do you sell high end technology but you cant get a website right?
2/5: my phone is all static during calls and when i plug in headphones any audio still comes thru the speaks wtf
3/5: hi, i'm having trouble logging into my groceries account it keeps refreshing back to the log in page, any ideas?
4/5: please check you dms asap if you're really about customer service. 2 weeks since my accident and nothing.
5/5: I'm extremely disappointed with your service. You charged me for a temporary solution, and there's no adjustment in sight.
Now it's your turn, ensure to only generate one message
1/1:
"""
return context
st.header("ReplyCaddy")
st.write("AI-powered customer support assistant. Reduces anxiety when responding to customer support on social media.")
# image https://github.com/unslothai/unsloth/blob/main/images/made%20with%20unsloth.png?raw=true
# st.write("Made with [Unsloth](https://github.com/unslothai/unsloth/blob/main/images/made%20with%20unsloth.png?raw=true")
def stream_generation(
prompt: str,
tokenizer: PreTrainedTokenizerFast,
model: AutoModelForCausalLM,
max_new_tokens: int = 2048,
temperature: float = 0.7,
top_p: float = 0.9,
top_k: int = 100,
repetition_penalty: float = 1.1,
penalty_alpha: float = 0.25,
no_repeat_ngram_size: int = 3,
show_prompt: bool = False,
) -> Iterable[str]:
"""
Stream the generation of a prompt.
Args:
prompt (str): the prompt
max_new_tokens (int, optional): the maximum number of tokens to generate. Defaults to 32.
temperature (float, optional): the temperature of the generation. Defaults to 0.7.
top_p (float, optional): the top-p value of the generation. Defaults to 0.9.
top_k (int, optional): the top-k value of the generation. Defaults to 100.
repetition_penalty (float, optional): the repetition penalty of the generation. Defaults to 1.1.
penalty_alpha (float, optional): the penalty alpha of the generation. Defaults to 0.25.
no_repeat_ngram_size (int, optional): the no repeat ngram size of the generation. Defaults to 3.
show_prompt (bool, optional): whether to show the prompt or not. Defaults to False.
tokenizer (PreTrainedTokenizerFast): the tokenizer
model (AutoModelForCausalLM): the model
Yields:
str: the generated text
"""
# init the streaming object with tokenizer
# skip_prompt = not show_prompt, skip_special_tokens = True
streamer = TextIteratorStreamer(tokenizer, skip_prompt=not show_prompt, skip_special_tokens=True) # type: ignore
# setup kwargs for generation
generation_kwargs = dict(
input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to(device),
streamer=streamer,
do_sample=True,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
penalty_alpha=penalty_alpha,
no_repeat_ngram_size=no_repeat_ngram_size,
max_new_tokens=max_new_tokens,
)
# start the generation in a separate thread
generation_thread = threading.Thread(
target=model.generate, kwargs=generation_kwargs # type: ignore
)
generation_thread.start()
blacklisted_tokens = ["<|url|>"]
for new_text in streamer:
# filter out blacklisted tokens
if any(token in new_text for token in blacklisted_tokens):
continue
yield new_text
# wait for the generation to finish
generation_thread.join()
twitter_llama_model = None
twitter_llama_tokenizer = None
streamer = None
# define state and the chat messages
def init_session_states(assistant_chat, user_chat):
if "user_msg_as_prompt" not in st.session_state:
st.session_state["user_msg_as_prompt"] = ""
user_chat = None
if "user_msg_as_prompt" in st.session_state:
user_chat = st.chat_message("user")
assistant_chat = st.chat_message("assistant")
if "greet" not in st.session_state:
st.session_state["greet"] = False
greeting_text = "Hello! I'm here to help. Copy and paste your customer's message, or generate using AI."
assistant_chat.write(greeting_text)
init_session_states(assistant_chat, user_chat)
# Generate Response Tweet
if user_chat:
if st.button("Generate Polite and Friendly Response"):
if "user_msg_as_prompt" in st.session_state:
customer_msg = st.session_state["user_msg_as_prompt"]
if customer_msg:
write_user_chat_message(user_chat, customer_msg)
model, tokenizer = get_model_tokenizer(sota_model_name)
input_text = alpaca_input_text_format.format(customer_msg)
st.markdown(f"""```\n{input_text}```""", unsafe_allow_html=True)
response_tweet = assistant_chat.write_stream(
stream_generation(
input_text,
show_prompt=False,
tokenizer=tokenizer,
model=model,
)
)
else:
st.error("Please enter a customer message, or generate one for the ai to respond")
# main ui prompt
# - text box
# - submit
with st.form(key="my_form"):
prompt = st.text_area("Customer Message")
write_user_chat_message(user_chat, prompt)
if st.form_submit_button("Submit"):
assistant_chat.write("Hi, Human.")
# below ui prompt
# - examples
# st.markdown("<b>Example:</b>", unsafe_allow_html=True)
if st.button("your website is straight up garbage. how do you sell high end technology but you cant get a website right?"):
customer_msg = "your website is straight up garbage. how do you sell high end technology but you cant get a website right?"
st.session_state["user_msg_as_prompt"] = customer_msg
write_user_chat_message(user_chat, customer_msg)
model, tokenizer = get_model_tokenizer(sota_model_name)
input_text = alpaca_input_text_format.format(customer_msg)
st.write(f"```\n{input_text}```")
assistant_chat.write_stream(
stream_generation(
input_text,
show_prompt=False,
tokenizer=tokenizer,
model=model,
)
)
# - Generate Customer Tweet
if st.button("Generate Customer Message using Few Shots"):
max_seq_length = 2048
dtype = torch.float16
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.
model, tokenizer = get_mistral_model_tokenizer(sota_model_name)
noun_picker = DeckPicker(nouns)
verb_picker = DeckPicker(verbs)
adjective_picker = DeckPicker(adjectives)
few_shots = create_few_shots(noun_picker, verb_picker, adjective_picker)
few_shot_prompt = f"<s>[INST]{few_shots}[/INST]\n"
st.markdown("Prompt:")
st.markdown(f"""```\n{few_shot_prompt}```""", unsafe_allow_html=True)
new_customer_msg = write_stream_user_chat_message(user_chat, model, tokenizer, few_shot_prompt)
st.session_state["user_msg_as_prompt"] = new_customer_msg
st.markdown("------------")
st.markdown("<p>Thanks to:</p>", unsafe_allow_html=True)
st.markdown("""Unsloth https://github.com/unslothai check out the [wiki](https://github.com/unslothai/unsloth/wiki)""")
st.markdown("""Georgi Gerganov's ggml https://github.com/ggerganov/ggml""")
st.markdown("""Meta's Llama https://github.com/meta-llama""")
st.markdown("""Mistral AI - https://github.com/mistralai""")
st.markdown("""Zhang Peiyuan's TinyLlama https://github.com/jzhang38/TinyLlama""")
st.markdown("""Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois,
Xuechen Li, Carlos Guestrin, Percy Liang, Tatsunori B. Hashimoto
- [Alpaca: A Strong, Replicable Instruction-Following Model](https://crfm.stanford.edu/2023/03/13/alpaca.html)""")
if device == "cuda":
gpu_stats = torch.cuda.get_device_properties(0)
max_memory = gpu_stats.total_memory / 1024 ** 3
start_gpu_memory = torch.cuda.memory_reserved(0) / 1024 ** 3
st.write(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
st.write(f"{start_gpu_memory} GB of memory reserved.")
st.write("Packages:")
st.write(f"pytorch: {torch.__version__}")
|