saffr0n commited on
Commit
5d492b5
1 Parent(s): 574638c

Radical update with gr.Chatbot instead to actively append history

Browse files
Files changed (1) hide show
  1. app.py +37 -54
app.py CHANGED
@@ -8,7 +8,7 @@ import torch
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
 
10
  MAX_MAX_NEW_TOKENS = 1024
11
- DEFAULT_MAX_NEW_TOKENS = 512
12
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
13
 
14
  DESCRIPTION = """\
@@ -86,62 +86,45 @@ def generate(
86
  outputs.append(text)
87
  yield "".join(outputs)
88
 
89
- chat_interface = gr.ChatInterface(
90
- fn=generate,
91
- fill_height=True,
92
- additional_inputs=[
93
- gr.Textbox(label="System prompt", lines=6),
94
- gr.Slider(
95
- label="Max new tokens",
96
- minimum=1,
97
- maximum=MAX_MAX_NEW_TOKENS,
98
- step=1,
99
- value=DEFAULT_MAX_NEW_TOKENS,
100
- ),
101
- gr.Slider(
102
- label="Temperature",
103
- minimum=0.1,
104
- maximum=4.0,
105
- step=0.1,
106
- value=0.6,
107
- ),
108
- gr.Slider(
109
- label="Top-p (nucleus sampling)",
110
- minimum=0.05,
111
- maximum=1.0,
112
- step=0.05,
113
- value=0.9,
114
- ),
115
- gr.Slider(
116
- label="Top-k",
117
- minimum=1,
118
- maximum=1000,
119
- step=1,
120
- value=50,
121
- ),
122
- gr.Slider(
123
- label="Repetition penalty",
124
- minimum=1.0,
125
- maximum=2.0,
126
- step=0.05,
127
- value=1.2,
128
- ),
129
- ],
130
- stop_btn=None,
131
- examples=[
132
- ["நான் எப்படி வேகமாக தூங்க முடியும்?"],
133
- ["என் முதலாளி மிகவும் கட்டுப்படுத்துகிறார், நான் என்ன செய்ய வேண்டும்?"],
134
- ["திருமணத்திற்கு நான் என்ன அணிய வேண்டும்?"],
135
- ["வரலாற்றில் தெரிந்து கொள்ள வேண்டிய சில முக்கியமான காலங்கள் யாவை?"],
136
- ["நான் பணம் சம்பாதிக்க வேண்டும் ஆனால் வேடிக்கையாக இருக்க வேண்டும் என்றால் நல்ல தொழில் எது?"],
137
- ],
138
- )
139
-
140
 
141
  with gr.Blocks(css="style.css") as demo:
142
  gr.Markdown(DESCRIPTION)
143
- gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
144
- chat_interface.render()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  gr.Markdown(LICENSE)
146
 
147
  if __name__ == "__main__":
 
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
 
10
  MAX_MAX_NEW_TOKENS = 1024
11
+ DEFAULT_MAX_NEW_TOKENS = 256
12
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
13
 
14
  DESCRIPTION = """\
 
86
  outputs.append(text)
87
  yield "".join(outputs)
88
 
89
+ examples = [
90
+ ["நான் எப்படி வேகமாக தூங்க முடியும்?"],
91
+ ["என் முதலாளி மிகவும் கட்டுப்படுத்துகிறார், நான் என்ன செய்ய வேண்டும்?"],
92
+ ["திருமணத்திற்கு நான் என்ன அணிய வேண்டும்?"],
93
+ ["வரலாற்றில் தெரிந்து கொள்ள வேண்டிய சில முக்கியமான காலங்கள் யாவை?"],
94
+ ["நான் பணம் சம்பாதிக்க வேண்டும் ஆனால் வேடிக்கையாக இருக்க வேண்டும் என்றால் நல்ல தொழில் எது?"],
95
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  with gr.Blocks(css="style.css") as demo:
98
  gr.Markdown(DESCRIPTION)
99
+
100
+ chatbot = gr.Chatbot()
101
+ msg = gr.Textbox(label="Enter your message")
102
+ clear = gr.Button("Clear")
103
+
104
+ def user(user_message, history):
105
+ return "", history + [[user_message, None]]
106
+
107
+ def bot(history, system_prompt, max_new_tokens, temperature, top_p, top_k, repetition_penalty):
108
+ user_message = history[-1][0]
109
+ chat_history = [(msg[0], msg[1]) for msg in history[:-1]]
110
+ bot_message = ""
111
+ for response in generate(user_message, chat_history, system_prompt, max_new_tokens, temperature, top_p, top_k, repetition_penalty):
112
+ bot_message = response
113
+ history[-1][1] = bot_message
114
+ yield history
115
+
116
+ msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
117
+ bot,
118
+ [chatbot, gr.Textbox(label="System prompt", lines=6, value=SYSTEM_PROMPT),
119
+ gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS),
120
+ gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6),
121
+ gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
122
+ gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50),
123
+ gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)],
124
+ chatbot,
125
+ )
126
+ clear.click(lambda: None, None, chatbot, queue=False)
127
+
128
  gr.Markdown(LICENSE)
129
 
130
  if __name__ == "__main__":