lightmate commited on
Commit
10be371
1 Parent(s): a2c455a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -9
app.py CHANGED
@@ -118,17 +118,82 @@ def convert_history_to_token(history: List[Tuple[str, str]]):
118
  input_token = tok(text, return_tensors="pt", **tokenizer_kwargs).input_ids
119
  return input_token
120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  def bot(history, temperature, top_p, top_k, repetition_penalty, conversation_id):
122
- # Callback function for running chatbot on submit button click
123
- input_ids = convert_history_to_token(history)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  if input_ids.shape[1] > 2000:
 
125
  history = [history[-1]]
126
- input_ids = convert_history_to_token(history)
127
 
128
- streamer = TextIteratorStreamer(tok, timeout=3600.0, skip_prompt=True, skip_special_tokens=True)
 
 
129
  generate_kwargs = dict(
130
  input_ids=input_ids,
131
- max_new_tokens=256,
132
  temperature=temperature,
133
  do_sample=temperature > 0.0,
134
  top_p=top_p,
@@ -136,23 +201,32 @@ def bot(history, temperature, top_p, top_k, repetition_penalty, conversation_id)
136
  repetition_penalty=repetition_penalty,
137
  streamer=streamer,
138
  )
 
 
 
 
 
139
  stream_complete = Event()
140
 
141
  def generate_and_signal_complete():
142
  ov_model.generate(**generate_kwargs)
143
  stream_complete.set()
144
 
145
- Thread(target=generate_and_signal_complete).start()
 
 
 
146
  partial_text = ""
147
  for new_text in streamer:
148
- partial_text += new_text
149
- history[-1][1] = partial_text
 
150
  yield history
151
 
152
  def request_cancel():
153
  ov_model.request.cancel()
154
 
155
  # Gradio setup and launch
156
- demo = make_demo(run_fn=bot, stop_fn=request_cancel, title=f"OpenVINO {model_id_value} Chatbot", language=model_language_value)
157
  if __name__ == "__main__":
158
  demo.launch(debug=True, share=True, server_name="0.0.0.0", server_port=7860)
 
118
  input_token = tok(text, return_tensors="pt", **tokenizer_kwargs).input_ids
119
  return input_token
120
 
121
+ # Initialize the search tool
122
+ search = DuckDuckGoSearchRun()
123
+
124
+ # Function to retrieve and format search results based on user input
125
+ def fetch_search_results(query: str) -> str:
126
+ search_results = search.invoke(query)
127
+ # Displaying search results for debugging
128
+ print("Search results: ", search_results)
129
+ return f"Relevant and recent information:\n{search_results}"
130
+
131
+ # Function to decide if a search is needed based on the user query
132
+ def should_use_search(query: str) -> bool:
133
+ # Simple heuristic, can be extended with more advanced intent analysis
134
+ search_keywords = ["latest", "news", "update", "which" "who", "what", "when", "why","how", "recent", "result", "tell", "explain",
135
+ "announcement", "bulletin", "report", "brief", "insight", "disclosure", "update",
136
+ "release", "memo", "headline", "current", "ongoing", "fresh", "upcoming", "immediate",
137
+ "recently", "new", "now", "in-progress", "inquiry", "query", "ask", "investigate",
138
+ "explore", "seek", "clarify", "confirm", "discover", "learn", "describe", "define",
139
+ "illustrate", "outline", "interpret", "expound", "detail", "summarize", "elucidate",
140
+ "break down", "outcome", "effect", "consequence", "finding", "achievement", "conclusion",
141
+ "product", "performance", "resolution"
142
+ ]
143
+ return any(keyword in query.lower() for keyword in search_keywords)
144
+
145
+ # Generate prompt for model with optional search context
146
+ def construct_model_prompt(user_query: str, search_context: str, history: List[Tuple[str, str]]) -> str:
147
+ # Simple instruction for the model to prioritize search information if available
148
+ instructions = (
149
+ "If relevant information is provided below, use it to give an accurate and concise answer. If there is no relevant information available, please rely on your general knowledge and indicate that no recent or specific information is available to answer."
150
+ )
151
+
152
+ # Build the prompt with instructions, search context, and user query
153
+ prompt = f"{instructions}\n\n"
154
+ if search_context:
155
+ prompt += f"{search_context}\n\n" # Include search context prominently at the top
156
+
157
+ # Add the user's query
158
+ prompt += f"{user_query} ?\n\n"
159
+
160
+ # Optionally add recent history for context, without labels
161
+ # if history:
162
+ # prompt += "Recent conversation:\n"
163
+ # for user_msg, assistant_msg in history[:-1]: # Exclude the last message to prevent duplication
164
+ # prompt += f"{user_msg}\n{assistant_msg}\n"
165
+
166
+ return prompt
167
+
168
+
169
  def bot(history, temperature, top_p, top_k, repetition_penalty, conversation_id):
170
+ """
171
+ Main callback function for running chatbot on submit button click.
172
+ """
173
+ user_query = history[-1][0]
174
+ search_context = ""
175
+
176
+ # Decide if search is required based on the user query
177
+ if should_use_search(user_query):
178
+ search_context = fetch_search_results(user_query)
179
+ prompt = construct_model_prompt(user_query, search_context, history)
180
+ input_ids = tok(prompt, return_tensors="pt", truncation=True, max_length=2500).input_ids
181
+ else:
182
+ # If no search context, use the original logic with tokenization
183
+ prompt = construct_model_prompt(user_query, "", history)
184
+ input_ids = convert_history_to_token(history)
185
+
186
+ # Ensure input length does not exceed a threshold (e.g., 2000 tokens)
187
  if input_ids.shape[1] > 2000:
188
+ # If input exceeds the limit, only use the most recent conversation
189
  history = [history[-1]]
 
190
 
191
+ # Streamer for model response generation
192
+ streamer = TextIteratorStreamer(tok, timeout=4600.0, skip_prompt=True, skip_special_tokens=True)
193
+
194
  generate_kwargs = dict(
195
  input_ids=input_ids,
196
+ max_new_tokens=256, # Adjust this as needed
197
  temperature=temperature,
198
  do_sample=temperature > 0.0,
199
  top_p=top_p,
 
201
  repetition_penalty=repetition_penalty,
202
  streamer=streamer,
203
  )
204
+
205
+ if stop_tokens is not None:
206
+ generate_kwargs["stopping_criteria"] = StoppingCriteriaList(stop_tokens)
207
+
208
+ # Event to signal when streaming is complete
209
  stream_complete = Event()
210
 
211
  def generate_and_signal_complete():
212
  ov_model.generate(**generate_kwargs)
213
  stream_complete.set()
214
 
215
+ t1 = Thread(target=generate_and_signal_complete)
216
+ t1.start()
217
+
218
+ # Initialize an empty string to store the generated text
219
  partial_text = ""
220
  for new_text in streamer:
221
+ partial_text = text_processor(partial_text, new_text)
222
+ # Update the last entry in the original history with the response
223
+ history[-1] = (user_query, partial_text)
224
  yield history
225
 
226
  def request_cancel():
227
  ov_model.request.cancel()
228
 
229
  # Gradio setup and launch
230
+ demo = make_demo(run_fn=bot, stop_fn=request_cancel, title=f"OpenVINO Search & Reasoning Chatbot", language=model_language_value)
231
  if __name__ == "__main__":
232
  demo.launch(debug=True, share=True, server_name="0.0.0.0", server_port=7860)