LouisStability ysharma HF staff commited on
Commit
14e5da3
1 Parent(s): 015885c

updated and cleaned up the streaming code (#15)

Browse files

- updated and cleaned up the streaming code (63c0b01fef5a1fddd405ffccfe51a81e493ca700)


Co-authored-by: yuvraj sharma <[email protected]>

Files changed (1) hide show
  1. app.py +17 -16
app.py CHANGED
@@ -29,18 +29,18 @@ class StopOnTokens(StoppingCriteria):
29
  return True
30
  return False
31
 
32
- def user(user_message, history):
 
33
  history = history + [[user_message, ""]]
34
- return "", history, history
35
-
36
-
37
- def bot(history, curr_system_message):
38
  stop = StopOnTokens()
 
 
39
  messages = curr_system_message + \
40
  "".join(["".join(["<|USER|>"+item[0], "<|ASSISTANT|>"+item[1]])
41
  for item in history])
42
 
43
- #model_inputs = tok([messages], return_tensors="pt")['input_ids'].cuda()[:, :4096-1024]
44
  model_inputs = tok([messages], return_tensors="pt").to("cuda")
45
  streamer = TextIteratorStreamer(tok, timeout=10., skip_prompt=True, skip_special_tokens=True)
46
  generate_kwargs = dict(
@@ -58,16 +58,18 @@ def bot(history, curr_system_message):
58
  t.start()
59
 
60
  print(history)
 
 
61
  for new_text in streamer:
62
  print(new_text)
63
- history[-1][1] += new_text
64
- yield history, history
65
-
66
- return history, history
67
 
68
 
69
  with gr.Blocks() as demo:
70
- history = gr.State([])
71
  gr.Markdown("## StableLM-Tuned-Alpha-7b Chat")
72
  gr.HTML('''<center><a href="https://huggingface.co/spaces/stabilityai/stablelm-tuned-alpha-chat?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>Duplicate the Space to skip the queue and run in a private space</center>''')
73
  chatbot = gr.Chatbot().style(height=500)
@@ -81,10 +83,9 @@ with gr.Blocks() as demo:
81
  system_msg = gr.Textbox(
82
  start_message, label="System Message", interactive=False, visible=False)
83
 
84
- msg.submit(fn=user, inputs=[msg, history], outputs=[msg, chatbot, history], queue=False).then(
85
- fn=bot, inputs=[chatbot, system_msg], outputs=[chatbot, history], queue=True)
86
- submit.click(fn=user, inputs=[msg, history], outputs=[msg, chatbot, history], queue=False).then(
87
- fn=bot, inputs=[chatbot, system_msg], outputs=[chatbot, history], queue=True)
88
- clear.click(lambda: [None, []], None, [chatbot, history], queue=False)
89
  demo.queue(concurrency_count=2)
90
  demo.launch()
 
29
  return True
30
  return False
31
 
32
+ def chat(curr_system_message, user_message, history):
33
+ # Append the user's message to the conversation history
34
  history = history + [[user_message, ""]]
35
+ # Initialize a StopOnTokens object
 
 
 
36
  stop = StopOnTokens()
37
+
38
+ # Construct the input message string for the model by concatenating the current system message and conversation history
39
  messages = curr_system_message + \
40
  "".join(["".join(["<|USER|>"+item[0], "<|ASSISTANT|>"+item[1]])
41
  for item in history])
42
 
43
+ # Tokenize the messages string
44
  model_inputs = tok([messages], return_tensors="pt").to("cuda")
45
  streamer = TextIteratorStreamer(tok, timeout=10., skip_prompt=True, skip_special_tokens=True)
46
  generate_kwargs = dict(
 
58
  t.start()
59
 
60
  print(history)
61
+ # Initialize an empty string to store the generated text
62
+ partial_text = ""
63
  for new_text in streamer:
64
  print(new_text)
65
+ partial_text += new_text
66
+ history[-1][1] = partial_text
67
+ # Yield an empty string to cleanup the message textbox and the updated conversation history
68
+ yield "", history
69
 
70
 
71
  with gr.Blocks() as demo:
72
+ #history = gr.State([])
73
  gr.Markdown("## StableLM-Tuned-Alpha-7b Chat")
74
  gr.HTML('''<center><a href="https://huggingface.co/spaces/stabilityai/stablelm-tuned-alpha-chat?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>Duplicate the Space to skip the queue and run in a private space</center>''')
75
  chatbot = gr.Chatbot().style(height=500)
 
83
  system_msg = gr.Textbox(
84
  start_message, label="System Message", interactive=False, visible=False)
85
 
86
+ msg.submit(fn=chat, inputs=[system_msg, msg, chatbot], outputs=[msg, chatbot], queue=True)
87
+ submit.click(fn=chat, inputs=[system_msg, msg, chatbot], outputs=[msg, chatbot], queue=True)
88
+ clear.click(lambda: [None, []], None, [chatbot], queue=False)
89
+
 
90
  demo.queue(concurrency_count=2)
91
  demo.launch()