# Adapted from the Gradio tutorials: # https://www.gradio.app/guides/creating-a-chatbot-fast#example-using-a-local-open-source-llm-with-hugging-face import gradio as gr import torch # Get cpu, gpu or mps device for training. # See: https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html#creating-models device = ( "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" ) from transformers import AutoTokenizer from transformers import AutoModelForCausalLM from transformers import StoppingCriteria from transformers import StoppingCriteriaList from transformers import TextIteratorStreamer from threading import Thread MODEL_ID = "togethercomputer/RedPajama-INCITE-Chat-3B-v1" tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.float16) model = model.to(device) # move model to GPU class StopOnTokens(StoppingCriteria): """ Class used `stopping_criteria` in `generate_kwargs` that provides an additional way of stopping the generation loop (if this class returns `True` on a token, the generation is stopped)). """ # note: Python now supports type hints, see this: https://realpython.com/lessons/type-hinting/ # (for the **kwargs see also: https://realpython.com/python-kwargs-and-args/) # this could also be written: def __call__(self, input_ids, scores, **kwargs): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: stop_ids = [29, 0] # see the cell below to understand where these come from for stop_id in stop_ids: if input_ids[0][-1] == stop_id: return True return False def predict(message, history): history_transformer_format = history + [[message, ""]] stop = StopOnTokens() # useful to debug # msg = "history" # print(msg) # print(*history_transformer_format, sep="\n") # print("***") # at each step, we feed the entire history in string format, # restoring the format used in their dataset with new lines # and : or : added before the messages messages = "".join( ["".join( ["\n:"+item[0], "\n:"+item[1]] ) for item in history_transformer_format] ) # # to see what we feed to our net: # msg = "string prompt" # print(msg) # print("-" * len(msg)) # print(messages) # print("-" * 40) # convert the string into tensors & move to GPU model_inputs = tokenizer([messages], return_tensors="pt").to(device) streamer = TextIteratorStreamer( tokenizer, # timeout=30., # no timeout until I implement error handling for the empty stream skip_prompt=True, skip_special_tokens=True ) generate_kwargs = dict( model_inputs, streamer=streamer, max_new_tokens=1024, do_sample=True, top_p=0.95, top_k=1000, temperature=1.0, pad_token_id=tokenizer.eos_token_id, # mute annoying warning: https://stackoverflow.com/a/71397707 num_beams=1, # this is for beam search (disabled), see: https://huggingface.co/blog/how-to-generate#beam-search stopping_criteria=StoppingCriteriaList([stop]) ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() partial_message = "" for new_token in streamer: # seen the format : and \n above (when 'messages' is defined)? # we stream the message *until* we encounter '<', which is by the end if new_token != '<': partial_message += new_token yield partial_message gr.ChatInterface(predict).queue().launch(debug=True, share=True)