# app.py import os from pathlib import Path import torch from threading import Event, Thread from typing import List, Tuple # Importing necessary packages from transformers import AutoConfig, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer from optimum.intel.openvino import OVModelForCausalLM import openvino as ov import openvino.properties as props import openvino.properties.hint as hints import openvino.properties.streams as streams from gradio_helper import make_demo # UI logic import from llm_config import SUPPORTED_LLM_MODELS # Model configuration setup max_new_tokens = 256 model_language_value = "English" model_id_value = 'qwen2.5-0.5b-instruct' prepare_int4_model_value = True enable_awq_value = False device_value = 'CPU' model_to_run_value = 'INT4' pt_model_id = SUPPORTED_LLM_MODELS[model_language_value][model_id_value]["model_id"] pt_model_name = model_id_value.split("-")[0] int4_model_dir = Path(model_id_value) / "INT4_compressed_weights" int4_weights = int4_model_dir / "openvino_model.bin" model_configuration = SUPPORTED_LLM_MODELS[model_language_value][model_id_value] model_name = model_configuration["model_id"] start_message = model_configuration["start_message"] history_template = model_configuration.get("history_template") has_chat_template = model_configuration.get("has_chat_template", history_template is None) current_message_template = model_configuration.get("current_message_template") stop_tokens = model_configuration.get("stop_tokens") tokenizer_kwargs = model_configuration.get("tokenizer_kwargs", {}) # Model loading core = ov.Core() ov_config = { hints.performance_mode(): hints.PerformanceMode.LATENCY, streams.num(): "1", props.cache_dir(): "" } tok = AutoTokenizer.from_pretrained(int4_model_dir, trust_remote_code=True) ov_model = OVModelForCausalLM.from_pretrained( int4_model_dir, device=device_value, ov_config=ov_config, config=AutoConfig.from_pretrained(int4_model_dir, trust_remote_code=True), trust_remote_code=True, ) # Stopping criteria for token generation class StopOnTokens(StoppingCriteria): def __init__(self, token_ids): self.token_ids = token_ids def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: return any(input_ids[0][-1] == stop_id for stop_id in self.token_ids) # Functions for chatbot logic def convert_history_to_token(history: List[Tuple[str, str]]): """ function for conversion history stored as list pairs of user and assistant messages to tokens according to model expected conversation template Params: history: dialogue history Returns: history in token format """ if pt_model_name == "baichuan2": system_tokens = tok.encode(start_message) history_tokens = [] for old_query, response in history[:-1]: round_tokens = [] round_tokens.append(195) round_tokens.extend(tok.encode(old_query)) round_tokens.append(196) round_tokens.extend(tok.encode(response)) history_tokens = round_tokens + history_tokens input_tokens = system_tokens + history_tokens input_tokens.append(195) input_tokens.extend(tok.encode(history[-1][0])) input_tokens.append(196) input_token = torch.LongTensor([input_tokens]) elif history_template is None or has_chat_template: messages = [{"role": "system", "content": start_message}] for idx, (user_msg, model_msg) in enumerate(history): if idx == len(history) - 1 and not model_msg: messages.append({"role": "user", "content": user_msg}) break if user_msg: messages.append({"role": "user", "content": user_msg}) if model_msg: messages.append({"role": "assistant", "content": model_msg}) input_token = tok.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_tensors="pt") else: text = start_message + "".join( ["".join([history_template.format(num=round, user=item[0], assistant=item[1])]) for round, item in enumerate(history[:-1])] ) text += "".join( [ "".join( [ current_message_template.format( num=len(history) + 1, user=history[-1][0], assistant=history[-1][1], ) ] ) ] ) input_token = tok(text, return_tensors="pt", **tokenizer_kwargs).input_ids return input_token # Initialize the search tool search = DuckDuckGoSearchRun() # Function to retrieve and format search results based on user input def fetch_search_results(query: str) -> str: search_results = search.invoke(query) # Displaying search results for debugging print("Search results: ", search_results) return f"Relevant and recent information:\n{search_results}" # Function to decide if a search is needed based on the user query def should_use_search(query: str) -> bool: # Simple heuristic, can be extended with more advanced intent analysis search_keywords = ["latest", "news", "update", "which" "who", "what", "when", "why","how", "recent", "result", "tell", "explain", "announcement", "bulletin", "report", "brief", "insight", "disclosure", "update", "release", "memo", "headline", "current", "ongoing", "fresh", "upcoming", "immediate", "recently", "new", "now", "in-progress", "inquiry", "query", "ask", "investigate", "explore", "seek", "clarify", "confirm", "discover", "learn", "describe", "define", "illustrate", "outline", "interpret", "expound", "detail", "summarize", "elucidate", "break down", "outcome", "effect", "consequence", "finding", "achievement", "conclusion", "product", "performance", "resolution" ] return any(keyword in query.lower() for keyword in search_keywords) # Generate prompt for model with optional search context def construct_model_prompt(user_query: str, search_context: str, history: List[Tuple[str, str]]) -> str: # Simple instruction for the model to prioritize search information if available instructions = ( "If relevant information is provided below, use it to give an accurate and concise answer. If there is no relevant information available, please rely on your general knowledge and indicate that no recent or specific information is available to answer." ) # Build the prompt with instructions, search context, and user query prompt = f"{instructions}\n\n" if search_context: prompt += f"{search_context}\n\n" # Include search context prominently at the top # Add the user's query prompt += f"{user_query} ?\n\n" # Optionally add recent history for context, without labels # if history: # prompt += "Recent conversation:\n" # for user_msg, assistant_msg in history[:-1]: # Exclude the last message to prevent duplication # prompt += f"{user_msg}\n{assistant_msg}\n" return prompt def bot(history, temperature, top_p, top_k, repetition_penalty, conversation_id): """ Main callback function for running chatbot on submit button click. """ user_query = history[-1][0] search_context = "" # Decide if search is required based on the user query if should_use_search(user_query): search_context = fetch_search_results(user_query) prompt = construct_model_prompt(user_query, search_context, history) input_ids = tok(prompt, return_tensors="pt", truncation=True, max_length=2500).input_ids else: # If no search context, use the original logic with tokenization prompt = construct_model_prompt(user_query, "", history) input_ids = convert_history_to_token(history) # Ensure input length does not exceed a threshold (e.g., 2000 tokens) if input_ids.shape[1] > 2000: # If input exceeds the limit, only use the most recent conversation history = [history[-1]] # Streamer for model response generation streamer = TextIteratorStreamer(tok, timeout=4600.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( input_ids=input_ids, max_new_tokens=256, # Adjust this as needed temperature=temperature, do_sample=temperature > 0.0, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, streamer=streamer, ) if stop_tokens is not None: generate_kwargs["stopping_criteria"] = StoppingCriteriaList(stop_tokens) # Event to signal when streaming is complete stream_complete = Event() def generate_and_signal_complete(): ov_model.generate(**generate_kwargs) stream_complete.set() t1 = Thread(target=generate_and_signal_complete) t1.start() # Initialize an empty string to store the generated text partial_text = "" for new_text in streamer: partial_text = text_processor(partial_text, new_text) # Update the last entry in the original history with the response history[-1] = (user_query, partial_text) yield history def request_cancel(): ov_model.request.cancel() # Gradio setup and launch demo = make_demo(run_fn=bot, stop_fn=request_cancel, title=f"OpenVINO Search & Reasoning Chatbot", language=model_language_value) if __name__ == "__main__": demo.launch(debug=True, share=True, server_name="0.0.0.0", server_port=7860)