SandLogicTechnologies commited on
Commit
7e59f2d
1 Parent(s): 2e33ec7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -229
app.py CHANGED
@@ -1,43 +1,17 @@
1
  import os
2
  from threading import Thread
3
  from typing import Iterator
 
4
  import gradio as gr
5
  import spaces
6
  import torch
7
  import json
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
 
10
- DESCRIPTION = """\
11
- Shakti LLMs (Large Language Models) are a group of compact language models specifically optimized for resource-constrained environments such as edge devices, including smartphones, wearables, and IoT (Internet of Things) systems. These models provide support for vernacular languages and domain-specific tasks, making them particularly suitable for industries such as healthcare, finance, and customer service.
12
- For more details, please check [here](https://arxiv.org/pdf/2410.11331v1)
13
- """
14
-
15
-
16
- # """\
17
- # Shakti LLMs are a group of small language model specifically optimized for resource-constrained environments such as edge devices, including smartphones, wearables, and IoT systems. With support for vernacular languages and domain-specific tasks, Shakti excels in industries such as healthcare, finance, and customer service.
18
- # For more details, please check [here](https://arxiv.org/pdf/2410.11331v1).
19
- # """
20
-
21
-
22
- # Custom CSS for the send button
23
- CUSTOM_CSS = """
24
- .send-btn {
25
- padding: 0.5rem !important;
26
- width: 55px !important;
27
- height: 55px !important;
28
- border-radius: 50% !important;
29
- margin-top: 1rem;
30
- cursor: pointer;
31
- }
32
 
33
- .send-btn svg {
34
- width: 20px !important;
35
- height: 20px !important;
36
- position: absolute;
37
- top: 50%;
38
- left: 50%;
39
- transform: translate(-50%, -50%);
40
- }
41
  """
42
 
43
  MAX_MAX_NEW_TOKENS = 2048
@@ -46,63 +20,37 @@ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "2048"))
46
 
47
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
48
 
49
- # Model configurations
50
- model_options = {
51
- "Shakti-100M": "SandLogicTechnologies/Shakti-100M",
52
- "Shakti-250M": "SandLogicTechnologies/Shakti-250M",
53
- "Shakti-2.5B": "SandLogicTechnologies/Shakti-2.5B"
54
- }
55
-
56
- # Initialize tokenizer and model variables
57
- tokenizer = None
58
- model = None
59
- current_model = "Shakti-2.5B" # Keep track of current model
60
 
61
-
62
- def load_model(selected_model: str):
63
- global tokenizer, model, current_model
64
- model_id = model_options[selected_model]
65
- tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.getenv("SHAKTI"))
66
- model = AutoModelForCausalLM.from_pretrained(
67
- model_id,
68
- device_map="auto",
69
- torch_dtype=torch.bfloat16,
70
- token=os.getenv("SHAKTI")
71
- )
72
- model.eval()
73
- print("Selected Model: ", selected_model)
74
- current_model = selected_model
75
-
76
-
77
- # Initial model load
78
- load_model("Shakti-2.5B")
79
 
80
 
 
81
  def generate(
82
- message: str,
83
- chat_history: list[tuple[str, str]],
84
- max_new_tokens: int = 1024,
85
- temperature: float = 0.6,
86
- top_p: float = 0.9,
87
- top_k: int = 50,
88
- repetition_penalty: float = 1.2,
89
  ) -> Iterator[str]:
90
  conversation = []
91
-
92
- if current_model == "Shakti-2.5B":
93
- for user, assistant in chat_history:
94
- conversation.extend([
95
  json.loads(os.getenv("PROMPT")),
96
  {"role": "user", "content": user},
97
  {"role": "assistant", "content": assistant},
98
- ])
99
- else:
100
- for user, assistant in chat_history:
101
- conversation.extend([
102
- {"role": "user", "content": user},
103
- {"role": "assistant", "content": assistant},
104
- ])
105
-
106
  conversation.append({"role": "user", "content": message})
107
 
