SandLogicTechnologies commited on
Commit
bb12293
1 Parent(s): d816a8a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -57
app.py CHANGED
@@ -29,9 +29,11 @@ model_options = {
29
  # Initialize tokenizer and model variables
30
  tokenizer = None
31
  model = None
 
 
32
 
33
  def load_model(selected_model: str):
34
- global tokenizer, model
35
  model_id = model_options[selected_model]
36
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.getenv("SHAKTI"))
37
  model = AutoModelForCausalLM.from_pretrained(
@@ -41,29 +43,44 @@ def load_model(selected_model: str):
41
  token=os.getenv("SHAKTI")
42
  )
43
  model.eval()
 
 
44
 
45
  # Initial model load (default to 2.5B)
46
  load_model("Shakti-2.5B")
47
 
 
48
  @spaces.GPU(duration=90)
49
  def generate(
50
- message: str,
51
- chat_history: list[tuple[str, str]],
52
- max_new_tokens: int = 1024,
53
- temperature: float = 0.6,
54
- top_p: float = 0.9,
55
- top_k: int = 50,
56
- repetition_penalty: float = 1.2,
57
  ) -> Iterator[str]:
58
  conversation = []
59
- for user, assistant in chat_history:
60
- conversation.extend(
61
- [
62
- json.loads(os.getenv("PROMPT")),
63
- {"role": "user", "content": user},
64
- {"role": "assistant", "content": assistant},
65
- ]
66
- )
 
 
 
 
 
 
 
 
 
 
 
 
67
  conversation.append({"role": "user", "content": message})
68
 
69
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
@@ -92,51 +109,34 @@ def generate(
92
  outputs.append(text)
93
  yield "".join(outputs)
94
 
 
95
  def update_examples(selected_model):
96
  if selected_model == "Shakti-100M":
97
  return [["Tell me a story"],
98
- ["Write a short poem on Rose"],
99
- ["What are computers"]]
100
  elif selected_model == "Shakti-250M":
101
  return [["Can you explain the pathophysiology of hypertension and its impact on the cardiovascular system?"],
102
- ["What are the potential side effects of beta-blockers in the treatment of arrhythmias?"],
103
- ["What foods are good for boosting the immune system?"],
104
- ["What is the difference between a stock and a bond?"],
105
- ["How can I start saving for retirement?"],
106
- ["What are some low-risk investment options?"],
107
- ["What is a power of attorney and when is it used?"],
108
- ["What are the key differences between a will and a trust?"],
109
- ["How do I legally protect my business name?"]]
110
  else:
111
- return [["Tell me a story"], ["write a short poem which is hard to sing"], ['मुझे भारतीय इतिहास के बारे में बताएं']]
 
 
112
 
113
  def on_model_select(selected_model):
114
  load_model(selected_model) # Load the selected model
115
- return update_examples(selected_model) # Return new examples based on the selected model
116
-
117
-
118
- chat_interface = gr.ChatInterface(
119
- fn=generate,
120
- additional_inputs=[
121
- gr.Slider(
122
- label="Max new tokens",
123
- minimum=1,
124
- maximum=MAX_MAX_NEW_TOKENS,
125
- step=1,
126
- value=DEFAULT_MAX_NEW_TOKENS,
127
- ),
128
- gr.Slider(
129
- label="Temperature",
130
- minimum=0.1,
131
- maximum=4.0,
132
- step=0.1,
133
- value=0.6,
134
- ),
135
- ],
136
- stop_btn=None,
137
- examples=update_examples("Shakti-2.5B"), # Set initial examples for 2.5B model
138
- cache_examples=False,
139
- )
140
 
141
  with gr.Blocks(css="style.css", fill_height=True) as demo:
142
  gr.Markdown(DESCRIPTION)
@@ -150,10 +150,32 @@ with gr.Blocks(css="style.css", fill_height=True) as demo:
150
  interactive=True,
151
  )
152
 
153
- # Function to handle model change and update examples dynamically
154
- model_dropdown.change(on_model_select, inputs=model_dropdown, outputs=[chat_interface])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
- chat_interface.render()
 
157
 
158
- if __name__ == "__main__":
159
- demo.queue(max_size=20).launch()
 
29
  # Initialize tokenizer and model variables
30
  tokenizer = None
31
  model = None
32
+ current_model = "Shakti-2.5B" # Keep track of current model
33
+
34
 
