Spaces:
Sleeping
Sleeping
import datetime | |
from google.protobuf import message | |
import torch | |
import json | |
import time | |
import threading | |
import streamlit as st | |
import random | |
from typing import Iterable | |
from unsloth import FastLanguageModel | |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, PreTrainedTokenizerFast | |
from datetime import datetime | |
from threading import Thread | |
fine_tuned_model_name = "jed-tiotuico/twitter-llama" | |
sota_model_name = "unsloth/mistral-7b-instruct-v0.2-bnb-4bit" | |
# fine_tuned_model_name = "MBZUAI/LaMini-GPT-124M" | |
# sota_model_name = "MBZUAI/LaMini-GPT-124M" | |
alpaca_input_text_format = "</s>### Instruction:\n{}\n\n### Response:\n" | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
# if device is cpu try mps? | |
if device == "cpu": | |
# check if mps is available | |
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") | |
printer_models = [ | |
"HP Smart Tank 750", | |
"HP LaserJet Pro", | |
"HP LaserJet 4100", | |
"HP LaserJet 4000", | |
"HP Photosmart C4635", | |
"HP OfficeJet Pro 9015", | |
"HP Envy 6055", | |
"HP DeskJet 3755", | |
"HP Color LaserJet MFP M283fdw", | |
"HP DesignJet T630", | |
"HP PageWide Pro 477dw", | |
"HP LaserJet Enterprise M506", | |
"HP OfficeJet 5255", | |
"HP Envy Photo 7855", | |
"HP LaserJet Pro M404dn", | |
"HP DeskJet Plus 4155", | |
"HP LaserJet Enterprise MFP M528f", | |
"HP Neverstop Laser 1001nw", | |
"HP Tango X", | |
"HP Color LaserJet Pro M255dw", | |
"HP Smart Tank Plus 651", | |
"HP LaserJet Pro MFP M428fdw", | |
"HP OfficeJet Pro 8035", | |
"HP Envy 6075", | |
"HP DeskJet 2622", | |
"HP LaserJet Pro M15w" | |
] | |
def generate_printer_prompt(prompt_instructions): | |
"""Encode multiple prompt instructions into a single string.""" | |
prompt = """ | |
Come up with a printer related task or question that a person might ask for support. | |
no further text/explanation, no additional information. | |
Ensure the tasks/questions should follow the same style and complexity | |
Examples: | |
""" | |
for idx, instruction in enumerate(prompt_instructions): | |
instruction = re.sub(r"\s+", " ", instruction).strip().rstrip(":") | |
# pick one random printer model to replace the placeholder | |
printer_model = random.choice(printer_models) | |
instruction = re.sub(r"<\|hp-printer\|>", printer_model, instruction) | |
prompt += f"Q: {instruction}\n\n" | |
# prompt += f"{len(prompt_instructions) + 1}. Q:" | |
prompt += "Now it's your turn, come up with a printer task/question that a person might ask for support.\n" | |
prompt += "Q: (your task/question)" | |
return prompt | |
def get_model_tokenizer(sota_model_name): | |
model, tokenizer = FastLanguageModel.from_pretrained( | |
model_name = "jed-tiotuico/twitter-llama", | |
max_seq_length = 200, | |
dtype = None, | |
load_in_4bit = True, | |
cache_dir = "/data/.cache/hf-models", | |
token=st.secrets["HF_TOKEN"] | |
) | |
FastLanguageModel.for_inference(model) | |
return model, tokenizer | |
def write_user_chat_message(user_chat, customer_msg): | |
if customer_msg: | |
if user_chat == None: | |
user_chat = st.chat_message("user") | |
user_chat.write(customer_msg) | |
def write_stream_user_chat_message(user_chat, model, token, prompt): | |
if prompt: | |
if user_chat == None: | |
user_chat = st.chat_message("user") | |
new_customer_msg = user_chat.write_stream( | |
stream_generation( | |
prompt, | |
show_prompt=False, | |
tokenizer=tokenizer, | |
model=model, | |
temperature=0.5, | |
) | |
) | |
return new_customer_msg | |
def get_mistral_model_tokenizer(sota_model_name): | |
max_seq_length = 2048 | |
dtype = torch.float16 | |
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False. | |
model, tokenizer = FastLanguageModel.from_pretrained( | |
model_name = "unsloth/mistral-7b-instruct-v0.2-bnb-4bit", | |
max_seq_length = max_seq_length, | |
dtype = dtype, | |
load_in_4bit = load_in_4bit, | |
cache_dir = "/data/.cache/hf-models", | |
) | |
FastLanguageModel.for_inference(model) | |
return model, tokenizer | |
class DeckPicker: | |
def __init__(self, items): | |
self.items = items[:] # Make a copy of the items to shuffle | |
self.original_items = items[:] # Keep the original order | |
random.shuffle(self.items) # Shuffle the items | |
self.index = -1 # Initialize the index | |
def pick(self): | |
"""Pick the next item from the deck. If all items have been picked, reshuffle.""" | |
self.index += 1 | |
if self.index >= len(self.items): | |
self.index = 0 | |
random.shuffle(self.items) # Reshuffle if at the end | |
return self.items[self.index] | |
def get_state(self): | |
"""Return the current state of the deck and the last picked index.""" | |
return self.items, self.index | |
# Example of usage | |
nouns = [ | |
"service", "issue", "account", "support", "problem", "help", "team", | |
"request", "response", "email", "ticket", "update", "error", "system", | |
"connection", "downtime", "billing", "charge", "refund", "password", | |
"outage", "agent", "feature", "access", "status", "interface", "network", | |
"subscription", "upgrade", "notification", "data", "server", "log", "message", | |
"renewal", "setup", "security", "feedback", "confirmation", "printer" | |
] | |
verbs = [ | |
"have", "print", "need", "help", "update", "resolve", "access", "contact", | |
"receive", "reset", "support", "experience", "report", "request", "process", | |
"check", "confirm", "explain", "manage", "handle", "disconnect", "renew", | |
"change", "fix", "cancel", "complete", "notify", "respond", "fail", "restore", | |
"review", "escalate", "submit", "configure", "troubleshoot", "log", "operate", | |
"suspend", "pay", "adjust" | |
] | |
adjectives = [ | |
"quick", "immediate", "urgent", "unable", "detailed", "frequent", "technical", | |
"possible", "slow", "helpful", "unresponsive", "secure", "successful", "necessary", | |
"available", "scheduled", "regular", "interrupted", "automatic", "manual", "last", | |
"online", "offline", "new", "current", "prior", "due", "related", "temporary", | |
"permanent", "next", "previous", "complicated", "easy", "difficult", "major", | |
"minor", "alternative", "additional", "expired" | |
] | |
def create_few_shots(noun_picker, verb_picker, adjective_picker): | |
noun = noun_picker.pick() | |
verb = verb_picker.pick() | |
adjective = adjective_picker.pick() | |
context = f""" | |
Write a short realistic customer support tweet message by a customer for another company. | |
Avoid adding hashtags or mentions in the message. | |
Ensure that the sentiment is negative. | |
Ensure that the word count is around 15 to 25 words. | |
Ensure the message contains the noun: {noun}, verb: {verb}, and adjective: {adjective}. | |
Example of return messages 5/5: | |
1/5: your website is straight up garbage. how do you sell high end technology but you cant get a website right? | |
2/5: my phone is all static during calls and when i plug in headphones any audio still comes thru the speaks wtf | |
3/5: hi, i'm having trouble logging into my groceries account it keeps refreshing back to the log in page, any ideas? | |
4/5: please check you dms asap if you're really about customer service. 2 weeks since my accident and nothing. | |
5/5: I'm extremely disappointed with your service. You charged me for a temporary solution, and there's no adjustment in sight. | |
Now it's your turn, ensure to only generate one message | |
1/1: | |
""" | |
return context | |
st.header("ReplyCaddy") | |
st.write("AI-powered customer support assistant. Reduces anxiety when responding to customer support on social media.") | |
st.markdown(""" | |
Instructions: | |
1. Click the Generate Customer Message using Few Shots button to generate a custom message | |
2. Then click Generate Polite and Friendly Response | |
3. Or Enter a custom message in the text box and click Generate Polite and Friendly Response | |
""") | |
# image https://github.com/unslothai/unsloth/blob/main/images/made%20with%20unsloth.png?raw=true | |
# st.write("Made with [Unsloth](https://github.com/unslothai/unsloth/blob/main/images/made%20with%20unsloth.png?raw=true") | |
def stream_generation( | |
prompt: str, | |
tokenizer: PreTrainedTokenizerFast, | |
model: AutoModelForCausalLM, | |
max_new_tokens: int = 2048, | |
temperature: float = 0.7, | |
top_p: float = 0.9, | |
top_k: int = 100, | |
repetition_penalty: float = 1.1, | |
penalty_alpha: float = 0.25, | |
no_repeat_ngram_size: int = 3, | |
show_prompt: bool = False, | |
) -> Iterable[str]: | |
""" | |
Stream the generation of a prompt. | |
Args: | |
prompt (str): the prompt | |
max_new_tokens (int, optional): the maximum number of tokens to generate. Defaults to 32. | |
temperature (float, optional): the temperature of the generation. Defaults to 0.7. | |
top_p (float, optional): the top-p value of the generation. Defaults to 0.9. | |
top_k (int, optional): the top-k value of the generation. Defaults to 100. | |
repetition_penalty (float, optional): the repetition penalty of the generation. Defaults to 1.1. | |
penalty_alpha (float, optional): the penalty alpha of the generation. Defaults to 0.25. | |
no_repeat_ngram_size (int, optional): the no repeat ngram size of the generation. Defaults to 3. | |
show_prompt (bool, optional): whether to show the prompt or not. Defaults to False. | |
tokenizer (PreTrainedTokenizerFast): the tokenizer | |
model (AutoModelForCausalLM): the model | |
Yields: | |
str: the generated text | |
""" | |
# init the streaming object with tokenizer | |
# skip_prompt = not show_prompt, skip_special_tokens = True | |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=not show_prompt, skip_special_tokens=True) # type: ignore | |
# setup kwargs for generation | |
generation_kwargs = dict( | |
input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to(device), | |
streamer=streamer, | |
do_sample=True, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
repetition_penalty=repetition_penalty, | |
penalty_alpha=penalty_alpha, | |
no_repeat_ngram_size=no_repeat_ngram_size, | |
max_new_tokens=max_new_tokens, | |
) | |
# start the generation in a separate thread | |
generation_thread = threading.Thread( | |
target=model.generate, kwargs=generation_kwargs # type: ignore | |
) | |
generation_thread.start() | |
blacklisted_tokens = ["<|url|>"] | |
for new_text in streamer: | |
# filter out blacklisted tokens | |
if any(token in new_text for token in blacklisted_tokens): | |
continue | |
yield new_text | |
# wait for the generation to finish | |
generation_thread.join() | |
twitter_llama_model = None | |
twitter_llama_tokenizer = None | |
streamer = None | |
# define state and the chat messages | |
def init_session_states(assistant_chat, user_chat): | |
if "user_msg_as_prompt" not in st.session_state: | |
st.session_state["user_msg_as_prompt"] = "" | |
user_chat = None | |
if "user_msg_as_prompt" in st.session_state: | |
user_chat = st.chat_message("user") | |
assistant_chat = st.chat_message("assistant") | |
if "greet" not in st.session_state: | |
st.session_state["greet"] = False | |
greeting_text = "Hello! I'm here to help. Copy and paste your customer's message, or generate using AI." | |
assistant_chat.write(greeting_text) | |
init_session_states(assistant_chat, user_chat) | |
# Generate Response Tweet | |
if user_chat: | |
if st.button("Generate Polite and Friendly Response"): | |
if "user_msg_as_prompt" in st.session_state: | |
customer_msg = st.session_state["user_msg_as_prompt"] | |
if customer_msg: | |
write_user_chat_message(user_chat, customer_msg) | |
model, tokenizer = get_model_tokenizer(sota_model_name) | |
input_text = alpaca_input_text_format.format(customer_msg) | |
st.markdown(f"""```\n{input_text}```""", unsafe_allow_html=True) | |
response_tweet = assistant_chat.write_stream( | |
stream_generation( | |
input_text, | |
show_prompt=False, | |
tokenizer=tokenizer, | |
model=model, | |
temperature=0.5, | |
) | |
) | |
else: | |
st.error("Please enter a customer message, or generate one for the ai to respond") | |
# below ui prompt | |
# - examples | |
# st.markdown("<b>Example:</b>", unsafe_allow_html=True) | |
if st.button("your website is straight up garbage. how do you sell high end technology but you cant get a website right?"): | |
customer_msg = "your website is straight up garbage. how do you sell high end technology but you cant get a website right?" | |
st.session_state["user_msg_as_prompt"] = customer_msg | |
write_user_chat_message(user_chat, customer_msg) | |
model, tokenizer = get_model_tokenizer(sota_model_name) | |
input_text = alpaca_input_text_format.format(customer_msg) | |
st.write(f"```\n{input_text}```") | |
assistant_chat.write_stream( | |
stream_generation( | |
input_text, | |
show_prompt=False, | |
tokenizer=tokenizer, | |
model=model, | |
temperature=0.5, | |
) | |
) | |
if st.button("Generate printer task/question"): | |
seed_tasks = [json.loads(l) for l in open("printer-seed.jsonl", "r")] | |
seed_instructions = [t["text"] for t in seed_tasks] | |
prompt_instructions = [] | |
prompt_instructions += random.sample(seed_instructions, num_prompt_instructions - len(prompt_instructions)) | |
random.shuffle(prompt_instructions) | |
customer_msg = generate_printer_prompt(prompt_instructions) | |
st.session_state["user_msg_as_prompt"] = customer_msg | |
write_user_chat_message(user_chat, customer_msg) | |
model, tokenizer = get_model_tokenizer(sota_model_name) | |
input_text = alpaca_input_text_format.format(customer_msg) | |
st.write(f"```\n{input_text}```") | |
assistant_chat.write_stream( | |
stream_generation( | |
input_text, | |
show_prompt=False, | |
tokenizer=tokenizer, | |
model=model, | |
temperature=0.5, | |
) | |
) | |
# - Generate Customer Tweet | |
if st.button("Generate Customer Message using Few Shots"): | |
model, tokenizer = get_mistral_model_tokenizer(sota_model_name) | |
noun_picker = DeckPicker(nouns) | |
verb_picker = DeckPicker(verbs) | |
adjective_picker = DeckPicker(adjectives) | |
few_shots = create_few_shots(noun_picker, verb_picker, adjective_picker) | |
few_shot_prompt = f"<s>[INST]{few_shots}[/INST]\n" | |
st.markdown("Prompt:") | |
st.markdown(f"""```\n{few_shot_prompt}```""", unsafe_allow_html=True) | |
new_customer_msg = write_stream_user_chat_message(user_chat, model, tokenizer, few_shot_prompt) | |
st.session_state["user_msg_as_prompt"] = new_customer_msg | |
# main ui prompt | |
# - text box | |
# - submit | |
with st.form(key="my_form"): | |
customer_msg = st.text_area("Customer Message") | |
write_user_chat_message(user_chat, customer_msg) | |
if st.form_submit_button("Submit and Generate Response"): | |
st.session_state["user_msg_as_prompt"] = customer_msg | |
write_user_chat_message(user_chat, customer_msg) | |
model, tokenizer = get_model_tokenizer(sota_model_name) | |
input_text = alpaca_input_text_format.format(customer_msg) | |
st.write(f"```\n{input_text}```") | |
assistant_chat.write_stream( | |
stream_generation( | |
input_text, | |
show_prompt=False, | |
tokenizer=tokenizer, | |
model=model, | |
temperature=0.5, | |
) | |
) | |
st.markdown("------------") | |
st.markdown("<p>Thanks to:</p>", unsafe_allow_html=True) | |
st.markdown("""Unsloth https://github.com/unslothai check out the [wiki](https://github.com/unslothai/unsloth/wiki)""") | |
st.markdown("""Georgi Gerganov's ggml https://github.com/ggerganov/ggml""") | |
st.markdown("""Meta's Llama https://github.com/meta-llama""") | |
st.markdown("""Mistral AI - https://github.com/mistralai""") | |
st.markdown("""Zhang Peiyuan's TinyLlama https://github.com/jzhang38/TinyLlama""") | |
st.markdown("""Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, | |
Xuechen Li, Carlos Guestrin, Percy Liang, Tatsunori B. Hashimoto | |
- [Alpaca: A Strong, Replicable Instruction-Following Model](https://crfm.stanford.edu/2023/03/13/alpaca.html)""") | |
if True: | |
gpu_stats = torch.cuda.get_device_properties(0) | |
max_memory = gpu_stats.total_memory / 1024 ** 3 | |
start_gpu_memory = torch.cuda.memory_reserved(0) / 1024 ** 3 | |
st.write(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.") | |
st.write(f"{start_gpu_memory} GB of memory reserved.") | |
st.write("Packages:") | |
st.write(f"pytorch: {torch.__version__}") | |
st.markdown(""" | |
## Overview | |
Small business owners, app developers, and freelance talent professionals need more time and money to compete for social media presence. Social media apps provide excellent reach on X, FaceBook, and Instagram, so customers sometimes request initial support from those. The correct response has always been to redirect them to the proper support funnels, which then we run a process of account validation, issue classification, and prioritization. | |
However, composing the right tone for the first response message has been challenging and time-consuming, especially when the business owner has yet to experience customer support in public. One has to match the right tone and clarity and instruct the customer to redirect to the proper support channels. | |
By providing diverse responses, we can ensure each support message feels unique and tailored to the customer's message, avoiding the impersonal feel of a canned response. | |
## Problems with existing solutions | |
We prompted ChatGPT to respond an irate, emotionally charged, and informal tone to match generation criteria which we will provide below. | |
"You are a customer support representative. compose a customer response to this tweet: | |
Your website is straight up garbage. how do you sell high end technology but you cant get a website right?" | |
[GenerationCriteria] | |
"Thank you for reaching out and sharing your feedback. We apologize for the trouble you're experiencing with our website. Please DM us or contact our support team at [support link] so we can help resolve any issues you're encountering." | |
First, the response is wordily composed and does not contain a request to DM; second, it has an apology and its context. Third, from our experience, ChatGPT values politeness more than a human would. | |
As a business, we want our brands to hold a high standard for these responses. | |
We present ReplyCaddy, an AI-powered first-response text message generator that will help us compose the right first-response message that composes personal messages and matches the customer's tone. | |
We tested the prompt above to ReplyCaddy, and it generated these examples: | |
"hi! let's talk about it." | |
"we'd love to help. we're here to help!" | |
"we understand that you are not happy with the website. please send us an email at <|url|>" | |
""") | |