108
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
@@ -132,167 +80,56 @@ def generate(
132
  yield "".join(outputs)
133
 
134
 
135
- def respond(message, chat_history, max_new_tokens, temperature):
136
- bot_message = ""
137
- for chunk in generate(message, chat_history, max_new_tokens, temperature):
138
- bot_message += chunk
139
- chat_history.append((message, bot_message))
140
- return "", chat_history
141
-
142
-
143
- def get_examples(selected_model):
144
- examples = {
145
- "Shakti-100M": [
146
- ["Tell me a story"],
147
- ["Write a short poem on Rose"],
148
- ["What are computers"]
149
- ],
150
- "Shakti-250M": [
151
- ["Can you explain the pathophysiology of hypertension and its impact on the cardiovascular system?"],
152
- ["What are the potential side effects of beta-blockers in the treatment of arrhythmias?"],
153
- ["What foods are good for boosting the immune system?"],
154
- ["What is the difference between a stock and a bond?"],
155
- ["How can I start saving for retirement?"],
156
- ["What are some low-risk investment options?"]
157
- ],
158
- "Shakti-2.5B": [
159
- ["Tell me a story"],
160
- ["write a short poem which is hard to sing"],
161
- ['मुझे भारतीय इतिहास के बारे में बताएं']
162
- ]
163
- }
164
- return examples.get(selected_model, [])
165
-
166
-
167
- def on_model_select(selected_model):
168
- load_model(selected_model) # Load the selected model
169
- # Return the message and chat history updates
170
- return gr.update(value=""), gr.update(value=[]) # Clear message and chat history
171
-
172
-
173
- def update_examples_visibility(selected_model):
174
- # Return individual updates for each example section
175
- return (
176
- gr.update(visible=selected_model == "Shakti-100M"),
177
- gr.update(visible=selected_model == "Shakti-250M"),
178
- gr.update(visible=selected_model == "Shakti-2.5B")
179
- )
180
-
181
-
182
- def example_selector(example):
183
- return example
184
-
185
-
186
- with gr.Blocks(css=CUSTOM_CSS) as demo:
187
- gr.Markdown(DESCRIPTION)
188
-
189
- with gr.Row():
190
- model_dropdown = gr.Dropdown(
191
- label="Select Model",
192
- choices=list(model_options.keys()),
193
- value="Shakti-2.5B",
194
- interactive=True
195
- )
196
-
197
- chatbot = gr.Chatbot()
198
-
199
- with gr.Row():
200
- with gr.Column(scale=20):
201
- msg = gr.Textbox(
202
- label="Message",
203
- placeholder="Enter your message here",
204
- lines=2,
205
- show_label=False
206
- )
207
- with gr.Column(scale=1, min_width=50):
208
- send_btn = gr.Button(
209
- value="➤",
210
- variant="primary",
211
- elem_classes=["send-btn"]
212
- )
213
-
214
- with gr.Accordion("Parameters", open=False):
215
- max_tokens_slider = gr.Slider(
216
  label="Max new tokens",
217
  minimum=1,
218
  maximum=MAX_MAX_NEW_TOKENS,
219
  step=1,
220
  value=DEFAULT_MAX_NEW_TOKENS,
221
- )
222
- temperature_slider = gr.Slider(
223
  label="Temperature",
224
  minimum=0.1,
225
  maximum=4.0,
226
  step=0.1,
227
  value=0.6,
228
- )
229
-
230
- # Add submit action handlers
231
- submit_click = send_btn.click(
232
- respond,
233
- inputs=[msg, chatbot, max_tokens_slider, temperature_slider],
234
- outputs=[msg, chatbot]
235
- )
236
-
237
- submit_enter = msg.submit(
238
- respond,
239
- inputs=[msg, chatbot, max_tokens_slider, temperature_slider],
240
- outputs=[msg, chatbot]
241
- )
242
-
243
- # Create separate example sections for each model
244
- with gr.Row():
245
- with gr.Column(visible=False) as examples_100m:
246
- gr.Examples(
247
- examples=get_examples("Shakti-100M"),
248
- inputs=msg,
249
- label="Example prompts for Shakti-100M",
250
- fn=example_selector
251
- )
252
-
253
- with gr.Column(visible=False) as examples_250m:
254
- gr.Examples(
255
- examples=get_examples("Shakti-250M"),
256
- inputs=msg,
257
- label="Example prompts for Shakti-250M",
258
- fn=example_selector
259
- )
260
-
261
- with gr.Column(visible=True) as examples_2_5b:
262
- gr.Examples(
263
- examples=get_examples("Shakti-2.5B"),
264
- inputs=msg,
265
- label="Example prompts for Shakti-2.5B",
266
- fn=example_selector
267
- )
268
-
269
-
270
- # Update model selection and examples visibility
271
- def combined_update(selected_model):
272
- msg_update, chat_update = on_model_select(selected_model)
273
- examples_100m_update, examples_250m_update, examples_2_5b_update = update_examples_visibility(
274
- selected_model)
275
- return [
276
- msg_update,
277
- chat_update,
278
- examples_100m_update,
279
- examples_250m_update,
280
- examples_2_5b_update
281
- ]
282
-
283
-
284
- # Updated change event handler
285
- model_dropdown.change(
286
- combined_update,
287
- inputs=[model_dropdown],
288
- outputs=[
289
- msg,
290
- chatbot,
291
- examples_100m,
292
- examples_250m,
293
- examples_2_5b
294
- ]
295
- )
296
 
297
  if __name__ == "__main__":
298
  demo.queue(max_size=20).launch()
 
1
  import os
2
  from threading import Thread
3
  from typing import Iterator
4
+
5
  import gradio as gr
6
  import spaces
7
  import torch
8
  import json
9
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ DESCRIPTION = """\
13
+ Shakti is a 2.5 billion parameter language model specifically optimized for resource-constrained environments such as edge devices, including smartphones, wearables, and IoT systems. With support for vernacular languages and domain-specific tasks, Shakti excels in industries such as healthcare, finance, and customer service
14
+ For more details, please check [here](https://arxiv.org/pdf/2410.11331v1).
 
 
 
 
 
15
  """
16
 
17
  MAX_MAX_NEW_TOKENS = 2048
 
20
 
21
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
22
 
23
+ model_id = "SandLogicTechnologies/Shakti-2.5B"
24
+ tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.getenv("SHAKTI"))
25
+ model = AutoModelForCausalLM.from_pretrained(
26
+ model_id,
27
+ device_map="auto",
28
+ torch_dtype=torch.bfloat16,
29
+ token=os.getenv("SHAKTI")
 
 
 
 
30
 
31
+ )
32
+ model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
 
