jchwenger
app | no timeout
57543bf
raw
history blame contribute delete
No virus
3.82 kB
# Adapted from the Gradio tutorials:
# https://www.gradio.app/guides/creating-a-chatbot-fast#example-using-a-local-open-source-llm-with-hugging-face
import gradio as gr
import torch
# Get cpu, gpu or mps device for training.
# See: https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html#creating-models
device = (
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM
from transformers import StoppingCriteria
from transformers import StoppingCriteriaList
from transformers import TextIteratorStreamer
from threading import Thread
MODEL_ID = "togethercomputer/RedPajama-INCITE-Chat-3B-v1"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.float16)
model = model.to(device) # move model to GPU
class StopOnTokens(StoppingCriteria):
"""
Class used `stopping_criteria` in `generate_kwargs` that provides an additional
way of stopping the generation loop (if this class returns `True` on a token,
the generation is stopped)).
"""
# note: Python now supports type hints, see this: https://realpython.com/lessons/type-hinting/
# (for the **kwargs see also: https://realpython.com/python-kwargs-and-args/)
# this could also be written: def __call__(self, input_ids, scores, **kwargs):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
stop_ids = [29, 0] # see the cell below to understand where these come from
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id:
return True
return False
def predict(message, history):
history_transformer_format = history + [[message, ""]]
stop = StopOnTokens()
# useful to debug
# msg = "history"
# print(msg)
# print(*history_transformer_format, sep="\n")
# print("***")
# at each step, we feed the entire history in string format,
# restoring the format used in their dataset with new lines
# and <human>: or <bot>: added before the messages
messages = "".join(
["".join(
["\n<human>:"+item[0], "\n<bot>:"+item[1]]
)
for item in history_transformer_format]
)
# # to see what we feed to our net:
# msg = "string prompt"
# print(msg)
# print("-" * len(msg))
# print(messages)
# print("-" * 40)
# convert the string into tensors & move to GPU
model_inputs = tokenizer([messages], return_tensors="pt").to(device)
streamer = TextIteratorStreamer(
tokenizer,
# timeout=30., # no timeout until I implement error handling for the empty stream
skip_prompt=True,
skip_special_tokens=True
)
generate_kwargs = dict(
model_inputs,
streamer=streamer,
max_new_tokens=1024,
do_sample=True,
top_p=0.95,
top_k=1000,
temperature=1.0,
pad_token_id=tokenizer.eos_token_id, # mute annoying warning: https://stackoverflow.com/a/71397707
num_beams=1, # this is for beam search (disabled), see: https://huggingface.co/blog/how-to-generate#beam-search
stopping_criteria=StoppingCriteriaList([stop])
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
partial_message = ""
for new_token in streamer:
# seen the format <human>: and \n<bot> above (when 'messages' is defined)?
# we stream the message *until* we encounter '<', which is by the end
if new_token != '<':
partial_message += new_token
yield partial_message
gr.ChatInterface(predict).queue().launch(debug=True, share=True)