Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer | |
import time | |
import numpy as np | |
from torch.nn import functional as F | |
import os | |
from threading import Thread | |
title = "🦅Falcon 🗨️ChatBot" | |
description = "Falcon-RW-1B is a 1B parameters causal decoder-only model built by TII and trained on 350B tokens of RefinedWeb." | |
examples = [["How are you?"]] | |
tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-rw-1b") | |
model = AutoModelForCausalLM.from_pretrained( | |
"tiiuae/falcon-rw-1b", | |
trust_remote_code=True, | |
torch_dtype=torch.float16, | |
) | |
class StopOnTokens(StoppingCriteria): | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
stop_ids = [0] | |
for stop_id in stop_ids: | |
if input_ids[0][-1] == stop_id: | |
return True | |
return False | |
def user(message, history): | |
# Append the user's message to the conversation history | |
return "", history + [[message, ""]] | |
def chat(curr_system_message, history): | |
# Initialize a StopOnTokens object | |
stop = StopOnTokens() | |
# Construct the input message string for the model by concatenating the current system message and conversation history | |
messages = curr_system_message + \ | |
"".join(["".join(["<user>: "+item[0], "<chatbot>: "+item[1]]) | |
for item in history]) | |
# Tokenize the messages string | |
tokens = tokenizer([messages], return_tensors="pt") | |
streamer = TextIteratorStreamer( | |
tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True) | |
token_ids = tokens.input_ids | |
attention_mask=tokens.attention_mask | |
generate_kwargs = dict( | |
input_ids=token_ids, | |
attention_mask = attention_mask, | |
streamer = streamer, | |
max_length=2048, | |
do_sample=True, | |
num_return_sequences=1, | |
eos_token_id=tokenizer.eos_token_id, | |
temperature = 0.7, | |
stopping_criteria=StoppingCriteriaList([stop]) | |
) | |
t = Thread(target=model.generate, kwargs=generate_kwargs) | |
t.start() | |
#Initialize an empty string to store the generated text | |
partial_text = "" | |
for new_text in streamer: | |
# print(new_text) | |
partial_text += new_text | |
history[-1][1] = partial_text | |
# Yield an empty string to cleanup the message textbox and the updated conversation history | |
yield history | |
return partial_text | |
gr.ChatInterface(chat, | |
title=title, | |
description=description, | |
examples=examples, | |
cache_examples=True, | |
retry_btn=None, | |
undo_btn="Delete Previous", | |
clear_btn="Clear", | |
chatbot=gr.Chatbot(height=300), | |
textbox=gr.Textbox(placeholder="Chat with me")).queue().launch() |