35
+ @spaces.GPU(duration=90)
36
  def generate(
37
+ message: str,
38
+ chat_history: list[tuple[str, str]],
39
+ max_new_tokens: int = 1024,
40
+ temperature: float = 0.6,
41
+ top_p: float = 0.9,
42
+ top_k: int = 50,
43
+ repetition_penalty: float = 1.2,
44
  ) -> Iterator[str]:
45
  conversation = []
46
+ for user, assistant in chat_history:
47
+ conversation.extend(
48
+ [
 
49
  json.loads(os.getenv("PROMPT")),
50
  {"role": "user", "content": user},
51
  {"role": "assistant", "content": assistant},
52
+ ]
53
+ )
 
 
 
 
 
 
54
  conversation.append({"role": "user", "content": message})
55
 
56
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
 
80
  yield "".join(outputs)
81
 
82
 
83
+ chat_interface = gr.ChatInterface(
84
+ fn=generate,
85
+ additional_inputs=[
86
+ gr.Slider(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  label="Max new tokens",
88
  minimum=1,
89
  maximum=MAX_MAX_NEW_TOKENS,
90
  step=1,
91
  value=DEFAULT_MAX_NEW_TOKENS,
92
+ ),
93
+ gr.Slider(
94
  label="Temperature",
95
  minimum=0.1,
96
  maximum=4.0,
97
  step=0.1,
98
  value=0.6,
99
+ ),
100
+ # gr.Slider(
101
+ # label="Top-p (nucleus sampling)",
102
+ # minimum=0.05,
103
+ # maximum=1.0,
104
+ # step=0.05,
105
+ # value=0.9,
106
+ # ),
107
+ # gr.Slider(
108
+ # label="Top-k",
109
+ # minimum=1,
110
+ # maximum=1000,
111
+ # step=1,
112
+ # value=50,
113
+ # ),
114
+ # gr.Slider(
115
+ # label="Repetition penalty",
116
+ # minimum=1.0,
117
+ # maximum=2.0,
118
+ # step=0.05,
119
+ # value=1.2,
120
+ # ),
121
+ ],
122
+ stop_btn=None,
123
+ examples=[
124
+ ["Tell me a story"], ["write a short poem which is hard to sing"], ['मुझे भारतीय इतिहास के बारे में बताएं']
125
+ ],
126
+ cache_examples=False,
127
+ )
128
+
129
+ with gr.Blocks(css="style.css", fill_height=True) as demo:
130
+ gr.Markdown(DESCRIPTION)
131
+ gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
132
+ chat_interface.render()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
  if __name__ == "__main__":
135
  demo.queue(max_size=20).launch()