gabrielclark3330 commited on
Commit
2a15330
1 Parent(s): bcab068
Files changed (1) hide show
  1. app.py +0 -283
app.py CHANGED
@@ -1,286 +1,3 @@
1
- '''
2
- import os
3
- import gradio as gr
4
- from transformers import AutoTokenizer, AutoModelForCausalLM
5
- import torch
6
-
7
- model_name_2_7B_instruct = "Zyphra/Zamba2-2.7B-instruct"
8
- model_name_7B_instruct = "Zyphra/Zamba2-7B-instruct"
9
- max_context_length = 4096
10
-
11
- tokenizer_2_7B_instruct = AutoTokenizer.from_pretrained(model_name_2_7B_instruct)
12
- model_2_7B_instruct = AutoModelForCausalLM.from_pretrained(
13
- model_name_2_7B_instruct, device_map="cuda", torch_dtype=torch.bfloat16
14
- )
15
-
16
- tokenizer_7B_instruct = AutoTokenizer.from_pretrained(model_name_7B_instruct)
17
- model_7B_instruct = AutoModelForCausalLM.from_pretrained(
18
- model_name_7B_instruct, device_map="cuda", torch_dtype=torch.bfloat16
19
- )
20
-
21
- def extract_assistant_response(generated_text):
22
- assistant_token = '<|im_start|> assistant'
23
- end_token = '<|im_end|>'
24
- start_idx = generated_text.rfind(assistant_token)
25
- if start_idx == -1:
26
- # Assistant token not found
27
- return generated_text.strip()
28
- start_idx += len(assistant_token)
29
- end_idx = generated_text.find(end_token, start_idx)
30
- if end_idx == -1:
31
- # End token not found, return from start_idx to end
32
- return generated_text[start_idx:].strip()
33
- else:
34
- return generated_text[start_idx:end_idx].strip()
35
-
36
- def generate_response(chat_history, max_new_tokens, model, tokenizer):
37
- sample = []
38
- for turn in chat_history:
39
- if turn[0]:
40
- sample.append({'role': 'user', 'content': turn[0]})
41
- if turn[1]:
42
- sample.append({'role': 'assistant', 'content': turn[1]})
43
- chat_sample = tokenizer.apply_chat_template(sample, tokenize=False)
44
- input_ids = tokenizer(chat_sample, return_tensors='pt', add_special_tokens=False).to(model.device)
45
-
46
- max_new_tokens = int(max_new_tokens)
47
- max_input_length = max_context_length - max_new_tokens
48
- if input_ids['input_ids'].size(1) > max_input_length:
49
- input_ids['input_ids'] = input_ids['input_ids'][:, -max_input_length:]
50
- if 'attention_mask' in input_ids:
51
- input_ids['attention_mask'] = input_ids['attention_mask'][:, -max_input_length:]
52
-
53
- with torch.no_grad():
54
- outputs = model.generate(**input_ids, max_new_tokens=int(max_new_tokens), return_dict_in_generate=False, output_scores=False, use_cache=True, num_beams=1, do_sample=False)
55
- """
56
- outputs = model.generate(
57
- input_ids=input_ids,
58
- max_new_tokens=int(max_new_tokens),
59
- do_sample=True,
60
- use_cache=True,
61
- temperature=temperature,
62
- top_k=int(top_k),
63
- top_p=top_p,
64
- repetition_penalty=repetition_penalty,
65
- num_beams=int(num_beams),
66
- length_penalty=length_penalty,
67
- num_return_sequences=1
68
- )
69
- """
70
- generated_text = tokenizer.decode(outputs[0])
71
- assistant_response = extract_assistant_response(generated_text)
72
-
73
- del input_ids
74
- del outputs
75
- torch.cuda.empty_cache()
76
-
77
- return assistant_response
78
-
79
- with gr.Blocks() as demo:
80
- gr.Markdown("# Zamba2 Model Selector")
81
- with gr.Tabs():
82
- with gr.TabItem("7B Instruct Model"):
83
- gr.Markdown("### Zamba2-7B Instruct Model")
84
- with gr.Column():
85
- chat_history_7B_instruct = gr.State([])
86
- chatbot_7B_instruct = gr.Chatbot()
87
- message_7B_instruct = gr.Textbox(lines=2, placeholder="Enter your message...", label="Your Message")
88
- with gr.Accordion("Generation Parameters", open=False):
89
- max_new_tokens_7B_instruct = gr.Slider(50, 1000, step=50, value=500, label="Max New Tokens")
90
- # temperature_7B_instruct = gr.Slider(0.1, 1.5, step=0.1, value=0.2, label="Temperature")
91
- # top_k_7B_instruct = gr.Slider(1, 100, step=1, value=50, label="Top K")
92
- # top_p_7B_instruct = gr.Slider(0.1, 1.0, step=0.1, value=1.0, label="Top P")
93
- # repetition_penalty_7B_instruct = gr.Slider(1.0, 2.0, step=0.1, value=1.2, label="Repetition Penalty")
94
- # num_beams_7B_instruct = gr.Slider(1, 10, step=1, value=1, label="Number of Beams")
95
- # length_penalty_7B_instruct = gr.Slider(0.0, 2.0, step=0.1, value=1.0, label="Length Penalty")
96
-
97
- def user_message_7B_instruct(message, chat_history):
98
- chat_history = chat_history + [[message, None]]
99
- return gr.update(value=""), chat_history, chat_history
100
-
101
- def bot_response_7B_instruct(chat_history, max_new_tokens):
102
- response = generate_response(chat_history, max_new_tokens, model_7B_instruct, tokenizer_7B_instruct)
103
- chat_history[-1][1] = response
104
- return chat_history, chat_history
105
-
106
- send_button_7B_instruct = gr.Button("Send")
107
- send_button_7B_instruct.click(
108
- fn=user_message_7B_instruct,
109
- inputs=[message_7B_instruct, chat_history_7B_instruct],
110
- outputs=[message_7B_instruct, chat_history_7B_instruct, chatbot_7B_instruct]
111
- ).then(
112
- fn=bot_response_7B_instruct,
113
- inputs=[
114
- chat_history_7B_instruct,
115
- max_new_tokens_7B_instruct
116
- ],
117
- outputs=[chat_history_7B_instruct, chatbot_7B_instruct]
118
- )
119
- with gr.TabItem("2.7B Instruct Model"):
120
- gr.Markdown("### Zamba2-2.7B Instruct Model")
121
- with gr.Column():
122
- chat_history_2_7B_instruct = gr.State([])
123
- chatbot_2_7B_instruct = gr.Chatbot()
124
- message_2_7B_instruct = gr.Textbox(lines=2, placeholder="Enter your message...", label="Your Message")
125
- with gr.Accordion("Generation Parameters", open=False):
126
- max_new_tokens_2_7B_instruct = gr.Slider(50, 1000, step=50, value=500, label="Max New Tokens")
127
- # temperature_2_7B_instruct = gr.Slider(0.1, 1.5, step=0.1, value=0.2, label="Temperature")
128
- # top_k_2_7B_instruct = gr.Slider(1, 100, step=1, value=50, label="Top K")
129
- # top_p_2_7B_instruct = gr.Slider(0.1, 1.0, step=0.1, value=1.0, label="Top P")
130
- # repetition_penalty_2_7B_instruct = gr.Slider(1.0, 2.0, step=0.1, value=1.2, label="Repetition Penalty")
131
- # num_beams_2_7B_instruct = gr.Slider(1, 10, step=1, value=1, label="Number of Beams")
132
- # length_penalty_2_7B_instruct = gr.Slider(0.0, 2.0, step=0.1, value=1.0, label="Length Penalty")
133
-
134
- def user_message_2_7B_instruct(message, chat_history):
135
- chat_history = chat_history + [[message, None]]
136
- return gr.update(value=""), chat_history, chat_history
137
-
138
- def bot_response_2_7B_instruct(chat_history, max_new_tokens):
139
- response = generate_response(chat_history, max_new_tokens, model_2_7B_instruct, tokenizer_2_7B_instruct)
140
- chat_history[-1][1] = response
141
- return chat_history, chat_history
142
-
143
- send_button_2_7B_instruct = gr.Button("Send")
144
- send_button_2_7B_instruct.click(
145
- fn=user_message_2_7B_instruct,
146
- inputs=[message_2_7B_instruct, chat_history_2_7B_instruct],
147
- outputs=[message_2_7B_instruct, chat_history_2_7B_instruct, chatbot_2_7B_instruct]
148
- ).then(
149
- fn=bot_response_2_7B_instruct,
150
- inputs=[
151
- chat_history_2_7B_instruct,
152
- max_new_tokens_2_7B_instruct
153
- ],
154
- outputs=[chat_history_2_7B_instruct, chatbot_2_7B_instruct]
155
- )
156
-
157
- if __name__ == "__main__":
158
- demo.queue().launch(max_threads=1)
159
- '''
160
-
161
- '''
162
- import os
163
- import gradio as gr
164
- from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
165
- import torch
166
- import threading
167
- import re
168
-
169
- model_name_2_7B_instruct = "Zyphra/Zamba2-2.7B-instruct"
170
- model_name_7B_instruct = "Zyphra/Zamba2-7B-instruct"
171
- max_context_length = 4096
172
-
173
- tokenizer_2_7B_instruct = AutoTokenizer.from_pretrained(model_name_2_7B_instruct)
174
- model_2_7B_instruct = AutoModelForCausalLM.from_pretrained(
175
- model_name_2_7B_instruct, device_map="cuda", torch_dtype=torch.bfloat16
176
- )
177
-
178
- tokenizer_7B_instruct = AutoTokenizer.from_pretrained(model_name_7B_instruct)
179
- model_7B_instruct = AutoModelForCausalLM.from_pretrained(
180
- model_name_7B_instruct, device_map="cuda", torch_dtype=torch.bfloat16
181
- )
182
-
183
- def generate_response(chat_history, max_new_tokens, model, tokenizer):
184
- sample = []
185
- for turn in chat_history:
186
- if turn[0]:
187
- sample.append({'role': 'user', 'content': turn[0]})
188
- if turn[1]:
189
- sample.append({'role': 'assistant', 'content': turn[1]})
190
- chat_sample = tokenizer.apply_chat_template(sample, tokenize=False)
191
- input_ids = tokenizer(chat_sample, return_tensors='pt', add_special_tokens=False).to(model.device)
192
-
193
- max_new_tokens = int(max_new_tokens)
194
- max_input_length = max_context_length - max_new_tokens
195
- if input_ids['input_ids'].size(1) > max_input_length:
196
- input_ids['input_ids'] = input_ids['input_ids'][:, -max_input_length:]
197
- if 'attention_mask' in input_ids:
198
- input_ids['attention_mask'] = input_ids['attention_mask'][:, -max_input_length:]
199
-
200
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
201
- generation_kwargs = dict(**input_ids, max_new_tokens=int(max_new_tokens), streamer=streamer)
202
-
203
- thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
204
- thread.start()
205
-
206
- assistant_response = ""
207
-
208
- for new_text in streamer:
209
- new_text = re.sub(r'^\s*(?i:assistant)[:\s]*', '', new_text)
210
- assistant_response += new_text
211
- yield assistant_response
212
-
213
- thread.join()
214
- del input_ids
215
- torch.cuda.empty_cache()
216
-
217
- with gr.Blocks() as demo:
218
- gr.Markdown("# Zamba2 Model Selector")
219
- with gr.Tabs():
220
- with gr.TabItem("7B Instruct Model"):
221
- gr.Markdown("### Zamba2-7B Instruct Model")
222
- with gr.Column():
223
- chat_history_7B_instruct = gr.State([])
224
- chatbot_7B_instruct = gr.Chatbot()
225
- message_7B_instruct = gr.Textbox(lines=2, placeholder="Enter your message...", label="Your Message")
226
- with gr.Accordion("Generation Parameters", open=False):
227
- max_new_tokens_7B_instruct = gr.Slider(50, 1000, step=50, value=500, label="Max New Tokens")
228
-
229
- def user_message_7B_instruct(message, chat_history):
230
- chat_history = chat_history + [[message, None]]
231
- return gr.update(value=""), chat_history, chat_history
232
-
233
- def bot_response_7B_instruct(chat_history, max_new_tokens):
234
- assistant_response_generator = generate_response(chat_history, max_new_tokens, model_7B_instruct, tokenizer_7B_instruct)
235
- for assistant_response in assistant_response_generator:
236
- chat_history[-1][1] = assistant_response
237
- yield chat_history
238
-
239
- send_button_7B_instruct = gr.Button("Send")
240
- send_button_7B_instruct.click(
241
- fn=user_message_7B_instruct,
242
- inputs=[message_7B_instruct, chat_history_7B_instruct],
243
- outputs=[message_7B_instruct, chat_history_7B_instruct, chatbot_7B_instruct]
244
- ).then(
245
- fn=bot_response_7B_instruct,
246
- inputs=[chat_history_7B_instruct, max_new_tokens_7B_instruct],
247
- outputs=chatbot_7B_instruct,
248
- )
249
-
250
- with gr.TabItem("2.7B Instruct Model"):
251
- gr.Markdown("### Zamba2-2.7B Instruct Model")
252
- with gr.Column():
253
- chat_history_2_7B_instruct = gr.State([])
254
- chatbot_2_7B_instruct = gr.Chatbot()
255
- message_2_7B_instruct = gr.Textbox(lines=2, placeholder="Enter your message...", label="Your Message")
256
- with gr.Accordion("Generation Parameters", open=False):
257
- max_new_tokens_2_7B_instruct = gr.Slider(50, 1000, step=50, value=500, label="Max New Tokens")
258
-
259
- def user_message_2_7B_instruct(message, chat_history):
260
- chat_history = chat_history + [[message, None]]
261
- return gr.update(value=""), chat_history, chat_history
262
-
263
- def bot_response_2_7B_instruct(chat_history, max_new_tokens):
264
- assistant_response_generator = generate_response(chat_history, max_new_tokens, model_2_7B_instruct, tokenizer_2_7B_instruct)
265
- for assistant_response in assistant_response_generator:
266
- chat_history[-1][1] = assistant_response
267
- yield chat_history
268
-
269
- send_button_2_7B_instruct = gr.Button("Send")
270
- send_button_2_7B_instruct.click(
271
- fn=user_message_2_7B_instruct,
272
- inputs=[message_2_7B_instruct, chat_history_2_7B_instruct],
273
- outputs=[message_2_7B_instruct, chat_history_2_7B_instruct, chatbot_2_7B_instruct]
274
- ).then(
275
- fn=bot_response_2_7B_instruct,
276
- inputs=[chat_history_2_7B_instruct, max_new_tokens_2_7B_instruct],
277
- outputs=chatbot_2_7B_instruct,
278
- )
279
-
280
- if __name__ == "__main__":
281
- demo.queue().launch(max_threads=1)
282
- '''
283
-
284
  import os
285
  import gradio as gr
286
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import gradio as gr
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer