import gradio as gr from huggingface_hub import InferenceClient from transformers import BertTokenizer, BertForSequenceClassification, TrainingArguments, Trainer import openai import os import faiss import numpy as np import requests from datasets import load_dataset ds = load_dataset("epfl-llm/guidelines") # Load OpenAI and Serper API keys from Hugging Face secrets openai.api_key = os.getenv("OPENAI_API_KEY") # Ensure the OpenAI API key is pulled correctly serper_api_key = os.getenv("SERPER_API_KEY") # Ensure the Serper API key is pulled correctly # Load PubMedBERT tokenizer and model for FDA-related processing tokenizer = BertTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract") model = BertForSequenceClassification.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract", num_labels=2) # FAISS setup for vector search (embedding-based memory) dimension = 768 # PubMedBERT embedding size index = faiss.IndexFlatL2(dimension) def embed_text(text): inputs = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=512) outputs = model(**inputs, output_hidden_states=True) # Ensure hidden states are returned hidden_state = outputs.hidden_states[-1] # Get the last hidden state return hidden_state.mean(dim=1).detach().numpy() # Take the mean across the sequence # Example: Embed past conversation and store in FAISS past_conversation = "FDA approval for companion diagnostics requires careful documentation." past_embedding = embed_text(past_conversation) index.add(past_embedding) # Embed the incoming query and search for related memory def search_memory(query): query_embedding = embed_text(query) D, I = index.search(query_embedding, k=1) # Retrieve most similar past conversation return I # Function to handle FDA-specific queries with PubMedBERT def handle_fda_query(query): inputs = tokenizer(query, return_tensors="pt", padding="max_length", truncation=True) outputs = model(**inputs) logits = outputs.logits response = "Processed FDA-related query via PubMedBERT" return response # Function to handle general queries using GPT-4o def handle_openai_query(prompt): response = openai.Completion.create( engine="gpt-4o", # Using GPT-4o as per instruction prompt=prompt, max_tokens=100 ) return response.choices[0].text.strip() # Web search with Serper API def web_search(query): url = f"https://google.serper.dev/search" headers = { "X-API-KEY": serper_api_key } params = { "q": query } response = requests.get(url, headers=headers, params=params) return response.json() # Main assistant function that delegates to either OpenAI, PubMedBERT, or Serper (web search) def respond( message, history: list[tuple[str, str]], system_message, max_tokens, temperature, top_p, ): # Prepare the context for OpenAI and PubMedBERT messages = [{"role": "system", "content": system_message}] for val in history: if val[0]: messages.append({"role": "user", "content": val[0]}) if val[1]: messages.append({"role": "assistant", "content": val[1]}) messages.append({"role": "user", "content": message}) # Check if the query is related to FDA openai_response = handle_openai_query(f"Is this query FDA-related: {message}") if "FDA" in openai_response or "regulatory" in openai_response: # Search past conversations/memory using FAISS memory_index = search_memory(message) if memory_index: return f"Found relevant past memory: {past_conversation}" # Return past context from memory # If no memory match, proceed with PubMedBERT return handle_fda_query(message) # If query asks for a web search, perform web search if "search the web" in message.lower(): return web_search(message) # General conversational handling with GPT-4o response = "" for message in client.chat_completion( messages, max_tokens=max_tokens, stream=True, temperature=temperature, top_p=top_p, ): token = message.choices[0].delta.content response += token yield response # Create Gradio ChatInterface for interaction demo = gr.ChatInterface( respond, additional_inputs=[ gr.Textbox(value="You are Ferris2.0, an FDA expert.", label="System message"), gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)") ], ) if __name__ == "__main__": demo.launch(share=True)