Spaces:
Sleeping
Sleeping
import datetime | |
from google.protobuf import message | |
import torch | |
import time | |
import threading | |
import streamlit as st | |
import random | |
from typing import Iterable | |
from unsloth import FastLanguageModel | |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, PreTrainedTokenizerFast | |
from datetime import datetime | |
from threading import Thread | |
fine_tuned_model_name = "jed-tiotuico/twitter-llama" | |
sota_model_name = "unsloth/mistral-7b-instruct-v0.2-bnb-4bit" | |
# fine_tuned_model_name = "MBZUAI/LaMini-GPT-124M" | |
# sota_model_name = "MBZUAI/LaMini-GPT-124M" | |
alpaca_input_text_format = "### Instruction:\n{}\n\n### Response:\n" | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
# if device is cpu try mps? | |
if device == "cpu": | |
# check if mps is available | |
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") | |
def get_model_tokenizer(sota_model_name): | |
model, tokenizer = FastLanguageModel.from_pretrained( | |
model_name = "jed-tiotuico/twitter-llama", | |
max_seq_length = 200, | |
dtype = None, | |
load_in_4bit = True, | |
cache_dir = "/data/.cache/hf-models", | |
token=st.secrets["HF_TOKEN"] | |
) | |
FastLanguageModel.for_inference(model) | |
return model, tokenizer | |
def write_user_chat_message(user_chat, customer_msg): | |
if customer_msg: | |
if user_chat == None: | |
user_chat = st.chat_message("user") | |
user_chat.write(customer_msg) | |
def write_stream_user_chat_message(user_chat, model, token, prompt): | |
if prompt: | |
if user_chat == None: | |
user_chat = st.chat_message("user") | |
new_customer_msg = user_chat.write_stream( | |
stream_generation( | |
prompt, | |
show_prompt=False, | |
tokenizer=tokenizer, | |
model=model, | |
) | |
) | |
return new_customer_msg | |
def get_mistral_model_tokenizer(sota_model_name): | |
model, tokenizer = FastLanguageModel.from_pretrained( | |
model_name = "unsloth/mistral-7b-instruct-v0.2-bnb-4bit", | |
max_seq_length = max_seq_length, | |
dtype = dtype, | |
load_in_4bit = load_in_4bit, | |
cache_dir = "/data/.cache/hf-models", | |
) | |
FastLanguageModel.for_inference(model) | |
return model, tokenizer | |
class DeckPicker: | |
def __init__(self, items): | |
self.items = items[:] # Make a copy of the items to shuffle | |
self.original_items = items[:] # Keep the original order | |
random.shuffle(self.items) # Shuffle the items | |
self.index = -1 # Initialize the index | |
def pick(self): | |
"""Pick the next item from the deck. If all items have been picked, reshuffle.""" | |
self.index += 1 | |
if self.index >= len(self.items): | |
self.index = 0 | |
random.shuffle(self.items) # Reshuffle if at the end | |
return self.items[self.index] | |
def get_state(self): | |
"""Return the current state of the deck and the last picked index.""" | |
return self.items, self.index | |
# Example of usage | |
nouns = [ | |
"service", "issue", "account", "support", "problem", "help", "team", | |
"request", "response", "email", "ticket", "update", "error", "system", | |
"connection", "downtime", "billing", "charge", "refund", "password", | |
"outage", "agent", "feature", "access", "status", "interface", "network", | |
"subscription", "upgrade", "notification", "data", "server", "log", "message", | |
"renewal", "setup", "security", "feedback", "confirmation", "printer" | |
] | |
verbs = [ | |
"have", "print", "need", "help", "update", "resolve", "access", "contact", | |
"receive", "reset", "support", "experience", "report", "request", "process", | |
"check", "confirm", "explain", "manage", "handle", "disconnect", "renew", | |
"change", "fix", "cancel", "complete", "notify", "respond", "fail", "restore", | |
"review", "escalate", "submit", "configure", "troubleshoot", "log", "operate", | |
"suspend", "pay", "adjust" | |
] | |
adjectives = [ | |
"quick", "immediate", "urgent", "unable", "detailed", "frequent", "technical", | |
"possible", "slow", "helpful", "unresponsive", "secure", "successful", "necessary", | |
"available", "scheduled", "regular", "interrupted", "automatic", "manual", "last", | |
"online", "offline", "new", "current", "prior", "due", "related", "temporary", | |
"permanent", "next", "previous", "complicated", "easy", "difficult", "major", | |
"minor", "alternative", "additional", "expired" | |
] | |
def create_few_shots(noun_picker, verb_picker, adjective_picker): | |
noun = noun_picker.pick() | |
verb = verb_picker.pick() | |
adjective = adjective_picker.pick() | |
context = f""" | |
Write a short realistic customer support tweet message by a customer for another company. | |
Avoid adding hashtags or mentions in the message. | |
Ensure that the sentiment is negative. | |
Ensure that the word count is around 15 to 25 words. | |
Ensure the message contains the noun: {noun}, verb: {verb}, and adjective: {adjective}. | |
Example of return messages 5/5: | |
1/5: your website is straight up garbage. how do you sell high end technology but you cant get a website right? | |
2/5: my phone is all static during calls and when i plug in headphones any audio still comes thru the speaks wtf | |
3/5: hi, i'm having trouble logging into my groceries account it keeps refreshing back to the log in page, any ideas? | |
4/5: please check you dms asap if you're really about customer service. 2 weeks since my accident and nothing. | |
5/5: I'm extremely disappointed with your service. You charged me for a temporary solution, and there's no adjustment in sight. | |
Now it's your turn, ensure to only generate one message | |
1/1: | |
""" | |
return context | |
st.header("ReplyCaddy") | |
st.write("AI-powered customer support assistant. Reduces anxiety when responding to customer support on social media.") | |
# image https://github.com/unslothai/unsloth/blob/main/images/made%20with%20unsloth.png?raw=true | |
# st.write("Made with [Unsloth](https://github.com/unslothai/unsloth/blob/main/images/made%20with%20unsloth.png?raw=true") | |
def stream_generation( | |
prompt: str, | |
tokenizer: PreTrainedTokenizerFast, | |
model: AutoModelForCausalLM, | |
max_new_tokens: int = 2048, | |
temperature: float = 0.7, | |
top_p: float = 0.9, | |
top_k: int = 100, | |
repetition_penalty: float = 1.1, | |
penalty_alpha: float = 0.25, | |
no_repeat_ngram_size: int = 3, | |
show_prompt: bool = False, | |
) -> Iterable[str]: | |
""" | |
Stream the generation of a prompt. | |
Args: | |
prompt (str): the prompt | |
max_new_tokens (int, optional): the maximum number of tokens to generate. Defaults to 32. | |
temperature (float, optional): the temperature of the generation. Defaults to 0.7. | |
top_p (float, optional): the top-p value of the generation. Defaults to 0.9. | |
top_k (int, optional): the top-k value of the generation. Defaults to 100. | |
repetition_penalty (float, optional): the repetition penalty of the generation. Defaults to 1.1. | |
penalty_alpha (float, optional): the penalty alpha of the generation. Defaults to 0.25. | |
no_repeat_ngram_size (int, optional): the no repeat ngram size of the generation. Defaults to 3. | |
show_prompt (bool, optional): whether to show the prompt or not. Defaults to False. | |
tokenizer (PreTrainedTokenizerFast): the tokenizer | |
model (AutoModelForCausalLM): the model | |
Yields: | |
str: the generated text | |
""" | |
# init the streaming object with tokenizer | |
# skip_prompt = not show_prompt, skip_special_tokens = True | |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=not show_prompt, skip_special_tokens=True) # type: ignore | |
# setup kwargs for generation | |
generation_kwargs = dict( | |
input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to(device), | |
streamer=streamer, | |
do_sample=True, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
repetition_penalty=repetition_penalty, | |
penalty_alpha=penalty_alpha, | |
no_repeat_ngram_size=no_repeat_ngram_size, | |
max_new_tokens=max_new_tokens, | |
) | |
# start the generation in a separate thread | |
generation_thread = threading.Thread( | |
target=model.generate, kwargs=generation_kwargs # type: ignore | |
) | |
generation_thread.start() | |
blacklisted_tokens = ["<|url|>"] | |
for new_text in streamer: | |
# filter out blacklisted tokens | |
if any(token in new_text for token in blacklisted_tokens): | |
continue | |
yield new_text | |
# wait for the generation to finish | |
generation_thread.join() | |
twitter_llama_model = None | |
twitter_llama_tokenizer = None | |
streamer = None | |
# define state and the chat messages | |
def init_session_states(assistant_chat, user_chat): | |
if "user_msg_as_prompt" not in st.session_state: | |
st.session_state["user_msg_as_prompt"] = "" | |
user_chat = None | |
if "user_msg_as_prompt" in st.session_state: | |
user_chat = st.chat_message("user") | |
assistant_chat = st.chat_message("assistant") | |
if "greet" not in st.session_state: | |
st.session_state["greet"] = False | |
greeting_text = "Hello! I'm here to help. Copy and paste your customer's message, or generate using AI." | |
assistant_chat.write(greeting_text) | |
init_session_states(assistant_chat, user_chat) | |
# Generate Response Tweet | |
if user_chat: | |
if st.button("Generate Polite and Friendly Response"): | |
if "user_msg_as_prompt" in st.session_state: | |
customer_msg = st.session_state["user_msg_as_prompt"] | |
if customer_msg: | |
write_user_chat_message(user_chat, customer_msg) | |
model, tokenizer = get_model_tokenizer(sota_model_name) | |
input_text = alpaca_input_text_format.format(customer_msg) | |
st.markdown(f"""```\n{input_text}```""", unsafe_allow_html=True) | |
response_tweet = assistant_chat.write_stream( | |
stream_generation( | |
input_text, | |
show_prompt=False, | |
tokenizer=tokenizer, | |
model=model, | |
) | |
) | |
else: | |
st.error("Please enter a customer message, or generate one for the ai to respond") | |
# main ui prompt | |
# - text box | |
# - submit | |
with st.form(key="my_form"): | |
prompt = st.text_area("Customer Message") | |
write_user_chat_message(user_chat, prompt) | |
if st.form_submit_button("Submit"): | |
assistant_chat.write("Hi, Human.") | |
# below ui prompt | |
# - examples | |
# st.markdown("<b>Example:</b>", unsafe_allow_html=True) | |
if st.button("your website is straight up garbage. how do you sell high end technology but you cant get a website right?"): | |
customer_msg = "your website is straight up garbage. how do you sell high end technology but you cant get a website right?" | |
st.session_state["user_msg_as_prompt"] = customer_msg | |
write_user_chat_message(user_chat, customer_msg) | |
model, tokenizer = get_model_tokenizer(sota_model_name) | |
input_text = alpaca_input_text_format.format(customer_msg) | |
st.write(f"```\n{input_text}```") | |
assistant_chat.write_stream( | |
stream_generation( | |
input_text, | |
show_prompt=False, | |
tokenizer=tokenizer, | |
model=model, | |
) | |
) | |
# - Generate Customer Tweet | |
if st.button("Generate Customer Message using Few Shots"): | |
max_seq_length = 2048 | |
dtype = torch.float16 | |
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False. | |
model, tokenizer = get_mistral_model_tokenizer(sota_model_name) | |
noun_picker = DeckPicker(nouns) | |
verb_picker = DeckPicker(verbs) | |
adjective_picker = DeckPicker(adjectives) | |
few_shots = create_few_shots(noun_picker, verb_picker, adjective_picker) | |
few_shot_prompt = f"<s>[INST]{few_shots}[/INST]\n" | |
st.markdown("Prompt:") | |
st.markdown(f"""```\n{few_shot_prompt}```""", unsafe_allow_html=True) | |
new_customer_msg = write_stream_user_chat_message(user_chat, model, tokenizer, few_shot_prompt) | |
st.session_state["user_msg_as_prompt"] = new_customer_msg | |
st.markdown("------------") | |
st.markdown("<p>Thanks to:</p>", unsafe_allow_html=True) | |
st.markdown("""Unsloth https://github.com/unslothai check out the [wiki](https://github.com/unslothai/unsloth/wiki)""") | |
st.markdown("""Georgi Gerganov's ggml https://github.com/ggerganov/ggml""") | |
st.markdown("""Meta's Llama https://github.com/meta-llama""") | |
st.markdown("""Mistral AI - https://github.com/mistralai""") | |
st.markdown("""Zhang Peiyuan's TinyLlama https://github.com/jzhang38/TinyLlama""") | |
st.markdown("""Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, | |
Xuechen Li, Carlos Guestrin, Percy Liang, Tatsunori B. Hashimoto | |
- [Alpaca: A Strong, Replicable Instruction-Following Model](https://crfm.stanford.edu/2023/03/13/alpaca.html)""") | |
if device == "cuda": | |
gpu_stats = torch.cuda.get_device_properties(0) | |
max_memory = gpu_stats.total_memory / 1024 ** 3 | |
start_gpu_memory = torch.cuda.memory_reserved(0) / 1024 ** 3 | |
st.write(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.") | |
st.write(f"{start_gpu_memory} GB of memory reserved.") | |
st.write("Packages:") | |
st.write(f"pytorch: {torch.__version__}") | |