35
  def load_model(selected_model: str):
36
+ global tokenizer, model, current_model
37
  model_id = model_options[selected_model]
38
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.getenv("SHAKTI"))
39
  model = AutoModelForCausalLM.from_pretrained(
 
43
  token=os.getenv("SHAKTI")
44
  )
45
  model.eval()
46
+ current_model = selected_model # Update the current model
47
+
48
 
49
  # Initial model load (default to 2.5B)
50
  load_model("Shakti-2.5B")
51
 
52
+
53
  @spaces.GPU(duration=90)
54
  def generate(
55
+ message: str,
56
+ chat_history: list[tuple[str, str]],
57
+ max_new_tokens: int = 1024,
58
+ temperature: float = 0.6,
59
+ top_p: float = 0.9,
60
+ top_k: int = 50,
61
+ repetition_penalty: float = 1.2,
62
  ) -> Iterator[str]:
63
  conversation = []
64
+
65
+ # Conditional logic for adding prompt based on model
66
+ if current_model == "Shakti-2.5B":
67
+ for user, assistant in chat_history:
68
+ conversation.extend(
69
+ [
70
+ json.loads(os.getenv("PROMPT")),
71
+ {"role": "user", "content": user},
72
+ {"role": "assistant", "content": assistant},
73
+ ]
74
+ )
75
+ else:
76
+ for user, assistant in chat_history:
77
+ conversation.extend(
78
+ [
79
+ {"role": "user", "content": user},
80
+ {"role": "assistant", "content": assistant},
81
+ ]
82
+ )
83
+
84
  conversation.append({"role": "user", "content": message})
85
 
86
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
 
109
  outputs.append(text)
110
  yield "".join(outputs)
111
 
112
+
113
  def update_examples(selected_model):
114
  if selected_model == "Shakti-100M":
115
  return [["Tell me a story"],
116
+ ["Write a short poem on Rose"],
117
+ ["What are computers"]]
118
  elif selected_model == "Shakti-250M":
119
  return [["Can you explain the pathophysiology of hypertension and its impact on the cardiovascular system?"],
120
+ ["What are the potential side effects of beta-blockers in the treatment of arrhythmias?"],
121
+ ["What foods are good for boosting the immune system?"],
122
+ ["What is the difference between a stock and a bond?"],
123
+ ["How can I start saving for retirement?"],
124
+ ["What are some low-risk investment options?"],
125
+ ["What is a power of attorney and when is it used?"],
126
+ ["What are the key differences between a will and a trust?"],
127
+ ["How do I legally protect my business name?"]]
128
  else:
129
+ return [["Tell me a story"], ["write a short poem which is hard to sing"],
130
+ ['मुझे भारतीय इतिहास के बारे में बताएं']]
131
+
132
 
133
  def on_model_select(selected_model):
134
  load_model(selected_model) # Load the selected model
135
+ examples = update_examples(selected_model) # Update examples
136
+ return gr.update(examples=examples), gr.update(value=[]) # Clear the chat space and update examples
137
+
138
+
139
+ chat_history = gr.Chatbot()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
  with gr.Blocks(css="style.css", fill_height=True) as demo:
142
  gr.Markdown(DESCRIPTION)
 
150
  interactive=True,
151
  )
152
 
153
+ # Create the interface with dynamic inputs and chat history
154
+ max_tokens_slider = gr.Slider(
155
+ label="Max new tokens",
156
+ minimum=1,
157
+ maximum=MAX_MAX_NEW_TOKENS,
158
+ step=1,
159
+ value=DEFAULT_MAX_NEW_TOKENS,
160
+ )
161
+
162
+ temperature_slider = gr.Slider(
163
+ label="Temperature",
164
+ minimum=0.1,
165
+ maximum=4.0,
166
+ step=0.1,
167
+ value=0.6,
168
+ )
169
+
170
+ chat_interface = gr.Interface(
171
+ fn=generate,
172
+ inputs=[gr.Textbox(lines=2, placeholder="Enter your message here"), chat_history, max_tokens_slider,
173
+ temperature_slider],
174
+ outputs=chat_history,
175
+ live=True,
176
+ )
177
 
178
+ # Function to handle model change and update examples dynamically
179
+ model_dropdown.change(on_model_select, inputs=model_dropdown, outputs=[chat_interface, chat_history])
180
 
181
+ demo.launch()